AdvaitBERT:HS Code AI Explanability Through Mixtral 46.7B
"
+description = """
+AdvaitBERT is modified version of BERT (Bidirectional Encoder Representation for Transformers), \
+finetuned on the Text corpus of Indian Customs Declarations. It is trained for performing \
+downstream tasks like automating the tariff classification and validation process of Customs \
+declarations in realtime. This model may help Customs administration to efficiently use AI assisted \
+NLP in realtime Customs process like Assessment, Post Clearance Audit, thereby highlighting classification \
+inconsistencies and help in revenue augmentation.
+"""
+
+article="
Powered by NCTC
"
+
+
+css = """
+.gradio-container {
+ width: 100vw !important;
+ min-height: 100vh !important;
+ padding:0 !important;
+ margin:0 !important;
+ max-width: none !important;
+}
+"""
+
+footnote = """Note: All rights, including licensing and acceptable use policies, related to the AI models, can be found on their respective model pages on Hugging Face. Powered by NCTC
+"""
+
+#Powered by NCTC
+
+# input_txt=gr.Textbox(label='Enter Your Product Descrption',lines=3,)
+# textbox = gr.Textbox(container=False,placeholder='Enter text and click the Submit button or press Enter')
+
+textbox = gr.Textbox(label='Enter Your Product Descrption',lines=3,)
+textbox_2=textbox
+
+print('textbox',textbox)
+print('textbox_2',textbox_2)
+
+chat_prod = gr.Chatbot(label="Product Explanation", layout='panel') #height=300
+#chat_Advait = gr.Chatbot(label="Advaitbert Prediction", layout='panel')
+chat_alpha = gr.Chatbot(label="AI Explanability", layout='panel')
+chat_Advait=gr.Interface(predict_CTH,inputs=textbox,outputs="label",)
+
+
+submit = gr.Button('Submit', variant='primary',)
+submit_second = gr.Button('Submit', variant='secondary',)
+#submit2 = gr.Button('Submit', variant='primary',)
+retry = gr.Button('🔄Retry', variant='secondary')
+undo = gr.Button('↩️Undo', variant='secondary')
+
+with gr.Blocks(css=css) as demo:
+ gr.HTML(f'
{title}
')
+ gr.Markdown(description)
+
+ with gr.Row():
+ with gr.Column(scale=0,min_width=600):
+ chat_Advait.render()
+
+ with gr.Column(scale=1,min_width=600):
+ chat_alpha.render()
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=1):
+ submit.render()
+ with gr.Column(scale=1):
+ undo.render()
+ with gr.Column(scale=1):
+ clear = gr.ClearButton(value='🗑️Clear',components=[chat_alpha,chat_prod,textbox])
+ chat_prod.render()
+ #submit_second.render()
+
+ gr.Markdown(footnote)
+ textbox.submit(llm_model_function, [textbox, chat_alpha], [textbox, chat_alpha])
+ textbox_2.submit(product_explaination, [textbox_2, chat_prod], [textbox_2, chat_prod])
+
+ submit.click(llm_model_function,[textbox, chat_alpha], [textbox, chat_alpha])
+ submit.click(product_explaination,[textbox_2, chat_prod], [textbox_2, chat_prod])
+
+ undo.click(lambda x:x[:-1], [chat_alpha], [chat_alpha])
+ undo.click(lambda x:x[:-1], [chat_prod], [chat_prod])
+
+ gr.Examples([
+ ['200 SI/SI/SI LPO ALUMINIUM LIDS (QTY: 8820000 PCS/PRICE: 21.'],
+ ],
+ textbox)
+
+demo.launch(debug=True)
\ No newline at end of file
diff --git a/fun_advaitbert.py b/fun_advaitbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e66e76f09ef3f55755f92fbdc55e450278d9f13
--- /dev/null
+++ b/fun_advaitbert.py
@@ -0,0 +1,344 @@
+import pandas as pd
+import numpy as np
+import tensorflow as tf
+import tensorflow_hub as hub
+import sys
+import random
+sys.path.append('models')
+from official.nlp.data import classifier_data_lib
+from official.nlp.bert import tokenization
+from official.nlp import optimization
+tf.get_logger().setLevel('ERROR')
+from huggingface_hub import InferenceClient
+import math
+import gradio as gr
+
+num_warmup_steps=1
+num_train_steps=1
+init_lr = 3e-5
+optimizer = optimization.create_optimizer(init_lr=init_lr,num_train_steps=num_train_steps,num_warmup_steps=num_warmup_steps,optimizer_type='adamw')
+
+### Load Model
+checkpoint_filepath=r'./Checkpoint'
+model = tf.keras.models.load_model(checkpoint_filepath, custom_objects={'KerasLayer':hub.KerasLayer , 'AdamWeightDecay': optimizer})
+
+df_report = pd.read_csv('./CTH_Description.csv')
+df_report['CTH Code'] = df_report['CTH Code'].astype(str).str.zfill(8)
+
+df_report_DUTY = pd.read_csv('./CTH_WISE_DUTY_RATE.csv')
+df_report_DUTY['CTH'] = df_report_DUTY['CTH'].astype(str).str.zfill(8)
+
+df = pd.read_csv("./CTH_CODE_MAP.csv")
+df['CTH'] = df['CTH'].astype(str).str.zfill(8)
+df = df[['CTH', 'code']]
+
+client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
+
+
+class_names=df[['CTH','code']].drop_duplicates(subset='CTH').sort_values(by='code',ignore_index=True)['CTH'].values.tolist()
+label_list=list(range(0,len(class_names)))
+max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
+train_batch_size = 32 # batch size ( 16 choosen to avoid Out-Of-Memory errors)
+
+# Get BERT layer and tokenizer:
+# More details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4
+bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4" , trainable = True)
+vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
+do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
+tokenizer = tokenization.FullTokenizer(vocab_file , do_lower_case)
+
+# This provides a function to convert each row to input features and label ( as required by BERT)
+
+max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
+def to_feature(text, label, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer):
+ example = classifier_data_lib.InputExample(guid = None,
+ text_a = text.numpy(),
+ text_b = None,
+ label = label.numpy())
+ feature = classifier_data_lib.convert_single_example(0 , example , label_list , max_seq_length , tokenizer)
+
+ return (feature.input_ids , feature.input_mask , feature.segment_ids , feature.label_id)
+
+
+def to_feature_map(text, label):
+ input_ids , input_mask , segment_ids , label_id = tf.py_function(to_feature , inp = [text , label],
+ Tout = [tf.int32 , tf.int32 , tf.int32 , tf.int32])
+
+ input_ids.set_shape([max_seq_length])
+ input_mask.set_shape([max_seq_length])
+ segment_ids.set_shape([max_seq_length])
+ label_id.set_shape([])
+
+ x = {
+ "input_word_ids": input_ids,
+ "input_mask": input_mask,
+ "input_type_ids": segment_ids
+ }
+
+ return(x,label_id)
+
+
+def find_max_10_with_position(arr, arr_size):
+ max_values_with_position = [(-sys.maxsize, -1)] * 10
+
+ for i in range(arr_size):
+ for j in range(5):
+ value, position = max_values_with_position[j]
+ if arr[i] > value:
+ max_values_with_position[j+1:] = max_values_with_position[j:9]
+ max_values_with_position[j] = (arr[i], i)
+ break
+
+ return max_values_with_position
+
+def count_special_character(string):
+ special_char= 0
+ for i in range(len(string)):
+ ch = string[i]
+ if (string[i].isalpha()):
+ continue
+ else:
+ special_char += 1
+
+ if len(string)==special_char:
+ return False
+ else:
+ return True
+
+def format_prompt(message, history):
+ prompt = ""
+ for user_prompt, bot_response in history:
+ prompt += f"[INST] {user_prompt} [/INST]"
+ prompt += f" {bot_response} "
+ prompt += f"[INST] {message} [/INST]"
+ return prompt
+
+
+additional_inputs=[
+ gr.Textbox(
+ label="System Prompt",
+ max_lines=1,
+ interactive=True,
+ ),
+ gr.Slider(
+ label="Temperature",
+ value=0.5,
+ minimum=0.0,
+ maximum=1.0,
+ step=0.05,
+ interactive=True,
+ info="Higher values produce more diverse outputs",
+ ),
+ gr.Slider(
+ label="Max new tokens",
+ value=1024,
+ minimum=0,
+ maximum=4096,
+ step=64,
+ interactive=True,
+ info="The maximum numbers of new tokens",
+ ),
+ gr.Slider(
+ label="Top-p (nucleus sampling)",
+ value=0.90,
+ minimum=0.0,
+ maximum=1,
+ step=0.05,
+ interactive=True,
+ info="Higher values sample more low-probability tokens",
+ ),
+ gr.Slider(
+ label="Repetition penalty",
+ value=1.2,
+ minimum=1.0,
+ maximum=2.0,
+ step=0.05,
+ interactive=True,
+ info="Penalize repeated tokens",
+ )
+]
+
+def predict_CTH(txt):
+ print('Desc: ',txt)
+ global output_str_msg
+ if (txt!='') and len(txt)>=3 and (count_special_character(txt)):
+ valid_data = tf.data.Dataset.from_tensor_slices(([txt] , [1])) # 1 refers to 'entertainment' and 2 refers to 'sport'
+ valid_data = (valid_data.map(to_feature_map).batch(1))
+ preds = model.predict(valid_data)
+ predicted_values = tf.nn.softmax(preds)
+ arr = predicted_values.numpy().tolist()[0]
+ n = len(arr)
+
+ pred_value_max=find_max_10_with_position(arr, n)
+
+ sum_all = 0
+ for i in range(10):
+ sum_all += pred_value_max[i][0]
+
+
+ val_1 = pred_value_max[0][0]/sum_all
+ val_2 = pred_value_max[1][0]/sum_all
+ val_3 = pred_value_max[2][0]/sum_all
+ val_4 = pred_value_max[3][0]/sum_all
+ val_5 = pred_value_max[4][0]/sum_all
+ val_6 = pred_value_max[5][0]/sum_all
+ val_7 = pred_value_max[6][0]/sum_all
+ val_8 = pred_value_max[7][0]/sum_all
+ val_9 = pred_value_max[8][0]/sum_all
+ val_10 = pred_value_max[9][0]/sum_all
+
+ if pred_value_max[0][0]<=0.000131:
+ Var_CTH=[]
+ Var_desc=[]
+ Var_duty=[]
+ pred_duty=''
+ pred_desc=''
+ pred_CTH=''
+
+ output_str_msg='Not a adequate description'
+
+ return{'Not a adequate description':float(1.0)}
+ else:
+ Var_CTH=[]
+ Var_desc=[]
+ Var_duty=[]
+ pred_duty=''
+ pred_desc=''
+ pred_CTH=''
+
+ for i in range(len(pred_value_max)):
+ #predicted_code=np.where(predicted_values.numpy()==i)[1][0]
+ predicted_code=pred_value_max[i][1]
+ pred_CTH=df[df['code'] == predicted_code]['CTH'].iloc[0]
+
+ try:
+ pred_duty=df_report_DUTY[df_report_DUTY['CTH']==str(pred_CTH)]['DUTY_RATE'].iloc[0]
+ pred_desc=df_report[df_report['CTH Code']==str(pred_CTH)]['Concat Description'].iloc[0]
+ except:
+ pred_desc=''
+ pred_duty=''
+ pass
+
+ Var_CTH.append(pred_CTH)
+ Var_desc.append(pred_desc)
+ Var_duty.append(pred_duty)
+
+ P1 ='CTH: '+str(Var_CTH[0])+' Duty Rate(%): '+ str(Var_duty[0])
+ P2 ='CTH: '+str(Var_CTH[1])+' Duty Rate(%): '+ str(Var_duty[1])
+ P3 ='CTH: '+str(Var_CTH[2])+' Duty Rate(%): '+ str(Var_duty[2])
+ P4 ='CTH: '+str(Var_CTH[3])+' Duty Rate(%): '+ str(Var_duty[3])
+ P5 ='CTH: '+str(Var_CTH[4])+' Duty Rate(%): '+ str(Var_duty[4])
+ P6 ='CTH: '+str(Var_CTH[5])+' Duty Rate(%): '+ str(Var_duty[5])
+ P7 ='CTH: '+str(Var_CTH[6])+' Duty Rate(%): '+ str(Var_duty[6])
+ P8 ='CTH: '+str(Var_CTH[7])+' Duty Rate(%): '+ str(Var_duty[7])
+ P9 ='CTH: '+str(Var_CTH[8])+' Duty Rate(%): '+ str(Var_duty[8])
+ P10 ='CTH: '+str(Var_CTH[9])+' Duty Rate(%): '+ str(Var_duty[9])
+
+ Q1='Desc: '+str(Var_desc[0])
+ Q2='Desc: '+str(Var_desc[1])
+ Q3='Desc: '+str(Var_desc[2])
+ Q4='Desc: '+str(Var_desc[3])
+ Q5='Desc: '+str(Var_desc[4])
+ Q6='Desc: '+str(Var_desc[5])
+ Q7='Desc: '+str(Var_desc[6])
+ Q8='Desc: '+str(Var_desc[7])
+ Q9='Desc: '+str(Var_desc[8])
+ Q10='Desc: '+str(Var_desc[9])
+
+ output_str_msg = (
+ f'1. {P1} {Q1} '
+ f'2. {P2} {Q2} '
+ f'3. {P3} {Q3} '
+ f'4. {P4} {Q4} '
+ f'5. {P5} {Q5} '
+ f'6. {P6} {Q6} '
+ f'7. {P7} {Q7} '
+ f'8. {P8} {Q8} '
+ f'9. {P9} {Q9} '
+ f'10. {P10} {Q10}')
+
+ print('output_str_msg',output_str_msg)
+
+ return {str(P1):float(val_1),str(Q1):float(val_1),
+ str(P2):float(val_2),str(Q2):float(val_2),
+ str(P3):float(val_3),str(Q3):float(val_3),
+ str(P4):float(val_4),str(Q4):float(val_4),
+ str(P5):float(val_5),str(Q5):float(val_5),
+ str(P6):float(val_6),str(Q6):float(val_6),
+ str(P7):float(val_7),str(Q7):float(val_7),
+ str(P8):float(val_8),str(Q8):float(val_8),
+ str(P9):float(val_9),str(Q9):float(val_9),
+ str(P10):float(val_10),str(Q10):float(val_10),}
+ else:
+ output_str_msg='Not a adequate description'
+ return{'Enter Correct Description':float(1.0)}
+
+def llm_model_function(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
+ system_prompt=[]
+ chatbot=[]
+
+ global output_str_msg
+
+ print('output_str_msg',output_str_msg)
+
+ if output_str_msg!='Not a adequate description':
+
+ prompt=f'First Explain What is the product- {txt}. Which is the most appropriate 8 Digit classification code out of the three given below classes. Explain the reason step by step. if none of the three classification is applicable more precisely due to lack of any additional information, tell you need additional information and what is the that additional information. {output_str_msg} ?'
+
+ temperature = float(temperature)
+ if temperature < 1e-2:
+ temperature = 1e-2
+ top_p = float(top_p)
+
+ generate_kwargs = dict(
+ temperature=temperature,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ do_sample=True,
+ seed=42,
+ )
+
+ formatted_prompt = format_prompt(f", {prompt}", history)
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
+ output = ""
+ for response in stream:
+ output += response.token.text
+
+ chatbot.append((txt, output))
+ return "", chatbot
+ else:
+ # warning_msg = f"Unexpected response"
+ # raise gr.Error(warning_msg)
+ chatbot.append(('Not a adequate description', 'Not a adequate description'))
+ return "", chatbot
+
+def product_explaination(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
+ print('Input Descrption is:',txt)
+ chatbot=[]
+ prompt=f'What is the product- {txt}?'
+ #print('prompt',prompt)
+ temperature = float(temperature)
+ if temperature < 1e-2:
+ temperature = 1e-2
+ top_p = float(top_p)
+
+ generate_kwargs = dict(
+ temperature=temperature,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ do_sample=True,
+ seed=42,
+ )
+
+ formatted_prompt = format_prompt(f", {prompt}", history)
+
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
+ output = ""
+
+ for response in stream:
+ output += response.token.text
+
+ chatbot.append((txt, output))
+ return "", chatbot
\ No newline at end of file
diff --git a/models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md b/models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..51e08c26db66114de0b604bf0cc5c461311a0b4f
--- /dev/null
+++ b/models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md
@@ -0,0 +1,59 @@
+---
+name: "[Official Model] Bug Report"
+about: Use this template for reporting a bug for the “official” directory
+labels: type:bug,models:official
+
+---
+
+# Prerequisites
+
+Please answer the following questions for yourself before submitting an issue.
+
+- [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
+- [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
+- [ ] I checked to make sure that this issue has not been filed already.
+
+## 1. The entire URL of the file you are using
+
+https://github.com/tensorflow/models/tree/master/official/...
+
+## 2. Describe the bug
+
+A clear and concise description of what the bug is.
+
+## 3. Steps to reproduce
+
+Steps to reproduce the behavior.
+
+## 4. Expected behavior
+
+A clear and concise description of what you expected to happen.
+
+## 5. Additional context
+
+Include any logs that would be helpful to diagnose the problem.
+
+## 6. System information
+
+- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
+- Mobile device name if the issue happens on a mobile device:
+- TensorFlow installed from (source or binary):
+- TensorFlow version (use command below):
+- Python version:
+- Bazel version (if compiling from source):
+- GCC/Compiler version (if compiling from source):
+- CUDA/cuDNN version:
+- GPU model and memory:
+
+
diff --git a/models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md b/models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..00d79a16916c327d2d8a729791db7d7d3d96b735
--- /dev/null
+++ b/models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md
@@ -0,0 +1,20 @@
+---
+name: "[Official Model] Documentation Issue"
+about: Use this template for reporting a documentation issue for the “official” directory
+labels: type:docs,models:official
+
+---
+
+# Prerequisites
+
+Please answer the following question for yourself before submitting an issue.
+
+- [ ] I checked to make sure that this issue has not been filed already.
+
+## 1. The entire URL of the documentation with the issue
+
+https://github.com/tensorflow/models/tree/master/official/...
+
+## 2. Describe the issue
+
+A clear and concise description of what needs to be changed.
diff --git a/models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md b/models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..02d8cab52218202707646345a4ab2570519660dd
--- /dev/null
+++ b/models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md
@@ -0,0 +1,26 @@
+---
+name: "[Official Model] Feature request"
+about: Use this template for raising a feature request for the “official” directory
+labels: type:feature,models:official
+
+---
+
+# Prerequisites
+
+Please answer the following question for yourself before submitting an issue.
+
+- [ ] I checked to make sure that this feature has not been requested already.
+
+## 1. The entire URL of the file you are using
+
+https://github.com/tensorflow/models/tree/master/official/...
+
+## 2. Describe the feature you request
+
+A clear and concise description of what you want to happen.
+
+## 3. Additional context
+
+Add any other context about the feature request here.
+
+## 4. Are you willing to contribute it? (Yes or No)
diff --git a/models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md b/models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..4448ed9e40d6a089b84881635c2ee0f53524ae61
--- /dev/null
+++ b/models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md
@@ -0,0 +1,58 @@
+---
+name: "[Research Model] Bug Report"
+about: Use this template for reporting a bug for the “research” directory
+labels: type:bug,models:research
+
+---
+# Prerequisites
+
+Please answer the following questions for yourself before submitting an issue.
+
+- [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
+- [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
+- [ ] I checked to make sure that this issue has not already been filed.
+
+## 1. The entire URL of the file you are using
+
+https://github.com/tensorflow/models/tree/master/research/...
+
+## 2. Describe the bug
+
+A clear and concise description of what the bug is.
+
+## 3. Steps to reproduce
+
+Steps to reproduce the behavior.
+
+## 4. Expected behavior
+
+A clear and concise description of what you expected to happen.
+
+## 5. Additional context
+
+Include any logs that would be helpful to diagnose the problem.
+
+## 6. System information
+
+- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
+- Mobile device name if the issue happens on a mobile device:
+- TensorFlow installed from (source or binary):
+- TensorFlow version (use command below):
+- Python version:
+- Bazel version (if compiling from source):
+- GCC/Compiler version (if compiling from source):
+- CUDA/cuDNN version:
+- GPU model and memory:
+
+
diff --git a/models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md b/models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..26adfd83e1fbe27d045ecd8dfccef91bbd27fcf1
--- /dev/null
+++ b/models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md
@@ -0,0 +1,20 @@
+---
+name: "[Research Model] Documentation Issue"
+about: Use this template for reporting a documentation issue for the “research” directory
+labels: type:docs,models:research
+
+---
+
+# Prerequisites
+
+Please answer the following question for yourself before submitting an issue.
+
+- [ ] I checked to make sure that this issue has not been filed already.
+
+## 1. The entire URL of the documentation with the issue
+
+https://github.com/tensorflow/models/tree/master/research/...
+
+## 2. Describe the issue
+
+A clear and concise description of what needs to be changed.
diff --git a/models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md b/models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..412942a31be9cc4c2935dcd38ecb059a8a4ec18c
--- /dev/null
+++ b/models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md
@@ -0,0 +1,26 @@
+---
+name: "[Research Model] Feature Request"
+about: Use this template for raising a feature request for the “research” directory
+labels: type:feature,models:research
+
+---
+
+# Prerequisites
+
+Please answer the following question for yourself before submitting an issue.
+
+- [ ] I checked to make sure that this feature has not been requested already.
+
+## 1. The entire URL of the file you are using
+
+https://github.com/tensorflow/models/tree/master/research/...
+
+## 2. Describe the feature you request
+
+A clear and concise description of what you want to happen.
+
+## 3. Additional context
+
+Add any other context about the feature request here.
+
+## 4. Are you willing to contribute it? (Yes or No)
diff --git a/models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md b/models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md
new file mode 100644
index 0000000000000000000000000000000000000000..bc85e0bb019fd2d5960b822c18358f906d5264b7
--- /dev/null
+++ b/models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md
@@ -0,0 +1,14 @@
+---
+name: Questions and Help
+about: Use this template for Questions and Help.
+labels: type:support
+
+---
+
diff --git a/models/.github/ISSUE_TEMPLATE/config.yml b/models/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3ba13e0cec6cbbfd462e9ebf529dd2093148cd69
--- /dev/null
+++ b/models/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1 @@
+blank_issues_enabled: false
diff --git a/models/.github/PULL_REQUEST_TEMPLATE.md b/models/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000000000000000000000000000000000000..379b31c57c118a174d4e787e03099288957f9fe2
--- /dev/null
+++ b/models/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,41 @@
+# Description
+
+> :memo: Please include a summary of the change.
+>
+> * Please also include relevant motivation and context.
+> * List any dependencies that are required for this change.
+
+## Type of change
+
+For a new feature or function, please create an issue first to discuss it
+with us before submitting a pull request.
+
+Note: Please delete options that are not relevant.
+
+- [ ] Bug fix (non-breaking change which fixes an issue)
+- [ ] Documentation update
+- [ ] TensorFlow 2 migration
+- [ ] New feature (non-breaking change which adds functionality)
+- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
+- [ ] A new research paper code implementation
+- [ ] Other (Specify)
+
+## Tests
+
+> :memo: Please describe the tests that you ran to verify your changes.
+>
+> * Provide instructions so we can reproduce.
+> * Please also list any relevant details for your test configuration.
+
+**Test Configuration**:
+
+## Checklist
+
+- [ ] I have signed the [Contributor License Agreement](https://github.com/tensorflow/models/wiki/Contributor-License-Agreements).
+- [ ] I have read [guidelines for pull request](https://github.com/tensorflow/models/wiki/Submitting-a-pull-request).
+- [ ] My code follows the [coding guidelines](https://github.com/tensorflow/models/wiki/Coding-guidelines).
+- [ ] I have performed a self [code review](https://github.com/tensorflow/models/wiki/Code-review) of my own code.
+- [ ] I have commented my code, particularly in hard-to-understand areas.
+- [ ] I have made corresponding changes to the documentation.
+- [ ] My changes generate no new warnings.
+- [ ] I have added tests that prove my fix is effective or that my feature works.
diff --git a/models/.github/README_TEMPLATE.md b/models/.github/README_TEMPLATE.md
new file mode 100644
index 0000000000000000000000000000000000000000..45179d0aeba52caa8d84c102790b7e3fafc2c7fe
--- /dev/null
+++ b/models/.github/README_TEMPLATE.md
@@ -0,0 +1,122 @@
+> :memo: A README.md template for releasing a paper code implementation to a GitHub repository.
+>
+> * Template version: 1.0.2020.170
+> * Please modify sections depending on needs.
+
+# Model name, Paper title, or Project Name
+
+> :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN)
+
+[](https://arxiv.org/abs/...)
+
+This repository is the official or unofficial implementation of the following paper.
+
+* Paper title: [Paper Title](https://arxiv.org/abs/YYMM.NNNNN)
+
+## Description
+
+> :memo: Provide description of the model.
+>
+> * Provide brief information of the algorithms used.
+> * Provide links for demos, blog posts, etc.
+
+## History
+
+> :memo: Provide a changelog.
+
+## Authors or Maintainers
+
+> :memo: Provide maintainer information.
+
+* Full name ([@GitHub username](https://github.com/username))
+* Full name ([@GitHub username](https://github.com/username))
+
+## Table of Contents
+
+> :memo: Provide a table of contents to help readers navigate a lengthy README document.
+
+## Requirements
+
+[](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
+[](https://www.python.org/downloads/release/python-360/)
+
+> :memo: Provide details of the software required.
+>
+> * Add a `requirements.txt` file to the root directory for installing the necessary dependencies.
+> * Describe how to install requirements using pip.
+> * Alternatively, create INSTALL.md.
+
+To install requirements:
+
+```setup
+pip install -r requirements.txt
+```
+
+## Results
+
+> :memo: Provide a table with results. (e.g., accuracy, latency)
+>
+> * Provide links to the pre-trained models (checkpoint, SavedModel files).
+> * Publish TensorFlow SavedModel files on TensorFlow Hub (tfhub.dev) if possible.
+> * Add links to [TensorBoard.dev](https://tensorboard.dev/) for visualizing metrics.
+>
+> An example table for image classification results
+>
+> ### Image Classification
+>
+> | Model name | Download | Top 1 Accuracy | Top 5 Accuracy |
+> |------------|----------|----------------|----------------|
+> | Model name | [Checkpoint](https://drive.google.com/...), [SavedModel](https://tfhub.dev/...) | xx% | xx% |
+
+## Dataset
+
+> :memo: Provide information of the dataset used.
+
+## Training
+
+> :memo: Provide training information.
+>
+> * Provide details for preprocessing, hyperparameters, random seeds, and environment.
+> * Provide a command line example for training.
+
+Please run this command line for training.
+
+```shell
+python3 ...
+```
+
+## Evaluation
+
+> :memo: Provide an evaluation script with details of how to reproduce results.
+>
+> * Describe data preprocessing / postprocessing steps.
+> * Provide a command line example for evaluation.
+
+Please run this command line for evaluation.
+
+```shell
+python3 ...
+```
+
+## References
+
+> :memo: Provide links to references.
+
+## License
+
+[](https://opensource.org/licenses/Apache-2.0)
+
+> :memo: Place your license text in a file named LICENSE in the root of the repository.
+>
+> * Include information about your license.
+> * Reference: [Adding a license to a repository](https://help.github.com/en/github/building-a-strong-community/adding-a-license-to-a-repository)
+
+This project is licensed under the terms of the **Apache License 2.0**.
+
+## Citation
+
+> :memo: Make your repository citable.
+>
+> * Reference: [Making Your Code Citable](https://guides.github.com/activities/citable-code/)
+
+If you want to cite this repository in your research paper, please use the following information.
diff --git a/models/.gitignore b/models/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..cbc8846d64152b8a933f4bd2727877a94f98f92a
--- /dev/null
+++ b/models/.gitignore
@@ -0,0 +1,98 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*,cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# IPython Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# mypy
+.mypy_cache
+
+# celery beat schedule file
+celerybeat-schedule
+
+# dotenv
+.env
+
+# virtualenv
+venv/
+ENV/
+
+# Spyder project settings
+.spyderproject
+
+# Rope project settings
+.ropeproject
+
+# PyCharm
+.idea/
+
+# For mac
+.DS_Store
diff --git a/models/AUTHORS b/models/AUTHORS
new file mode 100644
index 0000000000000000000000000000000000000000..0fa85c98ffeb38c6d6d0ef2bddb790b75b90f3dc
--- /dev/null
+++ b/models/AUTHORS
@@ -0,0 +1,10 @@
+# This is the official list of authors for copyright purposes.
+# This file is distinct from the CONTRIBUTORS files.
+# See the latter for an explanation.
+
+# Names should be added to this file as:
+# Name or Organization
+# The email address is not required for organizations.
+
+Google Inc.
+David Dao
diff --git a/models/CODEOWNERS b/models/CODEOWNERS
new file mode 100644
index 0000000000000000000000000000000000000000..36b7ebd4e779dc110c53d49ed73cf43f519ca211
--- /dev/null
+++ b/models/CODEOWNERS
@@ -0,0 +1,61 @@
+* @tensorflow/tf-garden-team @tensorflow/tf-model-garden-team
+/official/ @rachellj218 @saberkun @jaeyounkim
+/official/nlp/ @saberkun @chenGitHuber @lehougoogle @rachellj218
+/official/vision/ @pengchongjin @xianzhidu @yeqingli @arashwan @saberkun @rachellj218
+/research/adv_imagenet_models/ @alexeykurakin
+/research/adversarial_crypto/ @dave-andersen
+/research/adversarial_logit_pairing/ @alexeykurakin
+/research/adversarial_text/ @rsepassi @a-dai
+/research/attention_ocr/ @xavigibert
+/research/audioset/ @plakal @dpwe
+/research/autoaugment/* @barretzoph
+/research/autoencoders/ @snurkabill
+/research/brain_coder/ @danabo
+/research/cognitive_mapping_and_planning/ @s-gupta
+/research/compression/ @nmjohn
+/research/cvt_text/ @clarkkev @lmthang
+/research/deep_contextual_bandits/ @rikel
+/research/deep_speech/ @yhliang2018
+/research/deeplab/ @aquariusjay @yknzhu @gpapan
+/research/delf/ @andrefaraujo
+/research/domain_adaptation/ @bousmalis @dmrd
+/research/efficient-hrl/ @ofirnachum
+/research/feelvos/ @pvoigtlaender @yuningchai @aquariusjay
+/research/fivo/ @dieterichlawson
+/research/global_objectives/ @mackeya-google
+/research/im2txt/ @cshallue
+/research/inception/ @shlens @vincentvanhoucke
+/research/keypointnet/ @mnorouzi
+/research/learned_optimizer/ @olganw @nirum
+/research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
+/research/learning_unsupervised_learning/ @lukemetz @nirum
+/research/lexnet_nc/ @vered1986 @waterson
+/research/lfads/ @jazcollins @sussillo
+/research/lm_1b/ @oriolvinyals @panyx0718
+/research/lm_commonsense/ @thtrieu
+/research/lstm_object_detection/ @yinxiaoli @yongzhe2160
+/research/marco/ @vincentvanhoucke
+/research/maskgan/ @liamb315 @a-dai
+/research/namignizer/ @knathanieltucker
+/research/neural_gpu/ @lukaszkaiser
+/research/neural_programmer/ @arvind2505
+/research/next_frame_prediction/ @panyx0718
+/research/object_detection/ @jch1 @tombstone @pkulzc
+/research/pcl_rl/ @ofirnachum
+/research/ptn/ @xcyan @arkanath @hellojas @honglaklee
+/research/qa_kg/ @yuyuz
+/research/real_nvp/ @laurent-dinh
+/research/rebar/ @gjtucker
+/research/sentiment_analysis/ @sculd
+/research/seq2species/ @apbusia @depristo
+/research/skip_thoughts/ @cshallue
+/research/slim/ @sguada @marksandler2
+/research/steve/ @buckman-google
+/research/street/ @theraysmith
+/research/struct2depth/ @aneliaangelova
+/research/swivel/ @waterson
+/research/tcn/ @coreylynch @sermanet
+/research/textsum/ @panyx0718 @peterjliu
+/research/transformer/ @daviddao
+/research/vid2depth/ @rezama
+/research/video_prediction/ @cbfinn
diff --git a/models/CONTRIBUTING.md b/models/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..f909461ae7b9c75264e0915ecb37228314933e4a
--- /dev/null
+++ b/models/CONTRIBUTING.md
@@ -0,0 +1,10 @@
+# How to contribute
+
+
+
+We encourage you to contribute to the TensorFlow Model Garden.
+
+Please read our [guidelines](../../wiki/How-to-contribute) for details.
+
+**NOTE**: Only [code owners](./CODEOWNERS) are allowed to merge a pull request.
+Please contact the code owners of each model to merge your pull request.
diff --git a/models/ISSUES.md b/models/ISSUES.md
new file mode 100644
index 0000000000000000000000000000000000000000..b23d6daa1654188d640beb67e6614bd0743f919f
--- /dev/null
+++ b/models/ISSUES.md
@@ -0,0 +1,24 @@
+# If you open a GitHub issue, here is our policy.
+
+* It must be a **bug**, a **feature request**, or a significant problem
+with **documentation**.
+ * Please send a pull request instead for small documentation fixes.
+* The required form must be filled out.
+* The issue should be related to the repository it is created in.
+
+General help and support should be sought on [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow-model-garden) or other non-GitHub channels.
+
+[](https://stackoverflow.com/questions/tagged/tensorflow-model-garden)
+
+TensorFlow developers respond to issues.
+We want to focus on work that benefits the whole community such as fixing bugs
+and adding new features.
+It helps us to address bugs and feature requests in a timely manner.
+
+---
+
+Please understand that research models in the [research directory](https://github.com/tensorflow/models/tree/master/research)
+included in this repository are experimental and research-style code.
+They are not officially supported by the TensorFlow team.
+
+
diff --git a/models/LICENSE b/models/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..43fcf7bf1f1f9f824a1debf05d6ced45bf5810aa
--- /dev/null
+++ b/models/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2016 The TensorFlow Authors. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2016, The Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/models/README.md b/models/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5b52e4a5cf41f949c2cf85744ea297ec3c324004
--- /dev/null
+++ b/models/README.md
@@ -0,0 +1,39 @@
+
+
+# Welcome to the Model Garden for TensorFlow
+
+The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users
+can take full advantage of TensorFlow for their research and product development.
+
+| Directory | Description |
+|-----------|-------------|
+| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs • Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow • Reasonably optimized for fast performance while still being easy to read |
+| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers • Maintained and supported by researchers |
+| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
+
+## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
+
+| Date | News |
+|------|------|
+| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released
+| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released
+| May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released
+| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection
+| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1
+| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
+
+## [Milestones](https://github.com/tensorflow/models/milestones)
+
+| Date | Milestone |
+|------|-----------|
+| July 7, 2020 | [](https://github.com/tensorflow/models/milestone/1) |
+
+## Contributions
+
+[](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
+
+If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
+
+## License
+
+[Apache License 2.0](LICENSE)
diff --git a/models/official/LICENSE b/models/official/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d3da228420e973edaf4123d5eeb42210f4450b0c
--- /dev/null
+++ b/models/official/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2015 The TensorFlow Authors. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2015, The TensorFlow Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/models/official/README-TPU.md b/models/official/README-TPU.md
new file mode 100644
index 0000000000000000000000000000000000000000..8a54f95314abc2bae40d11acdf5439939acf7583
--- /dev/null
+++ b/models/official/README-TPU.md
@@ -0,0 +1,25 @@
+# Offically Supported TensorFlow 2.1+ Models on Cloud TPU
+
+## Natural Language Processing
+
+* [bert](nlp/bert): A powerful pre-trained language representation model:
+ BERT, which stands for Bidirectional Encoder Representations from
+ Transformers.
+ [BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
+* [transformer](nlp/transformer): A transformer model to translate the WMT
+ English to German dataset.
+ [Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
+
+## Computer Vision
+
+* [efficientnet](vision/image_classification): A family of convolutional
+ neural networks that scale by balancing network depth, width, and
+ resolution and can be used to classify ImageNet's dataset of 1000 classes.
+ See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
+* [mnist](vision/image_classification): A basic model to classify digits
+ from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
+* [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
+* [resnet](vision/image_classification): A deep residual network that can
+ be used to classify ImageNet's dataset of 1000 classes.
+ See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
+* [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
diff --git a/models/official/README.md b/models/official/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2b3f2dd768d0b7cf8238136d003aa5cb89070cc3
--- /dev/null
+++ b/models/official/README.md
@@ -0,0 +1,142 @@
+
+
+# TensorFlow Official Models
+
+The TensorFlow official models are a collection of models
+that use TensorFlow’s high-level APIs.
+They are intended to be well-maintained, tested, and kept up to date
+with the latest TensorFlow API.
+
+They should also be reasonably optimized for fast performance while still
+being easy to read.
+These models are used as end-to-end tests, ensuring that the models run
+with the same or improved speed and performance with each new TensorFlow build.
+
+## More models to come!
+
+The team is actively developing new models.
+In the near future, we will add:
+
+* State-of-the-art language understanding models:
+ More members in Transformer family
+* Start-of-the-art image classification models:
+ EfficientNet, MnasNet, and variants
+* A set of excellent objection detection models.
+
+## Table of Contents
+
+- [Models and Implementations](#models-and-implementations)
+ * [Computer Vision](#computer-vision)
+ + [Image Classification](#image-classification)
+ + [Object Detection and Segmentation](#object-detection-and-segmentation)
+ * [Natural Language Processing](#natural-language-processing)
+ * [Recommendation](#recommendation)
+- [How to get started with the official models](#how-to-get-started-with-the-official-models)
+
+## Models and Implementations
+
+### Computer Vision
+
+#### Image Classification
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) |
+| [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
+| [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
+
+#### Object Detection and Segmentation
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
+| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
+| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
+
+### Natural Language Processing
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [ALBERT (A Lite BERT)](nlp/albert) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
+| [BERT (Bidirectional Encoder Representations from Transformers)](nlp/bert) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
+| [NHNet (News Headline generation model)](nlp/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
+| [Transformer](nlp/transformer) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
+| [XLNet](nlp/xlnet) | [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) |
+
+### Recommendation
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) |
+
+## How to get started with the official models
+
+* The models in the master branch are developed using TensorFlow 2,
+and they target the TensorFlow [nightly binaries](https://github.com/tensorflow/tensorflow#installation)
+built from the
+[master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master).
+* The stable versions targeting releases of TensorFlow are available
+as tagged branches or [downloadable releases](https://github.com/tensorflow/models/releases).
+* Model repository version numbers match the target TensorFlow release,
+such that
+[release v2.2.0](https://github.com/tensorflow/models/releases/tag/v2.2.0)
+are compatible with
+[TensorFlow v2.2.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0).
+
+Please follow the below steps before running models in this repository.
+
+### Requirements
+
+* The latest TensorFlow Model Garden release and TensorFlow 2
+ * If you are on a version of TensorFlow earlier than 2.2, please
+upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
+
+```shell
+pip3 install tf-nightly
+```
+
+### Installation
+
+#### Method 1: Install the TensorFlow Model Garden pip package
+
+**tf-models-nightly** is the nightly Model Garden package
+created daily automatically. pip will install all models
+and dependencies automatically.
+
+```shell
+pip install tf-models-nightly
+```
+
+Please check out our [example](colab/fine_tuning_bert.ipynb)
+to learn how to use a PIP package.
+
+#### Method 2: Clone the source
+
+1. Clone the GitHub repository:
+
+```shell
+git clone https://github.com/tensorflow/models.git
+```
+
+2. Add the top-level ***/models*** folder to the Python path.
+
+```shell
+export PYTHONPATH=$PYTHONPATH:/path/to/models
+```
+
+If you are using a Colab notebook, please set the Python path with os.environ.
+
+```python
+import os
+os.environ['PYTHONPATH'] += ":/path/to/models"
+```
+
+3. Install other dependencies
+
+```shell
+pip3 install --user -r official/requirements.txt
+```
+
+## Contributions
+
+If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
diff --git a/models/official/__init__.py b/models/official/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/__pycache__/__init__.cpython-310.pyc b/models/official/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46eb1d8fb6c544f21ff52681984538792572cd90
Binary files /dev/null and b/models/official/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/official/__pycache__/__init__.cpython-38.pyc b/models/official/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..556257a7074799181f4d95d751286d1b27bd4e77
Binary files /dev/null and b/models/official/__pycache__/__init__.cpython-38.pyc differ
diff --git a/models/official/__pycache__/__init__.cpython-39.pyc b/models/official/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..745a3d49d03eae81040f19e1b522ab1bf63ced87
Binary files /dev/null and b/models/official/__pycache__/__init__.cpython-39.pyc differ
diff --git a/models/official/benchmark/__init__.py b/models/official/benchmark/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/benchmark/benchmark_wrappers.py b/models/official/benchmark/benchmark_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d38b690c7865e0ab560e59422a2454e44be052d
--- /dev/null
+++ b/models/official/benchmark/benchmark_wrappers.py
@@ -0,0 +1,97 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utils to annotate and trace benchmarks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+from absl import logging
+from absl.testing import flagsaver
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_multi_string(
+ 'benchmark_method_flags', None,
+ 'Optional list of runtime flags of the form key=value. Specify '
+ 'multiple times to specify different flags. These will override the FLAGS '
+ 'object directly after hardcoded settings in individual benchmark methods '
+ 'before they call _run_and_report benchmark. Example if we set '
+ '--benchmark_method_flags=train_steps=10 and a benchmark method hardcodes '
+ 'FLAGS.train_steps=10000 and later calls _run_and_report_benchmark, '
+ 'it\'ll only run for 10 steps. This is useful for '
+ 'debugging/profiling workflows.')
+
+
+def enable_runtime_flags(decorated_func):
+ """Sets attributes from --benchmark_method_flags for method execution.
+
+ @enable_runtime_flags decorator temporarily adds flags passed in via
+ --benchmark_method_flags and runs the decorated function in that context.
+
+ A user can set --benchmark_method_flags=train_steps=5 to run the benchmark
+ method in the snippet below with FLAGS.train_steps=5 for debugging (without
+ modifying the benchmark code).
+
+ class ModelBenchmark():
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self):
+ # run benchmark ...
+ # report benchmark results ...
+
+ def benchmark_method(self):
+ FLAGS.train_steps = 1000
+ ...
+ self._run_and_report_benchmark()
+
+ Args:
+ decorated_func: The method that runs the benchmark after previous setup
+ execution that set some flags.
+
+ Returns:
+ new_func: The same method which executes in a temporary context where flag
+ overrides from --benchmark_method_flags are active.
+ """
+
+ def runner(*args, **kwargs):
+ """Creates a temporary context to activate --benchmark_method_flags."""
+ if FLAGS.benchmark_method_flags:
+ saved_flag_values = flagsaver.save_flag_values()
+ for key_value in FLAGS.benchmark_method_flags:
+ key, value = key_value.split('=', 1)
+ try:
+ numeric_float = float(value)
+ numeric_int = int(numeric_float)
+ if abs(numeric_int) == abs(numeric_float):
+ flag_value = numeric_int
+ else:
+ flag_value = numeric_float
+ except ValueError:
+ flag_value = value
+ logging.info('Setting --%s=%s', key, flag_value)
+ setattr(FLAGS, key, flag_value)
+ else:
+ saved_flag_values = None
+ try:
+ result = decorated_func(*args, **kwargs)
+ return result
+ finally:
+ if saved_flag_values:
+ flagsaver.restore_flag_values(saved_flag_values)
+
+ return runner
diff --git a/models/official/benchmark/bert_benchmark.py b/models/official/benchmark/bert_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..35daac672ebe87434e99db8c7c3bbcc67a8061e4
--- /dev/null
+++ b/models/official/benchmark/bert_benchmark.py
@@ -0,0 +1,365 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes BERT benchmarks and accuracy tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import json
+import math
+import os
+import time
+
+# pylint: disable=g-bad-import-order
+from absl import flags
+from absl.testing import flagsaver
+import tensorflow as tf
+# pylint: enable=g-bad-import-order
+
+from official.benchmark import bert_benchmark_utils as benchmark_utils
+from official.benchmark import owner_utils
+from official.nlp.bert import configs
+from official.nlp.bert import run_classifier
+from official.utils.misc import distribution_utils
+from official.benchmark import benchmark_wrappers
+
+# pylint: disable=line-too-long
+PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
+CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record'
+CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record'
+CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data'
+MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
+# pylint: enable=line-too-long
+
+TMP_DIR = os.getenv('TMPDIR')
+FLAGS = flags.FLAGS
+
+
+class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
+ """Base class to hold methods common to test classes in the module."""
+
+ def __init__(self, output_dir=None, tpu=None):
+ super(BertClassifyBenchmarkBase, self).__init__(output_dir, tpu=tpu)
+ self.num_epochs = None
+ self.num_steps_per_epoch = None
+ FLAGS.steps_per_loop = 1
+
+ @flagsaver.flagsaver
+ def _run_bert_classifier(self, callbacks=None, use_ds=True):
+ """Starts BERT classification task."""
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
+ input_meta_data = json.loads(reader.read().decode('utf-8'))
+
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
+ if self.num_steps_per_epoch:
+ steps_per_epoch = self.num_steps_per_epoch
+ else:
+ train_data_size = input_meta_data['train_data_size']
+ steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
+ warmup_steps = int(epochs * steps_per_epoch * 0.1)
+ eval_steps = int(
+ math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
+ if self.tpu:
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy='tpu', tpu_address=self.tpu)
+ else:
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy='mirrored' if use_ds else 'off',
+ num_gpus=self.num_gpus)
+
+ max_seq_length = input_meta_data['max_seq_length']
+ train_input_fn = run_classifier.get_dataset_fn(
+ FLAGS.train_data_path,
+ max_seq_length,
+ FLAGS.train_batch_size,
+ is_training=True)
+ eval_input_fn = run_classifier.get_dataset_fn(
+ FLAGS.eval_data_path,
+ max_seq_length,
+ FLAGS.eval_batch_size,
+ is_training=False)
+ _, summary = run_classifier.run_bert_classifier(
+ strategy,
+ bert_config,
+ input_meta_data,
+ FLAGS.model_dir,
+ epochs,
+ steps_per_epoch,
+ FLAGS.steps_per_loop,
+ eval_steps,
+ warmup_steps,
+ FLAGS.learning_rate,
+ FLAGS.init_checkpoint,
+ train_input_fn,
+ eval_input_fn,
+ training_callbacks=False,
+ custom_callbacks=callbacks)
+ return summary
+
+
+class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
+ """Short benchmark performance tests for BERT model.
+
+ Tests BERT classification performance in different GPU, TPU configurations.
+ The naming convention of below test cases follow
+ `benchmark_(number of gpus)_gpu_(dataset type)` for GPUs and
+ `benchmark_(topology)_tpu_(dataset type)` for TPUs.
+ """
+
+ def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
+ super(BertClassifyBenchmarkReal, self).__init__(
+ output_dir=output_dir, tpu=tpu)
+
+ self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
+ self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
+ self.bert_config_file = MODEL_CONFIG_FILE_PATH
+ self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
+
+ # Since we only care about performance metrics, we limit
+ # the number of training steps and epochs to prevent unnecessarily
+ # long tests.
+ self.num_steps_per_epoch = 100
+ self.num_epochs = 1
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ training_summary_path,
+ min_accuracy=0,
+ max_accuracy=1,
+ use_ds=True):
+ """Starts BERT performance benchmark test."""
+ start_time_sec = time.time()
+ summary = self._run_bert_classifier(
+ callbacks=[self.timer_callback], use_ds=use_ds)
+ wall_time_sec = time.time() - start_time_sec
+
+ # Since we do not load from any pretrained checkpoints, we ignore all
+ # accuracy metrics.
+ summary.pop('eval_metrics', None)
+ summary['start_time_sec'] = start_time_sec
+
+ super(BertClassifyBenchmarkReal, self)._report_benchmark(
+ stats=summary,
+ wall_time_sec=wall_time_sec,
+ min_accuracy=min_accuracy,
+ max_accuracy=max_accuracy)
+
+ def benchmark_1_gpu_mrpc(self):
+ """Test BERT model performance with 1 GPU."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc')
+ FLAGS.train_data_path = self.train_data_path
+ FLAGS.eval_data_path = self.eval_data_path
+ FLAGS.input_meta_data_path = self.input_meta_data_path
+ FLAGS.bert_config_file = self.bert_config_file
+ FLAGS.train_batch_size = 4
+ FLAGS.eval_batch_size = 4
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+ def benchmark_1_gpu_mrpc_xla(self):
+ """Test BERT model performance with 1 GPU."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_xla')
+ FLAGS.train_data_path = self.train_data_path
+ FLAGS.eval_data_path = self.eval_data_path
+ FLAGS.input_meta_data_path = self.input_meta_data_path
+ FLAGS.bert_config_file = self.bert_config_file
+ FLAGS.train_batch_size = 4
+ FLAGS.eval_batch_size = 4
+ FLAGS.enable_xla = True
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+ def benchmark_1_gpu_mrpc_no_dist_strat(self):
+ """Test BERT model performance with 1 GPU, no distribution strategy."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_no_dist_strat')
+ FLAGS.train_data_path = self.train_data_path
+ FLAGS.eval_data_path = self.eval_data_path
+ FLAGS.input_meta_data_path = self.input_meta_data_path
+ FLAGS.bert_config_file = self.bert_config_file
+ FLAGS.train_batch_size = 4
+ FLAGS.eval_batch_size = 4
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path, use_ds=False)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_8_gpu_mrpc(self):
+ """Test BERT model performance with 8 GPUs."""
+
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
+ FLAGS.train_data_path = self.train_data_path
+ FLAGS.eval_data_path = self.eval_data_path
+ FLAGS.input_meta_data_path = self.input_meta_data_path
+ FLAGS.bert_config_file = self.bert_config_file
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+ def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
+ """Performance for 1 GPU no DS with automatic mixed precision."""
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_amp_mrpc_no_dist_strat')
+ FLAGS.train_data_path = self.train_data_path
+ FLAGS.eval_data_path = self.eval_data_path
+ FLAGS.input_meta_data_path = self.input_meta_data_path
+ FLAGS.bert_config_file = self.bert_config_file
+ FLAGS.train_batch_size = 4
+ FLAGS.eval_batch_size = 4
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path, use_ds=False)
+
+ def benchmark_8_gpu_amp_mrpc(self):
+ """Test BERT model performance with 8 GPUs with automatic mixed precision."""
+
+ self._setup()
+ self.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_mrpc')
+ FLAGS.train_data_path = self.train_data_path
+ FLAGS.eval_data_path = self.eval_data_path
+ FLAGS.input_meta_data_path = self.input_meta_data_path
+ FLAGS.bert_config_file = self.bert_config_file
+ FLAGS.train_batch_size = 32
+ FLAGS.eval_batch_size = 32
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path, use_ds=False)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_2x2_tpu_mrpc(self):
+ """Test BERT model performance with 2x2 TPU."""
+
+ self._setup()
+ FLAGS.steps_per_loop = 50
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
+ FLAGS.train_data_path = self.train_data_path
+ FLAGS.eval_data_path = self.eval_data_path
+ FLAGS.input_meta_data_path = self.input_meta_data_path
+ FLAGS.bert_config_file = self.bert_config_file
+ FLAGS.train_batch_size = 32
+ FLAGS.eval_batch_size = 32
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path, use_ds=False)
+
+
+class BertClassifyAccuracy(BertClassifyBenchmarkBase):
+ """Short accuracy test for BERT model.
+
+ Tests BERT classification task model accuracy. The naming
+ convention of below test cases follow
+ `benchmark_(number of gpus)_gpu_(dataset type)` format.
+ """
+
+ def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
+ self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
+ self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
+ self.bert_config_file = MODEL_CONFIG_FILE_PATH
+ self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
+ self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
+
+ super(BertClassifyAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ training_summary_path,
+ min_accuracy=0.84,
+ max_accuracy=0.88):
+ """Starts BERT accuracy benchmark test."""
+
+ start_time_sec = time.time()
+ summary = self._run_bert_classifier(callbacks=[self.timer_callback])
+ wall_time_sec = time.time() - start_time_sec
+
+ super(BertClassifyAccuracy, self)._report_benchmark(
+ stats=summary,
+ wall_time_sec=wall_time_sec,
+ min_accuracy=min_accuracy,
+ max_accuracy=max_accuracy)
+
+ def _setup(self):
+ super(BertClassifyAccuracy, self)._setup()
+ FLAGS.train_data_path = self.train_data_path
+ FLAGS.eval_data_path = self.eval_data_path
+ FLAGS.input_meta_data_path = self.input_meta_data_path
+ FLAGS.bert_config_file = self.bert_config_file
+ FLAGS.init_checkpoint = self.pretrained_checkpoint_path
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_8_gpu_mrpc(self):
+ """Run BERT model accuracy test with 8 GPUs.
+
+ Due to comparatively small cardinality of MRPC dataset, training
+ accuracy metric has high variance between trainings. As so, we
+ set the wide range of allowed accuracy (84% to 88%).
+ """
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+ def benchmark_8_gpu_mrpc_xla(self):
+ """Run BERT model accuracy test with 8 GPUs with XLA."""
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
+ FLAGS.enable_xla = True
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_2x2_tpu_mrpc(self):
+ """Run BERT model accuracy test on 2x2 TPU."""
+ self._setup()
+ FLAGS.steps_per_loop = 50
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/bert_benchmark_utils.py b/models/official/benchmark/bert_benchmark_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..705a243315616080fe15c70925ed74a905818cdc
--- /dev/null
+++ b/models/official/benchmark/bert_benchmark_utils.py
@@ -0,0 +1,127 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions or classes shared between BERT benchmarks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+# pylint: disable=g-bad-import-order
+import numpy as np
+from absl import flags
+import tensorflow as tf
+# pylint: enable=g-bad-import-order
+
+from official.utils.flags import core as flags_core
+from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
+
+FLAGS = flags.FLAGS
+
+
+class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
+ """Callback that records time it takes to run each batch."""
+
+ def __init__(self, num_batches_to_skip=10):
+ super(BenchmarkTimerCallback, self).__init__()
+ self.batch_start_times = {}
+ self.batch_stop_times = {}
+
+ def on_batch_begin(self, batch, logs=None):
+ self.batch_start_times[batch] = time.time()
+
+ def on_batch_end(self, batch, logs=None):
+ # If there are multiple steps_per_loop, the end batch index will not be the
+ # same as the starting index. Use the last starting index instead.
+ if batch not in self.batch_start_times:
+ batch = max(self.batch_start_times.keys())
+
+ self.batch_stop_times[batch] = time.time()
+
+ def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
+ batch_durations = []
+ for batch in self.batch_start_times:
+ if batch in self.batch_stop_times and batch >= num_batches_to_skip:
+ batch_durations.append(self.batch_stop_times[batch] -
+ self.batch_start_times[batch])
+ return batch_size / np.mean(batch_durations)
+
+ def get_startup_time(self, program_start_time):
+ return self.batch_start_times[0] - program_start_time
+
+
+class BertBenchmarkBase(PerfZeroBenchmark):
+ """Base class to hold methods common to test classes."""
+ local_flags = None
+
+ def __init__(self, output_dir=None, tpu=None, **kwargs):
+ super(BertBenchmarkBase, self).__init__(
+ output_dir=output_dir, tpu=tpu, **kwargs)
+ self.num_gpus = 8
+ self.timer_callback = None
+
+ def _setup(self):
+ """Sets up and resets flags before each test."""
+ super(BertBenchmarkBase, self)._setup()
+ self.timer_callback = BenchmarkTimerCallback()
+
+ def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
+ """Report benchmark results by writing to local protobuf file.
+
+ Args:
+ stats: dict returned from BERT models with known entries.
+ wall_time_sec: the during of the benchmark execution in seconds
+ min_accuracy: Minimum classification accuracy constraint to verify
+ correctness of the model.
+ max_accuracy: Maximum classification accuracy constraint to verify
+ correctness of the model.
+ """
+ metrics = [{
+ 'name': 'training_loss',
+ 'value': stats['train_loss'],
+ }]
+ if self.timer_callback:
+ metrics.append({
+ 'name':
+ 'exp_per_second',
+ 'value':
+ self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
+ FLAGS.steps_per_loop)
+ })
+ else:
+ metrics.append({
+ 'name': 'exp_per_second',
+ 'value': 0.0,
+ })
+ if self.timer_callback and 'start_time_sec' in stats:
+ metrics.append({
+ 'name': 'startup_time',
+ 'value': self.timer_callback.get_startup_time(stats['start_time_sec'])
+ })
+
+ if 'eval_metrics' in stats:
+ metrics.append({
+ 'name': 'eval_accuracy',
+ 'value': stats['eval_metrics'],
+ 'min_value': min_accuracy,
+ 'max_value': max_accuracy,
+ })
+ flags_str = flags_core.get_nondefault_flags_as_str()
+ self.report_benchmark(
+ iters=stats['total_training_steps'],
+ wall_time=wall_time_sec,
+ metrics=metrics,
+ extras={'flags': flags_str})
diff --git a/models/official/benchmark/bert_pretrain_benchmark.py b/models/official/benchmark/bert_pretrain_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..d63c894847d8e9e9308523d3efcb06c162d323c6
--- /dev/null
+++ b/models/official/benchmark/bert_pretrain_benchmark.py
@@ -0,0 +1,179 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes benchmark testing for bert pretraining."""
+# pylint: disable=line-too-long
+from __future__ import print_function
+
+import json
+import os
+import time
+from typing import Optional
+
+from absl import flags
+from absl import logging
+import tensorflow as tf # pylint: disable=g-bad-import-order
+
+from official.benchmark import benchmark_wrappers
+from official.benchmark import bert_benchmark_utils
+from official.benchmark import owner_utils
+from official.nlp.bert import run_pretraining
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+
+# Pretrain masked lanauge modeling accuracy range:
+MIN_MLM_ACCURACY = 0.635
+MAX_MLM_ACCURACY = 0.645
+
+# Pretrain next sentence prediction accuracy range:
+MIN_NSP_ACCURACY = 0.94
+MAX_NSP_ACCURACY = 0.96
+
+BERT_PRETRAIN_FILES_SEQ128 = 'gs://mlcompass-data/bert/pretraining_data/seq_128/wikipedia.tfrecord*,gs://mlcompass-data/bert/pretraining_data/seq_128/books.tfrecord*'
+BERT_BASE_CONFIG_FILE = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json'
+
+FLAGS = flags.FLAGS
+
+
+class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
+ """Benchmark accuracy tests for BERT Pretraining."""
+
+ def __init__(self,
+ output_dir: Optional[str] = None,
+ tpu: Optional[str] = None,
+ **kwargs):
+ """Inits BertPretrainAccuracyBenchmark class.
+
+ Args:
+ output_dir: Directory where to output e.g. log files
+ tpu: TPU name to use in a TPU benchmark.
+ **kwargs: Additional keyword arguments.
+ """
+ super(BertPretrainAccuracyBenchmark, self).__init__(
+ output_dir=output_dir, tpu=tpu, **kwargs)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool):
+ """Runs and reports the benchmark given the provided configuration."""
+ distribution = distribution_utils.get_distribution_strategy(
+ distribution_strategy='tpu', tpu_address=self.tpu)
+ logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str())
+ start_time_sec = time.time()
+ run_pretraining.run_bert_pretrain(
+ strategy=distribution, custom_callbacks=self.timer_callback)
+ wall_time_sec = time.time() - start_time_sec
+
+ with tf.io.gfile.GFile(summary_path, 'rb') as reader:
+ summary = json.loads(reader.read().decode('utf-8'))
+ self._report_benchmark(summary, start_time_sec, wall_time_sec,
+ report_accuracy)
+
+ def _report_benchmark(self, summary, start_time_sec, wall_time_sec,
+ report_accuracy):
+ metrics = [{
+ 'name': 'train_loss',
+ 'value': summary['train_loss'],
+ }, {
+ 'name':
+ 'exp_per_second',
+ 'value':
+ self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
+ FLAGS.steps_per_loop)
+ }, {
+ 'name': 'startup_time',
+ 'value': self.timer_callback.get_startup_time(start_time_sec)
+ }]
+ if report_accuracy:
+ metrics.extend([{
+ 'name': 'masked_lm_accuracy',
+ 'value': summary['masked_lm_accuracy'],
+ 'min_value': MIN_MLM_ACCURACY,
+ 'max_value': MAX_MLM_ACCURACY,
+ }, {
+ 'name': 'next_sentence_accuracy',
+ 'value': summary['next_sentence_accuracy'],
+ 'min_value': MIN_NSP_ACCURACY,
+ 'max_value': MAX_NSP_ACCURACY,
+ }])
+ self.report_benchmark(
+ iters=summary['total_training_steps'],
+ wall_time=wall_time_sec,
+ metrics=metrics,
+ extras={'flags': flags_core.get_nondefault_flags_as_str()})
+
+ def _specify_common_flags(self):
+ FLAGS.bert_config_file = BERT_BASE_CONFIG_FILE
+ FLAGS.train_batch_size = 512
+ FLAGS.learning_rate = 1e-4
+ FLAGS.warmup_steps = 10000
+ FLAGS.steps_per_loop = 10000
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.input_files = BERT_PRETRAIN_FILES_SEQ128
+ FLAGS.max_seq_length = 128
+ FLAGS.max_predictions_per_seq = 20
+ FLAGS.dtype = 'bf16'
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps(self):
+ """Test bert pretraining with 8x8 TPU for 500k steps."""
+ # This is used for accuracy test.
+ self._setup()
+ self._specify_common_flags()
+ FLAGS.num_steps_per_epoch = 500000
+ FLAGS.num_train_epochs = 1
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps')
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ # Set train_summary_interval to -1 to disable training summary, because
+ # writing summary to gcs may fail and summaries are not needed for this
+ # accuracy benchmark test.
+ FLAGS.train_summary_interval = -1
+ self._run_and_report_benchmark(summary_path=summary_path,
+ report_accuracy=True)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
+ """Test bert pretraining with 4x4 TPU for 10000 steps."""
+ self._setup()
+ self._specify_common_flags()
+ FLAGS.num_steps_per_epoch = 5000
+ FLAGS.num_train_epochs = 2
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps')
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ # Disable accuracy check.
+ self._run_and_report_benchmark(
+ summary_path=summary_path, report_accuracy=False)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
+ """Test bert pretraining with 8x8 TPU for 10000 steps."""
+ self._setup()
+ self._specify_common_flags()
+ FLAGS.num_steps_per_epoch = 5000
+ FLAGS.num_train_epochs = 2
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_perf_8x8_tpu_bf16_seq128_10k_steps')
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ # Disable accuracy check.
+ self._run_and_report_benchmark(summary_path=summary_path,
+ report_accuracy=False)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/bert_squad_benchmark.py b/models/official/benchmark/bert_squad_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..dab90a485b9a2c22d11da82ec1d9c320ea0db114
--- /dev/null
+++ b/models/official/benchmark/bert_squad_benchmark.py
@@ -0,0 +1,608 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes BERT SQuAD benchmarks and accuracy tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import time
+
+# pylint: disable=g-bad-import-order
+from absl import flags
+from absl import logging
+from absl.testing import flagsaver
+import tensorflow as tf
+# pylint: enable=g-bad-import-order
+
+from official.benchmark import bert_benchmark_utils as benchmark_utils
+from official.benchmark import owner_utils
+from official.nlp.bert import run_squad
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+from official.benchmark import benchmark_wrappers
+
+
+# pylint: disable=line-too-long
+PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
+SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
+SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
+SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
+SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data'
+SQUAD_LONG_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_long_meta_data'
+SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data'
+MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
+# pylint: enable=line-too-long
+
+TMP_DIR = os.getenv('TMPDIR')
+FLAGS = flags.FLAGS
+
+
+class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
+ """Base class to hold methods common to test classes in the module."""
+
+ def __init__(self, output_dir=None, tpu=None):
+ super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
+
+ def _read_training_summary_from_file(self):
+ """Reads the training summary from a file."""
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ with tf.io.gfile.GFile(summary_path, 'rb') as reader:
+ return json.loads(reader.read().decode('utf-8'))
+
+ def _read_input_meta_data_from_file(self):
+ """Reads the input metadata from a file."""
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
+ return json.loads(reader.read().decode('utf-8'))
+
+ def _get_distribution_strategy(self, ds_type='mirrored'):
+ """Gets the distribution strategy.
+
+ Args:
+ ds_type: String, the distribution strategy type to be used. Can be
+ 'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
+
+ Returns:
+ A `tf.distribute.DistibutionStrategy` object.
+ """
+ if self.tpu or ds_type == 'tpu':
+ return distribution_utils.get_distribution_strategy(
+ distribution_strategy='tpu', tpu_address=self.tpu)
+ elif ds_type == 'multi_worker_mirrored':
+ # Configures cluster spec for multi-worker distribution strategy.
+ _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
+ FLAGS.task_index)
+ return distribution_utils.get_distribution_strategy(
+ distribution_strategy=ds_type,
+ num_gpus=self.num_gpus,
+ all_reduce_alg=FLAGS.all_reduce_alg)
+
+ def _init_gpu_and_data_threads(self):
+ """Set env variables before any TF calls."""
+ if FLAGS.tf_gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=FLAGS.per_gpu_thread_count,
+ gpu_thread_mode=FLAGS.tf_gpu_thread_mode,
+ num_gpus=self.num_gpus,
+ datasets_num_private_threads=FLAGS.datasets_num_private_threads)
+
+ @flagsaver.flagsaver
+ def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
+ """Runs BERT SQuAD training. Uses mirrored strategy by default."""
+ self._init_gpu_and_data_threads()
+ input_meta_data = self._read_input_meta_data_from_file()
+ strategy = self._get_distribution_strategy(ds_type)
+
+ run_squad.train_squad(
+ strategy=strategy,
+ input_meta_data=input_meta_data,
+ run_eagerly=run_eagerly,
+ custom_callbacks=[self.timer_callback])
+
+ @flagsaver.flagsaver
+ def _evaluate_squad(self, ds_type='mirrored'):
+ """Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
+ self._init_gpu_and_data_threads()
+ input_meta_data = self._read_input_meta_data_from_file()
+ strategy = self._get_distribution_strategy(ds_type)
+
+ if input_meta_data.get('version_2_with_negative', False):
+ logging.error('In memory evaluation result for SQuAD v2 is not accurate')
+ eval_metrics = run_squad.eval_squad(strategy=strategy,
+ input_meta_data=input_meta_data)
+ # Use F1 score as reported evaluation metric.
+ self.eval_metrics = eval_metrics['final_f1']
+
+
+class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
+ """Short benchmark performance tests for BERT SQuAD model.
+
+ Tests BERT SQuAD performance in different GPU configurations.
+ The naming convention of below test cases follow
+ `benchmark_(number of gpus)_gpu` format for GPUs and
+ `benchmark_(topology)_tpu` format for TPUs.
+ """
+
+ def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
+ super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu)
+
+ def _setup(self):
+ """Sets up the benchmark and SQuAD flags."""
+ super(BertSquadBenchmarkReal, self)._setup()
+ FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
+ FLAGS.predict_file = SQUAD_PREDICT_FILE
+ FLAGS.vocab_file = SQUAD_VOCAB_FILE
+ FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
+ FLAGS.num_train_epochs = 1
+ FLAGS.steps_per_loop = 100
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ run_eagerly=False,
+ ds_type='mirrored'):
+ """Runs the benchmark and reports various metrics."""
+ if FLAGS.train_batch_size <= 4 or run_eagerly:
+ FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
+ else:
+ FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
+ start_time_sec = time.time()
+ self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
+ wall_time_sec = time.time() - start_time_sec
+
+ summary = self._read_training_summary_from_file()
+ summary['start_time_sec'] = start_time_sec
+
+ super(BertSquadBenchmarkReal, self)._report_benchmark(
+ stats=summary,
+ wall_time_sec=wall_time_sec,
+ min_accuracy=0,
+ max_accuracy=1)
+
+ def benchmark_1_gpu(self):
+ """Tests BERT SQuAD model performance with 1 GPU."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad')
+ FLAGS.train_batch_size = 4
+
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_eager(self):
+ """Tests BERT SQuAD model performance with 1 GPU."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
+ FLAGS.train_batch_size = 2
+
+ self._run_and_report_benchmark(run_eagerly=True)
+
+ def benchmark_1_gpu_xla(self):
+ """Tests BERT SQuAD model performance with 1 GPU with XLA."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad')
+ # XLA runs out of memory when running with batch size 4.
+ FLAGS.train_batch_size = 3
+ FLAGS.enable_xla = True
+
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat(self):
+ """Tests BERT SQuAD model performance with 1 GPU without DS."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
+ FLAGS.train_batch_size = 4
+
+ self._run_and_report_benchmark(ds_type='off')
+
+ def benchmark_1_gpu_eager_no_dist_strat(self):
+ """Tests BERT SQuAD model performance with 1 GPU with eager execution."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_eager_no_dist_strat_squad')
+ FLAGS.train_batch_size = 4
+
+ self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_8_gpu(self):
+ """Tests BERT SQuAD model performance with 8 GPUs."""
+
+ self._setup()
+ self.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
+ FLAGS.train_batch_size = 24
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_fp16_eager(self):
+ """Tests BERT SQuAD model performance with 1 GPU and FP16."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16_eager')
+ FLAGS.train_batch_size = 4
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 'dynamic'
+
+ self._run_and_report_benchmark(run_eagerly=True)
+
+ def benchmark_1_gpu_fp16(self):
+ """Tests BERT SQuAD model performance with 1 GPU and FP16."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16')
+ FLAGS.train_batch_size = 4
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 'dynamic'
+
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_xla_fp16(self):
+ """Tests BERT SQuAD model performance with 1 GPU with XLA and FP16."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad_fp16')
+ FLAGS.train_batch_size = 4
+ FLAGS.enable_xla = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 'dynamic'
+
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_fp16(self):
+ """Tests BERT SQuAD model performance with 8 GPUs."""
+
+ self._setup()
+ self.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
+ FLAGS.train_batch_size = 32
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 'dynamic'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_xla_fp16(self):
+ """Tests BERT SQuAD model performance with 8 GPUs with XLA."""
+
+ self._setup()
+ self.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
+ FLAGS.train_batch_size = 32
+ FLAGS.enable_xla = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 'dynamic'
+
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_amp(self):
+ """Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp_squad')
+ FLAGS.train_batch_size = 4
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_amp(self):
+ """Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
+
+ self._setup()
+ self.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_squad')
+ FLAGS.train_batch_size = 32
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+
+ self._run_and_report_benchmark()
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_2x2_tpu(self):
+ """Tests BERT SQuAD model performance with 2x2 TPU."""
+
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
+ FLAGS.train_batch_size = 48
+ FLAGS.predict_batch_size = 48
+ FLAGS.mode = 'train'
+ FLAGS.learning_rate = 8e-5
+ FLAGS.num_train_epochs = 1
+ FLAGS.steps_per_loop = 100
+ FLAGS.do_lower_case = True
+ FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
+ self._run_and_report_benchmark()
+
+
+class BertSquadAccuracy(BertSquadBenchmarkBase):
+ """Short accuracy test for BERT SQuAD model.
+
+ Tests BERT SQuAD accuracy. The naming convention of below test cases follow
+ `benchmark_(number of gpus)_gpu` format for GPUs and
+ `benchmark_(topology)_tpu` format for TPUs.
+ """
+
+ def __init__(self, output_dir=None, tpu=None, **kwargs):
+ super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
+
+ def _setup(self):
+ """Sets up the benchmark and SQuAD flags."""
+ super(BertSquadAccuracy, self)._setup()
+ FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
+ FLAGS.predict_file = SQUAD_PREDICT_FILE
+ FLAGS.vocab_file = SQUAD_VOCAB_FILE
+ FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
+ FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
+ FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
+ FLAGS.num_train_epochs = 2
+ FLAGS.steps_per_loop = 100
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ run_eagerly=False,
+ ds_type='mirrored'):
+ """Runs the benchmark and reports various metrics."""
+ start_time_sec = time.time()
+ self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
+ self._evaluate_squad(ds_type=ds_type)
+ wall_time_sec = time.time() - start_time_sec
+
+ summary = self._read_training_summary_from_file()
+ summary['eval_metrics'] = self.eval_metrics
+ summary['start_time_sec'] = start_time_sec
+
+ super(BertSquadAccuracy, self)._report_benchmark(
+ stats=summary,
+ wall_time_sec=wall_time_sec,
+ min_accuracy=0.900,
+ max_accuracy=0.920)
+
+ def benchmark_1_gpu_eager(self):
+ """Tests BERT SQuAD model accuracy with 1 GPU with eager execution."""
+
+ self._setup()
+ self.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
+ FLAGS.train_batch_size = 4
+
+ self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_8_gpu(self):
+ """Tests BERT SQuAD model accuracy with 8 GPUs."""
+
+ self._setup()
+ self.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
+ FLAGS.train_batch_size = 24
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_fp16(self):
+ """Tests BERT SQuAD model accuracy with 8 GPUs and FP16."""
+
+ self._setup()
+ self.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
+ FLAGS.train_batch_size = 32
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 'dynamic'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_xla(self):
+ """Tests BERT SQuAD model accuracy with 8 GPUs."""
+
+ self._setup()
+ self.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_xla')
+ FLAGS.train_batch_size = 32
+ FLAGS.enable_xla = True
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+
+ self._run_and_report_benchmark()
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_2x2_tpu(self):
+ """Tests BERT SQuAD model accuracy with 2x2 TPU."""
+
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
+ FLAGS.train_batch_size = 48
+
+ self._run_and_report_benchmark()
+
+
+class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
+ """BERT SQuAD distributed accuracy tests with multiple workers."""
+
+ def __init__(self, output_dir=None, tpu=None, **kwargs):
+ super(BertSquadMultiWorkerAccuracy, self).__init__(
+ output_dir=output_dir, tpu=tpu)
+
+ def _setup(self):
+ """Sets up the benchmark and SQuAD flags."""
+ super(BertSquadMultiWorkerAccuracy, self)._setup()
+ FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
+ FLAGS.predict_file = SQUAD_PREDICT_FILE
+ FLAGS.vocab_file = SQUAD_VOCAB_FILE
+ FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
+ FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
+ FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
+ FLAGS.num_train_epochs = 2
+ FLAGS.steps_per_loop = 100
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ use_ds=True,
+ run_eagerly=False):
+ """Runs the benchmark and reports various metrics."""
+ start_time_sec = time.time()
+ self._train_squad(run_eagerly=run_eagerly,
+ ds_type='multi_worker_mirrored')
+ self._evaluate_squad(ds_type='multi_worker_mirrored')
+ wall_time_sec = time.time() - start_time_sec
+
+ summary = self._read_training_summary_from_file()
+ summary['eval_metrics'] = self.eval_metrics
+
+ super(BertSquadMultiWorkerAccuracy, self)._report_benchmark(
+ stats=summary,
+ wall_time_sec=wall_time_sec,
+ min_accuracy=0.900,
+ max_accuracy=0.920)
+
+ def _benchmark_common(self, num_workers, all_reduce_alg):
+ """Common to all benchmarks in this class."""
+ self._setup()
+
+ num_gpus = 8
+ FLAGS.num_gpus = num_gpus
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_xla = False
+ FLAGS.distribution_strategy = 'multi_worker_mirrored'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.datasets_num_private_threads = 32
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
+ num_workers, all_reduce_alg))
+ FLAGS.train_batch_size = 4 * num_gpus * num_workers
+ FLAGS.all_reduce_alg = all_reduce_alg
+
+ self._run_and_report_benchmark()
+
+ def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
+ """8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
+ self._benchmark_common(num_workers=2, all_reduce_alg='ring')
+
+ def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
+ """8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
+ self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
+
+ def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
+ """8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
+ self._benchmark_common(num_workers=8, all_reduce_alg='ring')
+
+ def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
+ """8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
+ self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
+
+
+class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
+ """BERT SQuAD distributed benchmark tests with multiple workers."""
+
+ def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
+ super(BertSquadMultiWorkerBenchmark, self).__init__(
+ output_dir=output_dir, tpu=tpu)
+
+ def _setup(self):
+ """Sets up the benchmark and SQuAD flags."""
+ super(BertSquadMultiWorkerBenchmark, self)._setup()
+ FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
+ FLAGS.predict_file = SQUAD_PREDICT_FILE
+ FLAGS.vocab_file = SQUAD_VOCAB_FILE
+ FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
+ FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
+ FLAGS.num_train_epochs = 1
+ FLAGS.steps_per_loop = 100
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ use_ds=True,
+ run_eagerly=False):
+ """Runs the benchmark and reports various metrics."""
+ if FLAGS.train_batch_size <= 4 * 8:
+ FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
+ else:
+ FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
+ start_time_sec = time.time()
+ self._train_squad(run_eagerly=run_eagerly,
+ ds_type='multi_worker_mirrored')
+ wall_time_sec = time.time() - start_time_sec
+
+ summary = self._read_training_summary_from_file()
+ summary['start_time_sec'] = start_time_sec
+
+ super(BertSquadMultiWorkerBenchmark, self)._report_benchmark(
+ stats=summary,
+ wall_time_sec=wall_time_sec,
+ min_accuracy=0,
+ max_accuracy=1)
+
+ def _benchmark_common(self, num_workers, all_reduce_alg):
+ """Common to all benchmarks in this class."""
+ self._setup()
+
+ num_gpus = 8
+ FLAGS.num_gpus = num_gpus
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_xla = False
+ FLAGS.distribution_strategy = 'multi_worker_mirrored'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.datasets_num_private_threads = 32
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
+ num_workers, all_reduce_alg))
+ FLAGS.train_batch_size = 4 * num_gpus * num_workers
+ FLAGS.all_reduce_alg = all_reduce_alg
+
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_1_worker_fp16_ring_tweaked(self):
+ """8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
+ self._benchmark_common(num_workers=1, all_reduce_alg='ring')
+
+ def benchmark_8_gpu_1_worker_fp16_nccl_tweaked(self):
+ """8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
+ self._benchmark_common(num_workers=1, all_reduce_alg='nccl')
+
+ def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self):
+ """8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
+ self._benchmark_common(num_workers=2, all_reduce_alg='ring')
+
+ def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self):
+ """8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
+ self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
+
+ def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
+ """8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
+ self._benchmark_common(num_workers=8, all_reduce_alg='ring')
+
+ def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
+ """8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
+ self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/datastore/schema/benchmark_metric.json b/models/official/benchmark/datastore/schema/benchmark_metric.json
new file mode 100644
index 0000000000000000000000000000000000000000..cc571d480605241e7c71d2e4cabdaf6ad3da9295
--- /dev/null
+++ b/models/official/benchmark/datastore/schema/benchmark_metric.json
@@ -0,0 +1,56 @@
+[
+ {
+ "description": "The ID of the benchmark run, where this metric should tie to.",
+ "mode": "REQUIRED",
+ "name": "run_id",
+ "type": "STRING"
+ },
+ {
+ "description": "The name of the metric, which should be descriptive. E.g. training_loss, accuracy.",
+ "mode": "REQUIRED",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "description": "The unit of the metric. E.g. MB per sec.",
+ "mode": "NULLABLE",
+ "name": "unit",
+ "type": "STRING"
+ },
+ {
+ "description": "The value of the metric.",
+ "mode": "NULLABLE",
+ "name": "value",
+ "type": "FLOAT"
+ },
+ {
+ "description": "The timestamp when the metric is recorded.",
+ "mode": "REQUIRED",
+ "name": "timestamp",
+ "type": "TIMESTAMP"
+ },
+ {
+ "description": "The global step when this metric is recorded.",
+ "mode": "NULLABLE",
+ "name": "global_step",
+ "type": "INTEGER"
+ },
+ {
+ "description": "Free format metadata for the extra information about the metric.",
+ "mode": "REPEATED",
+ "name": "extras",
+ "type": "RECORD",
+ "fields": [
+ {
+ "mode": "NULLABLE",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "mode": "NULLABLE",
+ "name": "value",
+ "type": "STRING"
+ }
+ ]
+ }
+]
diff --git a/models/official/benchmark/datastore/schema/benchmark_run.json b/models/official/benchmark/datastore/schema/benchmark_run.json
new file mode 100644
index 0000000000000000000000000000000000000000..58e5ddcadeff98b05c328c2798071f9cd73ef9d2
--- /dev/null
+++ b/models/official/benchmark/datastore/schema/benchmark_run.json
@@ -0,0 +1,368 @@
+[
+ {
+ "description": "The UUID of the run for the benchmark.",
+ "mode": "REQUIRED",
+ "name": "model_id",
+ "type": "STRING"
+ },
+ {
+ "description": "The name of the model, E.g ResNet50, LeNet-5 etc.",
+ "mode": "REQUIRED",
+ "name": "model_name",
+ "type": "STRING"
+ },
+ {
+ "description": "The date when the test of the model is started",
+ "mode": "REQUIRED",
+ "name": "run_date",
+ "type": "TIMESTAMP"
+ },
+ {
+ "description": "The unique name for a test by the combination of key parameters, eg batch size, num of GPU, etc. It is hardware independent.",
+ "mode": "NULLABLE",
+ "name": "test_id",
+ "type": "STRING"
+ },
+ {
+ "description": "The tensorflow version information.",
+ "fields": [
+ {
+ "description": "Version of the tensorflow. E.g. 1.7.0-rc0",
+ "mode": "REQUIRED",
+ "name": "version",
+ "type": "STRING"
+ },
+ {
+ "description": "Git Hash of the tensorflow",
+ "mode": "NULLABLE",
+ "name": "git_hash",
+ "type": "STRING"
+ },
+ {
+ "description": "The channel of the tensorflow binary, eg, nightly, RC, final, custom.",
+ "mode": "NULLABLE",
+ "name": "channel",
+ "type": "STRING"
+ },
+ {
+ "description": "Identify anything special about the build, eg CUDA 10, NCCL, MKL, etc.",
+ "mode": "NULLABLE",
+ "name": "build_type",
+ "type": "STRING"
+ }
+ ],
+ "mode": "REQUIRED",
+ "name": "tensorflow_version",
+ "type": "RECORD"
+ },
+ {
+ "description": "The arbitrary attribute of the model.",
+ "fields": [
+ {
+ "description": "The name of the attribute.",
+ "mode": "REQUIRED",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "description": "The value of the attribute.",
+ "mode": "NULLABLE",
+ "name": "value",
+ "type": "STRING"
+ }
+ ],
+ "mode": "REPEATED",
+ "name": "attribute",
+ "type": "RECORD"
+ },
+ {
+ "description": "Environment variables when the benchmark run is executed.",
+ "fields": [
+ {
+ "description": "The name of the variable.",
+ "mode": "REQUIRED",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "description": "The value of the variable.",
+ "mode": "NULLABLE",
+ "name": "value",
+ "type": "STRING"
+ }
+ ],
+ "mode": "REPEATED",
+ "name": "environment_variable",
+ "type": "RECORD"
+ },
+ {
+ "description": "TF Environment variables when the benchmark run is executed.",
+ "fields": [
+ {
+ "description": "The name of the variable.",
+ "mode": "REQUIRED",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "description": "The value of the variable.",
+ "mode": "NULLABLE",
+ "name": "value",
+ "type": "STRING"
+ }
+ ],
+ "mode": "REPEATED",
+ "name": "tensorflow_environment_variables",
+ "type": "RECORD"
+ },
+ {
+ "description": "The list of parameters run with the model. It could contain hyperparameters or others.",
+ "fields": [
+ {
+ "description": "The name of the parameter.",
+ "mode": "REQUIRED",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "description": "The string value of the parameter.",
+ "mode": "NULLABLE",
+ "name": "string_value",
+ "type": "STRING"
+ },
+ {
+ "description": "The bool value of the parameter.",
+ "mode": "NULLABLE",
+ "name": "bool_value",
+ "type": "STRING"
+ },
+ {
+ "description": "The int/long value of the parameter.",
+ "mode": "NULLABLE",
+ "name": "long_value",
+ "type": "INTEGER"
+ },
+ {
+ "description": "The double/float value of parameter.",
+ "mode": "NULLABLE",
+ "name": "float_value",
+ "type": "FLOAT"
+ }
+ ],
+ "mode": "REPEATED",
+ "name": "run_parameters",
+ "type": "RECORD"
+ },
+ {
+ "description": "The dataset that run with the benchmark.",
+ "mode": "NULLABLE",
+ "name": "dataset",
+ "type": "RECORD",
+ "fields": [
+ {
+ "description": "The name of the dataset that the model is trained/validated with. E.g ImageNet, mnist.",
+ "mode": "REQUIRED",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "description": "The arbitrary attribute of the dataset.",
+ "fields": [
+ {
+ "description": "The name of the attribute.",
+ "mode": "REQUIRED",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "description": "The value of the attribute.",
+ "mode": "NULLABLE",
+ "name": "value",
+ "type": "STRING"
+ }
+ ],
+ "mode": "REPEATED",
+ "name": "attribute",
+ "type": "RECORD"
+ }
+ ]
+ },
+ {
+ "description": "Used to differentiate from AWS, GCE or DGX-1 at a high level",
+ "mode": "NULLABLE",
+ "name": "test_environment",
+ "type": "STRING"
+ },
+ {
+ "description": "The machine configuration of the benchmark run.",
+ "mode": "NULLABLE",
+ "name": "machine_config",
+ "type": "RECORD",
+ "fields": [
+ {
+ "description": "The platform information of the benchmark run.",
+ "mode": "NULLABLE",
+ "name": "platform_info",
+ "type": "RECORD",
+ "fields": [
+ {
+ "description": "Eg: 64bit.",
+ "mode": "NULLABLE",
+ "name": "bits",
+ "type": "STRING"
+ },
+ {
+ "description": "Eg: ELF.",
+ "mode": "NULLABLE",
+ "name": "linkage",
+ "type": "STRING"
+ },
+ {
+ "description": "Eg: i386.",
+ "mode": "NULLABLE",
+ "name": "machine",
+ "type": "STRING"
+ },
+ {
+ "description": "Eg: 3.13.0-76-generic.",
+ "mode": "NULLABLE",
+ "name": "release",
+ "type": "STRING"
+ },
+ {
+ "description": "Eg: Linux.",
+ "mode": "NULLABLE",
+ "name": "system",
+ "type": "STRING"
+ },
+ {
+ "description": "Eg: #120-Ubuntu SMP Mon Jan 18 15:59:10 UTC 2016.",
+ "mode": "NULLABLE",
+ "name": "version",
+ "type": "STRING"
+ }
+ ]
+ },
+ {
+ "description": "The CPU information of the benchmark run.",
+ "mode": "NULLABLE",
+ "name": "cpu_info",
+ "type": "RECORD",
+ "fields": [
+ {
+ "mode": "NULLABLE",
+ "name": "num_cores",
+ "type": "INTEGER"
+ },
+ {
+ "mode": "NULLABLE",
+ "name": "num_cores_allowed",
+ "type": "INTEGER"
+ },
+ {
+ "description" : "How fast are those CPUs.",
+ "mode": "NULLABLE",
+ "name": "mhz_per_cpu",
+ "type": "FLOAT"
+ },
+ {
+ "description" : "Additional CPU info, Eg: Intel Ivybridge with HyperThreading (24 cores).",
+ "mode": "NULLABLE",
+ "name": "cpu_info",
+ "type": "STRING"
+ },
+ {
+ "description" : "What kind of cpu scaling is enabled on the host. Eg performance, ondemand, conservative, mixed.",
+ "mode": "NULLABLE",
+ "name": "cpu_governor",
+ "type": "STRING"
+ },
+ {
+ "description": "Cache size of the CPUs.",
+ "mode": "NULLABLE",
+ "name": "cache_size",
+ "type": "RECORD",
+ "fields": [
+ {
+ "mode": "NULLABLE",
+ "name": "level",
+ "type": "STRING"
+ },
+ {
+ "mode": "NULLABLE",
+ "name": "size",
+ "type": "INTEGER"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "mode": "NULLABLE",
+ "name": "gpu_info",
+ "type": "RECORD",
+ "fields": [
+ {
+ "mode": "NULLABLE",
+ "name": "count",
+ "type": "INTEGER"
+ },
+ {
+ "mode": "NULLABLE",
+ "name": "model",
+ "type": "STRING"
+ },
+ {
+ "mode": "NULLABLE",
+ "name": "cuda_version",
+ "type": "STRING"
+ }
+ ]
+ },
+ {
+ "description": "The cloud instance inforation if the benchmark run is executed on cloud",
+ "mode": "NULLABLE",
+ "name": "cloud_info",
+ "type": "RECORD",
+ "fields": [
+ {
+ "description": "The instance type, E.g. n1-standard-4.",
+ "mode": "NULLABLE",
+ "name": "instance_type",
+ "type": "STRING"
+ },
+ {
+ "description": "The arbitrary attribute of the cloud info.",
+ "fields": [
+ {
+ "description": "The name of the attribute.",
+ "mode": "REQUIRED",
+ "name": "name",
+ "type": "STRING"
+ },
+ {
+ "description": "The value of the attribute.",
+ "mode": "NULLABLE",
+ "name": "value",
+ "type": "STRING"
+ }
+ ],
+ "mode": "REPEATED",
+ "name": "attribute",
+ "type": "RECORD"
+ }
+ ]
+ },
+ {
+ "mode": "NULLABLE",
+ "name": "memory_total",
+ "type": "INTEGER"
+ },
+ {
+ "mode": "NULLABLE",
+ "name": "memory_available",
+ "type": "STRING"
+ }
+ ]
+ }
+]
diff --git a/models/official/benchmark/datastore/schema/benchmark_run_status.json b/models/official/benchmark/datastore/schema/benchmark_run_status.json
new file mode 100644
index 0000000000000000000000000000000000000000..f7ac59eb8042c181e8996d9e1a0e7ee79f6f0343
--- /dev/null
+++ b/models/official/benchmark/datastore/schema/benchmark_run_status.json
@@ -0,0 +1,14 @@
+[
+ {
+ "description": "The UUID of the run for the benchmark.",
+ "mode": "REQUIRED",
+ "name": "run_id",
+ "type": "STRING"
+ },
+ {
+ "description": "The status of the run for the benchmark. Eg, running, failed, success",
+ "mode": "REQUIRED",
+ "name": "status",
+ "type": "STRING"
+ }
+]
\ No newline at end of file
diff --git a/models/official/benchmark/keras_benchmark.py b/models/official/benchmark/keras_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..770674ac658f213d614f0a3704a0bbb200bb94aa
--- /dev/null
+++ b/models/official/benchmark/keras_benchmark.py
@@ -0,0 +1,98 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes Keras benchmarks and accuracy tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
+from official.utils.flags import core as flags_core
+
+
+class KerasBenchmark(PerfZeroBenchmark):
+ """Base benchmark class with methods to simplify testing."""
+
+ def __init__(self,
+ output_dir=None,
+ default_flags=None,
+ flag_methods=None,
+ tpu=None):
+ super(KerasBenchmark, self).__init__(
+ output_dir=output_dir,
+ default_flags=default_flags,
+ flag_methods=flag_methods,
+ tpu=tpu)
+
+ def _report_benchmark(self,
+ stats,
+ wall_time_sec,
+ top_1_max=None,
+ top_1_min=None,
+ log_steps=None,
+ total_batch_size=None,
+ warmup=1,
+ start_time_sec=None):
+ """Report benchmark results by writing to local protobuf file.
+
+ Args:
+ stats: dict returned from keras models with known entries.
+ wall_time_sec: the during of the benchmark execution in seconds
+ top_1_max: highest passing level for top_1 accuracy.
+ top_1_min: lowest passing level for top_1 accuracy.
+ log_steps: How often the log was created for stats['step_timestamp_log'].
+ total_batch_size: Global batch-size.
+ warmup: number of entries in stats['step_timestamp_log'] to ignore.
+ start_time_sec: the start time of the program in seconds since epoch
+ """
+
+ metrics = []
+ if 'accuracy_top_1' in stats:
+ metrics.append({'name': 'accuracy_top_1',
+ 'value': stats['accuracy_top_1'],
+ 'min_value': top_1_min,
+ 'max_value': top_1_max})
+ metrics.append({'name': 'top_1_train_accuracy',
+ 'value': stats['training_accuracy_top_1']})
+
+ if (warmup and 'step_timestamp_log' in stats and
+ len(stats['step_timestamp_log']) > warmup):
+ # first entry in the time_log is start of step 1. The rest of the
+ # entries are the end of each step recorded
+ time_log = stats['step_timestamp_log']
+ elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
+ num_examples = (
+ total_batch_size * log_steps * (len(time_log) - warmup - 1))
+ examples_per_sec = num_examples / elapsed
+ metrics.append({'name': 'exp_per_second',
+ 'value': examples_per_sec})
+
+ if 'avg_exp_per_second' in stats:
+ metrics.append({'name': 'avg_exp_per_second',
+ 'value': stats['avg_exp_per_second']})
+
+ if start_time_sec and 'step_timestamp_log' in stats:
+ time_log = stats['step_timestamp_log']
+ # time_log[0] is recorded at the beginning of the first step.
+ startup_time = time_log[0].timestamp - start_time_sec
+ metrics.append({'name': 'startup_time', 'value': startup_time})
+
+ flags_str = flags_core.get_nondefault_flags_as_str()
+ self.report_benchmark(
+ iters=-1,
+ wall_time=wall_time_sec,
+ metrics=metrics,
+ extras={'flags': flags_str})
diff --git a/models/official/benchmark/keras_cifar_benchmark.py b/models/official/benchmark/keras_cifar_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..694200f66678a1bc9bc44194377a52489a1b97f3
--- /dev/null
+++ b/models/official/benchmark/keras_cifar_benchmark.py
@@ -0,0 +1,402 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes Keras benchmarks and accuracy tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+from absl import flags
+import tensorflow as tf # pylint: disable=g-bad-import-order
+
+from official.benchmark import keras_benchmark
+from official.benchmark import benchmark_wrappers
+from official.benchmark.models import resnet_cifar_main
+
+MIN_TOP_1_ACCURACY = 0.929
+MAX_TOP_1_ACCURACY = 0.938
+
+FLAGS = flags.FLAGS
+CIFAR_DATA_DIR_NAME = 'cifar-10-batches-bin'
+
+
+class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
+ """Accuracy tests for ResNet56 Keras CIFAR-10."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ """A benchmark class.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more
+ named arguments before updating the constructor.
+ """
+
+ self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
+ flag_methods = [resnet_cifar_main.define_cifar_flags]
+
+ super(Resnet56KerasAccuracy, self).__init__(
+ output_dir=output_dir, flag_methods=flag_methods)
+
+ def _setup(self):
+ super(Resnet56KerasAccuracy, self)._setup()
+ FLAGS.use_tensor_lr = False
+
+ def benchmark_graph_1_gpu(self):
+ """Test keras based model with Keras fit and distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
+ FLAGS.dtype = 'fp32'
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu(self):
+ """Test keras based model with eager and distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu(self):
+ """Test keras based model on CPU."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ FLAGS.data_format = 'channels_last'
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu_no_dist_strat(self):
+ """Test keras based model on CPU without distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.data_format = 'channels_last'
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu_no_dist_strat_run_eagerly(self):
+ """Test keras based model on CPU w/forced eager and no dist_strat."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_cpu_no_dist_strat_run_eagerly')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.data_format = 'channels_last'
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat(self):
+ """Test keras based model with eager and no dist strat."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
+ """Test keras based model w/forced eager and no dist_strat."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_graph_1_gpu_no_dist_strat(self):
+ """Test keras based model with Keras fit but not distribution strategies."""
+ self._setup()
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.num_gpus = 1
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
+ FLAGS.dtype = 'fp32'
+ self._run_and_report_benchmark()
+
+ def benchmark_2_gpu(self):
+ """Test keras based model with eager and distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 2
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ self._run_and_report_benchmark()
+
+ def benchmark_graph_2_gpu(self):
+ """Test keras based model with Keras fit and distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 2
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128
+ FLAGS.train_epochs = 182
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
+ FLAGS.dtype = 'fp32'
+ self._run_and_report_benchmark()
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self):
+ start_time_sec = time.time()
+ stats = resnet_cifar_main.run(FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ super(Resnet56KerasAccuracy, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ top_1_min=MIN_TOP_1_ACCURACY,
+ top_1_max=MAX_TOP_1_ACCURACY,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=100)
+
+
+class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
+ """Short performance tests for ResNet56 via Keras and CIFAR-10."""
+
+ def __init__(self, output_dir=None, default_flags=None):
+ flag_methods = [resnet_cifar_main.define_cifar_flags]
+
+ super(Resnet56KerasBenchmarkBase, self).__init__(
+ output_dir=output_dir,
+ flag_methods=flag_methods,
+ default_flags=default_flags)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self):
+ start_time_sec = time.time()
+ stats = resnet_cifar_main.run(FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ super(Resnet56KerasBenchmarkBase, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_1_gpu(self):
+ """Test 1 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_xla(self):
+ """Test 1 gpu with xla enabled."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = False
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_graph_1_gpu(self):
+ """Test 1 gpu graph."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = False
+ FLAGS.run_eagerly = False
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat(self):
+ """Test 1 gpu without distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_graph_1_gpu_no_dist_strat(self):
+ """Test 1 gpu graph mode without distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = False
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
+ """Test 1 gpu without distribution strategy and forced eager."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = 128
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_2_gpu(self):
+ """Test 2 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 2
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = False
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
+ FLAGS.batch_size = 128 * 2 # 2 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_graph_2_gpu(self):
+ """Test 2 gpu graph mode."""
+ self._setup()
+ FLAGS.num_gpus = 2
+ FLAGS.enable_eager = False
+ FLAGS.run_eagerly = False
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
+ FLAGS.batch_size = 128 * 2 # 2 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu(self):
+ """Test cpu."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.enable_eager = True
+ FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
+ FLAGS.batch_size = 128
+ FLAGS.data_format = 'channels_last'
+ self._run_and_report_benchmark()
+
+ def benchmark_graph_cpu(self):
+ """Test cpu graph mode."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.enable_eager = False
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu')
+ FLAGS.batch_size = 128
+ FLAGS.data_format = 'channels_last'
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu_no_dist_strat_run_eagerly(self):
+ """Test cpu without distribution strategy and forced eager."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_cpu_no_dist_strat_run_eagerly')
+ FLAGS.batch_size = 128
+ FLAGS.data_format = 'channels_last'
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu_no_dist_strat(self):
+ """Test cpu without distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
+ FLAGS.batch_size = 128
+ FLAGS.data_format = 'channels_last'
+ self._run_and_report_benchmark()
+
+ def benchmark_graph_cpu_no_dist_strat(self):
+ """Test cpu graph mode without distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.enable_eager = False
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu_no_dist_strat')
+ FLAGS.batch_size = 128
+ FLAGS.data_format = 'channels_last'
+ self._run_and_report_benchmark()
+
+
+class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
+ """Synthetic benchmarks for ResNet56 and Keras."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ default_flags = {}
+ default_flags['skip_eval'] = True
+ default_flags['use_synthetic_data'] = True
+ default_flags['train_steps'] = 110
+ default_flags['log_steps'] = 10
+ default_flags['use_tensor_lr'] = False
+
+ super(Resnet56KerasBenchmarkSynth, self).__init__(
+ output_dir=output_dir, default_flags=default_flags)
+
+
+class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
+ """Real data benchmarks for ResNet56 and Keras."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ default_flags = {}
+ default_flags['skip_eval'] = True
+ default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
+ default_flags['train_steps'] = 110
+ default_flags['log_steps'] = 10
+ default_flags['use_tensor_lr'] = False
+
+ super(Resnet56KerasBenchmarkReal, self).__init__(
+ output_dir=output_dir, default_flags=default_flags)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/keras_imagenet_benchmark.py b/models/official/benchmark/keras_imagenet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..63a48dfb1222b65311652e3bee4241854a55043e
--- /dev/null
+++ b/models/official/benchmark/keras_imagenet_benchmark.py
@@ -0,0 +1,1724 @@
+# Lint as: python3
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes Keras benchmarks and accuracy tests."""
+# pylint: disable=line-too-long
+from __future__ import print_function
+
+import json
+import os
+import time
+
+from typing import Any, MutableMapping, Optional
+
+from absl import flags
+import tensorflow as tf # pylint: disable=g-bad-import-order
+
+from official.benchmark import benchmark_wrappers
+from official.benchmark import keras_benchmark
+from official.benchmark.models import resnet_imagenet_main
+from official.vision.image_classification import classifier_trainer
+
+MIN_TOP_1_ACCURACY = 0.76
+MAX_TOP_1_ACCURACY = 0.77
+
+MOBILENET_V1_MIN_TOP_1_ACCURACY = 0.65
+MOBILENET_V1_MAX_TOP_1_ACCURACY = 0.68
+
+# Range of top-1 accracies for model optimization techniques.
+# Each item indicates (MIN_TOP_1_ACCURACY, MAX_TOP_1_ACCURACY).
+MODEL_OPTIMIZATION_TOP_1_ACCURACY = {
+ 'RESNET50_FINETUNE_PRUNING': (0.76, 0.77),
+ 'MOBILENET_V1_FINETUNE_PRUNING': (0.67, 0.68),
+}
+
+FLAGS = flags.FLAGS
+
+
+def _get_classifier_parameters(
+ num_gpus: int = 0,
+ builder: str = 'records',
+ skip_eval: bool = False,
+ distribution_strategy: str = 'mirrored',
+ per_replica_batch_size: int = 128,
+ epochs: int = 90,
+ steps: int = 0,
+ epochs_between_evals: int = 1,
+ dtype: str = 'float32',
+ enable_xla: bool = False,
+ run_eagerly: bool = False,
+ gpu_thread_mode: Optional[str] = None,
+ dataset_num_private_threads: Optional[int] = None,
+ loss_scale: Optional[str] = None,
+ report_metrics: bool = True,
+ batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]:
+ """Gets classifier trainer's ResNet parameters."""
+ return {
+ 'runtime': {
+ 'num_gpus': num_gpus,
+ 'distribution_strategy': distribution_strategy,
+ 'run_eagerly': run_eagerly,
+ 'enable_xla': enable_xla,
+ 'dataset_num_private_threads': dataset_num_private_threads,
+ 'gpu_thread_mode': gpu_thread_mode,
+ 'loss_scale': loss_scale,
+ 'batchnorm_spatial_persistent': batchnorm_spatial_persistent,
+ },
+ 'train_dataset': {
+ 'builder': builder,
+ 'use_per_replica_batch_size': True,
+ 'batch_size': per_replica_batch_size,
+ 'image_size': 224,
+ 'dtype': dtype,
+ },
+ 'validation_dataset': {
+ 'builder': builder,
+ 'batch_size': per_replica_batch_size,
+ 'use_per_replica_batch_size': True,
+ 'image_size': 224,
+ 'dtype': dtype,
+ },
+ 'train': {
+ 'epochs': epochs,
+ 'steps': steps,
+ 'callbacks': {
+ 'enable_tensorboard': False,
+ 'enable_checkpoint_and_export': False,
+ 'enable_time_history': True,
+ },
+ 'metrics': ['accuracy'] if report_metrics else [],
+ },
+ 'model': {
+ 'loss': {
+ 'label_smoothing': 0.1,
+ },
+ },
+ 'evaluation': {
+ 'epochs_between_evals': epochs_between_evals,
+ 'skip_eval': skip_eval,
+ },
+ }
+
+
+class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
+ """Benchmark accuracy tests for ResNet50 in Keras."""
+
+ def __init__(self,
+ output_dir: Optional[str] = None,
+ root_data_dir: Optional[str] = None,
+ **kwargs):
+ """A benchmark class.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more
+ named arguments before updating the constructor.
+ """
+
+ flag_methods = [classifier_trainer.define_classifier_flags]
+
+ self.data_dir = os.path.join(root_data_dir, 'imagenet')
+ super(Resnet50KerasAccuracy, self).__init__(
+ output_dir=output_dir, flag_methods=flag_methods)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(
+ self,
+ experiment_name: str,
+ top_1_min: float = MIN_TOP_1_ACCURACY,
+ top_1_max: float = MAX_TOP_1_ACCURACY,
+ num_gpus: int = 0,
+ distribution_strategy: str = 'mirrored',
+ per_replica_batch_size: int = 128,
+ epochs: int = 90,
+ steps: int = 0,
+ epochs_between_evals: int = 1,
+ dtype: str = 'float32',
+ enable_xla: bool = False,
+ run_eagerly: bool = False,
+ gpu_thread_mode: Optional[str] = None,
+ dataset_num_private_threads: Optional[int] = None,
+ loss_scale: Optional[str] = None):
+ """Runs and reports the benchmark given the provided configuration."""
+ FLAGS.model_type = 'resnet'
+ FLAGS.dataset = 'imagenet'
+ FLAGS.mode = 'train_and_eval'
+ FLAGS.data_dir = self.data_dir
+ FLAGS.model_dir = self._get_model_dir(experiment_name)
+ parameters = _get_classifier_parameters(
+ num_gpus=num_gpus,
+ distribution_strategy=distribution_strategy,
+ per_replica_batch_size=per_replica_batch_size,
+ epochs=epochs,
+ steps=steps,
+ epochs_between_evals=epochs_between_evals,
+ dtype=dtype,
+ enable_xla=enable_xla,
+ run_eagerly=run_eagerly,
+ gpu_thread_mode=gpu_thread_mode,
+ dataset_num_private_threads=dataset_num_private_threads,
+ report_metrics=True,
+ loss_scale=loss_scale,
+ batchnorm_spatial_persistent=True)
+ FLAGS.params_override = json.dumps(parameters)
+ total_batch_size = num_gpus * per_replica_batch_size
+
+ start_time_sec = time.time()
+ stats = classifier_trainer.run(flags.FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ super(Resnet50KerasAccuracy, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ top_1_min=top_1_min,
+ top_1_max=top_1_max,
+ total_batch_size=total_batch_size,
+ log_steps=100)
+
+ def benchmark_8_gpu(self):
+ """Tests Keras model with eager, dist_strat and 8 GPUs."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_8_gpu',
+ num_gpus=8,
+ per_replica_batch_size=128,
+ epochs=90,
+ epochs_between_evals=10,
+ dtype='float32')
+
+ def benchmark_8_gpu_fp16(self):
+ """Tests Keras model with eager, dist_strat, 8 GPUs, and fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_8_gpu_fp16',
+ num_gpus=8,
+ per_replica_batch_size=256,
+ epochs=90,
+ epochs_between_evals=10,
+ dtype='float16')
+
+ def benchmark_xla_8_gpu_fp16(self):
+ """Tests Keras model with XLA, eager, dist_strat, 8 GPUs and fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_8_gpu_fp16',
+ num_gpus=8,
+ per_replica_batch_size=256,
+ epochs=90,
+ epochs_between_evals=10,
+ dtype='float16',
+ enable_xla=True)
+
+ def benchmark_xla_8_gpu_fp16_dynamic(self):
+ """Tests Keras model with XLA, eager, dist_strat, 8 GPUs, dynamic fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_8_gpu_fp16_dynamic',
+ top_1_min=0.736,
+ num_gpus=8,
+ per_replica_batch_size=256,
+ epochs=90,
+ epochs_between_evals=10,
+ dtype='float16',
+ loss_scale='dynamic')
+
+ def _get_model_dir(self, folder_name):
+ return os.path.join(self.output_dir, folder_name)
+
+
+class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
+ """Benchmark accuracy tests for MobilenetV1 in Keras."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ """A benchmark class.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more
+ named arguments before updating the constructor.
+ """
+
+ flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
+
+ self.data_dir = os.path.join(root_data_dir, 'imagenet')
+ super(MobilenetV1KerasAccuracy, self).__init__(
+ output_dir=output_dir,
+ flag_methods=flag_methods,
+ default_flags={
+ 'model': 'mobilenet',
+ 'optimizer': 'mobilenet_default',
+ 'initial_learning_rate_per_sample': 0.00039,
+ })
+
+ def benchmark_8_gpu(self):
+ """Test Keras model with eager, dist_strat and 8 GPUs."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128 * 8
+ FLAGS.train_epochs = 90
+ FLAGS.epochs_between_evals = 10
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ self._run_and_report_benchmark()
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ top_1_min=MOBILENET_V1_MIN_TOP_1_ACCURACY,
+ top_1_max=MOBILENET_V1_MAX_TOP_1_ACCURACY):
+ start_time_sec = time.time()
+ stats = resnet_imagenet_main.run(flags.FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ super(MobilenetV1KerasAccuracy, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ top_1_min=top_1_min,
+ top_1_max=top_1_max,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=100)
+
+ def _get_model_dir(self, folder_name):
+ return os.path.join(self.output_dir, folder_name)
+
+
+class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
+ """Resnet50 (classifier_trainer) benchmarks."""
+
+ def __init__(self, output_dir=None, default_flags=None,
+ tpu=None, dataset_builder='records', train_epochs=1,
+ train_steps=110, data_dir=None):
+ flag_methods = [classifier_trainer.define_classifier_flags]
+
+ self.dataset_builder = dataset_builder
+ self.train_epochs = train_epochs
+ self.train_steps = train_steps
+ self.data_dir = data_dir
+
+ super(Resnet50KerasClassifierBenchmarkBase, self).__init__(
+ output_dir=output_dir,
+ flag_methods=flag_methods,
+ default_flags=default_flags,
+ tpu=tpu)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(
+ self,
+ experiment_name: str,
+ skip_steps: Optional[int] = None,
+ top_1_min: float = MIN_TOP_1_ACCURACY,
+ top_1_max: float = MAX_TOP_1_ACCURACY,
+ num_gpus: int = 0,
+ num_tpus: int = 0,
+ distribution_strategy: str = 'mirrored',
+ per_replica_batch_size: int = 128,
+ epochs_between_evals: int = 1,
+ dtype: str = 'float32',
+ enable_xla: bool = False,
+ run_eagerly: bool = False,
+ gpu_thread_mode: Optional[str] = None,
+ dataset_num_private_threads: Optional[int] = None,
+ loss_scale: Optional[str] = None):
+ """Runs and reports the benchmark given the provided configuration."""
+ FLAGS.model_type = 'resnet'
+ FLAGS.dataset = 'imagenet'
+ FLAGS.mode = 'train_and_eval'
+ FLAGS.data_dir = self.data_dir
+ FLAGS.model_dir = self._get_model_dir(experiment_name)
+ parameters = _get_classifier_parameters(
+ builder=self.dataset_builder,
+ skip_eval=True,
+ num_gpus=num_gpus,
+ distribution_strategy=distribution_strategy,
+ per_replica_batch_size=per_replica_batch_size,
+ epochs=self.train_epochs,
+ steps=self.train_steps,
+ epochs_between_evals=epochs_between_evals,
+ dtype=dtype,
+ enable_xla=enable_xla,
+ gpu_thread_mode=gpu_thread_mode,
+ dataset_num_private_threads=dataset_num_private_threads,
+ loss_scale=loss_scale,
+ report_metrics=False,
+ batchnorm_spatial_persistent=True)
+ FLAGS.params_override = json.dumps(parameters)
+ if distribution_strategy == 'tpu':
+ total_batch_size = num_tpus * per_replica_batch_size
+ else:
+ total_batch_size = num_gpus * per_replica_batch_size
+
+ start_time_sec = time.time()
+ stats = classifier_trainer.run(flags.FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+ # Number of logged step time entries that are excluded in performance
+ # report. We keep results from last 100 batches, or skip the steps based on
+ # input skip_steps.
+ warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps
+
+ super(Resnet50KerasClassifierBenchmarkBase, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ total_batch_size=total_batch_size,
+ log_steps=FLAGS.log_steps,
+ warmup=warmup,
+ start_time_sec=start_time_sec)
+
+ def benchmark_1_gpu_no_dist_strat(self):
+ """Tests Keras model with 1 GPU, no distribution strategy."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_1_gpu_no_dist_strat',
+ num_gpus=1,
+ distribution_strategy='off',
+ per_replica_batch_size=128)
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
+ """Tests Keras model with 1 GPU, no distribution strategy, run eagerly."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_1_gpu_no_dist_strat_run_eagerly',
+ num_gpus=1,
+ run_eagerly=True,
+ distribution_strategy='off',
+ per_replica_batch_size=64)
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
+ """Tests with 1 GPU, no distribution strategy, fp16, run eagerly."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_1_gpu_no_dist_strat_run_eagerly_fp16',
+ num_gpus=1,
+ run_eagerly=True,
+ distribution_strategy='off',
+ dtype='float16',
+ per_replica_batch_size=128)
+
+ def benchmark_1_gpu(self):
+ """Tests Keras model with 1 GPU."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_1_gpu',
+ num_gpus=1,
+ distribution_strategy='one_device',
+ per_replica_batch_size=128)
+
+ def benchmark_xla_1_gpu(self):
+ """Tests Keras model with XLA and 1 GPU."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_1_gpu',
+ num_gpus=1,
+ enable_xla=True,
+ distribution_strategy='one_device',
+ per_replica_batch_size=128)
+
+ def benchmark_1_gpu_fp16(self):
+ """Tests Keras model with 1 GPU and fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_1_gpu_fp16',
+ num_gpus=1,
+ distribution_strategy='one_device',
+ dtype='float16',
+ per_replica_batch_size=256)
+
+ def benchmark_1_gpu_fp16_dynamic(self):
+ """Tests Keras model with 1 GPU, fp16, and dynamic loss scaling."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_1_gpu_fp16_dynamic',
+ num_gpus=1,
+ distribution_strategy='one_device',
+ dtype='float16',
+ per_replica_batch_size=256,
+ loss_scale='dynamic')
+
+ def benchmark_xla_1_gpu_fp16(self):
+ """Tests Keras model with XLA, 1 GPU and fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_1_gpu_fp16',
+ num_gpus=1,
+ enable_xla=True,
+ distribution_strategy='one_device',
+ dtype='float16',
+ per_replica_batch_size=256)
+
+ def benchmark_xla_1_gpu_fp16_tweaked(self):
+ """Tests Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_1_gpu_fp16_tweaked',
+ num_gpus=1,
+ enable_xla=True,
+ distribution_strategy='one_device',
+ dtype='float16',
+ per_replica_batch_size=256,
+ gpu_thread_mode='gpu_private')
+
+ def benchmark_xla_1_gpu_fp16_dynamic(self):
+ """Tests Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_1_gpu_fp16_dynamic',
+ num_gpus=1,
+ enable_xla=True,
+ distribution_strategy='one_device',
+ dtype='float16',
+ per_replica_batch_size=256,
+ loss_scale='dynamic')
+
+ def benchmark_8_gpu(self):
+ """Tests Keras model with 8 GPUs."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_8_gpu',
+ num_gpus=8,
+ distribution_strategy='mirrored',
+ per_replica_batch_size=128)
+
+ def benchmark_8_gpu_tweaked(self):
+ """Tests Keras model with manual config tuning and 8 GPUs."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_8_gpu_tweaked',
+ num_gpus=8,
+ distribution_strategy='mirrored',
+ per_replica_batch_size=128,
+ dataset_num_private_threads=14)
+
+ def benchmark_xla_8_gpu(self):
+ """Tests Keras model with XLA and 8 GPUs."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_8_gpu',
+ num_gpus=8,
+ enable_xla=True,
+ distribution_strategy='mirrored',
+ per_replica_batch_size=128)
+
+ def benchmark_xla_8_gpu_tweaked(self):
+ """Tests Keras model with manual config tuning, 8 GPUs, and XLA."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_8_gpu_tweaked',
+ num_gpus=8,
+ enable_xla=True,
+ distribution_strategy='mirrored',
+ per_replica_batch_size=128,
+ gpu_thread_mode='gpu_private',
+ dataset_num_private_threads=24)
+
+ def benchmark_8_gpu_fp16(self):
+ """Tests Keras model with 8 GPUs and fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_8_gpu_fp16',
+ num_gpus=8,
+ dtype='float16',
+ distribution_strategy='mirrored',
+ per_replica_batch_size=256)
+
+ def benchmark_8_gpu_fp16_tweaked(self):
+ """Tests Keras model with 8 GPUs, fp16, and manual config tuning."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_8_gpu_fp16_tweaked',
+ num_gpus=8,
+ dtype='float16',
+ distribution_strategy='mirrored',
+ per_replica_batch_size=256,
+ gpu_thread_mode='gpu_private',
+ dataset_num_private_threads=40)
+
+ def benchmark_8_gpu_fp16_dynamic_tweaked(self):
+ """Tests Keras model with 8 GPUs, fp16, dynamic loss scaling, and tuned."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_8_gpu_fp16_dynamic_tweaked',
+ num_gpus=8,
+ dtype='float16',
+ distribution_strategy='mirrored',
+ per_replica_batch_size=256,
+ loss_scale='dynamic',
+ gpu_thread_mode='gpu_private',
+ dataset_num_private_threads=40)
+
+ def benchmark_xla_8_gpu_fp16(self):
+ """Tests Keras model with XLA, 8 GPUs and fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_8_gpu_fp16',
+ dtype='float16',
+ num_gpus=8,
+ enable_xla=True,
+ distribution_strategy='mirrored',
+ per_replica_batch_size=256)
+
+ def benchmark_xla_8_gpu_fp16_tweaked(self):
+ """Test Keras model with manual config tuning, XLA, 8 GPUs and fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_8_gpu_fp16_tweaked',
+ dtype='float16',
+ num_gpus=8,
+ enable_xla=True,
+ distribution_strategy='mirrored',
+ per_replica_batch_size=256,
+ gpu_thread_mode='gpu_private',
+ dataset_num_private_threads=48)
+
+ def benchmark_xla_8_gpu_fp16_tweaked_delay_measure(self):
+ """Tests with manual config tuning, XLA, 8 GPUs and fp16.
+
+ Delay performance measurement for stable performance on 96 vCPU platforms.
+ """
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_8_gpu_fp16_tweaked_delay_measure',
+ dtype='float16',
+ num_gpus=8,
+ enable_xla=True,
+ distribution_strategy='mirrored',
+ per_replica_batch_size=256,
+ gpu_thread_mode='gpu_private',
+ dataset_num_private_threads=48,
+ steps=310)
+
+ def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
+ """Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_xla_8_gpu_fp16_dynamic_tweaked',
+ dtype='float16',
+ num_gpus=8,
+ enable_xla=True,
+ distribution_strategy='mirrored',
+ per_replica_batch_size=256,
+ gpu_thread_mode='gpu_private',
+ loss_scale='dynamic',
+ dataset_num_private_threads=48)
+
+ def benchmark_2x2_tpu_bf16(self):
+ """Test Keras model with 2x2 TPU, bf16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_2x2_tpu_bf16',
+ dtype='bfloat16',
+ num_tpus=8,
+ distribution_strategy='tpu',
+ per_replica_batch_size=128)
+
+ def benchmark_4x4_tpu_bf16(self):
+ """Test Keras model with 4x4 TPU, bf16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_4x4_tpu_bf16',
+ dtype='bfloat16',
+ num_tpus=32,
+ distribution_strategy='tpu',
+ per_replica_batch_size=128)
+
+ def benchmark_8x8_tpu_bf16(self):
+ """Test Keras model with 8x8 TPU, bf16."""
+ self._setup()
+ self._run_and_report_benchmark(
+ experiment_name='benchmark_8x8_tpu_bf16',
+ dtype='bfloat16',
+ num_tpus=128,
+ distribution_strategy='tpu',
+ per_replica_batch_size=64)
+
+ def fill_report_object(self, stats):
+ super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object(
+ stats,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+
+class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
+ """Resnet50 benchmarks."""
+
+ def __init__(self, output_dir=None, default_flags=None, tpu=None):
+ flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
+
+ super(Resnet50KerasBenchmarkBase, self).__init__(
+ output_dir=output_dir,
+ flag_methods=flag_methods,
+ default_flags=default_flags,
+ tpu=tpu)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self, skip_steps=None):
+ start_time_sec = time.time()
+ stats = resnet_imagenet_main.run(FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+ # Number of logged step time entries that are excluded in performance
+ # report. We keep results from last 100 batches, or skip the steps based on
+ # input skip_steps.
+ warmup = (skip_steps or (FLAGS.train_steps - 100)) // FLAGS.log_steps
+
+ super(Resnet50KerasBenchmarkBase, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ warmup=warmup,
+ start_time_sec=start_time_sec)
+
+ def benchmark_1_gpu_no_dist_strat(self):
+ """Test Keras model with 1 GPU, no distribution strategy."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
+ """Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly')
+ FLAGS.batch_size = 64
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked(self):
+ """Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.explicit_gpu_placement = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked')
+ FLAGS.batch_size = 64
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
+ """Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked(self):
+ """Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.explicit_gpu_placement = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu(self):
+ """Test Keras model with 1 GPU."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_amp(self):
+ """Test Keras model with 1 GPU with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
+ FLAGS.batch_size = 256
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu(self):
+ """Test Keras model with XLA and 1 GPU."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_amp(self):
+ """Test Keras model with XLA and 1 GPU with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
+ FLAGS.batch_size = 256
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_fp16(self):
+ """Test Keras model with 1 GPU and fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_fp16_dynamic(self):
+ """Test Keras model with 1 GPU, fp16, and dynamic loss scaling."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_dynamic')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ FLAGS.loss_scale = 'dynamic'
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_fp16(self):
+ """Test Keras model with XLA, 1 GPU and fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_fp16_tweaked(self):
+ """Test Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_fp16_dynamic(self):
+ """Test Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_dynamic')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ FLAGS.loss_scale = 'dynamic'
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu(self):
+ """Test Keras model with 8 GPUs."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
+ FLAGS.batch_size = 128 * 8 # 8 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_amp(self):
+ """Test Keras model with 8 GPUs with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.enable_eager = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_tweaked(self):
+ """Test Keras model with manual config tuning and 8 GPUs."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
+ FLAGS.batch_size = 128 * 8 # 8 GPUs
+ FLAGS.datasets_num_private_threads = 14
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu(self):
+ """Test Keras model with XLA and 8 GPUs."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
+ FLAGS.batch_size = 128 * 8 # 8 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu_amp(self):
+ """Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.enable_eager = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu_tweaked(self):
+ """Test Keras model with manual config tuning, 8 GPUs, and XLA."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_tweaked')
+ FLAGS.batch_size = 128 * 8
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.datasets_num_private_threads = 24
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_fp16(self):
+ """Test Keras model with 8 GPUs and fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_fp16_tweaked(self):
+ """Test Keras model with 8 GPUs, fp16, and manual config tuning."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.dataset_num_private_threads = 40
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_fp16_dynamic_tweaked(self):
+ """Test Keras model with 8 GPUs, fp16, dynamic loss scaling, and tuned."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_8_gpu_fp16_dynamic_tweaked')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ FLAGS.loss_scale = 'dynamic'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.dataset_num_private_threads = 40
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu_fp16(self):
+ """Test Keras model with XLA, 8 GPUs and fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu_fp16_tweaked(self):
+ """Test Keras model with manual config tuning, XLA, 8 GPUs and fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.datasets_num_private_threads = 48
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu_fp16_tweaked_delay_measure(self):
+ """Test with manual config tuning, XLA, 8 GPUs and fp16.
+
+ Delay performance measurement for stable performance on 96 vCPU platforms.
+ """
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_xla_8_gpu_fp16_tweaked_delay_measure')
+ FLAGS.batch_size = 256 * 8
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.datasets_num_private_threads = 48
+ FLAGS.train_steps = 310
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
+ """Test Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_xla_8_gpu_fp16_dynamic_tweaked')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ FLAGS.loss_scale = 'dynamic'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.datasets_num_private_threads = 48
+ self._run_and_report_benchmark()
+
+ def benchmark_2x2_tpu_bf16(self):
+ """Test Keras model with 2x2 TPU, bf16."""
+ self._setup()
+
+ FLAGS.dtype = 'bf16'
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
+ FLAGS.batch_size = 1024
+ self._run_and_report_benchmark()
+
+ def benchmark_4x4_tpu_bf16(self):
+ """Test Keras model with 4x4 TPU, bf16."""
+ self._setup()
+
+ FLAGS.dtype = 'bf16'
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
+ FLAGS.batch_size = 4096
+ self._run_and_report_benchmark()
+
+ def benchmark_8x8_tpu_bf16(self):
+ """Test Keras model with 8x8 TPU, bf16."""
+ self._setup()
+
+ FLAGS.dtype = 'bf16'
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8x8_tpu_bf16')
+ FLAGS.batch_size = 8192
+ self._run_and_report_benchmark()
+
+ def fill_report_object(self, stats):
+ super(Resnet50KerasBenchmarkBase, self).fill_report_object(
+ stats,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+
+class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
+ """Resnet50 synthetic benchmark tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
+ def_flags = {}
+ def_flags['log_steps'] = 10
+
+ super(Resnet50KerasBenchmarkSynth, self).__init__(
+ output_dir=output_dir, default_flags=def_flags, tpu=tpu,
+ dataset_builder='synthetic', train_epochs=1, train_steps=110)
+
+
+class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
+ """Resnet50 real data benchmark tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
+ data_dir = os.path.join(root_data_dir, 'imagenet')
+ def_flags = {}
+ def_flags['log_steps'] = 10
+
+ super(Resnet50KerasBenchmarkReal, self).__init__(
+ output_dir=output_dir, default_flags=def_flags, tpu=tpu,
+ dataset_builder='records', train_epochs=1, train_steps=110,
+ data_dir=data_dir)
+
+
+class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
+ """Resnet50 real data (stored in remote storage) benchmark tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ def_flags = {}
+ def_flags['skip_eval'] = True
+ def_flags['report_accuracy_metrics'] = False
+ def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
+ # Defining multiple epochs overrides the train_steps setting in benchmarks.
+ def_flags['train_epochs'] = 2
+ # Cache dataset so performance is stable after the first epoch.
+ def_flags['training_dataset_cache'] = True
+ def_flags['log_steps'] = 100
+ # Note that for single GPU and pure eager tests which are less likely to be
+ # input bound and more stable, these tests will run for shorter time by
+ # overriding FLAGS.train_epochs, train_seteps, log_steps in benchmark
+ # methods, and skip_steps in _run_and_report_benchmark().
+
+ super(Resnet50KerasBenchmarkRemoteData, self).__init__(
+ output_dir=output_dir, default_flags=def_flags)
+
+ def _override_flags_to_run_test_shorter(self):
+ FLAGS.train_epochs = 1
+ FLAGS.train_steps = 300
+ FLAGS.log_steps = 10
+
+ def benchmark_1_gpu_no_dist_strat(self):
+ """Test Keras model with 1 GPU, no distribution strategy."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
+ FLAGS.batch_size = 128
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
+ """Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly')
+ FLAGS.batch_size = 64
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked(self):
+ """Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.explicit_gpu_placement = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked')
+ FLAGS.batch_size = 64
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
+ """Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 128
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked(self):
+ """Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.run_eagerly = True
+ FLAGS.explicit_gpu_placement = True
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 128
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu(self):
+ """Test Keras model with 1 GPU."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
+ FLAGS.batch_size = 128
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_amp(self):
+ """Test Keras model with 1 GPU with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
+ FLAGS.batch_size = 256
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu(self):
+ """Test Keras model with XLA and 1 GPU."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
+ FLAGS.batch_size = 128
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_amp(self):
+ """Test Keras model with XLA and 1 GPU with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
+ FLAGS.batch_size = 256
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_fp16(self):
+ """Test Keras model with 1 GPU and fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_fp16_dynamic(self):
+ """Test Keras model with 1 GPU, fp16, and dynamic loss scaling."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_dynamic')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ FLAGS.loss_scale = 'dynamic'
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_fp16(self):
+ """Test Keras model with XLA, 1 GPU and fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_fp16_tweaked(self):
+ """Test Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_fp16_dynamic(self):
+ """Test Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.enable_eager = True
+ FLAGS.enable_xla = True
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_dynamic')
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = 256
+ FLAGS.loss_scale = 'dynamic'
+ self._override_flags_to_run_test_shorter()
+ self._run_and_report_benchmark()
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self):
+ if FLAGS.num_gpus == 1 or FLAGS.run_eagerly:
+ # For single GPU and pure eager tests which are less likely to be input
+ # bound and more stable, run for shorter time and use the default
+ # skip_steps.
+ skip_steps = None
+ else:
+ # skip the first epoch for performance measurement.
+ skip_steps = 600
+ super(Resnet50KerasBenchmarkRemoteData,
+ self)._run_and_report_benchmark(skip_steps=skip_steps)
+
+
+class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
+ """Trivial model with real data benchmark tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
+
+ def_flags = {}
+ def_flags['use_trivial_model'] = True
+ def_flags['skip_eval'] = True
+ def_flags['report_accuracy_metrics'] = False
+ def_flags['dtype'] = 'fp16'
+ def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
+ def_flags['train_steps'] = 600
+ def_flags['log_steps'] = 100
+ def_flags['distribution_strategy'] = 'mirrored'
+
+ super(TrivialKerasBenchmarkReal, self).__init__(
+ output_dir=output_dir,
+ flag_methods=flag_methods,
+ default_flags=def_flags)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self):
+ start_time_sec = time.time()
+ stats = resnet_imagenet_main.run(FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ super(TrivialKerasBenchmarkReal, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_8_gpu_warmup(self):
+ """Dummy test that runs over an epoch to warmup the machine."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.enable_eager = True
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_warmup')
+ FLAGS.batch_size = 256 * 8
+ FLAGS.train_steps = 700
+ self._run_and_report_benchmark()
+
+ def fill_report_object(self, stats):
+ super(TrivialKerasBenchmarkReal, self).fill_report_object(
+ stats,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+
+class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
+ """Resnet50 distributed accuracy tests with multiple workers."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ flag_methods = [classifier_trainer.define_imagenet_keras_flags]
+ self.data_dir = os.path.join(root_data_dir, 'imagenet')
+ super(Resnet50MultiWorkerKerasAccuracy, self).__init__(
+ output_dir=output_dir, flag_methods=flag_methods)
+
+ def _benchmark_common(self, eager, num_workers, all_reduce_alg):
+ """Common to all benchmarks in this class."""
+ self._setup()
+
+ num_gpus = 8
+ FLAGS.num_gpus = num_gpus
+ FLAGS.data_dir = self.data_dir
+ FLAGS.train_epochs = 90
+ FLAGS.epochs_between_evals = 10
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = eager
+ FLAGS.enable_xla = False
+ FLAGS.distribution_strategy = 'multi_worker_mirrored'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.datasets_num_private_threads = 32
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
+ 'eager' if eager else 'graph', num_workers, all_reduce_alg))
+ FLAGS.batch_size = 256 * num_gpus * num_workers
+ FLAGS.all_reduce_alg = all_reduce_alg
+
+ self._run_and_report_benchmark()
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ top_1_min=MIN_TOP_1_ACCURACY,
+ top_1_max=MAX_TOP_1_ACCURACY):
+ start_time_sec = time.time()
+ stats = classifier_trainer.run(flags.FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ super(Resnet50MultiWorkerKerasAccuracy, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ top_1_min=top_1_min,
+ top_1_max=top_1_max,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=100)
+
+ def _get_model_dir(self, folder_name):
+ return os.path.join(self.output_dir, folder_name)
+
+ def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
+ """Eager, 8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
+ self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='ring')
+
+ def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
+ """Eager, 8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
+ self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='nccl')
+
+ def benchmark_eager_8_gpu_8_workers_fp16_ring_tweaked(self):
+ """Eager, 8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
+ self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='ring')
+
+ def benchmark_eager_8_gpu_8_workers_fp16_nccl_tweaked(self):
+ """Eager, 8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
+ self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='nccl')
+
+
+class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
+ """Resnet50 distributed benchmark tests with multiple workers."""
+
+ def __init__(self, output_dir=None, default_flags=None):
+ super(Resnet50MultiWorkerKerasBenchmark, self).__init__(
+ output_dir=output_dir, default_flags=default_flags)
+
+ def _benchmark_common(self, eager, num_workers, all_reduce_alg):
+ """Common to all benchmarks in this class."""
+ self._setup()
+
+ num_gpus = 8
+ FLAGS.num_gpus = num_gpus
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_eager = eager
+ FLAGS.enable_xla = False
+ FLAGS.distribution_strategy = 'multi_worker_mirrored'
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
+ FLAGS.datasets_num_private_threads = 32
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
+ 'eager' if eager else 'graph', num_workers, all_reduce_alg))
+ FLAGS.batch_size = 256 * num_gpus * num_workers
+ FLAGS.all_reduce_alg = all_reduce_alg
+
+ self._run_and_report_benchmark()
+
+ def benchmark_eager_8_gpu_1_worker_fp16_ring_tweaked(self):
+ """Eager, 8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
+ self._benchmark_common(eager=True, num_workers=1, all_reduce_alg='ring')
+
+ def benchmark_eager_8_gpu_1_worker_fp16_nccl_tweaked(self):
+ """Eager, 8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
+ self._benchmark_common(eager=True, num_workers=1, all_reduce_alg='nccl')
+
+ def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
+ """Eager, 8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
+ self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='ring')
+
+ def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
+ """Eager, 8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
+ self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='nccl')
+
+ def benchmark_eager_8_gpu_8_workers_fp16_ring_tweaked(self):
+ """Eager, 8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
+ self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='ring')
+
+ def benchmark_eager_8_gpu_8_workers_fp16_nccl_tweaked(self):
+ """Eager, 8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
+ self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='nccl')
+
+
+class Resnet50MultiWorkerKerasBenchmarkSynth(Resnet50MultiWorkerKerasBenchmark):
+ """Resnet50 multi-worker synthetic data benchmark tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ def_flags = {}
+ def_flags['skip_eval'] = True
+ def_flags['report_accuracy_metrics'] = False
+ def_flags['use_synthetic_data'] = True
+ def_flags['train_steps'] = 110
+ def_flags['log_steps'] = 10
+
+ super(Resnet50MultiWorkerKerasBenchmarkSynth, self).__init__(
+ output_dir=output_dir, default_flags=def_flags)
+
+
+class Resnet50MultiWorkerKerasBenchmarkReal(Resnet50MultiWorkerKerasBenchmark):
+ """Resnet50 multi-worker real data benchmark tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ def_flags = {}
+ def_flags['skip_eval'] = True
+ def_flags['report_accuracy_metrics'] = False
+ def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
+ def_flags['train_steps'] = 110
+ def_flags['log_steps'] = 10
+
+ super(Resnet50MultiWorkerKerasBenchmarkReal, self).__init__(
+ output_dir=output_dir, default_flags=def_flags)
+
+
+# TODO(kimjaehong): It also should be also cover other metheods of model
+# optimization techniques. In that time, this class will change to something
+# like 'KerasModelOptimizationAccuracyBase'.
+class KerasPruningAccuracyBase(keras_benchmark.KerasBenchmark):
+ """Benchmark accuracy tests for pruning method."""
+
+ def __init__(self,
+ output_dir=None,
+ root_data_dir=None,
+ default_flags=None,
+ **kwargs):
+ """A accuracy benchmark class for pruning method.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ default_flags: default flags
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more
+ named arguments before updating the constructor.
+ """
+ if default_flags is None:
+ default_flags = {}
+ default_flags['pruning_method'] = 'polynomial_decay'
+ default_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
+
+ flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
+
+ super(KerasPruningAccuracyBase, self).__init__(
+ output_dir=output_dir,
+ flag_methods=flag_methods,
+ default_flags=default_flags,
+ **kwargs)
+
+ def benchmark_8_gpu(self):
+ """Test Keras model with eager, dist_strat and 8 GPUs."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.batch_size = 32 * 8
+ FLAGS.train_epochs = 90
+ FLAGS.epochs_between_evals = 10
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
+ FLAGS.dtype = 'fp32'
+ FLAGS.enable_eager = True
+ self._run_and_report_benchmark()
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ top_1_min=MODEL_OPTIMIZATION_TOP_1_ACCURACY[
+ 'RESNET50_FINETUNE_PRUNING'][0],
+ top_1_max=MODEL_OPTIMIZATION_TOP_1_ACCURACY[
+ 'RESNET50_FINETUNE_PRUNING'][1]):
+ start_time_sec = time.time()
+ stats = resnet_imagenet_main.run(flags.FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ super(KerasPruningAccuracyBase, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ top_1_min=top_1_min,
+ top_1_max=top_1_max,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=100)
+
+
+class MobilenetV1KerasPruningAccuracy(KerasPruningAccuracyBase):
+ """Benchmark accuracy tests for MobilenetV1 with pruning method."""
+
+ def __init__(self, root_data_dir=None, **kwargs):
+ default_flags = {
+ 'model': 'mobilenet',
+ 'optimizer': 'mobilenet_default',
+ 'initial_learning_rate_per_sample': 0.00007,
+ 'pretrained_filepath': tf.train.latest_checkpoint(
+ os.path.join(root_data_dir, 'mobilenet_v1')),
+ 'pruning_begin_step': 0,
+ 'pruning_end_step': 100000,
+ 'pruning_initial_sparsity': 0.0,
+ 'pruning_final_sparsity': 0.5,
+ 'pruning_frequency': 100,
+ }
+ super(MobilenetV1KerasPruningAccuracy, self).__init__(
+ root_data_dir=root_data_dir,
+ default_flags=default_flags,
+ **kwargs)
+
+ def _run_and_report_benchmark(self):
+ super(MobilenetV1KerasPruningAccuracy, self)._run_and_report_benchmark(
+ top_1_min=\
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY['MOBILENET_V1_FINETUNE_PRUNING'][0],
+ top_1_max=\
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY['MOBILENET_V1_FINETUNE_PRUNING'][1])
+
+
+class Resnet50KerasPruningAccuracy(KerasPruningAccuracyBase):
+ """Benchmark accuracy tests for resnet50 with pruning method."""
+
+ def __init__(self, root_data_dir=None, **kwargs):
+ default_flags = {
+ 'model': 'resnet50_v1.5',
+ 'optimizer': 'mobilenet_default',
+ 'initial_learning_rate_per_sample': 0.0000039,
+ 'pretrained_filepath': tf.train.latest_checkpoint(
+ os.path.join(root_data_dir, 'resnet50')),
+ 'pruning_begin_step': 0,
+ 'pruning_end_step': 50000,
+ 'pruning_initial_sparsity': 0.0,
+ 'pruning_final_sparsity': 0.5,
+ 'pruning_frequency': 100,
+ }
+ super(Resnet50KerasPruningAccuracy, self).__init__(
+ root_data_dir=root_data_dir,
+ default_flags=default_flags,
+ **kwargs)
+
+ def _run_and_report_benchmark(self):
+ super(Resnet50KerasPruningAccuracy, self)._run_and_report_benchmark(
+ top_1_min=\
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY['RESNET50_FINETUNE_PRUNING'][0],
+ top_1_max=\
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY['RESNET50_FINETUNE_PRUNING'][1])
+
+
+class KerasPruningBenchmarkRealBase(Resnet50KerasBenchmarkBase):
+ """Pruning method benchmarks."""
+
+ def __init__(self, root_data_dir=None, default_flags=None, **kwargs):
+ if default_flags is None:
+ default_flags = {}
+ default_flags.update({
+ 'skip_eval': True,
+ 'report_accuracy_metrics': False,
+ 'data_dir': os.path.join(root_data_dir, 'imagenet'),
+ 'train_steps': 110,
+ 'log_steps': 10,
+ 'pruning_method': 'polynomial_decay',
+ 'pruning_begin_step': 0,
+ 'pruning_end_step': 50000,
+ 'pruning_initial_sparsity': 0,
+ 'pruning_final_sparsity': 0.5,
+ 'pruning_frequency': 100,
+ })
+ super(KerasPruningBenchmarkRealBase, self).__init__(
+ default_flags=default_flags, **kwargs)
+
+
+class MobilenetV1KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
+ """Pruning method benchmarks for MobilenetV1."""
+
+ def __init__(self, **kwargs):
+ default_flags = {
+ 'model': 'mobilenet',
+ 'optimizer': 'mobilenet_default',
+ }
+ super(MobilenetV1KerasPruningBenchmarkReal, self).__init__(
+ default_flags=default_flags, **kwargs)
+
+
+class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
+ """Pruning method benchmarks for resnet50."""
+
+ def __init__(self, **kwargs):
+ default_flags = {
+ 'model': 'resnet50_v1.5',
+ 'optimizer': 'mobilenet_default',
+ }
+ super(Resnet50KerasPruningBenchmarkReal, self).__init__(
+ default_flags=default_flags, **kwargs)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/models/__init__.py b/models/official/benchmark/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/benchmark/models/cifar_preprocessing.py b/models/official/benchmark/models/cifar_preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d7fe630e194953c8c5f3f7552c7104c6155c9a
--- /dev/null
+++ b/models/official/benchmark/models/cifar_preprocessing.py
@@ -0,0 +1,159 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides utilities to Cifar-10 dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from absl import logging
+import tensorflow as tf
+
+from official.vision.image_classification.resnet import imagenet_preprocessing
+
+HEIGHT = 32
+WIDTH = 32
+NUM_CHANNELS = 3
+_DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS
+# The record is the image plus a one-byte label
+_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
+
+# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
+NUM_IMAGES = {
+ 'train': 50000,
+ 'validation': 10000,
+}
+_NUM_DATA_FILES = 5
+NUM_CLASSES = 10
+
+
+def parse_record(raw_record, is_training, dtype):
+ """Parses a record containing a training example of an image.
+
+ The input record is parsed into a label and image, and the image is passed
+ through preprocessing steps (cropping, flipping, and so on).
+
+ This method converts the label to one hot to fit the loss function.
+
+ Args:
+ raw_record: scalar Tensor tf.string containing a serialized
+ Example protocol buffer.
+ is_training: A boolean denoting whether the input is for training.
+ dtype: Data type to use for input images.
+
+ Returns:
+ Tuple with processed image tensor and one-hot-encoded label tensor.
+ """
+ # Convert bytes to a vector of uint8 that is record_bytes long.
+ record_vector = tf.io.decode_raw(raw_record, tf.uint8)
+
+ # The first byte represents the label, which we convert from uint8 to int32
+ # and then to one-hot.
+ label = tf.cast(record_vector[0], tf.int32)
+
+ # The remaining bytes after the label represent the image, which we reshape
+ # from [depth * height * width] to [depth, height, width].
+ depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
+ [NUM_CHANNELS, HEIGHT, WIDTH])
+
+ # Convert from [depth, height, width] to [height, width, depth], and cast as
+ # float32.
+ image = tf.cast(tf.transpose(a=depth_major, perm=[1, 2, 0]), tf.float32)
+
+ image = preprocess_image(image, is_training)
+ image = tf.cast(image, dtype)
+
+ return image, label
+
+
+def preprocess_image(image, is_training):
+ """Preprocess a single image of layout [height, width, depth]."""
+ if is_training:
+ # Resize the image to add four extra pixels on each side.
+ image = tf.image.resize_with_crop_or_pad(
+ image, HEIGHT + 8, WIDTH + 8)
+
+ # Randomly crop a [HEIGHT, WIDTH] section of the image.
+ image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])
+
+ # Randomly flip the image horizontally.
+ image = tf.image.random_flip_left_right(image)
+
+ # Subtract off the mean and divide by the variance of the pixels.
+ image = tf.image.per_image_standardization(image)
+ return image
+
+
+def get_filenames(is_training, data_dir):
+ """Returns a list of filenames."""
+ assert tf.io.gfile.exists(data_dir), (
+ 'Run cifar10_download_and_extract.py first to download and extract the '
+ 'CIFAR-10 data.')
+
+ if is_training:
+ return [
+ os.path.join(data_dir, 'data_batch_%d.bin' % i)
+ for i in range(1, _NUM_DATA_FILES + 1)
+ ]
+ else:
+ return [os.path.join(data_dir, 'test_batch.bin')]
+
+
+def input_fn(is_training,
+ data_dir,
+ batch_size,
+ dtype=tf.float32,
+ datasets_num_private_threads=None,
+ parse_record_fn=parse_record,
+ input_context=None,
+ drop_remainder=False):
+ """Input function which provides batches for train or eval.
+
+ Args:
+ is_training: A boolean denoting whether the input is for training.
+ data_dir: The directory containing the input data.
+ batch_size: The number of samples per batch.
+ dtype: Data type to use for images/features
+ datasets_num_private_threads: Number of private threads for tf.data.
+ parse_record_fn: Function to use for parsing the records.
+ input_context: A `tf.distribute.InputContext` object passed in by
+ `tf.distribute.Strategy`.
+ drop_remainder: A boolean indicates whether to drop the remainder of the
+ batches. If True, the batch dimension will be static.
+
+ Returns:
+ A dataset that can be used for iteration.
+ """
+ filenames = get_filenames(is_training, data_dir)
+ dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
+
+ if input_context:
+ logging.info(
+ 'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
+ input_context.input_pipeline_id, input_context.num_input_pipelines)
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+
+ return imagenet_preprocessing.process_record_dataset(
+ dataset=dataset,
+ is_training=is_training,
+ batch_size=batch_size,
+ shuffle_buffer=NUM_IMAGES['train'],
+ parse_record_fn=parse_record_fn,
+ dtype=dtype,
+ datasets_num_private_threads=datasets_num_private_threads,
+ drop_remainder=drop_remainder
+ )
diff --git a/models/official/benchmark/models/resnet_cifar_main.py b/models/official/benchmark/models/resnet_cifar_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a02fec8b96e25228e6e0467d646c26995f944fc
--- /dev/null
+++ b/models/official/benchmark/models/resnet_cifar_main.py
@@ -0,0 +1,284 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Runs a ResNet model on the Cifar-10 dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+import tensorflow as tf
+from official.benchmark.models import cifar_preprocessing
+from official.benchmark.models import resnet_cifar_model
+from official.benchmark.models import synthetic_util
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+from official.vision.image_classification.resnet import common
+
+
+LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
+ (0.1, 91), (0.01, 136), (0.001, 182)
+]
+
+
+def learning_rate_schedule(current_epoch,
+ current_batch,
+ batches_per_epoch,
+ batch_size):
+ """Handles linear scaling rule and LR decay.
+
+ Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
+ provided scaling factor.
+
+ Args:
+ current_epoch: integer, current epoch indexed from 0.
+ current_batch: integer, current batch in the current epoch, indexed from 0.
+ batches_per_epoch: integer, number of steps in an epoch.
+ batch_size: integer, total batch sized.
+
+ Returns:
+ Adjusted learning rate.
+ """
+ del current_batch, batches_per_epoch # not used
+ initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
+ learning_rate = initial_learning_rate
+ for mult, start_epoch in LR_SCHEDULE:
+ if current_epoch >= start_epoch:
+ learning_rate = initial_learning_rate * mult
+ else:
+ break
+ return learning_rate
+
+
+class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
+ """Callback to update learning rate on every batch (not epoch boundaries).
+
+ N.B. Only support Keras optimizers, not TF optimizers.
+
+ Attributes:
+ schedule: a function that takes an epoch index and a batch index as input
+ (both integer, indexed from 0) and returns a new learning rate as
+ output (float).
+ """
+
+ def __init__(self, schedule, batch_size, steps_per_epoch):
+ super(LearningRateBatchScheduler, self).__init__()
+ self.schedule = schedule
+ self.steps_per_epoch = steps_per_epoch
+ self.batch_size = batch_size
+ self.epochs = -1
+ self.prev_lr = -1
+
+ def on_epoch_begin(self, epoch, logs=None):
+ if not hasattr(self.model.optimizer, 'learning_rate'):
+ raise ValueError('Optimizer must have a "learning_rate" attribute.')
+ self.epochs += 1
+
+ def on_batch_begin(self, batch, logs=None):
+ """Executes before step begins."""
+ lr = self.schedule(self.epochs,
+ batch,
+ self.steps_per_epoch,
+ self.batch_size)
+ if not isinstance(lr, (float, np.float32, np.float64)):
+ raise ValueError('The output of the "schedule" function should be float.')
+ if lr != self.prev_lr:
+ self.model.optimizer.learning_rate = lr # lr should be a float here
+ self.prev_lr = lr
+ logging.debug(
+ 'Epoch %05d Batch %05d: LearningRateBatchScheduler '
+ 'change learning rate to %s.', self.epochs, batch, lr)
+
+
+def run(flags_obj):
+ """Run ResNet Cifar-10 training and eval loop using native Keras APIs.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+
+ Raises:
+ ValueError: If fp16 is passed as it is not currently supported.
+
+ Returns:
+ Dictionary of training and eval stats.
+ """
+ keras_utils.set_session_config(
+ enable_xla=flags_obj.enable_xla)
+
+ # Execute flag override logic for better model performance
+ if flags_obj.tf_gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=flags_obj.per_gpu_thread_count,
+ gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
+ num_gpus=flags_obj.num_gpus,
+ datasets_num_private_threads=flags_obj.datasets_num_private_threads)
+ common.set_cudnn_batchnorm_mode()
+
+ dtype = flags_core.get_tf_dtype(flags_obj)
+ if dtype == 'fp16':
+ raise ValueError('dtype fp16 is not supported in Keras. Use the default '
+ 'value(fp32).')
+
+ data_format = flags_obj.data_format
+ if data_format is None:
+ data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
+ else 'channels_last')
+ tf.keras.backend.set_image_data_format(data_format)
+
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=flags_obj.num_gpus,
+ all_reduce_alg=flags_obj.all_reduce_alg,
+ num_packs=flags_obj.num_packs)
+
+ if strategy:
+ # flags_obj.enable_get_next_as_optional controls whether enabling
+ # get_next_as_optional behavior in DistributedIterator. If true, last
+ # partial batch can be supported.
+ strategy.extended.experimental_enable_get_next_as_optional = (
+ flags_obj.enable_get_next_as_optional
+ )
+
+ strategy_scope = distribution_utils.get_strategy_scope(strategy)
+
+ if flags_obj.use_synthetic_data:
+ synthetic_util.set_up_synthetic_data()
+ input_fn = common.get_synth_input_fn(
+ height=cifar_preprocessing.HEIGHT,
+ width=cifar_preprocessing.WIDTH,
+ num_channels=cifar_preprocessing.NUM_CHANNELS,
+ num_classes=cifar_preprocessing.NUM_CLASSES,
+ dtype=flags_core.get_tf_dtype(flags_obj),
+ drop_remainder=True)
+ else:
+ synthetic_util.undo_set_up_synthetic_data()
+ input_fn = cifar_preprocessing.input_fn
+
+ train_input_dataset = input_fn(
+ is_training=True,
+ data_dir=flags_obj.data_dir,
+ batch_size=flags_obj.batch_size,
+ parse_record_fn=cifar_preprocessing.parse_record,
+ datasets_num_private_threads=flags_obj.datasets_num_private_threads,
+ dtype=dtype,
+ # Setting drop_remainder to avoid the partial batch logic in normalization
+ # layer, which triggers tf.where and leads to extra memory copy of input
+ # sizes between host and GPU.
+ drop_remainder=(not flags_obj.enable_get_next_as_optional))
+
+ eval_input_dataset = None
+ if not flags_obj.skip_eval:
+ eval_input_dataset = input_fn(
+ is_training=False,
+ data_dir=flags_obj.data_dir,
+ batch_size=flags_obj.batch_size,
+ parse_record_fn=cifar_preprocessing.parse_record)
+
+ steps_per_epoch = (
+ cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
+ lr_schedule = 0.1
+ if flags_obj.use_tensor_lr:
+ initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
+ lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
+ boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
+ values=[initial_learning_rate] +
+ list(p[0] * initial_learning_rate for p in LR_SCHEDULE))
+
+ with strategy_scope:
+ optimizer = common.get_optimizer(lr_schedule)
+ model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
+ model.compile(
+ loss='sparse_categorical_crossentropy',
+ optimizer=optimizer,
+ metrics=(['sparse_categorical_accuracy']
+ if flags_obj.report_accuracy_metrics else None),
+ run_eagerly=flags_obj.run_eagerly)
+
+ train_epochs = flags_obj.train_epochs
+
+ callbacks = common.get_callbacks()
+
+ if not flags_obj.use_tensor_lr:
+ lr_callback = LearningRateBatchScheduler(
+ schedule=learning_rate_schedule,
+ batch_size=flags_obj.batch_size,
+ steps_per_epoch=steps_per_epoch)
+ callbacks.append(lr_callback)
+
+ # if mutliple epochs, ignore the train_steps flag.
+ if train_epochs <= 1 and flags_obj.train_steps:
+ steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
+ train_epochs = 1
+
+ num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
+ flags_obj.batch_size)
+
+ validation_data = eval_input_dataset
+ if flags_obj.skip_eval:
+ if flags_obj.set_learning_phase_to_train:
+ # TODO(haoyuzhang): Understand slowdown of setting learning phase when
+ # not using distribution strategy.
+ tf.keras.backend.set_learning_phase(1)
+ num_eval_steps = None
+ validation_data = None
+
+ if not strategy and flags_obj.explicit_gpu_placement:
+ # TODO(b/135607227): Add device scope automatically in Keras training loop
+ # when not using distribition strategy.
+ no_dist_strat_device = tf.device('/device:GPU:0')
+ no_dist_strat_device.__enter__()
+
+ history = model.fit(train_input_dataset,
+ epochs=train_epochs,
+ steps_per_epoch=steps_per_epoch,
+ callbacks=callbacks,
+ validation_steps=num_eval_steps,
+ validation_data=validation_data,
+ validation_freq=flags_obj.epochs_between_evals,
+ verbose=2)
+ eval_output = None
+ if not flags_obj.skip_eval:
+ eval_output = model.evaluate(eval_input_dataset,
+ steps=num_eval_steps,
+ verbose=2)
+
+ if not strategy and flags_obj.explicit_gpu_placement:
+ no_dist_strat_device.__exit__()
+
+ stats = common.build_stats(history, eval_output, callbacks)
+ return stats
+
+
+def define_cifar_flags():
+ common.define_keras_flags(dynamic_loss_scale=False)
+
+ flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
+ model_dir='/tmp/cifar10_model',
+ epochs_between_evals=10,
+ batch_size=128)
+
+
+def main(_):
+ return run(flags.FLAGS)
+
+
+if __name__ == '__main__':
+ logging.set_verbosity(logging.INFO)
+ define_cifar_flags()
+ app.run(main)
diff --git a/models/official/benchmark/models/resnet_cifar_model.py b/models/official/benchmark/models/resnet_cifar_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b507381f1b6907fdfb078d8316f3621a9e2b8f7
--- /dev/null
+++ b/models/official/benchmark/models/resnet_cifar_model.py
@@ -0,0 +1,262 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ResNet56 model for Keras adapted from tf.keras.applications.ResNet50.
+
+# Reference:
+- [Deep Residual Learning for Image Recognition](
+ https://arxiv.org/abs/1512.03385)
+Adapted from code contributed by BigMoyan.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import tensorflow as tf
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import regularizers
+
+
+BATCH_NORM_DECAY = 0.997
+BATCH_NORM_EPSILON = 1e-5
+L2_WEIGHT_DECAY = 2e-4
+
+
+def identity_building_block(input_tensor,
+ kernel_size,
+ filters,
+ stage,
+ block,
+ training=None):
+ """The identity block is the block that has no conv layer at shortcut.
+
+ Arguments:
+ input_tensor: input tensor
+ kernel_size: default 3, the kernel size of
+ middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ block: current block label, used for generating layer names
+ training: Only used if training keras model with Estimator. In other
+ scenarios it is handled automatically.
+
+ Returns:
+ Output tensor for the block.
+ """
+ filters1, filters2 = filters
+ if backend.image_data_format() == 'channels_last':
+ bn_axis = 3
+ else:
+ bn_axis = 1
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+
+ x = layers.Conv2D(filters1, kernel_size,
+ padding='same', use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
+ name=conv_name_base + '2a')(input_tensor)
+ x = layers.BatchNormalization(
+ axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
+ name=bn_name_base + '2a')(x, training=training)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(filters2, kernel_size,
+ padding='same', use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
+ name=conv_name_base + '2b')(x)
+ x = layers.BatchNormalization(
+ axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
+ name=bn_name_base + '2b')(x, training=training)
+
+ x = layers.add([x, input_tensor])
+ x = layers.Activation('relu')(x)
+ return x
+
+
+def conv_building_block(input_tensor,
+ kernel_size,
+ filters,
+ stage,
+ block,
+ strides=(2, 2),
+ training=None):
+ """A block that has a conv layer at shortcut.
+
+ Arguments:
+ input_tensor: input tensor
+ kernel_size: default 3, the kernel size of
+ middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ block: current block label, used for generating layer names
+ strides: Strides for the first conv layer in the block.
+ training: Only used if training keras model with Estimator. In other
+ scenarios it is handled automatically.
+
+ Returns:
+ Output tensor for the block.
+
+ Note that from stage 3,
+ the first conv layer at main path is with strides=(2, 2)
+ And the shortcut should have strides=(2, 2) as well
+ """
+ filters1, filters2 = filters
+ if tf.keras.backend.image_data_format() == 'channels_last':
+ bn_axis = 3
+ else:
+ bn_axis = 1
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+
+ x = layers.Conv2D(filters1, kernel_size, strides=strides,
+ padding='same', use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
+ name=conv_name_base + '2a')(input_tensor)
+ x = layers.BatchNormalization(
+ axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
+ name=bn_name_base + '2a')(x, training=training)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(filters2, kernel_size, padding='same', use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
+ name=conv_name_base + '2b')(x)
+ x = layers.BatchNormalization(
+ axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
+ name=bn_name_base + '2b')(x, training=training)
+
+ shortcut = layers.Conv2D(filters2, (1, 1), strides=strides, use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
+ name=conv_name_base + '1')(input_tensor)
+ shortcut = layers.BatchNormalization(
+ axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
+ name=bn_name_base + '1')(shortcut, training=training)
+
+ x = layers.add([x, shortcut])
+ x = layers.Activation('relu')(x)
+ return x
+
+
+def resnet_block(input_tensor,
+ size,
+ kernel_size,
+ filters,
+ stage,
+ conv_strides=(2, 2),
+ training=None):
+ """A block which applies conv followed by multiple identity blocks.
+
+ Arguments:
+ input_tensor: input tensor
+ size: integer, number of constituent conv/identity building blocks.
+ A conv block is applied once, followed by (size - 1) identity blocks.
+ kernel_size: default 3, the kernel size of
+ middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ conv_strides: Strides for the first conv layer in the block.
+ training: Only used if training keras model with Estimator. In other
+ scenarios it is handled automatically.
+
+ Returns:
+ Output tensor after applying conv and identity blocks.
+ """
+
+ x = conv_building_block(input_tensor, kernel_size, filters, stage=stage,
+ strides=conv_strides, block='block_0',
+ training=training)
+ for i in range(size - 1):
+ x = identity_building_block(x, kernel_size, filters, stage=stage,
+ block='block_%d' % (i + 1), training=training)
+ return x
+
+
+def resnet(num_blocks, classes=10, training=None):
+ """Instantiates the ResNet architecture.
+
+ Arguments:
+ num_blocks: integer, the number of conv/identity blocks in each block.
+ The ResNet contains 3 blocks with each block containing one conv block
+ followed by (layers_per_block - 1) number of idenity blocks. Each
+ conv/idenity block has 2 convolutional layers. With the input
+ convolutional layer and the pooling layer towards the end, this brings
+ the total size of the network to (6*num_blocks + 2)
+ classes: optional number of classes to classify images into
+ training: Only used if training keras model with Estimator. In other
+ scenarios it is handled automatically.
+
+ Returns:
+ A Keras model instance.
+ """
+
+ input_shape = (32, 32, 3)
+ img_input = layers.Input(shape=input_shape)
+
+ if backend.image_data_format() == 'channels_first':
+ x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
+ name='transpose')(img_input)
+ bn_axis = 1
+ else: # channel_last
+ x = img_input
+ bn_axis = 3
+
+ x = layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
+ x = layers.Conv2D(16, (3, 3),
+ strides=(1, 1),
+ padding='valid', use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
+ name='conv1')(x)
+ x = layers.BatchNormalization(axis=bn_axis,
+ momentum=BATCH_NORM_DECAY,
+ epsilon=BATCH_NORM_EPSILON,
+ name='bn_conv1',)(x, training=training)
+ x = layers.Activation('relu')(x)
+
+ x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16],
+ stage=2, conv_strides=(1, 1), training=training)
+
+ x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[32, 32],
+ stage=3, conv_strides=(2, 2), training=training)
+
+ x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64],
+ stage=4, conv_strides=(2, 2), training=training)
+
+ rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
+ x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
+ x = layers.Dense(classes,
+ activation='softmax',
+ kernel_initializer=initializers.RandomNormal(stddev=0.01),
+ kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
+ bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
+ name='fc10')(x)
+
+ inputs = img_input
+ # Create model.
+ model = tf.keras.models.Model(inputs, x, name='resnet56')
+
+ return model
+
+
+resnet20 = functools.partial(resnet, num_blocks=3)
+resnet32 = functools.partial(resnet, num_blocks=5)
+resnet56 = functools.partial(resnet, num_blocks=9)
+resnet10 = functools.partial(resnet, num_blocks=110)
diff --git a/models/official/benchmark/models/resnet_cifar_test.py b/models/official/benchmark/models/resnet_cifar_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c160f44eca1b6faf9def08860ebbdc6403d352e3
--- /dev/null
+++ b/models/official/benchmark/models/resnet_cifar_test.py
@@ -0,0 +1,180 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test the keras ResNet model with Cifar data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+
+import tensorflow as tf
+
+from tensorflow.python.eager import context
+from tensorflow.python.platform import googletest
+from official.benchmark.models import cifar_preprocessing
+from official.benchmark.models import resnet_cifar_main
+from official.utils.testing import integration
+
+
+class KerasCifarTest(googletest.TestCase):
+ """Unit tests for Keras ResNet with Cifar."""
+
+ _extra_flags = [
+ "-batch_size", "4",
+ "-train_steps", "1",
+ "-use_synthetic_data", "true"
+ ]
+ _tempdir = None
+
+ def get_temp_dir(self):
+ if not self._tempdir:
+ self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
+ return self._tempdir
+
+ @classmethod
+ def setUpClass(cls): # pylint: disable=invalid-name
+ super(KerasCifarTest, cls).setUpClass()
+ resnet_cifar_main.define_cifar_flags()
+
+ def setUp(self):
+ super(KerasCifarTest, self).setUp()
+ cifar_preprocessing.NUM_IMAGES["validation"] = 4
+
+ def tearDown(self):
+ super(KerasCifarTest, self).tearDown()
+ tf.io.gfile.rmtree(self.get_temp_dir())
+
+ def test_end_to_end_no_dist_strat(self):
+ """Test Keras model with 1 GPU, no distribution strategy."""
+
+ extra_flags = [
+ "-distribution_strategy", "off",
+ "-model_dir", "keras_cifar_no_dist_strat",
+ "-data_format", "channels_last",
+ ]
+ extra_flags = extra_flags + self._extra_flags
+
+ integration.run_synthetic(
+ main=resnet_cifar_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_graph_no_dist_strat(self):
+ """Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
+ extra_flags = [
+ "-enable_eager", "false",
+ "-distribution_strategy", "off",
+ "-model_dir", "keras_cifar_graph_no_dist_strat",
+ "-data_format", "channels_last",
+ ]
+ extra_flags = extra_flags + self._extra_flags
+
+ integration.run_synthetic(
+ main=resnet_cifar_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_1_gpu(self):
+ """Test Keras model with 1 GPU."""
+
+ if context.num_gpus() < 1:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(1, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "1",
+ "-distribution_strategy", "mirrored",
+ "-model_dir", "keras_cifar_1_gpu",
+ "-data_format", "channels_last",
+ ]
+ extra_flags = extra_flags + self._extra_flags
+
+ integration.run_synthetic(
+ main=resnet_cifar_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_graph_1_gpu(self):
+ """Test Keras model in legacy graph mode with 1 GPU."""
+ if context.num_gpus() < 1:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(1, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "1",
+ "-noenable_eager",
+ "-distribution_strategy", "mirrored",
+ "-model_dir", "keras_cifar_graph_1_gpu",
+ "-data_format", "channels_last",
+ ]
+ extra_flags = extra_flags + self._extra_flags
+
+ integration.run_synthetic(
+ main=resnet_cifar_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_2_gpu(self):
+ """Test Keras model with 2 GPUs."""
+
+ if context.num_gpus() < 2:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(2, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "2",
+ "-distribution_strategy", "mirrored",
+ "-model_dir", "keras_cifar_2_gpu",
+ ]
+ extra_flags = extra_flags + self._extra_flags
+
+ integration.run_synthetic(
+ main=resnet_cifar_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_graph_2_gpu(self):
+ """Test Keras model in legacy graph mode with 2 GPUs."""
+ if context.num_gpus() < 2:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(2, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "2",
+ "-enable_eager", "false",
+ "-distribution_strategy", "mirrored",
+ "-model_dir", "keras_cifar_graph_2_gpu",
+ ]
+ extra_flags = extra_flags + self._extra_flags
+
+ integration.run_synthetic(
+ main=resnet_cifar_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/models/official/benchmark/models/resnet_imagenet_main.py b/models/official/benchmark/models/resnet_imagenet_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a3cd503126e8796aed8a59164e9dcd6bef9c1dc
--- /dev/null
+++ b/models/official/benchmark/models/resnet_imagenet_main.py
@@ -0,0 +1,301 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Runs a ResNet model on the ImageNet dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+
+import tensorflow_model_optimization as tfmot
+from official.modeling import performance
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+from official.utils.misc import model_helpers
+from official.vision.image_classification import test_utils
+from official.vision.image_classification.resnet import common
+from official.vision.image_classification.resnet import imagenet_preprocessing
+from official.vision.image_classification.resnet import resnet_model
+
+
+def run(flags_obj):
+ """Run ResNet ImageNet training and eval loop using native Keras APIs.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+
+ Raises:
+ ValueError: If fp16 is passed as it is not currently supported.
+ NotImplementedError: If some features are not currently supported.
+
+ Returns:
+ Dictionary of training and eval stats.
+ """
+ keras_utils.set_session_config(
+ enable_xla=flags_obj.enable_xla)
+
+ # Execute flag override logic for better model performance
+ if flags_obj.tf_gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=flags_obj.per_gpu_thread_count,
+ gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
+ num_gpus=flags_obj.num_gpus,
+ datasets_num_private_threads=flags_obj.datasets_num_private_threads)
+ common.set_cudnn_batchnorm_mode()
+
+ dtype = flags_core.get_tf_dtype(flags_obj)
+ performance.set_mixed_precision_policy(
+ flags_core.get_tf_dtype(flags_obj),
+ flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
+
+ data_format = flags_obj.data_format
+ if data_format is None:
+ data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
+ else 'channels_last')
+ tf.keras.backend.set_image_data_format(data_format)
+
+ # Configures cluster spec for distribution strategy.
+ _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
+ flags_obj.task_index)
+
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=flags_obj.num_gpus,
+ all_reduce_alg=flags_obj.all_reduce_alg,
+ num_packs=flags_obj.num_packs,
+ tpu_address=flags_obj.tpu)
+
+ if strategy:
+ # flags_obj.enable_get_next_as_optional controls whether enabling
+ # get_next_as_optional behavior in DistributedIterator. If true, last
+ # partial batch can be supported.
+ strategy.extended.experimental_enable_get_next_as_optional = (
+ flags_obj.enable_get_next_as_optional
+ )
+
+ strategy_scope = distribution_utils.get_strategy_scope(strategy)
+
+ # pylint: disable=protected-access
+ if flags_obj.use_synthetic_data:
+ input_fn = common.get_synth_input_fn(
+ height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
+ width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
+ num_channels=imagenet_preprocessing.NUM_CHANNELS,
+ num_classes=imagenet_preprocessing.NUM_CLASSES,
+ dtype=dtype,
+ drop_remainder=True)
+ else:
+ input_fn = imagenet_preprocessing.input_fn
+
+ # When `enable_xla` is True, we always drop the remainder of the batches
+ # in the dataset, as XLA-GPU doesn't support dynamic shapes.
+ drop_remainder = flags_obj.enable_xla
+
+ # Current resnet_model.resnet50 input format is always channel-last.
+ # We use keras_application mobilenet model which input format is depends on
+ # the keras beckend image data format.
+ # This use_keras_image_data_format flags indicates whether image preprocessor
+ # output format should be same as the keras backend image data format or just
+ # channel-last format.
+ use_keras_image_data_format = (flags_obj.model == 'mobilenet')
+ train_input_dataset = input_fn(
+ is_training=True,
+ data_dir=flags_obj.data_dir,
+ batch_size=flags_obj.batch_size,
+ parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
+ use_keras_image_data_format=use_keras_image_data_format),
+ datasets_num_private_threads=flags_obj.datasets_num_private_threads,
+ dtype=dtype,
+ drop_remainder=drop_remainder,
+ tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
+ training_dataset_cache=flags_obj.training_dataset_cache,
+ )
+
+ eval_input_dataset = None
+ if not flags_obj.skip_eval:
+ eval_input_dataset = input_fn(
+ is_training=False,
+ data_dir=flags_obj.data_dir,
+ batch_size=flags_obj.batch_size,
+ parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
+ use_keras_image_data_format=use_keras_image_data_format),
+ dtype=dtype,
+ drop_remainder=drop_remainder)
+
+ lr_schedule = common.PiecewiseConstantDecayWithWarmup(
+ batch_size=flags_obj.batch_size,
+ epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
+ warmup_epochs=common.LR_SCHEDULE[0][1],
+ boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
+ multipliers=list(p[0] for p in common.LR_SCHEDULE),
+ compute_lr_on_cpu=True)
+ steps_per_epoch = (
+ imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
+
+ with strategy_scope:
+ if flags_obj.optimizer == 'resnet50_default':
+ optimizer = common.get_optimizer(lr_schedule)
+ elif flags_obj.optimizer == 'mobilenet_default':
+ initial_learning_rate = \
+ flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
+ optimizer = tf.keras.optimizers.SGD(
+ learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate,
+ decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
+ decay_rate=flags_obj.lr_decay_factor,
+ staircase=True),
+ momentum=0.9)
+ if flags_obj.fp16_implementation == 'graph_rewrite':
+ # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
+ # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
+ # which will ensure tf.compat.v2.keras.mixed_precision and
+ # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
+ # up.
+ optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
+ optimizer)
+
+ # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
+ if flags_obj.use_trivial_model:
+ model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
+ elif flags_obj.model == 'resnet50_v1.5':
+ model = resnet_model.resnet50(
+ num_classes=imagenet_preprocessing.NUM_CLASSES)
+ elif flags_obj.model == 'mobilenet':
+ # TODO(kimjaehong): Remove layers attribute when minimum TF version
+ # support 2.0 layers by default.
+ model = tf.keras.applications.mobilenet.MobileNet(
+ weights=None,
+ classes=imagenet_preprocessing.NUM_CLASSES,
+ layers=tf.keras.layers)
+ if flags_obj.pretrained_filepath:
+ model.load_weights(flags_obj.pretrained_filepath)
+
+ if flags_obj.pruning_method == 'polynomial_decay':
+ if dtype != tf.float32:
+ raise NotImplementedError(
+ 'Pruning is currently only supported on dtype=tf.float32.')
+ pruning_params = {
+ 'pruning_schedule':
+ tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=flags_obj.pruning_initial_sparsity,
+ final_sparsity=flags_obj.pruning_final_sparsity,
+ begin_step=flags_obj.pruning_begin_step,
+ end_step=flags_obj.pruning_end_step,
+ frequency=flags_obj.pruning_frequency),
+ }
+ model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
+ elif flags_obj.pruning_method:
+ raise NotImplementedError(
+ 'Only polynomial_decay is currently supported.')
+
+ model.compile(
+ loss='sparse_categorical_crossentropy',
+ optimizer=optimizer,
+ metrics=(['sparse_categorical_accuracy']
+ if flags_obj.report_accuracy_metrics else None),
+ run_eagerly=flags_obj.run_eagerly)
+
+ train_epochs = flags_obj.train_epochs
+
+ callbacks = common.get_callbacks(
+ pruning_method=flags_obj.pruning_method,
+ enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
+ model_dir=flags_obj.model_dir)
+
+ # if mutliple epochs, ignore the train_steps flag.
+ if train_epochs <= 1 and flags_obj.train_steps:
+ steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
+ train_epochs = 1
+
+ num_eval_steps = (
+ imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
+
+ validation_data = eval_input_dataset
+ if flags_obj.skip_eval:
+ # Only build the training graph. This reduces memory usage introduced by
+ # control flow ops in layers that have different implementations for
+ # training and inference (e.g., batch norm).
+ if flags_obj.set_learning_phase_to_train:
+ # TODO(haoyuzhang): Understand slowdown of setting learning phase when
+ # not using distribution strategy.
+ tf.keras.backend.set_learning_phase(1)
+ num_eval_steps = None
+ validation_data = None
+
+ if not strategy and flags_obj.explicit_gpu_placement:
+ # TODO(b/135607227): Add device scope automatically in Keras training loop
+ # when not using distribition strategy.
+ no_dist_strat_device = tf.device('/device:GPU:0')
+ no_dist_strat_device.__enter__()
+
+ history = model.fit(train_input_dataset,
+ epochs=train_epochs,
+ steps_per_epoch=steps_per_epoch,
+ callbacks=callbacks,
+ validation_steps=num_eval_steps,
+ validation_data=validation_data,
+ validation_freq=flags_obj.epochs_between_evals,
+ verbose=2)
+
+ eval_output = None
+ if not flags_obj.skip_eval:
+ eval_output = model.evaluate(eval_input_dataset,
+ steps=num_eval_steps,
+ verbose=2)
+
+ if flags_obj.pruning_method:
+ model = tfmot.sparsity.keras.strip_pruning(model)
+ if flags_obj.enable_checkpoint_and_export:
+ if dtype == tf.bfloat16:
+ logging.warning('Keras model.save does not support bfloat16 dtype.')
+ else:
+ # Keras model.save assumes a float32 input designature.
+ export_path = os.path.join(flags_obj.model_dir, 'saved_model')
+ model.save(export_path, include_optimizer=False)
+
+ if not strategy and flags_obj.explicit_gpu_placement:
+ no_dist_strat_device.__exit__()
+
+ stats = common.build_stats(history, eval_output, callbacks)
+ return stats
+
+
+def define_imagenet_keras_flags():
+ common.define_keras_flags(
+ model=True,
+ optimizer=True,
+ pretrained_filepath=True)
+ common.define_pruning_flags()
+ flags_core.set_defaults()
+ flags.adopt_module_key_flags(common)
+
+
+def main(_):
+ model_helpers.apply_clean(flags.FLAGS)
+ stats = run(flags.FLAGS)
+ logging.info('Run stats:\n%s', stats)
+
+
+if __name__ == '__main__':
+ logging.set_verbosity(logging.INFO)
+ define_imagenet_keras_flags()
+ app.run(main)
diff --git a/models/official/benchmark/models/resnet_imagenet_test.py b/models/official/benchmark/models/resnet_imagenet_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c35d539ce2d7fcd0df30ed1d520e47e51312fa
--- /dev/null
+++ b/models/official/benchmark/models/resnet_imagenet_test.py
@@ -0,0 +1,249 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test the keras ResNet model with ImageNet data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import tensorflow as tf
+
+from tensorflow.python.eager import context
+from official.benchmark.models import resnet_imagenet_main
+from official.utils.testing import integration
+from official.vision.image_classification.resnet import imagenet_preprocessing
+
+
+@parameterized.parameters(
+ "resnet",
+ # "resnet_polynomial_decay", b/151854314
+ "mobilenet",
+ # "mobilenet_polynomial_decay" b/151854314
+)
+class KerasImagenetTest(tf.test.TestCase):
+ """Unit tests for Keras Models with ImageNet."""
+ _default_flags_dict = [
+ "-batch_size", "4",
+ "-train_steps", "1",
+ "-use_synthetic_data", "true",
+ "-data_format", "channels_last",
+ ]
+ _extra_flags_dict = {
+ "resnet": [
+ "-model", "resnet50_v1.5",
+ "-optimizer", "resnet50_default",
+ ],
+ "resnet_polynomial_decay": [
+ "-model", "resnet50_v1.5",
+ "-optimizer", "resnet50_default",
+ "-pruning_method", "polynomial_decay",
+ ],
+ "mobilenet": [
+ "-model", "mobilenet",
+ "-optimizer", "mobilenet_default",
+ ],
+ "mobilenet_polynomial_decay": [
+ "-model", "mobilenet",
+ "-optimizer", "mobilenet_default",
+ "-pruning_method", "polynomial_decay",
+ ],
+ }
+ _tempdir = None
+
+ @classmethod
+ def setUpClass(cls): # pylint: disable=invalid-name
+ super(KerasImagenetTest, cls).setUpClass()
+ resnet_imagenet_main.define_imagenet_keras_flags()
+
+ def setUp(self):
+ super(KerasImagenetTest, self).setUp()
+ imagenet_preprocessing.NUM_IMAGES["validation"] = 4
+ self.policy = \
+ tf.keras.mixed_precision.experimental.global_policy()
+
+ def tearDown(self):
+ super(KerasImagenetTest, self).tearDown()
+ tf.io.gfile.rmtree(self.get_temp_dir())
+ tf.keras.mixed_precision.experimental.set_policy(self.policy)
+
+ def get_extra_flags_dict(self, flags_key):
+ return self._extra_flags_dict[flags_key] + self._default_flags_dict
+
+ def test_end_to_end_no_dist_strat(self, flags_key):
+ """Test Keras model with 1 GPU, no distribution strategy."""
+
+ extra_flags = [
+ "-distribution_strategy", "off",
+ ]
+ extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_graph_no_dist_strat(self, flags_key):
+ """Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
+ extra_flags = [
+ "-enable_eager", "false",
+ "-distribution_strategy", "off",
+ ]
+ extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_1_gpu(self, flags_key):
+ """Test Keras model with 1 GPU."""
+
+ if context.num_gpus() < 1:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(1, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "1",
+ "-distribution_strategy", "mirrored",
+ "-enable_checkpoint_and_export", "1",
+ ]
+ extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_1_gpu_fp16(self, flags_key):
+ """Test Keras model with 1 GPU and fp16."""
+
+ if context.num_gpus() < 1:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available"
+ .format(1, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "1",
+ "-dtype", "fp16",
+ "-distribution_strategy", "mirrored",
+ ]
+ extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
+
+ if "polynomial_decay" in extra_flags:
+ self.skipTest("Pruning with fp16 is not currently supported.")
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_2_gpu(self, flags_key):
+ """Test Keras model with 2 GPUs."""
+
+ if context.num_gpus() < 2:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(2, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "2",
+ "-distribution_strategy", "mirrored",
+ ]
+ extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_xla_2_gpu(self, flags_key):
+ """Test Keras model with XLA and 2 GPUs."""
+
+ if context.num_gpus() < 2:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(2, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "2",
+ "-enable_xla", "true",
+ "-distribution_strategy", "mirrored",
+ ]
+ extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_2_gpu_fp16(self, flags_key):
+ """Test Keras model with 2 GPUs and fp16."""
+
+ if context.num_gpus() < 2:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(2, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "2",
+ "-dtype", "fp16",
+ "-distribution_strategy", "mirrored",
+ ]
+ extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
+
+ if "polynomial_decay" in extra_flags:
+ self.skipTest("Pruning with fp16 is not currently supported.")
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ def test_end_to_end_xla_2_gpu_fp16(self, flags_key):
+ """Test Keras model with XLA, 2 GPUs and fp16."""
+ if context.num_gpus() < 2:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(2, context.num_gpus()))
+
+ extra_flags = [
+ "-num_gpus", "2",
+ "-dtype", "fp16",
+ "-enable_xla", "true",
+ "-distribution_strategy", "mirrored",
+ ]
+ extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
+
+ if "polynomial_decay" in extra_flags:
+ self.skipTest("Pruning with fp16 is not currently supported.")
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/benchmark/models/resnet_imagenet_test_tpu.py b/models/official/benchmark/models/resnet_imagenet_test_tpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fd72c404139b723407cc9a68c8afddd158ed691
--- /dev/null
+++ b/models/official/benchmark/models/resnet_imagenet_test_tpu.py
@@ -0,0 +1,105 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test the keras ResNet model with ImageNet data on TPU."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import tensorflow as tf
+from official.benchmark.models import resnet_imagenet_main
+from official.utils.testing import integration
+from official.vision.image_classification.resnet import imagenet_preprocessing
+
+
+class KerasImagenetTest(tf.test.TestCase, parameterized.TestCase):
+ """Unit tests for Keras Models with ImageNet."""
+
+ _extra_flags_dict = {
+ "resnet": [
+ "-batch_size", "4",
+ "-train_steps", "1",
+ "-use_synthetic_data", "true"
+ "-model", "resnet50_v1.5",
+ "-optimizer", "resnet50_default",
+ ],
+ "resnet_polynomial_decay": [
+ "-batch_size", "4",
+ "-train_steps", "1",
+ "-use_synthetic_data", "true",
+ "-model", "resnet50_v1.5",
+ "-optimizer", "resnet50_default",
+ "-pruning_method", "polynomial_decay",
+ ],
+ }
+ _tempdir = None
+
+ @classmethod
+ def setUpClass(cls): # pylint: disable=invalid-name
+ super(KerasImagenetTest, cls).setUpClass()
+ resnet_imagenet_main.define_imagenet_keras_flags()
+
+ def setUp(self):
+ super(KerasImagenetTest, self).setUp()
+ imagenet_preprocessing.NUM_IMAGES["validation"] = 4
+ self.policy = \
+ tf.keras.mixed_precision.experimental.global_policy()
+
+ def tearDown(self):
+ super(KerasImagenetTest, self).tearDown()
+ tf.io.gfile.rmtree(self.get_temp_dir())
+ tf.keras.mixed_precision.experimental.set_policy(self.policy)
+
+ @parameterized.parameters([
+ "resnet",
+ # "resnet_polynomial_decay" b/151854314
+ ])
+ def test_end_to_end_tpu(self, flags_key):
+ """Test Keras model with TPU distribution strategy."""
+
+ extra_flags = [
+ "-distribution_strategy", "tpu",
+ "-data_format", "channels_last",
+ "-enable_checkpoint_and_export", "1",
+ ]
+ extra_flags = extra_flags + self._extra_flags_dict[flags_key]
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+ @parameterized.parameters(["resnet"])
+ def test_end_to_end_tpu_bf16(self, flags_key):
+ """Test Keras model with TPU and bfloat16 activation."""
+
+ extra_flags = [
+ "-distribution_strategy", "tpu",
+ "-data_format", "channels_last",
+ "-dtype", "bf16",
+ ]
+ extra_flags = extra_flags + self._extra_flags_dict[flags_key]
+
+ integration.run_synthetic(
+ main=resnet_imagenet_main.run,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags
+ )
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/benchmark/models/shakespeare/README.md b/models/official/benchmark/models/shakespeare/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5395cc9642845ffb8bf36fdbc4f93bb450ba557f
--- /dev/null
+++ b/models/official/benchmark/models/shakespeare/README.md
@@ -0,0 +1,31 @@
+# Shakespeare character LSTM model
+
+This is an implemention of a simple character LSTM used to generate text.
+
+## Instructions
+
+First download the source data:
+
+```
+wget https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
+```
+
+Note that files other than shakepeare.txt can also be used to train the model to generater other text.
+
+Then train the model:
+
+```python
+python3 shakespeare_main.py --training_data shakespeare.txt \
+ --model_dir /tmp/shakespeare
+```
+
+This will place model checkpoints in `/tmp/shakespeare`, so that we can use them to make predictions.
+
+Then generate predictions:
+
+```python
+python3 shakespeare_main.py --training_data shakespeare.txt \
+ --model_dir /tmp/shakespeare --notrain --predict_context=ROMEO:
+```
+
+Change `--predict_context` and `--predict_length` to suit your needs.
diff --git a/models/official/benchmark/models/shakespeare/__init__.py b/models/official/benchmark/models/shakespeare/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/official/benchmark/models/shakespeare/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/official/benchmark/models/shakespeare/shakespeare_main.py b/models/official/benchmark/models/shakespeare/shakespeare_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..6928dd1d61491acf84b969a52c7f0693617ac7f0
--- /dev/null
+++ b/models/official/benchmark/models/shakespeare/shakespeare_main.py
@@ -0,0 +1,313 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Runs a character LSTM model trained on Shakespeare."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+
+# pylint: disable=wrong-import-order
+from absl import app
+from absl import flags
+import numpy as np
+import tensorflow as tf
+# pylint: enable=wrong-import-order
+
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+
+EMBEDDING_DIM = 256
+RNN_UNITS = 1024
+SEQ_LENGTH = 100
+# Calculated by running batch_size=1
+BATCHES_PER_EPOCH = 11043
+
+
+def define_flags():
+ """Define the flags for the Shakespeare character LSTM."""
+ flags_core.define_base(data_dir=False,
+ clean=False,
+ train_epochs=True,
+ epochs_between_evals=False,
+ stop_threshold=False,
+ num_gpu=True,
+ export_dir=False,
+ run_eagerly=True,
+ distribution_strategy=True)
+
+ flags_core.define_performance(num_parallel_calls=False,
+ inter_op=False,
+ intra_op=False,
+ synthetic_data=False,
+ max_train_steps=False,
+ dtype=True,
+ loss_scale=True,
+ enable_xla=True)
+
+ flags_core.set_defaults(train_epochs=43,
+ batch_size=64)
+
+ flags.DEFINE_boolean(name='enable_eager', default=True, help='Enable eager?')
+ flags.DEFINE_boolean(
+ name='train', default=True,
+ help='If true trains the model.')
+ flags.DEFINE_string(
+ name='predict_context', default=None,
+ help='If set, makes a prediction with the given context.')
+ flags.DEFINE_integer(
+ name='predict_length', default=1000,
+ help='Length of the predicted text including the context.')
+ flags.DEFINE_integer(name='train_steps', default=None,
+ help='Overrides train_steps per epoch if not None.')
+ flags.DEFINE_integer(
+ name='log_steps', default=100,
+ help='For every log_steps, we log the timing information such as '
+ 'examples per second.')
+ flags.DEFINE_string(
+ name='training_data', default=None,
+ help='Path to file containing the training data.')
+ flags.DEFINE_boolean(name='cudnn', default=True, help='Use CuDNN LSTM.')
+
+
+def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
+ """Creates a dataset from a given text file.
+
+ Args:
+ path_to_file: The path to the training data.
+ batch_size: Batch size to use.
+ seq_length: The length of the LSTM sequence.
+
+ Returns:
+ A tuple, consisting of the Dataset and the class to character mapping
+ and character to class mapping.
+ """
+ with tf.io.gfile.GFile(path_to_file, 'rb') as train_data:
+ text = train_data.read().decode(encoding='utf-8')
+
+ # Create vocab
+ vocab = sorted(set(text))
+ char2idx = {u: i for i, u in enumerate(vocab)}
+ idx2char = np.array(vocab)
+
+ # Split text into sequence length + 1 chucks to create examples
+ text_as_int = np.array([char2idx[c] for c in text])
+ char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
+ sequences = char_dataset.batch(seq_length+1, drop_remainder=True)
+
+ def split_input_target(chunk):
+ input_text = chunk[:-1]
+ target_text = chunk[1:]
+ return input_text, tf.one_hot(target_text, len(vocab))
+ dataset = sequences.map(split_input_target)
+ dataset = dataset.shuffle(10000).repeat()
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+
+ return dataset, idx2char, char2idx
+
+
+def build_model(vocab_size,
+ embedding_dim=EMBEDDING_DIM,
+ rnn_units=RNN_UNITS,
+ batch_size=None,
+ stateful=False,
+ use_cudnn=True):
+ """Builds the Shakespeare model.
+
+ Args:
+ vocab_size: The number of character classes in the input.
+ embedding_dim: The dimension of the embedding space for each class.
+ rnn_units: The number of RNN units in the layer.
+ batch_size: When predicting, the batch size of the predictions.
+ stateful: If true, the LSTM is stateful.
+
+ Returns:
+ A Keras Model.
+ """
+ LSTM = functools.partial(tf.keras.layers.LSTM, implementation=2)
+
+ # By indirecting the activation through a lambda layer, the logic to dispatch
+ # to CuDNN in V2 doesn't trigger and we force the LSTM to run in non-CuDNN
+ # mode.
+ lstm_activation = ('tanh' if use_cudnn else
+ lambda x: tf.math.tanh(x))
+
+ batch_shape = [batch_size if stateful else None, None]
+ return tf.keras.Sequential([
+ tf.keras.layers.Embedding(vocab_size, embedding_dim,
+ batch_input_shape=batch_shape),
+ LSTM(rnn_units,
+ activation=lstm_activation,
+ return_sequences=True,
+ stateful=stateful,
+ recurrent_initializer='glorot_uniform'),
+ tf.keras.layers.Dense(vocab_size),
+ tf.keras.layers.Softmax(dtype=tf.float32)])
+
+
+def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
+ """Trains a Shakespeare model.
+
+ Args:
+ flags_obj: An object containing parsed flag values.s
+ dataset: the training data set.
+ vocab_size: the number of unique character classes.
+ strategy: distribution strategy to use.
+ checkpoint_dir: if not None, the directory in which to make checkpoints.
+
+ Returns:
+ The training history and callbacks.
+ """
+ if flags_obj.train_steps:
+ train_steps = flags_obj.train_steps
+ else:
+ train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size
+ strategy_scope = distribution_utils.get_strategy_scope(strategy)
+
+ with strategy_scope:
+ model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size,
+ use_cudnn=flags_obj.cudnn)
+
+ # When keras_use_ctl is False, Model.fit() automatically applies
+ # loss scaling so we don't need to create a LossScaleOptimizer.
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(),
+ loss=tf.keras.losses.CategoricalCrossentropy(),
+ metrics=[tf.keras.metrics.Recall(top_k=1, name='RecallAt1'),
+ tf.keras.metrics.Recall(top_k=5, name='RecallAt5')],
+ run_eagerly=flags_obj.run_eagerly)
+
+ callbacks = []
+ if checkpoint_dir:
+ checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch}')
+ checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_prefix,
+ save_weights_only=True)
+ callbacks.append(checkpoint_callback)
+ time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
+ flags_obj.log_steps)
+ callbacks.append(time_callback)
+ history = model.fit(dataset,
+ epochs=flags_obj.train_epochs,
+ steps_per_epoch=train_steps,
+ callbacks=callbacks,
+ verbose=2)
+ return history, callbacks
+
+
+def make_prediction(checkpoint_dir, length, context, idx2char, char2idx):
+ """Make predictions from a Shakespeare model.
+
+ Args:
+ checkpoint_dir: the directory from which to load checkpoints
+ length: the total length of the generated text (including the context).
+ context: the initial text with which the LSTM is primed.
+ idx2char: the character class to character mapping.
+ char2idx: the character to character class mapping.
+
+ Returns:
+ A generated string of text of the given length.
+ """
+ prediction_model = build_model(
+ vocab_size=len(idx2char), batch_size=1, stateful=True)
+ prediction_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
+ prediction_model.build(tf.TensorShape([1, None]))
+
+ input_eval = [char2idx[s] for s in context]
+ input_eval = tf.expand_dims(input_eval, 0)
+
+ text_generated = []
+
+ prediction_model.reset_states()
+ for _ in range(length - len(context)):
+ predictions = prediction_model(input_eval)
+ predictions = tf.squeeze(predictions, 0)
+
+ # We applied a softmax to the output of the model so that
+ # tf.keras.metrics.Recall would work. We need logits for
+ # tf.random.categorical, so we convert the probabilities back to log odds
+ predictions = tf.math.log(predictions / (1 - predictions))
+
+ random_output = tf.random.categorical(predictions, num_samples=1)
+ selected_id = random_output[-1, 0].numpy()
+ input_eval = tf.expand_dims([selected_id], 0)
+ text_generated.append(idx2char[selected_id])
+
+ return context + ''.join(text_generated)
+
+
+def run(flags_obj):
+ """Run Shakespeare training and predict.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+
+ Returns:
+ Dictionary with status from the run.
+ """
+ if not flags_obj.training_data:
+ raise ValueError(
+ 'Must set the path to a training data file. e.g download the following '
+ 'https://storage.googleapis.com/download.tensorflow.org/data/'
+ 'shakespeare.txt')
+
+ if flags_obj.dtype == 'fp16':
+ policy = tf.keras.mixed_precision.experimental.Policy(
+ 'mixed_float16',
+ loss_scale=flags_core.get_loss_scale(flags_obj,
+ default_for_fp16='dynamic'))
+ tf.keras.mixed_precision.experimental.set_policy(policy)
+
+ keras_utils.set_session_config(
+ enable_xla=flags_obj.enable_xla)
+
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=flags_obj.num_gpus)
+
+ dataset, idx2char, char2idx = get_dataset(flags_obj.training_data,
+ batch_size=flags_obj.batch_size)
+ stats = {}
+ if flags_obj.train:
+ history, callbacks = train_model(flags_obj, dataset,
+ len(idx2char), strategy,
+ checkpoint_dir=flags_obj.model_dir)
+
+ stats['history'] = history.history
+ stats['callbacks'] = callbacks
+
+ if flags_obj.predict_context:
+ if not flags_obj.model_dir:
+ raise ValueError('Must set model_dir to get predictions.')
+ print(make_prediction(flags_obj.model_dir,
+ flags_obj.predict_length,
+ flags_obj.predict_context,
+ idx2char,
+ char2idx))
+
+ return stats
+
+
+def main(_):
+ flags_obj = flags.FLAGS
+ run(flags_obj)
+
+
+if __name__ == '__main__':
+ define_flags()
+ app.run(main)
diff --git a/models/official/benchmark/models/synthetic_util.py b/models/official/benchmark/models/synthetic_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..c14d0223dc417e6b0bd220f65dc3db0291bb773c
--- /dev/null
+++ b/models/official/benchmark/models/synthetic_util.py
@@ -0,0 +1,129 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper functions to generate data directly on devices."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import string
+
+from absl import logging
+import tensorflow as tf
+
+
+# The `SyntheticDataset` is a temporary solution for generating synthetic data
+# directly on devices. It is only useful for Keras with Distribution
+# Strategies. We will have better support in `tf.data` or Distribution Strategy
+# later.
+class SyntheticDataset(object):
+ """A dataset that generates synthetic data on each device."""
+
+ def __init__(self, dataset, split_by=1):
+ # dataset.take(1) doesn't have GPU kernel.
+ with tf.device('device:CPU:0'):
+ tensor = tf.data.experimental.get_single_element(dataset.take(1))
+ flat_tensor = tf.nest.flatten(tensor)
+ variable_data = []
+ initializers = []
+ for t in flat_tensor:
+ rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
+ assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
+ v = tf.compat.v1.get_local_variable(self._random_name(),
+ initializer=rebatched_t)
+ variable_data.append(v)
+ initializers.append(v.initializer)
+ input_data = tf.nest.pack_sequence_as(tensor, variable_data)
+ self._iterator = SyntheticIterator(input_data, initializers)
+
+ def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
+ return ''.join(random.choice(chars) for _ in range(size))
+
+ def __iter__(self):
+ return self._iterator
+
+ def make_one_shot_iterator(self):
+ return self._iterator
+
+ def make_initializable_iterator(self):
+ return self._iterator
+
+
+class SyntheticIterator(object):
+ """A dataset that generates synthetic data on each device."""
+
+ def __init__(self, input_data, initializers):
+ self._input_data = input_data
+ self._initializers = initializers
+
+ def get_next(self):
+ return self._input_data
+
+ def next(self):
+ return self.__next__()
+
+ def __next__(self):
+ try:
+ return self.get_next()
+ except tf.errors.OutOfRangeError:
+ raise StopIteration
+
+ def initialize(self):
+ if tf.executing_eagerly():
+ return tf.no_op()
+ else:
+ return self._initializers
+
+
+def _monkey_patch_dataset_method(strategy):
+ """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
+ def make_dataset(self, dataset):
+ logging.info('Using pure synthetic data.')
+ with self.scope():
+ if self.extended._global_batch_size: # pylint: disable=protected-access
+ return SyntheticDataset(dataset, self.num_replicas_in_sync)
+ else:
+ return SyntheticDataset(dataset)
+
+ def make_iterator(self, dataset):
+ dist_dataset = make_dataset(self, dataset)
+ return iter(dist_dataset)
+
+ strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
+ strategy.make_dataset_iterator = make_iterator
+ strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
+ strategy.experimental_distribute_dataset = make_dataset
+
+
+def _undo_monkey_patch_dataset_method(strategy):
+ if hasattr(strategy, 'orig_make_dataset_iterator'):
+ strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
+ if hasattr(strategy, 'orig_distribute_dataset'):
+ strategy.make_dataset_iterator = strategy.orig_distribute_dataset
+
+
+def set_up_synthetic_data():
+ _monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
+ _monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
+ _monkey_patch_dataset_method(
+ tf.distribute.experimental.MultiWorkerMirroredStrategy)
+
+
+def undo_set_up_synthetic_data():
+ _undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
+ _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
+ _undo_monkey_patch_dataset_method(
+ tf.distribute.experimental.MultiWorkerMirroredStrategy)
diff --git a/models/official/benchmark/ncf_keras_benchmark.py b/models/official/benchmark/ncf_keras_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..170c99a33f46f14f977182c4e8a6d7ffbf96682d
--- /dev/null
+++ b/models/official/benchmark/ncf_keras_benchmark.py
@@ -0,0 +1,488 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes Keras benchmarks and accuracy tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+from absl import flags
+from absl import logging
+from absl.testing import flagsaver
+import tensorflow as tf
+from official.benchmark import benchmark_wrappers
+from official.benchmark import owner_utils
+from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
+from official.recommendation import ncf_common
+from official.recommendation import ncf_keras_main
+from official.utils.flags import core
+
+FLAGS = flags.FLAGS
+NCF_DATA_DIR_NAME = 'movielens_data'
+NCF_TF_REGRESSION_DATA_DIR_NAME = 'gs://tf-regression/ncf/data'
+
+
+class NCFKerasBenchmarkBase(PerfZeroBenchmark):
+ """Base class for NCF model benchmark."""
+
+ def __init__(self, output_dir=None, default_flags=None, **kwargs):
+ super(NCFKerasBenchmarkBase, self).__init__(output_dir, default_flags,
+ **kwargs)
+
+ # Run all benchmarks with ml_perf flag.
+ self.default_flags['ml_perf'] = True
+
+ def _setup(self):
+ """Sets up and resets flags before each test."""
+ logging.set_verbosity(logging.INFO)
+ if NCFKerasBenchmarkBase.local_flags is None:
+ ncf_common.define_ncf_flags()
+ # Loads flags to get defaults to then override. List cannot be empty.
+ flags.FLAGS(['foo'])
+ core.set_defaults(**self.default_flags)
+ saved_flag_values = flagsaver.save_flag_values()
+ NCFKerasBenchmarkBase.local_flags = saved_flag_values
+ else:
+ flagsaver.restore_flag_values(NCFKerasBenchmarkBase.local_flags)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self, hr_at_10_min=0, hr_at_10_max=0):
+ start_time_sec = time.time()
+ stats = ncf_keras_main.run_ncf(FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ metrics = []
+ metrics.append({'name': 'exp_per_second',
+ 'value': stats['avg_exp_per_second']})
+
+ if hr_at_10_min > 0:
+ metrics.append({'name': 'hr_at_10',
+ 'value': stats['eval_hit_rate'],
+ 'min_value': hr_at_10_min,
+ 'max_value': hr_at_10_max})
+
+ metrics.append({'name': 'train_loss',
+ 'value': stats['loss']})
+
+ self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)
+
+
+class NCFKerasAccuracy(NCFKerasBenchmarkBase):
+ """Benchmark NCF model using real data."""
+
+ def __init__(self,
+ output_dir=None,
+ root_data_dir=None,
+ default_flags=None,
+ **kwargs):
+ root_data_dir = root_data_dir if root_data_dir else ''
+ default_flags = {}
+ default_flags['dataset'] = 'ml-20m'
+ default_flags['num_gpus'] = 1
+ default_flags['train_epochs'] = 10
+ default_flags['clean'] = True
+ default_flags['batch_size'] = 99000
+ default_flags['learning_rate'] = 0.00382059
+ default_flags['beta1'] = 0.783529
+ default_flags['beta2'] = 0.909003
+ default_flags['epsilon'] = 1.45439e-07
+ default_flags['layers'] = [256, 256, 128, 64]
+ default_flags['num_factors'] = 64
+ default_flags['hr_threshold'] = 0.635
+ default_flags['ml_perf'] = True
+ default_flags['use_synthetic_data'] = False
+ default_flags['data_dir'] = os.path.join(root_data_dir, NCF_DATA_DIR_NAME)
+
+ super(NCFKerasAccuracy, self).__init__(
+ output_dir=output_dir,
+ default_flags=default_flags,
+ **kwargs)
+
+ def _run_and_report_benchmark_mlperf_like(self):
+ """Run test and report results.
+
+ Note: MLPerf like tests are not tuned to hit a specific hr@10 value, but
+ we want it recorded.
+ """
+ self._run_and_report_benchmark(hr_at_10_min=0.61)
+
+ def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.645):
+ """Run test and report results.
+
+ Note: Target is 0.635, but some runs are below that level. Until we have
+ multi-run tests, we have to accept a lower target.
+
+ Args:
+ hr_at_10_min: Minimum acceptable hr@10 value.
+ hr_at_10_max: Maximum acceptable hr@10 value.
+ """
+ super(NCFKerasAccuracy, self)._run_and_report_benchmark(
+ hr_at_10_min=hr_at_10_min,
+ hr_at_10_max=hr_at_10_max)
+
+ def _set_8_gpu_defaults(self):
+ FLAGS.num_gpus = 8
+ FLAGS.learning_rate = 0.0045
+ FLAGS.beta1 = 0.25
+ FLAGS.beta2 = 0.5
+ FLAGS.epsilon = 1e-8
+ FLAGS.train_epochs = 14
+ FLAGS.batch_size = 99000
+ FLAGS.eval_batch_size = 160000
+ FLAGS.train_dataset_path = os.path.join(NCF_TF_REGRESSION_DATA_DIR_NAME,
+ 'training_cycle_*/*')
+ FLAGS.eval_dataset_path = os.path.join(NCF_TF_REGRESSION_DATA_DIR_NAME,
+ 'eval_data/*')
+ FLAGS.input_meta_data_path = os.path.join(NCF_TF_REGRESSION_DATA_DIR_NAME,
+ 'metadata')
+ FLAGS.data_dir = NCF_TF_REGRESSION_DATA_DIR_NAME
+
+ def benchmark_1_gpu_early_stop(self):
+ self._setup()
+ FLAGS.early_stopping = True
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_early_stop(self):
+ self._setup()
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.early_stopping = True
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_early_stop(self):
+ self._setup()
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.early_stopping = True
+ FLAGS.run_eagerly = True
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_early_stop(self):
+ self._setup()
+ FLAGS.early_stopping = True
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_ctl_early_stop(self):
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.early_stopping = True
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_ctl_run_eagerly_early_stop(self):
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.early_stopping = True
+ FLAGS.run_eagerly = True
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_ctl_early_stop(self):
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.early_stopping = True
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def benchmark_2_gpus_early_stop(self):
+ self._setup()
+ FLAGS.early_stopping = True
+ FLAGS.num_gpus = 2
+ FLAGS.eval_batch_size = 160000
+ self._run_and_report_benchmark()
+
+ def benchmark_2_gpus_ctl_early_stop(self):
+ """NCF with custom training loop. Works only in TF 2.0."""
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.early_stopping = True
+ FLAGS.num_gpus = 2
+ FLAGS.eval_batch_size = 160000
+ self._run_and_report_benchmark()
+
+#############################################
+# Tests below with mlperf in the test name are of two types:
+# 1) 1 GPU tests are based on MLPerf 0.5 and the TensorFlow pulled submission.
+# 2) 8 GPU tests are based on MLPerf 0.5 and use NVIDIA's hyper parameters.
+#
+# The purpose of both is to get a number to compare to existing results. To do
+# this the number of epochs is held constant rather than a race to a given
+# accuracy. The accuracy validation is done by the "early_stop" tests.
+#############################################
+
+ def benchmark_1_gpu_mlperf_like(self):
+ """1 GPU using keras fit/compile."""
+ self._setup()
+ FLAGS.train_epochs = 7
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_1_gpu_no_dist_strat_mlperf_like(self):
+ """1 GPU using compile/fit without dist_strat."""
+ self._setup()
+ FLAGS.train_epochs = 7
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_mlperf_like(self):
+ self._setup()
+ FLAGS.train_epochs = 7
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.run_eagerly = True
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_xla_1_gpu_mlperf_like(self):
+ """1 GPU using compile/fit with XLA."""
+ self._setup()
+ FLAGS.train_epochs = 7
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_1_gpu_ctl_mlperf_like(self):
+ """1 GPU using CTL."""
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.train_epochs = 7
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_1_gpu_ctl_fp16_mlperf_like(self):
+ """1 GPU using CTL and FP16."""
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.train_epochs = 7
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_1_gpu_fp16_mlperf_like(self):
+ """1 GPU using FP16."""
+ self._setup()
+ FLAGS.train_epochs = 7
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_1_gpu_ctl_fp16_graph_rewrite_mlperf_like(self):
+ """1 GPU using CTL and FP16 graph rewrite."""
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.train_epochs = 7
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_1_gpu_fp16_graph_rewrite_mlperf_like(self):
+ """1 GPU using FP16 graph rewrite."""
+ self._setup()
+ FLAGS.train_epochs = 7
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self):
+ """1 GPU using CTL with eager and distribution strategy."""
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.run_eagerly = True
+ FLAGS.train_epochs = 7
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_ctl_mlperf_like(self):
+ """1 GPU using CTL with XLA."""
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.enable_xla = True
+ FLAGS.train_epochs = 7
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_xla_1_gpu_fp16_mlperf_like(self):
+ """1 GPU using with XLA and FP16."""
+ self._setup()
+ FLAGS.enable_xla = True
+ FLAGS.train_epochs = 7
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_xla_1_gpu_ctl_fp16_mlperf_like(self):
+ """1 GPU using CTL with XLA and FP16."""
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.enable_xla = True
+ FLAGS.train_epochs = 7
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_8_gpu_mlperf_like(self):
+ """8 GPU using keras fit/compile."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.train_epochs = 17
+ FLAGS.batch_size = 1048576
+ FLAGS.eval_batch_size = 160000
+ FLAGS.learning_rate = 0.0045
+ FLAGS.beta1 = 0.25
+ FLAGS.beta2 = 0.5
+ FLAGS.epsilon = 1e-8
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_8_gpu_ctl_mlperf_like(self):
+ """8 GPU using CTL."""
+ self._setup()
+ FLAGS.keras_use_ctl = True
+ FLAGS.num_gpus = 8
+ FLAGS.train_epochs = 17
+ FLAGS.batch_size = 1048576
+ FLAGS.eval_batch_size = 160000
+ FLAGS.learning_rate = 0.0045
+ FLAGS.beta1 = 0.25
+ FLAGS.beta2 = 0.5
+ FLAGS.epsilon = 1e-8
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_8_gpu_tf_data_ctl_mlperf_like(self):
+ """8 GPU using CTL."""
+ self._setup()
+ self._set_8_gpu_defaults()
+ FLAGS.keras_use_ctl = True
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_8_gpu_tf_data_fp16_mlperf_like(self):
+ """8 GPU FP16."""
+ self._setup()
+ self._set_8_gpu_defaults()
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_8_gpu_tf_data_ctl_fp16_mlperf_like(self):
+ """8 GPU FP16 using CTL."""
+ self._setup()
+ self._set_8_gpu_defaults()
+ FLAGS.keras_use_ctl = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+ def benchmark_8_gpu_tf_data_ctl_fp16_graph_rewrite_mlperf_like(self):
+ """8 GPU FP16 graph rewrite using CTL."""
+ self._setup()
+ self._set_8_gpu_defaults()
+ FLAGS.keras_use_ctl = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.loss_scale = 8192
+ self._run_and_report_benchmark_mlperf_like()
+
+
+class NCFKerasBenchmarkReal(NCFKerasBenchmarkBase):
+ """NCF Keras throughput benchmarks."""
+
+ def __init__(self,
+ output_dir=None,
+ root_data_dir=None,
+ default_flags=None,
+ **kwargs):
+
+ root_data_dir = root_data_dir if root_data_dir else ''
+ default_flags = {}
+ default_flags['dataset'] = 'ml-20m'
+ default_flags['num_gpus'] = 1
+ default_flags['train_epochs'] = 14
+ default_flags['clean'] = True
+ default_flags['batch_size'] = 99000
+ default_flags['eval_batch_size'] = 160000
+ default_flags['learning_rate'] = 0.00382059
+ default_flags['beta1'] = 0.783529
+ default_flags['beta2'] = 0.909003
+ default_flags['epsilon'] = 1.45439e-07
+ default_flags['layers'] = [256, 256, 128, 64]
+ default_flags['num_factors'] = 64
+ default_flags['hr_threshold'] = 0.635
+ default_flags['ml_perf'] = True
+ default_flags['use_synthetic_data'] = False
+ default_flags['train_dataset_path'] = os.path.join(
+ NCF_TF_REGRESSION_DATA_DIR_NAME, 'training_cycle_*/*')
+ default_flags['eval_dataset_path'] = os.path.join(
+ NCF_TF_REGRESSION_DATA_DIR_NAME, 'eval_data/*')
+ default_flags['input_meta_data_path'] = os.path.join(
+ NCF_TF_REGRESSION_DATA_DIR_NAME, 'metadata')
+ default_flags['data_dir'] = NCF_TF_REGRESSION_DATA_DIR_NAME
+
+ super(NCFKerasBenchmarkReal, self).__init__(
+ output_dir=output_dir, default_flags=default_flags, **kwargs)
+
+ def benchmark_2x2_tpu(self):
+ """2x2 TPU using CTL with distribution strategy."""
+ self._setup()
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.keras_use_ctl = True
+ FLAGS.num_gpus = 0
+ FLAGS.train_epochs = 1
+ self._run_and_report_benchmark()
+
+ @owner_utils.Owner('tf-graph-compiler')
+ def benchmark_2x2_tpu_mlir(self):
+ """2x2 TPU using CTL with distribution strategy using the MLIR bridge."""
+ self._setup()
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.keras_use_ctl = True
+ FLAGS.num_gpus = 0
+ FLAGS.train_epochs = 1
+ tf.config.experimental.enable_mlir_bridge()
+ self._run_and_report_benchmark()
+
+
+class NCFKerasSynth(NCFKerasBenchmarkBase):
+ """Benchmark NCF model using synthetic data."""
+
+ def __init__(self,
+ output_dir=None,
+ default_flags=None,
+ **kwargs):
+
+ default_flags = {}
+ default_flags['dataset'] = 'ml-20m'
+ default_flags['num_gpus'] = 1
+ default_flags['train_epochs'] = 8
+ default_flags['batch_size'] = 99000
+ default_flags['eval_batch_size'] = 160000
+ default_flags['learning_rate'] = 0.00382059
+ default_flags['beta1'] = 0.783529
+ default_flags['beta2'] = 0.909003
+ default_flags['epsilon'] = 1.45439e-07
+ default_flags['layers'] = [256, 256, 128, 64]
+ default_flags['num_factors'] = 64
+ default_flags['hr_threshold'] = 0.635
+ default_flags['use_synthetic_data'] = True
+
+ super(NCFKerasSynth, self).__init__(
+ output_dir=output_dir,
+ default_flags=default_flags,
+ **kwargs)
+
+ def benchmark_1_gpu(self):
+ self._setup()
+ self._run_and_report_benchmark()
+
+ def benchmark_2_gpus(self):
+ self._setup()
+ FLAGS.num_gpus = 2
+ self._run_and_report_benchmark()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/nhnet_benchmark.py b/models/official/benchmark/nhnet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eac36b204a4f064216fb4c81effff06d8c7e6f0
--- /dev/null
+++ b/models/official/benchmark/nhnet_benchmark.py
@@ -0,0 +1,148 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes benchmark testing for bert pretraining."""
+# pylint: disable=line-too-long
+from __future__ import print_function
+
+import time
+from typing import Optional
+
+from absl import flags
+import tensorflow as tf
+
+from official.benchmark import benchmark_wrappers
+from official.benchmark import owner_utils
+from official.benchmark import perfzero_benchmark
+from official.nlp.nhnet import trainer
+from official.utils.flags import core as flags_core
+
+MIN_LOSS = 0.40
+MAX_LOSS = 0.55
+NHNET_DATA = 'gs://tf-perfzero-data/nhnet/v1/processed/train.tfrecord*'
+PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_model.ckpt'
+
+FLAGS = flags.FLAGS
+
+
+class NHNetBenchmark(perfzero_benchmark.PerfZeroBenchmark):
+ """Base benchmark class for NHNet."""
+
+ def __init__(self, output_dir=None, default_flags=None, tpu=None, **kwargs):
+ self.default_flags = default_flags or {}
+ flag_methods = trainer.define_flags()
+ super(NHNetBenchmark, self).__init__(
+ output_dir=output_dir,
+ default_flags=default_flags,
+ flag_methods=flag_methods,
+ tpu=tpu,
+ **kwargs)
+
+ def _report_benchmark(self,
+ stats,
+ wall_time_sec,
+ max_value=None,
+ min_value=None):
+ """Report benchmark results by writing to local protobuf file.
+
+ Args:
+ stats: dict returned from keras models with known entries.
+ wall_time_sec: the during of the benchmark execution in seconds
+ max_value: highest passing level.
+ min_value: lowest passing level.
+ """
+
+ metrics = []
+ metrics.append({
+ 'name': 'training_loss',
+ 'value': stats['training_loss'],
+ 'min_value': min_value,
+ 'max_value': max_value
+ })
+ # These metrics are placeholders to avoid PerfZero failure.
+ metrics.append({
+ 'name': 'exp_per_second',
+ 'value': 0.0,
+ })
+ metrics.append({
+ 'name': 'startup_time',
+ 'value': 9999.,
+ })
+ flags_str = flags_core.get_nondefault_flags_as_str()
+ self.report_benchmark(
+ iters=-1,
+ wall_time=wall_time_sec,
+ metrics=metrics,
+ extras={'flags': flags_str})
+
+
+class NHNetAccuracyBenchmark(NHNetBenchmark):
+ """Benchmark accuracy tests for NHNet."""
+
+ def __init__(self,
+ output_dir: Optional[str] = None,
+ tpu: Optional[str] = None,
+ **kwargs):
+ default_flags = dict(
+ mode='train',
+ train_file_pattern=NHNET_DATA,
+ train_batch_size=1024,
+ model_type='nhnet',
+ len_title=15,
+ len_passage=200,
+ num_encoder_layers=12,
+ num_decoder_layers=12,
+ num_nhnet_articles=5,
+ steps_per_loop=1000,
+ params_override='init_from_bert2bert=false')
+ super(NHNetAccuracyBenchmark, self).__init__(
+ output_dir=output_dir, default_flags=default_flags, tpu=tpu, **kwargs)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self, max_value=MAX_LOSS, min_value=MIN_LOSS):
+ """Runs and reports the benchmark given the provided configuration."""
+ start_time_sec = time.time()
+ stats = trainer.run()
+ wall_time_sec = time.time() - start_time_sec
+ self._report_benchmark(
+ stats, wall_time_sec, max_value=max_value, min_value=min_value)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_accuracy_4x4_tpu_f32_50k_steps(self):
+ """Test bert pretraining with 4x4 TPU for 50k steps."""
+ # This is used for accuracy test.
+ self._setup()
+ FLAGS.train_steps = 50000
+ FLAGS.checkpoint_interval = FLAGS.train_steps
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_accuracy_4x4_tpu_bf32_50k_steps')
+ self._run_and_report_benchmark()
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_accuracy_4x4_tpu_f32_1k_steps(self):
+ """Test bert pretraining with 4x4 TPU for 1k steps."""
+ self._setup()
+ FLAGS.train_steps = 1000
+ FLAGS.checkpoint_interval = FLAGS.train_steps
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_accuracy_4x4_tpu_bf32_1k_steps')
+ self._run_and_report_benchmark()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/owner_utils.py b/models/official/benchmark/owner_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7d189d7b9a2ba05a0bd3af8cb970d52cc85f5a0
--- /dev/null
+++ b/models/official/benchmark/owner_utils.py
@@ -0,0 +1,67 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utils to set Owner annotations on benchmarks.
+
+@owner_utils.Owner('owner_team/user') can be set either at the benchmark class
+level / benchmark method level or both.
+
+Runner frameworks can use owner_utils.GetOwner(benchmark_method) to get the
+actual owner. Python inheritance for the owner attribute is respected. (E.g
+method level owner takes precedence over class level).
+
+See owner_utils_test for associated tests and more examples.
+
+The decorator can be applied both at the method level and at the class level.
+
+Simple example:
+===============
+
+class MLBenchmark:
+
+ @Owner('example_id')
+ def benchmark_method_1_gpu(self):
+ return True
+"""
+
+
+def Owner(owner_name):
+ """Sets the owner attribute on a decorated method or class."""
+
+ def _Wrapper(func_or_class):
+ """Sets the benchmark owner attribute."""
+ func_or_class.__benchmark__owner__ = owner_name
+ return func_or_class
+
+ return _Wrapper
+
+
+def GetOwner(benchmark_method_or_class):
+ """Gets the inherited owner attribute for this benchmark.
+
+ Checks for existence of __benchmark__owner__. If it's not present, looks for
+ it in the parent class's attribute list.
+
+ Args:
+ benchmark_method_or_class: A benchmark method or class.
+
+ Returns:
+ string - the associated owner if present / None.
+ """
+ if hasattr(benchmark_method_or_class, '__benchmark__owner__'):
+ return benchmark_method_or_class.__benchmark__owner__
+ elif hasattr(benchmark_method_or_class, '__self__'):
+ if hasattr(benchmark_method_or_class.__self__, '__benchmark__owner__'):
+ return benchmark_method_or_class.__self__.__benchmark__owner__
+ return None
diff --git a/models/official/benchmark/owner_utils_test.py b/models/official/benchmark/owner_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..588bb80378fbf7ba5a6aec470f24fc1c4ad995b2
--- /dev/null
+++ b/models/official/benchmark/owner_utils_test.py
@@ -0,0 +1,104 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for official.benchmark.owner_utils."""
+
+from absl.testing import absltest
+
+from official.benchmark import owner_utils
+
+
+@owner_utils.Owner('static_owner')
+def static_function(foo=5):
+ return foo
+
+
+def static_function_without_owner(foo=5):
+ return foo
+
+
+class BenchmarkClassWithoutOwner:
+
+ def method_without_owner(self):
+ return 100
+
+ @owner_utils.Owner('method_owner')
+ def method_with_owner(self):
+ return 200
+
+
+@owner_utils.Owner('class_owner')
+class SomeBenchmarkClass:
+
+ def method_inherited_owner(self):
+ return 123
+
+ @owner_utils.Owner('method_owner')
+ def method_override_owner(self):
+ return 345
+
+
+@owner_utils.Owner('new_class_owner')
+class InheritedClass(SomeBenchmarkClass):
+
+ def method_inherited_owner(self):
+ return 456
+
+ @owner_utils.Owner('new_method_owner')
+ def method_override_owner(self):
+ return 567
+
+
+class OwnerUtilsTest(absltest.TestCase):
+ """Tests to assert for owner decorator functionality."""
+
+ def test_owner_tag_missing(self):
+ self.assertEqual(None, owner_utils.GetOwner(static_function_without_owner))
+
+ benchmark_class = BenchmarkClassWithoutOwner()
+ self.assertEqual(None,
+ owner_utils.GetOwner(benchmark_class.method_without_owner))
+ self.assertEqual(100, benchmark_class.method_without_owner())
+
+ self.assertEqual('method_owner',
+ owner_utils.GetOwner(benchmark_class.method_with_owner))
+ self.assertEqual(200, benchmark_class.method_with_owner())
+
+ def test_owner_attributes_static(self):
+ self.assertEqual('static_owner', owner_utils.GetOwner(static_function))
+ self.assertEqual(5, static_function(5))
+
+ def test_owner_attributes_per_class(self):
+ level1 = SomeBenchmarkClass()
+ self.assertEqual('class_owner',
+ owner_utils.GetOwner(level1.method_inherited_owner))
+ self.assertEqual(123, level1.method_inherited_owner())
+
+ self.assertEqual('method_owner',
+ owner_utils.GetOwner(level1.method_override_owner))
+ self.assertEqual(345, level1.method_override_owner())
+
+ def test_owner_attributes_inherited_class(self):
+ level2 = InheritedClass()
+ self.assertEqual('new_class_owner',
+ owner_utils.GetOwner(level2.method_inherited_owner))
+ self.assertEqual(456, level2.method_inherited_owner())
+
+ self.assertEqual('new_method_owner',
+ owner_utils.GetOwner(level2.method_override_owner))
+ self.assertEqual(567, level2.method_override_owner())
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/models/official/benchmark/perfzero_benchmark.py b/models/official/benchmark/perfzero_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..bedc1320217d1b9469333a8cdfdf70c56de34f77
--- /dev/null
+++ b/models/official/benchmark/perfzero_benchmark.py
@@ -0,0 +1,100 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utils for creating PerfZero benchmarks."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+from absl import logging
+from absl.testing import flagsaver
+import tensorflow as tf
+
+FLAGS = flags.FLAGS
+
+
+class PerfZeroBenchmark(tf.test.Benchmark):
+ """Common methods used in PerfZero Benchmarks.
+
+ Handles the resetting of flags between tests, loading of default_flags,
+ overriding of defaults. PerfZero (OSS) runs each test in a separate
+ process reducing some need to reset the flags.
+ """
+ local_flags = None
+
+ def __init__(self,
+ output_dir=None,
+ default_flags=None,
+ root_data_dir=None,
+ flag_methods=None,
+ tpu=None):
+ """Initialize class.
+
+ Args:
+ output_dir: Base directory to store all output for the test.
+ default_flags: Set of flags to pass to model.
+ root_data_dir: Optional param used by child classes to look for the
+ dataset.
+ flag_methods: Set of flag methods to run during setup.
+ tpu: (optional) TPU name to use in a TPU benchmark.
+ """
+ if os.getenv('BENCHMARK_OUTPUT_DIR'):
+ self.output_dir = os.getenv('BENCHMARK_OUTPUT_DIR')
+ elif output_dir:
+ self.output_dir = output_dir
+ else:
+ self.output_dir = '/tmp'
+ self.default_flags = default_flags or {}
+ self.flag_methods = flag_methods or {}
+
+ if os.getenv('BENCHMARK_TPU'):
+ resolved_tpu = os.getenv('BENCHMARK_TPU')
+ elif tpu:
+ resolved_tpu = tpu
+ else:
+ resolved_tpu = None
+
+ if resolved_tpu:
+ # TPU models are expected to accept a --tpu=name flag. PerfZero creates
+ # the TPU at runtime and passes the TPU's name to this flag.
+ self.default_flags['tpu'] = resolved_tpu
+
+ logging.info('root_data_dir: %s', root_data_dir)
+
+ @property
+ def tpu(self):
+ return self.default_flags.get('tpu', None)
+
+ def _get_model_dir(self, folder_name):
+ """Returns directory to store info, e.g. saved model and event log."""
+ return os.path.join(self.output_dir, folder_name)
+
+ def _setup(self):
+ """Sets up and resets flags before each test."""
+ logging.set_verbosity(logging.INFO)
+ if PerfZeroBenchmark.local_flags is None:
+ for flag_method in self.flag_methods:
+ flag_method()
+ # Loads flags to get defaults to then override. List cannot be empty.
+ flags.FLAGS(['foo'])
+ # Overrides flag values with defaults for the class of tests.
+ for k, v in self.default_flags.items():
+ setattr(FLAGS, k, v)
+ saved_flag_values = flagsaver.save_flag_values()
+ PerfZeroBenchmark.local_flags = saved_flag_values
+ else:
+ flagsaver.restore_flag_values(PerfZeroBenchmark.local_flags)
diff --git a/models/official/benchmark/resnet_ctl_imagenet_benchmark.py b/models/official/benchmark/resnet_ctl_imagenet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e70e8da969ec9b02a2de00d1973bdd2aa5f2b51
--- /dev/null
+++ b/models/official/benchmark/resnet_ctl_imagenet_benchmark.py
@@ -0,0 +1,452 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes CTL benchmarks and accuracy tests."""
+# pylint: disable=line-too-long,g-bad-import-order
+from __future__ import print_function
+
+import os
+import time
+
+from absl import flags
+import tensorflow as tf
+
+from official.benchmark import owner_utils
+from official.vision.image_classification.resnet import common
+from official.vision.image_classification.resnet import resnet_ctl_imagenet_main
+from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
+from official.benchmark import benchmark_wrappers
+from official.utils.flags import core as flags_core
+
+MIN_TOP_1_ACCURACY = 0.76
+MAX_TOP_1_ACCURACY = 0.77
+
+FLAGS = flags.FLAGS
+
+
+class CtlBenchmark(PerfZeroBenchmark):
+ """Base benchmark class with methods to simplify testing."""
+
+ def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
+ self.default_flags = default_flags or {}
+ self.flag_methods = flag_methods or {}
+ super(CtlBenchmark, self).__init__(
+ output_dir=output_dir,
+ default_flags=self.default_flags,
+ flag_methods=self.flag_methods)
+
+ def _report_benchmark(self,
+ stats,
+ wall_time_sec,
+ top_1_max=None,
+ top_1_min=None,
+ total_batch_size=None,
+ log_steps=None,
+ warmup=1,
+ start_time_sec=None):
+ """Report benchmark results by writing to local protobuf file.
+
+ Args:
+ stats: dict returned from keras models with known entries.
+ wall_time_sec: the during of the benchmark execution in seconds
+ top_1_max: highest passing level for top_1 accuracy.
+ top_1_min: lowest passing level for top_1 accuracy.
+ total_batch_size: Global batch-size.
+ log_steps: How often the log was created for stats['step_timestamp_log'].
+ warmup: number of entries in stats['step_timestamp_log'] to ignore.
+ start_time_sec: the start time of the program in seconds since epoch.
+ """
+
+ metrics = []
+ if 'eval_acc' in stats:
+ metrics.append({
+ 'name': 'accuracy_top_1',
+ 'value': stats['eval_acc'],
+ 'min_value': top_1_min,
+ 'max_value': top_1_max
+ })
+ metrics.append({'name': 'eval_loss', 'value': stats['eval_loss']})
+
+ metrics.append({
+ 'name': 'top_1_train_accuracy',
+ 'value': stats['train_acc']
+ })
+ metrics.append({'name': 'train_loss', 'value': stats['train_loss']})
+
+ if (warmup and 'step_timestamp_log' in stats and
+ len(stats['step_timestamp_log']) > warmup + 1):
+ # first entry in the time_log is start of step 0. The rest of the
+ # entries are the end of each step recorded
+ time_log = stats['step_timestamp_log']
+ steps_elapsed = time_log[-1].batch_index - time_log[warmup].batch_index
+ time_elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
+ examples_per_sec = total_batch_size * (steps_elapsed / time_elapsed)
+ metrics.append({'name': 'exp_per_second', 'value': examples_per_sec})
+
+ if 'avg_exp_per_second' in stats:
+ metrics.append({
+ 'name': 'avg_exp_per_second',
+ 'value': stats['avg_exp_per_second']
+ })
+
+ if start_time_sec and 'step_timestamp_log' in stats:
+ time_log = stats['step_timestamp_log']
+ # time_log[0] is recorded at the beginning of the first step.
+ startup_time = time_log[0].timestamp - start_time_sec
+ metrics.append({'name': 'startup_time', 'value': startup_time})
+
+ flags_str = flags_core.get_nondefault_flags_as_str()
+ self.report_benchmark(
+ iters=-1,
+ wall_time=wall_time_sec,
+ metrics=metrics,
+ extras={'flags': flags_str})
+
+
+class Resnet50CtlAccuracy(CtlBenchmark):
+ """Benchmark accuracy tests for ResNet50 in CTL."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ """A benchmark class.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more named
+ arguments before updating the constructor.
+ """
+
+ flag_methods = [common.define_keras_flags]
+
+ self.data_dir = os.path.join(root_data_dir, 'imagenet')
+ super(Resnet50CtlAccuracy, self).__init__(
+ output_dir=output_dir, flag_methods=flag_methods)
+
+ def benchmark_8_gpu(self):
+ """Test Keras model with eager, dist_strat and 8 GPUs."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 128 * 8
+ FLAGS.train_epochs = 90
+ FLAGS.epochs_between_evals = 10
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
+ FLAGS.dtype = 'fp32'
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_fp16(self):
+ """Test Keras model with eager, 8 GPUs with tf.keras mixed precision."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 256 * 8
+ FLAGS.train_epochs = 90
+ FLAGS.epochs_between_evals = 10
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
+ FLAGS.dtype = 'fp16'
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_amp(self):
+ """Test Keras model with 8 GPUs and mixed precision via graph rewrite."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.data_dir = self.data_dir
+ FLAGS.batch_size = 256 * 8
+ FLAGS.train_epochs = 90
+ FLAGS.epochs_between_evals = 10
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ self._run_and_report_benchmark()
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self):
+ start_time_sec = time.time()
+ stats = resnet_ctl_imagenet_main.run(flags.FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ super(Resnet50CtlAccuracy, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ top_1_min=MIN_TOP_1_ACCURACY,
+ top_1_max=MAX_TOP_1_ACCURACY,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=100,
+ start_time_sec=start_time_sec)
+
+
+class Resnet50CtlBenchmarkBase(CtlBenchmark):
+ """Resnet50 benchmarks."""
+
+ def __init__(self, output_dir=None, default_flags=None):
+ flag_methods = [common.define_keras_flags]
+
+ super(Resnet50CtlBenchmarkBase, self).__init__(
+ output_dir=output_dir,
+ flag_methods=flag_methods,
+ default_flags=default_flags)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self):
+ start_time_sec = time.time()
+ stats = resnet_ctl_imagenet_main.run(FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ # Warmup means the number of logged step time entries that are excluded in
+ # performance report. Default to exclude 1 FLAGS.log_steps time.
+ super(Resnet50CtlBenchmarkBase, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ warmup=1,
+ start_time_sec=start_time_sec)
+
+ def benchmark_1_gpu_no_dist_strat(self):
+ """Test Keras model with 1 GPU, no distribution strategy."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu(self):
+ """Test Keras model with 1 GPU."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_fp16(self):
+ """Test Keras model with 1 GPU with tf.keras mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
+ FLAGS.batch_size = 256
+ FLAGS.dtype = 'fp16'
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_amp(self):
+ """Test Keras model with 1 GPU with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
+ FLAGS.batch_size = 256
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_amp(self):
+ """Test Keras model with XLA and 1 GPU with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
+ FLAGS.batch_size = 256
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_eager(self):
+ """Test Keras model with 1 GPU in pure eager mode."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
+ FLAGS.batch_size = 120
+ FLAGS.use_tf_function = False
+ FLAGS.use_tf_while_loop = False
+ FLAGS.single_l2_loss_op = True
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_fp16_eager(self):
+ """Test Keras model with 1 GPU with fp16 and pure eager mode."""
+ self._setup()
+
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_eager')
+ FLAGS.batch_size = 240
+ FLAGS.dtype = 'fp16'
+ FLAGS.use_tf_function = False
+ FLAGS.use_tf_while_loop = False
+ FLAGS.single_l2_loss_op = True
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu(self):
+ """Test Keras model with 8 GPUs."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
+ FLAGS.batch_size = 128 * 8 # 8 GPUs
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_fp16(self):
+ """Test Keras model with 8 GPUs with tf.keras mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ FLAGS.dtype = 'fp16'
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_eager(self):
+ """Test Keras model with 8 GPUs, eager, fp32."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.use_tf_function = False
+ FLAGS.use_tf_while_loop = False
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_eager_fp16(self):
+ """Test Keras model with 8 GPUs, eager, fp16."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.use_tf_function = False
+ FLAGS.use_tf_while_loop = False
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager_fp16')
+ FLAGS.batch_size = 128
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_amp(self):
+ """Test Keras model with 8 GPUs with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu_amp(self):
+ """Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
+ self._setup()
+
+ FLAGS.num_gpus = 8
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def _set_df_common(self):
+ FLAGS.steps_per_loop = 500
+ FLAGS.train_epochs = 2
+ FLAGS.train_steps = None
+ FLAGS.skip_eval = True
+ FLAGS.enable_eager = True
+ FLAGS.enable_tensorboard = False
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.report_accuracy_metrics = False
+ FLAGS.log_steps = 50
+ FLAGS.single_l2_loss_op = True
+ FLAGS.use_tf_function = True
+ FLAGS.enable_checkpoint_and_export = False
+
+ def benchmark_2x2_tpu_bf16(self):
+ self._setup()
+ self._set_df_common()
+ FLAGS.batch_size = 1024
+ FLAGS.dtype = 'bf16'
+ self._run_and_report_benchmark()
+
+ def benchmark_4x4_tpu_bf16(self):
+ self._setup()
+ self._set_df_common()
+ FLAGS.batch_size = 4096
+ FLAGS.dtype = 'bf16'
+ self._run_and_report_benchmark()
+
+ @owner_utils.Owner('tf-graph-compiler')
+ def benchmark_4x4_tpu_bf16_mlir(self):
+ """Run resnet model on 4x4 with the MLIR Bridge enabled."""
+ self._setup()
+ self._set_df_common()
+ FLAGS.batch_size = 4096
+ FLAGS.dtype = 'bf16'
+ tf.config.experimental.enable_mlir_bridge()
+ self._run_and_report_benchmark()
+
+ def benchmark_8x16_tpu_bf16(self):
+ self._setup()
+ self._set_df_common()
+ FLAGS.batch_size = 8192
+ FLAGS.dtype = 'bf16'
+ self._run_and_report_benchmark()
+
+ def fill_report_object(self, stats):
+ super(Resnet50CtlBenchmarkBase, self).fill_report_object(
+ stats, total_batch_size=FLAGS.batch_size, log_steps=FLAGS.log_steps)
+
+
+class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
+ """Resnet50 synthetic benchmark tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ def_flags = {}
+ def_flags['skip_eval'] = True
+ def_flags['use_synthetic_data'] = True
+ def_flags['train_steps'] = 110
+ def_flags['steps_per_loop'] = 20
+ def_flags['log_steps'] = 10
+
+ super(Resnet50CtlBenchmarkSynth, self).__init__(
+ output_dir=output_dir, default_flags=def_flags)
+
+
+class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
+ """Resnet50 real data benchmark tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ def_flags = {}
+ def_flags['skip_eval'] = True
+ def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
+ def_flags['train_steps'] = 110
+ def_flags['steps_per_loop'] = 20
+ def_flags['log_steps'] = 10
+
+ super(Resnet50CtlBenchmarkReal, self).__init__(
+ output_dir=output_dir, default_flags=def_flags)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/retinanet_benchmark.py b/models/official/benchmark/retinanet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..62bc80eef1fd00d5087af5522561ff7cf7863f5e
--- /dev/null
+++ b/models/official/benchmark/retinanet_benchmark.py
@@ -0,0 +1,276 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes RetinaNet benchmarks and accuracy tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=g-bad-import-order
+import json
+import time
+
+from absl import flags
+from absl.testing import flagsaver
+import tensorflow as tf
+# pylint: enable=g-bad-import-order
+
+from official.benchmark import benchmark_wrappers
+from official.benchmark import perfzero_benchmark
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+from official.vision.detection import main as detection
+from official.vision.detection.configs import base_config
+
+FLAGS = flags.FLAGS
+
+# pylint: disable=line-too-long
+COCO_TRAIN_DATA = 'gs://tf-perfzero-data/coco/train*'
+COCO_EVAL_DATA = 'gs://tf-perfzero-data/coco/val*'
+COCO_EVAL_JSON = 'gs://tf-perfzero-data/coco/instances_val2017.json'
+RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoint-2018-02-07'
+# pylint: enable=line-too-long
+
+
+class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
+ """Base class to hold methods common to test classes."""
+
+ def __init__(self, **kwargs):
+ super(DetectionBenchmarkBase, self).__init__(**kwargs)
+ self.timer_callback = None
+
+ def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap,
+ max_ap, warmup):
+ """Report benchmark results by writing to local protobuf file.
+
+ Args:
+ stats: dict returned from Detection models with known entries.
+ start_time_sec: the start of the benchmark execution in seconds
+ wall_time_sec: the duration of the benchmark execution in seconds
+ min_ap: Minimum detection AP constraint to verify correctness of the
+ model.
+ max_ap: Maximum detection AP accuracy constraint to verify correctness of
+ the model.
+ warmup: Number of time log entries to ignore when computing examples/sec.
+ """
+ metrics = [{
+ 'name': 'total_loss',
+ 'value': stats['total_loss'],
+ }]
+ if self.timer_callback:
+ metrics.append({
+ 'name': 'exp_per_second',
+ 'value': self.timer_callback.get_examples_per_sec(warmup)
+ })
+ metrics.append({
+ 'name': 'startup_time',
+ 'value': self.timer_callback.get_startup_time(start_time_sec)
+ })
+ else:
+ metrics.append({
+ 'name': 'exp_per_second',
+ 'value': 0.0,
+ })
+
+ if 'eval_metrics' in stats:
+ metrics.append({
+ 'name': 'AP',
+ 'value': stats['AP'],
+ 'min_value': min_ap,
+ 'max_value': max_ap,
+ })
+ flags_str = flags_core.get_nondefault_flags_as_str()
+ self.report_benchmark(
+ iters=stats['total_steps'],
+ wall_time=wall_time_sec,
+ metrics=metrics,
+ extras={'flags': flags_str})
+
+
+class RetinanetBenchmarkBase(DetectionBenchmarkBase):
+ """Base class to hold methods common to test classes in the module."""
+
+ def __init__(self, **kwargs):
+ self.train_data_path = COCO_TRAIN_DATA
+ self.eval_data_path = COCO_EVAL_DATA
+ self.eval_json_path = COCO_EVAL_JSON
+ self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH
+ super(RetinanetBenchmarkBase, self).__init__(**kwargs)
+
+ def _run_detection_main(self):
+ """Starts detection job."""
+ if self.timer_callback:
+ FLAGS.log_steps = 0 # prevent detection.run from adding the same callback
+ return detection.run(callbacks=[self.timer_callback])
+ else:
+ return detection.run()
+
+
+class RetinanetAccuracy(RetinanetBenchmarkBase):
+ """Accuracy test for RetinaNet model.
+
+ Tests RetinaNet detection task model accuracy. The naming
+ convention of below test cases follow
+ `benchmark_(number of gpus)_gpu_(dataset type)` format.
+ """
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ params,
+ min_ap=0.325,
+ max_ap=0.35,
+ do_eval=True,
+ warmup=1):
+ """Starts RetinaNet accuracy benchmark test."""
+ FLAGS.params_override = json.dumps(params)
+ # Need timer callback to measure performance
+ self.timer_callback = keras_utils.TimeHistory(
+ batch_size=params['train']['batch_size'],
+ log_steps=FLAGS.log_steps,
+ )
+
+ start_time_sec = time.time()
+ FLAGS.mode = 'train'
+ summary, _ = self._run_detection_main()
+ wall_time_sec = time.time() - start_time_sec
+
+ if do_eval:
+ FLAGS.mode = 'eval'
+ eval_metrics = self._run_detection_main()
+ summary.update(eval_metrics)
+
+ summary['total_steps'] = params['train']['total_steps']
+ self._report_benchmark(summary, start_time_sec, wall_time_sec, min_ap,
+ max_ap, warmup)
+
+ def _setup(self):
+ super(RetinanetAccuracy, self)._setup()
+ FLAGS.model = 'retinanet'
+
+ def _params(self):
+ return {
+ 'architecture': {
+ 'use_bfloat16': True,
+ },
+ 'train': {
+ 'batch_size': 64,
+ 'iterations_per_loop': 100,
+ 'total_steps': 22500,
+ 'train_file_pattern': self.train_data_path,
+ 'checkpoint': {
+ 'path': self.resnet_checkpoint_path,
+ 'prefix': 'resnet50/'
+ },
+ # Speed up ResNet training when loading from the checkpoint.
+ 'frozen_variable_prefix': base_config.RESNET_FROZEN_VAR_PREFIX,
+ },
+ 'eval': {
+ 'batch_size': 8,
+ 'eval_samples': 5000,
+ 'val_json_file': self.eval_json_path,
+ 'eval_file_pattern': self.eval_data_path,
+ },
+ }
+
+ @flagsaver.flagsaver
+ def benchmark_8_gpu_coco(self):
+ """Run RetinaNet model accuracy test with 8 GPUs."""
+ self._setup()
+ params = self._params()
+ FLAGS.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_coco')
+ FLAGS.strategy_type = 'mirrored'
+ self._run_and_report_benchmark(params)
+
+
+class RetinanetBenchmarkReal(RetinanetAccuracy):
+ """Short benchmark performance tests for RetinaNet model.
+
+ Tests RetinaNet performance in different GPU configurations.
+ The naming convention of below test cases follow
+ `benchmark_(number of gpus)_gpu` format.
+ """
+
+ def _setup(self):
+ super(RetinanetBenchmarkReal, self)._setup()
+ # Use negative value to avoid saving checkpoints.
+ FLAGS.save_checkpoint_freq = -1
+
+ @flagsaver.flagsaver
+ def benchmark_8_gpu_coco(self):
+ """Run RetinaNet model accuracy test with 8 GPUs."""
+ self._setup()
+ params = self._params()
+ params['architecture']['use_bfloat16'] = False
+ params['train']['total_steps'] = 1875 # One epoch.
+ # The iterations_per_loop must be one, otherwise the number of examples per
+ # second would be wrong. Currently only support calling callback per batch
+ # when each loop only runs on one batch, i.e. host loop for one step. The
+ # performance of this situation might be lower than the case of
+ # iterations_per_loop > 1.
+ # Related bug: b/135933080
+ params['train']['iterations_per_loop'] = 1
+ params['eval']['eval_samples'] = 8
+ FLAGS.num_gpus = 8
+ FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco')
+ FLAGS.strategy_type = 'mirrored'
+ self._run_and_report_benchmark(params)
+
+ @flagsaver.flagsaver
+ def benchmark_1_gpu_coco(self):
+ """Run RetinaNet model accuracy test with 1 GPU."""
+ self._setup()
+ params = self._params()
+ params['architecture']['use_bfloat16'] = False
+ params['train']['batch_size'] = 8
+ params['train']['total_steps'] = 200
+ params['train']['iterations_per_loop'] = 1
+ params['eval']['eval_samples'] = 8
+ FLAGS.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('real_benchmark_1_gpu_coco')
+ FLAGS.strategy_type = 'one_device'
+ self._run_and_report_benchmark(params)
+
+ @flagsaver.flagsaver
+ def benchmark_xla_1_gpu_coco(self):
+ """Run RetinaNet model accuracy test with 1 GPU and XLA enabled."""
+ self._setup()
+ params = self._params()
+ params['architecture']['use_bfloat16'] = False
+ params['train']['batch_size'] = 8
+ params['train']['total_steps'] = 200
+ params['train']['iterations_per_loop'] = 1
+ params['eval']['eval_samples'] = 8
+ FLAGS.num_gpus = 1
+ FLAGS.model_dir = self._get_model_dir('real_benchmark_xla_1_gpu_coco')
+ FLAGS.strategy_type = 'one_device'
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark(params)
+
+ @flagsaver.flagsaver
+ def benchmark_2x2_tpu_coco(self):
+ """Run RetinaNet model accuracy test with 4 TPUs."""
+ self._setup()
+ params = self._params()
+ params['train']['batch_size'] = 64
+ params['train']['total_steps'] = 1875 # One epoch.
+ params['train']['iterations_per_loop'] = 500
+ FLAGS.model_dir = self._get_model_dir('real_benchmark_2x2_tpu_coco')
+ FLAGS.strategy_type = 'tpu'
+ self._run_and_report_benchmark(params, do_eval=False, warmup=0)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/shakespeare_benchmark.py b/models/official/benchmark/shakespeare_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..430ab75da5300e3c374bbe56c2c02befb4dc2dff
--- /dev/null
+++ b/models/official/benchmark/shakespeare_benchmark.py
@@ -0,0 +1,355 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes Shakespeare (LSTM) benchmark and accuracy tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+from absl import flags
+import tensorflow as tf # pylint: disable=g-bad-import-order
+
+from official.benchmark.models.shakespeare import shakespeare_main
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+from official.benchmark import benchmark_wrappers
+from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
+
+SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt'
+TMP_DIR = os.getenv('TMPDIR')
+FLAGS = flags.FLAGS
+
+
+class ShakespeareBenchmarkBase(PerfZeroBenchmark):
+ """Base class for Shakespeare (LSTM) benchmark and accuracy tests."""
+
+ def __init__(self, output_dir=None, default_flags=None, root_data_dir=None):
+ super(ShakespeareBenchmarkBase, self).__init__(
+ output_dir=output_dir,
+ default_flags=default_flags,
+ flag_methods=[shakespeare_main.define_flags])
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ top_1_train_min=0.91,
+ top_1_train_max=0.94,
+ warmup=1,
+ log_steps=100):
+ """Report benchmark results by writing to local protobuf file.
+
+ Average epoch time is calculated by skipping the first epoch. This average
+ ignores time spent between epoch and is recorded by begin and end epoch. To
+ skip accuracy check set `top_1_train_min=None`.
+
+ Args:
+ top_1_train_min: lowest passing value.
+ top_1_train_max: highest passing value.
+ warmup: number of entries in `timestamp_log` to ignore.
+ log_steps: How often the log was created for `timestamp_log`.
+ """
+ total_batch_size = FLAGS.batch_size
+ metrics = []
+ start_time_sec = time.time()
+ stats = shakespeare_main.run(FLAGS)
+ wall_time_sec = time.time() - start_time_sec
+
+ if top_1_train_min:
+ metrics.append({'name': 'accuracy_top_1_train',
+ 'value': stats['history']['RecallAt1'][-1],
+ 'min_value': top_1_train_min,
+ 'max_value': top_1_train_max})
+
+ # Look for the time history callback which was used during keras.fit
+ for callback in stats['callbacks']:
+ if isinstance(callback, keras_utils.TimeHistory):
+ epoch_timings = callback.epoch_runtime_log
+ if len(epoch_timings) > 1:
+ average_time = sum(epoch_timings[1:]) / len(epoch_timings[1:])
+ metrics.append({'name': 'avg_epoch_time',
+ 'value': average_time})
+
+ # First entry in timestamp_log is the start of step 1. The rest of the
+ # entries are the end of each step recorded.
+ time_log = callback.timestamp_log
+ elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
+ num_examples = (
+ total_batch_size * log_steps * (len(time_log) - warmup - 1))
+ if elapsed > 0:
+ examples_per_sec = num_examples / elapsed
+ metrics.append({'name': 'exp_per_second',
+ 'value': examples_per_sec})
+
+ flags_str = flags_core.get_nondefault_flags_as_str()
+ self.report_benchmark(iters=-1, wall_time=wall_time_sec,
+ metrics=metrics,
+ extras={'flags': flags_str})
+
+
+class ShakespeareAccuracy(ShakespeareBenchmarkBase):
+ """Shakespeare accuracy tests.
+
+ This is not an ideal test. The best we can use for the accuracy check is to
+ validate top_1 of the training set. At batch size 64 the top_1 training
+ stabilizes to ~0.92 around 40-45 epochs.
+ """
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ """Shakespeare accuracy tests.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more
+ named arguments before updating the constructor.
+ """
+ self.train_data = os.path.join(root_data_dir, SHAKESPEARE_TRAIN_DATA)
+ super(ShakespeareAccuracy, self).__init__(
+ output_dir=output_dir, root_data_dir=root_data_dir)
+
+ def benchmark_cpu(self):
+ """Benchmark cpu."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.training_data = self.train_data
+ FLAGS.batch_size = 64
+ FLAGS.train_epochs = 43
+ FLAGS.model_dir = ''
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu_no_ds_run_eagerly(self):
+ """Benchmark cpu without distribution strategies and run eagerly."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.training_data = self.train_data
+ FLAGS.batch_size = 64
+ FLAGS.train_epochs = 43
+ FLAGS.model_dir = ''
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu(self):
+ """Benchmark 1 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.training_data = self.train_data
+ FLAGS.batch_size = 64
+ FLAGS.train_epochs = 43
+ FLAGS.model_dir = ''
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_ds(self):
+ """Benchmark 1 gpu without distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.training_data = self.train_data
+ FLAGS.batch_size = 64
+ FLAGS.train_epochs = 43
+ FLAGS.model_dir = ''
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_ds_run_eagerly(self):
+ """Benchmark 1 gpu without distribution strategies and run eagerly."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.training_data = self.train_data
+ FLAGS.batch_size = 64
+ FLAGS.train_epochs = 43
+ FLAGS.model_dir = ''
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu(self):
+ """Benchmark 1 gpu w/xla."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.training_data = self.train_data
+ FLAGS.batch_size = 64
+ FLAGS.train_epochs = 43
+ FLAGS.model_dir = ''
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu(self):
+ """Benchmark 8 gpu.
+
+ This is test is for accuracy not scaling. The batch-size is not scaled to
+ the number of gpus.
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.training_data = self.train_data
+ FLAGS.batch_size = 64
+ FLAGS.train_epochs = 43
+ FLAGS.model_dir = ''
+ self._run_and_report_benchmark()
+
+
+class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
+ """Benchmark accuracy tests."""
+
+ def __init__(self, output_dir=None, root_data_dir=TMP_DIR, **kwargs):
+ """Benchmark tests w/Keras.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more
+ named arguments before updating the constructor.
+ """
+ self.train_data = os.path.join(root_data_dir, SHAKESPEARE_TRAIN_DATA)
+
+ def_flags = {}
+ def_flags['training_data'] = self.train_data
+ def_flags['model_dir'] = ''
+ def_flags['train_epochs'] = 4
+ def_flags['log_steps'] = 50
+
+ super(ShakespeareKerasBenchmarkReal, self).__init__(
+ output_dir=output_dir,
+ root_data_dir=root_data_dir,
+ default_flags=def_flags)
+
+ def benchmark_cpu(self):
+ """Benchmark cpu."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.batch_size = 64
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu_no_ds_run_eagerly(self):
+ """Benchmark cpu without distribution strategy and run eagerly."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.batch_size = 64
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.run_eagerly = True
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu_no_ds(self):
+ """Benchmark cpu without distribution strategy."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.batch_size = 64
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_cpu_no_ds_force_v2(self):
+ """Benchmark cpu no ds, and force v2."""
+ self._setup()
+ FLAGS.num_gpus = 0
+ FLAGS.batch_size = 64
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu(self):
+ """Benchmark 1 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = 64
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_cudnn(self):
+ """Benchmark 1 gpu with CuDNN disabled."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = 64
+ FLAGS.cudnn = False
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_ds(self):
+ """Benchmark 1 gpu without distribution strategies."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = 64
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_1_gpu_no_ds_run_eagerly(self):
+ """Benchmark 1 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = 64
+ FLAGS.run_eagerly = True
+ FLAGS.distribution_strategy = 'off'
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu(self):
+ """Benchmark 1 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = 64
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_1_gpu_no_cudnn(self):
+ """Benchmark 1 gpu w/xla and CuDNN disabled."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = 64
+ FLAGS.cudnn = False
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu(self):
+ """Benchmark 8 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.batch_size = 64 * 8
+ FLAGS.log_steps = 10
+ self._run_and_report_benchmark()
+
+ def benchmark_8_gpu_no_cudnn(self):
+ """Benchmark 8 gpu with CuDNN disabled."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.batch_size = 64 * 8
+ FLAGS.log_steps = 10
+ FLAGS.cudnn = False
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu(self):
+ """Benchmark 8 gpu w/xla."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = 64 * 8
+ FLAGS.log_steps = 10
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def benchmark_xla_8_gpu_no_cudnn(self):
+ """Benchmark 8 gpu w/xla and CuDNN disabled."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.batch_size = 64 * 8
+ FLAGS.log_steps = 10
+ FLAGS.cudnn = False
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark()
+
+ def _run_and_report_benchmark(self):
+ """Run and report benchmark."""
+ super(ShakespeareKerasBenchmarkReal, self)._run_and_report_benchmark(
+ top_1_train_min=None, log_steps=FLAGS.log_steps)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/tfhub_memory_usage_benchmark.py b/models/official/benchmark/tfhub_memory_usage_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f50ecf6b3e0c95c78c0ac574131321a1e41fceb
--- /dev/null
+++ b/models/official/benchmark/tfhub_memory_usage_benchmark.py
@@ -0,0 +1,69 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Runs a memory usage benchmark for a Tensorflow Hub model.
+
+Loads a SavedModel and records memory usage.
+"""
+import functools
+import time
+
+from absl import flags
+import tensorflow as tf
+import tensorflow_hub as hub
+
+from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
+
+FLAGS = flags.FLAGS
+
+
+class TfHubMemoryUsageBenchmark(PerfZeroBenchmark):
+ """A benchmark measuring memory usage for a given TF Hub SavedModel."""
+
+ def __init__(self,
+ hub_model_handle_list=None,
+ output_dir=None,
+ default_flags=None,
+ root_data_dir=None,
+ **kwargs):
+ super(TfHubMemoryUsageBenchmark, self).__init__(
+ output_dir=output_dir, default_flags=default_flags, **kwargs)
+ if hub_model_handle_list:
+ for hub_model_handle in hub_model_handle_list.split(';'):
+ # Converts a model handle of the form
+ # https://tfhub.dev/google/nnlm-en-dim128/1 to valid python method name
+ # like google_nnlm_en_dim128_1.
+ hub_model_method_name = hub_model_handle.replace(
+ 'https://tfhub.dev',
+ '').replace('/', '_').replace('-', '_').strip('_')
+ setattr(
+ self, 'benchmark_' + hub_model_method_name,
+ functools.partial(self.benchmark_memory_usage, hub_model_handle))
+
+ def benchmark_memory_usage(
+ self, hub_model_handle='https://tfhub.dev/google/nnlm-en-dim128/1'):
+ start_time_sec = time.time()
+ self.load_model(hub_model_handle)
+ wall_time_sec = time.time() - start_time_sec
+
+ metrics = []
+ self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)
+
+ def load_model(self, hub_model_handle):
+ """Loads a TF Hub module."""
+ hub.load(hub_model_handle)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/transformer_benchmark.py b/models/official/benchmark/transformer_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61201aa174af4882c6dbab28e10fe64d8cc1377
--- /dev/null
+++ b/models/official/benchmark/transformer_benchmark.py
@@ -0,0 +1,757 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes Transformer w/Keras benchmark and accuracy tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+from absl import flags
+import tensorflow as tf
+from official.benchmark import benchmark_wrappers
+from official.benchmark import owner_utils
+from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
+from official.nlp.transformer import misc
+from official.nlp.transformer import transformer_main as transformer_main
+from official.utils.flags import core as flags_core
+
+TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
+EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
+FLAGS = flags.FLAGS
+TMP_DIR = os.getenv('TMPDIR')
+
+
+class TransformerBenchmark(PerfZeroBenchmark):
+ """Methods common to executing transformer w/keras tests.
+
+ Code under test for the Transformer Keras models report the same data and
+ require the same FLAG setup.
+ """
+
+ def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
+ flag_methods=None, tpu=None):
+ root_data_dir = root_data_dir if root_data_dir else ''
+
+ self.train_data_dir = os.path.join(root_data_dir,
+ TRANSFORMER_EN2DE_DATA_DIR_NAME)
+
+ self.vocab_file = os.path.join(root_data_dir,
+ TRANSFORMER_EN2DE_DATA_DIR_NAME,
+ 'vocab.ende.32768')
+
+ self.bleu_source = os.path.join(root_data_dir,
+ EN2DE_2014_BLEU_DATA_DIR_NAME,
+ 'newstest2014.en')
+
+ self.bleu_ref = os.path.join(root_data_dir,
+ EN2DE_2014_BLEU_DATA_DIR_NAME,
+ 'newstest2014.de')
+
+ if default_flags is None:
+ default_flags = {}
+ default_flags['data_dir'] = self.train_data_dir
+ default_flags['vocab_file'] = self.vocab_file
+
+ super(TransformerBenchmark, self).__init__(
+ output_dir=output_dir,
+ default_flags=default_flags,
+ flag_methods=flag_methods,
+ tpu=tpu)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ bleu_max=None,
+ bleu_min=None,
+ log_steps=None,
+ total_batch_size=None,
+ warmup=1):
+ """Report benchmark results by writing to local protobuf file.
+
+ Args:
+ bleu_max: highest passing level for bleu score.
+ bleu_min: lowest passing level for bleu score.
+ log_steps: How often the log was created for stats['step_timestamp_log'].
+ total_batch_size: Global batch-size.
+ warmup: number of entries in stats['step_timestamp_log'] to ignore.
+ """
+ start_time_sec = time.time()
+ task = transformer_main.TransformerTask(FLAGS)
+ stats = task.train()
+ wall_time_sec = time.time() - start_time_sec
+
+ metrics = []
+ if 'bleu_uncased' in stats:
+ if 'bleu_uncased_history' in stats:
+ bleu_uncased_best = max(stats['bleu_uncased_history'],
+ key=lambda x: x[1])
+ metrics.append({'name': 'bleu_uncased',
+ 'value': bleu_uncased_best[1],
+ 'min_value': bleu_min,
+ 'max_value': bleu_max})
+ metrics.append({'name': 'bleu_best_score_iteration',
+ 'value': bleu_uncased_best[0]})
+ metrics.append({'name': 'bleu_uncased_last',
+ 'value': stats['bleu_uncased']})
+ else:
+ metrics.append({'name': 'bleu_uncased',
+ 'value': stats['bleu_uncased'],
+ 'min_value': bleu_min,
+ 'max_value': bleu_max})
+
+ if (warmup and 'step_timestamp_log' in stats and
+ len(stats['step_timestamp_log']) > warmup + 1):
+ # first entry in the time_log is start of step 1. The rest of the
+ # entries are the end of each step recorded
+ time_log = stats['step_timestamp_log']
+ elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
+ num_examples = (
+ total_batch_size * log_steps * (len(time_log) - warmup - 1))
+ examples_per_sec = num_examples / elapsed
+ metrics.append({'name': 'exp_per_second',
+ 'value': examples_per_sec})
+
+ if 'avg_exp_per_second' in stats:
+ metrics.append({'name': 'avg_exp_per_second',
+ 'value': stats['avg_exp_per_second']})
+
+ if 'step_timestamp_log' in stats:
+ time_log = stats['step_timestamp_log']
+ metrics.append({'name': 'startup_time',
+ 'value': time_log[0].timestamp - start_time_sec})
+
+ flags_str = flags_core.get_nondefault_flags_as_str()
+ self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics,
+ extras={'flags': flags_str})
+
+
+class TransformerBaseKerasAccuracy(TransformerBenchmark):
+ """Benchmark accuracy tests for Transformer Base model w/ Keras."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ """Benchmark accuracy tests for Transformer Base model w/ Keras.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more
+ named arguments before updating the constructor.
+ """
+ flag_methods = [misc.define_transformer_flags]
+
+ super(TransformerBaseKerasAccuracy, self).__init__(
+ output_dir=output_dir, root_data_dir=root_data_dir,
+ flag_methods=flag_methods)
+
+ def benchmark_1_gpu(self):
+ """Benchmark 1 gpu.
+
+ The paper uses 8 GPUs and a much larger effective batch size, this is will
+ not converge to the 27.3 BLEU (uncased) SOTA.
+ """
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'base'
+ FLAGS.batch_size = 2048
+ FLAGS.train_steps = 1000
+ FLAGS.steps_between_evals = 500
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
+ # These bleu scores are based on test runs after at this limited
+ # number of steps and batch size after verifying SOTA at 8xV100s.
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=25.3,
+ bleu_max=26)
+
+ def benchmark_1_gpu_static_batch(self):
+ """Benchmark 1 gpu with static_batch.
+
+ The paper uses 8 GPUs and a much larger effective batch size, this is will
+ not converge to the 27.3 BLEU (uncased) SOTA.
+ """
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'base'
+ FLAGS.batch_size = 4096
+ FLAGS.train_steps = 100000
+ FLAGS.steps_between_evals = 5000
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
+ # These bleu scores are based on test runs after at this limited
+ # number of steps and batch size after verifying SOTA at 8xV100s.
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=25.3,
+ bleu_max=26)
+
+ def benchmark_8_gpu(self):
+ """Benchmark 8 gpu.
+
+ Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'base'
+ FLAGS.batch_size = 4096*8
+ FLAGS.train_steps = 100000
+ FLAGS.steps_between_evals = 20000
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=27,
+ bleu_max=28)
+
+ def benchmark_8_gpu_static_batch(self):
+ """Benchmark 8 gpu.
+
+ Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'base'
+ FLAGS.batch_size = 4096*8
+ FLAGS.train_steps = 100000
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ FLAGS.steps_between_evals = 5000
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=27,
+ bleu_max=28)
+
+
+class TransformerBigKerasAccuracy(TransformerBenchmark):
+ """Benchmark accuracy tests for Transformer Big model w/ Keras."""
+
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
+ """Benchmark accuracy tests for Transformer Big model w/ Keras.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more
+ named arguments before updating the constructor.
+ """
+ flag_methods = [misc.define_transformer_flags]
+
+ super(TransformerBigKerasAccuracy, self).__init__(
+ output_dir=output_dir, root_data_dir=root_data_dir,
+ flag_methods=flag_methods)
+
+ def benchmark_8_gpu(self):
+ """Benchmark 8 gpu.
+
+ Over 6 runs with eval every 20K steps the average highest value was 28.195
+ (bleu uncased). 28.424 was the highest and 27.96 the lowest. The values are
+ the highest value seen during a run and occurred at a median of iteration 9.
+ Iterations are not epochs, an iteration is a number of steps between evals.
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'big'
+ FLAGS.batch_size = 3072*8
+ FLAGS.train_steps = 20000 * 12
+ FLAGS.steps_between_evals = 20000
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=27.9,
+ bleu_max=29.2)
+
+ def benchmark_8_gpu_static_batch(self):
+ """Benchmark 8 gpu.
+
+ Should converge to 28.4 BLEU (uncased). This has not be verified yet."
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'big'
+ FLAGS.batch_size = 3072*8
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ FLAGS.train_steps = 20000 * 12
+ FLAGS.steps_between_evals = 20000
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=28,
+ bleu_max=29.2)
+
+ def benchmark_8_gpu_fp16(self):
+ """Benchmark 8 gpu with dynamic batch and fp16.
+
+ Over 6 runs with eval every 20K steps the average highest value was 28.247
+ (bleu uncased). 28.424 was the highest and 28.09 the lowest. The values are
+ the highest value seen during a run and occurred at a median of iteration
+ 11. While this could be interpreted as worse than FP32, if looking at the
+ first iteration at which 28 is passed FP16 performs equal and possibly
+ better. Although not part of the initial test runs, the highest value
+ recorded with the arguments below was 28.9 at iteration 12. Iterations are
+ not epochs, an iteration is a number of steps between evals.
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'big'
+ FLAGS.batch_size = 3072*8
+ FLAGS.train_steps = 20000 * 12
+ FLAGS.steps_between_evals = 20000
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=28,
+ bleu_max=29.2)
+
+ def benchmark_8_gpu_fp16_amp(self):
+ """Benchmark 8 gpu with dynamic batch and fp16 with automatic mixed precision.
+
+ Should converge to 28.4 BLEU (uncased). This has not be verified yet."
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.fp16_implementation = 'graph_rewrite'
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'big'
+ FLAGS.batch_size = 3072*8
+ FLAGS.train_steps = 20000 * 12
+ FLAGS.steps_between_evals = 20000
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_amp')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=28,
+ bleu_max=29)
+
+ def benchmark_8_gpu_static_batch_fp16(self):
+ """Benchmark 8 gpu with static batch and fp16.
+
+ Should converge to 28.4 BLEU (uncased). This has not be verified yet."
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'big'
+ FLAGS.batch_size = 3072*8
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ FLAGS.train_steps = 400000
+ FLAGS.steps_between_evals = 20000
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch_fp16')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=28,
+ bleu_max=29.2)
+
+ def benchmark_xla_8_gpu_static_batch_fp16(self):
+ """Benchmark 8 gpu with static batch, XLA, and FP16.
+
+ Should converge to 28.4 BLEU (uncased). This has not be verified yet."
+ """
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.enable_xla = True
+ FLAGS.data_dir = self.train_data_dir
+ FLAGS.vocab_file = self.vocab_file
+ # Sets values directly to avoid validation check.
+ FLAGS['bleu_source'].value = self.bleu_source
+ FLAGS['bleu_ref'].value = self.bleu_ref
+ FLAGS.param_set = 'big'
+ FLAGS.batch_size = 3072*8
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ FLAGS.train_steps = 400000
+ FLAGS.steps_between_evals = 20000
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_xla_8_gpu_static_batch_fp16')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps,
+ bleu_min=28,
+ bleu_max=29.2)
+
+
+class TransformerKerasBenchmark(TransformerBenchmark):
+ """Benchmarks for Transformer (Base and Big) using Keras."""
+
+ def __init__(self, output_dir=None, default_flags=None,
+ root_data_dir=None, batch_per_gpu=4096, tpu=None):
+ """Initialize.
+
+ Args:
+ output_dir: Based directory for saving artifacts, e.g. checkpoints.
+ default_flags: default flags to use for all tests.
+ root_data_dir: root directory for data, e.g. training.
+ batch_per_gpu: batch size to use per gpu.
+ tpu: Target TPU to use.
+ """
+ flag_methods = [misc.define_transformer_flags]
+ self.batch_per_gpu = batch_per_gpu
+
+ super(TransformerKerasBenchmark, self).__init__(
+ output_dir=output_dir,
+ default_flags=default_flags,
+ root_data_dir=root_data_dir,
+ flag_methods=flag_methods,
+ tpu=tpu)
+
+ def benchmark_1_gpu_no_dist_strat(self):
+ """Benchmark 1 gpu without distribution strategy."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_1_gpu_no_dist_strat_static_batch(self):
+ """Benchmark 1 gpu without distribution strategy with static batch."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_ds_sb')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_1_gpu(self):
+ """Benchmark 1 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_1_gpu_fp16(self):
+ """Benchmark 1 gpu FP16."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
+ FLAGS.dtype = 'fp16'
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_xla_1_gpu(self):
+ """Benchmark 1 gpu w/xla."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_xla_1_gpu_fp16(self):
+ """Benchmark 1 gpu w/xla and FP16."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
+ FLAGS.enable_xla = True
+ FLAGS.dtype = 'fp16'
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_1_gpu_static_batch(self):
+ """Benchmark 1 gpu with static batch."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_xla_1_gpu_static_batch(self):
+ """Benchmark 1 gpu with static batch w/xla."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_static_batch')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ FLAGS.enable_xla = True
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_1_gpu_static_batch_fp16(self):
+ """Benchmark 1 gpu with static batch FP16."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_1_gpu_static_batch_fp16')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ FLAGS.dtype = 'fp16'
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_xla_1_gpu_static_batch_fp16(self):
+ """Benchmark 1 gpu with static batch w/xla and FP16."""
+ self._setup()
+ FLAGS.num_gpus = 1
+ FLAGS.batch_size = self.batch_per_gpu
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_xla_1_gpu_static_batch_fp16')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ FLAGS.enable_xla = True
+ FLAGS.dtype = 'fp16'
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_8_gpu(self):
+ """Benchmark 8 gpu."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.batch_size = self.batch_per_gpu * 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_8_gpu_fp16(self):
+ """Benchmark 8 gpu FP16."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = self.batch_per_gpu * 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_xla_8_gpu(self):
+ """Benchmark 8 gpu w/xla."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.enable_xla = True
+ FLAGS.batch_size = self.batch_per_gpu * 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_xla_8_gpu_fp16(self):
+ """Benchmark 8 gpu w/xla and FP16."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.enable_xla = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = self.batch_per_gpu * 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16')
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_8_gpu_static_batch(self):
+ """Benchmark 8 gpu with static batch."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.batch_size = self.batch_per_gpu * 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_8_gpu_static_batch_fp16(self):
+ """Benchmark 8 gpu with static batch FP16."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = self.batch_per_gpu * 8
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_8_gpu_static_batch_fp16')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_xla_8_gpu_static_batch(self):
+ """Benchmark 8 gpu with static batch w/xla."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.enable_xla = True
+ FLAGS.batch_size = self.batch_per_gpu * 8
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_static_batch')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_xla_8_gpu_static_batch_fp16(self):
+ """Benchmark 8 gpu with static batch w/xla and FP16."""
+ self._setup()
+ FLAGS.num_gpus = 8
+ FLAGS.enable_xla = True
+ FLAGS.dtype = 'fp16'
+ FLAGS.batch_size = self.batch_per_gpu * 8
+ FLAGS.model_dir = self._get_model_dir(
+ 'benchmark_xla_8_gpu_static_batch_fp16')
+ FLAGS.static_batch = True
+ FLAGS.max_length = 64
+ self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+
+class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
+ """Transformer based version real data benchmark tests."""
+
+ def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
+ def_flags = {}
+ def_flags['param_set'] = 'base'
+ def_flags['train_steps'] = 50
+ def_flags['log_steps'] = 10
+
+ super(TransformerBaseKerasBenchmarkReal, self).__init__(
+ output_dir=output_dir, default_flags=def_flags,
+ root_data_dir=root_data_dir, batch_per_gpu=4096)
+
+
+class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
+ """Transformer based version real data benchmark tests."""
+
+ def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR,
+ tpu=None, **kwargs):
+ def_flags = {}
+ def_flags['param_set'] = 'big'
+ def_flags['train_steps'] = 50
+ def_flags['log_steps'] = 10
+
+ super(TransformerBigKerasBenchmarkReal, self).__init__(
+ output_dir=output_dir, default_flags=def_flags,
+ root_data_dir=root_data_dir, batch_per_gpu=3072,
+ tpu=tpu)
+
+ def benchmark_2x2_tpu(self):
+ """Port of former snaggletooth transformer_big model on 2x2."""
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
+ FLAGS.train_steps = 300
+ FLAGS.log_steps = 150
+ FLAGS.steps_between_evals = 150
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.static_batch = True
+ FLAGS.use_ctl = True
+ FLAGS.batch_size = 6144
+ FLAGS.max_length = 64
+ FLAGS.decode_batch_size = 32
+ FLAGS.decode_max_length = 97
+ FLAGS.padded_decode = True
+ FLAGS.enable_checkpointing = False
+
+ self._run_and_report_benchmark(
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ def benchmark_4x4_tpu(self):
+ """Port of former GCP transformer_big model on 4x4."""
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
+ FLAGS.train_steps = 300
+ FLAGS.log_steps = 150
+ FLAGS.steps_between_evals = 150
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.static_batch = True
+ FLAGS.use_ctl = True
+ FLAGS.batch_size = 24576
+ FLAGS.max_length = 64
+ FLAGS.decode_batch_size = 32
+ FLAGS.decode_max_length = 97
+ FLAGS.padded_decode = True
+ FLAGS.enable_checkpointing = False
+
+ self._run_and_report_benchmark(
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+ @owner_utils.Owner('tf-graph-compiler')
+ def benchmark_4x4_tpu_mlir(self):
+ """Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
+ FLAGS.train_steps = 300
+ FLAGS.log_steps = 150
+ FLAGS.steps_between_evals = 150
+ FLAGS.distribution_strategy = 'tpu'
+ FLAGS.static_batch = True
+ FLAGS.use_ctl = True
+ FLAGS.batch_size = 24576
+ FLAGS.max_length = 64
+ FLAGS.decode_batch_size = 32
+ FLAGS.decode_max_length = 97
+ FLAGS.padded_decode = True
+ FLAGS.enable_checkpointing = False
+ tf.config.experimental.enable_mlir_bridge()
+
+ self._run_and_report_benchmark(
+ total_batch_size=FLAGS.batch_size,
+ log_steps=FLAGS.log_steps)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/unet3d_benchmark.py b/models/official/benchmark/unet3d_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..2614b29259dcf4c85d609abca94706c95570b7ec
--- /dev/null
+++ b/models/official/benchmark/unet3d_benchmark.py
@@ -0,0 +1,148 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes benchmark testing for 3D Unet model."""
+# pylint: disable=line-too-long
+from __future__ import print_function
+
+import functools
+import os
+import time
+from typing import Optional
+from absl import flags
+import tensorflow as tf # pylint: disable=g-bad-import-order
+
+from official.benchmark import benchmark_wrappers
+from official.benchmark import keras_benchmark
+from official.benchmark import owner_utils
+from official.vision.segmentation import unet_main as unet_training_lib
+from official.vision.segmentation import unet_model as unet_model_lib
+
+UNET3D_MIN_ACCURACY = 0.90
+UNET3D_MAX_ACCURACY = 0.98
+UNET_TRAINING_FILES = 'gs://mlcompass-data/unet3d/train_data/*'
+UNET_EVAL_FILES = 'gs://mlcompass-data/unet3d/eval_data/*'
+UNET_MODEL_CONFIG_FILE = 'gs://mlcompass-data/unet3d/config/unet_config.yaml'
+
+FLAGS = flags.FLAGS
+
+
+class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
+ """Benchmark accuracy tests for UNet3D model in Keras."""
+
+ def __init__(self,
+ output_dir: Optional[str] = None,
+ root_data_dir: Optional[str] = None,
+ **kwargs):
+ """A benchmark class.
+
+ Args:
+ output_dir: directory where to output e.g. log files
+ root_data_dir: directory under which to look for dataset
+ **kwargs: arbitrary named arguments. This is needed to make the
+ constructor forward compatible in case PerfZero provides more named
+ arguments before updating the constructor.
+ """
+
+ flag_methods = [unet_training_lib.define_unet3d_flags]
+
+ # UNet3D model in Keras."""
+ self.training_file_pattern = UNET_TRAINING_FILES
+ self.eval_file_pattern = UNET_EVAL_FILES
+
+ # TODO(hongjunchoi): Create and use shared config file instead.
+ self.config_file = UNET_MODEL_CONFIG_FILE
+ super(Unet3DAccuracyBenchmark, self).__init__(
+ output_dir=output_dir, flag_methods=flag_methods)
+
+ def _set_benchmark_parameters(self, experiment_name):
+ """Overrides training parameters for benchmark tests."""
+ FLAGS.model_dir = self._get_model_dir(experiment_name)
+ FLAGS.mode = 'train'
+ FLAGS.training_file_pattern = self.training_file_pattern
+ FLAGS.eval_file_pattern = self.eval_file_pattern
+ FLAGS.config_file = self.config_file
+ FLAGS.lr_init_value = 0.00005
+ FLAGS.lr_decay_rate = 0.5
+ FLAGS.epochs = 3
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ experiment_name: str,
+ min_accuracy: float = UNET3D_MIN_ACCURACY,
+ max_accuracy: float = UNET3D_MAX_ACCURACY,
+ distribution_strategy: str = 'tpu',
+ epochs: int = 10,
+ steps: int = 0,
+ epochs_between_evals: int = 1,
+ dtype: str = 'float32',
+ enable_xla: bool = False,
+ run_eagerly: bool = False):
+ """Runs and reports the benchmark given the provided configuration."""
+ params = unet_training_lib.extract_params(FLAGS)
+ strategy = unet_training_lib.create_distribution_strategy(params)
+ if params.use_bfloat16:
+ policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
+ tf.keras.mixed_precision.experimental.set_policy(policy)
+
+ stats = {}
+ start_time_sec = time.time()
+ with strategy.scope():
+ unet_model = unet_model_lib.build_unet_model(params)
+ history = unet_training_lib.train(
+ params, strategy, unet_model,
+ functools.partial(unet_training_lib.get_train_dataset, params),
+ functools.partial(unet_training_lib.get_eval_dataset, params))
+
+ stats['accuracy_top_1'] = history.history['val_metric_accuracy'][-1]
+ stats['training_accuracy_top_1'] = history.history['metric_accuracy'][-1]
+ wall_time_sec = time.time() - start_time_sec
+
+ super(Unet3DAccuracyBenchmark, self)._report_benchmark(
+ stats,
+ wall_time_sec,
+ top_1_min=min_accuracy,
+ top_1_max=max_accuracy,
+ total_batch_size=params.train_batch_size)
+
+ def _get_model_dir(self, folder_name):
+ return os.path.join(self.output_dir, folder_name)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_4x4_tpu_bf16(self):
+ """Test Keras model with 4x4 TPU, fp16."""
+ experiment_name = 'benchmark_4x4_tpu_fp16'
+ self._setup()
+ self._set_benchmark_parameters(experiment_name)
+ self._run_and_report_benchmark(
+ experiment_name=experiment_name,
+ dtype='bfloat16',
+ distribution_strategy='tpu')
+
+ @owner_utils.Owner('tf-graph-compiler')
+ def benchmark_4x4_tpu_bf16_mlir(self):
+ """Test Keras model with 4x4 TPU, fp16 and MLIR enabled."""
+ experiment_name = 'benchmark_4x4_tpu_fp16_mlir'
+ tf.config.experimental.enable_mlir_bridge()
+ self._setup()
+ self._set_benchmark_parameters(experiment_name)
+ self._run_and_report_benchmark(
+ experiment_name=experiment_name,
+ dtype='bfloat16',
+ distribution_strategy='tpu')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/benchmark/xlnet_benchmark.py b/models/official/benchmark/xlnet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4df69cf081a4a06000ed46ea66ac742cb1c39e02
--- /dev/null
+++ b/models/official/benchmark/xlnet_benchmark.py
@@ -0,0 +1,246 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executes XLNet benchmarks and accuracy tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import time
+
+# pylint: disable=g-bad-import-order
+from absl import flags
+from absl.testing import flagsaver
+import tensorflow as tf
+# pylint: enable=g-bad-import-order
+
+from official.benchmark import bert_benchmark_utils as benchmark_utils
+from official.benchmark import owner_utils
+from official.nlp.xlnet import run_classifier
+from official.nlp.xlnet import run_squad
+from official.benchmark import benchmark_wrappers
+
+
+# pylint: disable=line-too-long
+PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
+CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.train.tf_record'
+CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.dev.eval.tf_record'
+SQUAD_DATA_PATH = 'gs://tf-perfzero-data/xlnet/squadv2_cased/'
+# pylint: enable=line-too-long
+
+FLAGS = flags.FLAGS
+
+
+class XLNetBenchmarkBase(benchmark_utils.BertBenchmarkBase):
+ """Base class to hold methods common to test classes in the module."""
+
+ def __init__(self, output_dir=None, tpu=None):
+ super(XLNetBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
+ self.num_epochs = None
+ self.num_steps_per_epoch = None
+
+ @flagsaver.flagsaver
+ def _run_xlnet_classifier(self):
+ """Starts XLNet classification task."""
+ run_classifier.main(unused_argv=None)
+
+ @flagsaver.flagsaver
+ def _run_xlnet_squad(self):
+ """Starts XLNet classification task."""
+ run_squad.main(unused_argv=None)
+
+
+class XLNetClassifyAccuracy(XLNetBenchmarkBase):
+ """Short accuracy test for XLNet classifier model.
+
+ Tests XLNet classification task model accuracy. The naming
+ convention of below test cases follow
+ `benchmark_(number of gpus)_gpu_(dataset type)` format.
+ """
+
+ def __init__(self, output_dir=None, tpu=None, **kwargs):
+ self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
+ self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
+ self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
+
+ super(XLNetClassifyAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ training_summary_path,
+ min_accuracy=0.95,
+ max_accuracy=0.97):
+ """Starts XLNet accuracy benchmark test."""
+
+ start_time_sec = time.time()
+ self._run_xlnet_classifier()
+ wall_time_sec = time.time() - start_time_sec
+
+ with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
+ summary = json.loads(reader.read().decode('utf-8'))
+
+ super(XLNetClassifyAccuracy, self)._report_benchmark(
+ stats=summary,
+ wall_time_sec=wall_time_sec,
+ min_accuracy=min_accuracy,
+ max_accuracy=max_accuracy)
+
+ def _setup(self):
+ super(XLNetClassifyAccuracy, self)._setup()
+ FLAGS.test_data_size = 25024
+ FLAGS.train_batch_size = 16
+ FLAGS.seq_len = 512
+ FLAGS.mem_len = 0
+ FLAGS.n_layer = 24
+ FLAGS.d_model = 1024
+ FLAGS.d_embed = 1024
+ FLAGS.n_head = 16
+ FLAGS.d_head = 64
+ FLAGS.d_inner = 4096
+ FLAGS.untie_r = True
+ FLAGS.n_class = 2
+ FLAGS.ff_activation = 'gelu'
+ FLAGS.strategy_type = 'mirror'
+ FLAGS.learning_rate = 2e-5
+ FLAGS.train_steps = 4000
+ FLAGS.warmup_steps = 500
+ FLAGS.iterations = 200
+ FLAGS.bi_data = False
+ FLAGS.init_checkpoint = self.pretrained_checkpoint_path
+ FLAGS.train_tfrecord_path = self.train_data_path
+ FLAGS.test_tfrecord_path = self.eval_data_path
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_8_gpu_imdb(self):
+ """Run XLNet model accuracy test with 8 GPUs."""
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_imdb')
+ # Sets timer_callback to None as we do not use it now.
+ self.timer_callback = None
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_2x2_tpu_imdb(self):
+ """Run XLNet model accuracy test on 2x2 tpu."""
+ self._setup()
+ FLAGS.strategy_type = 'tpu'
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_imdb')
+ # Sets timer_callback to None as we do not use it now.
+ self.timer_callback = None
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+
+class XLNetSquadAccuracy(XLNetBenchmarkBase):
+ """Short accuracy test for XLNet squad model.
+
+ Tests XLNet squad task model accuracy. The naming
+ convention of below test cases follow
+ `benchmark_(number of gpus)_gpu_(dataset type)` format.
+ """
+
+ def __init__(self, output_dir=None, tpu=None, **kwargs):
+ self.train_data_path = SQUAD_DATA_PATH
+ self.predict_file = os.path.join(SQUAD_DATA_PATH, "dev-v2.0.json")
+ self.test_data_path = os.path.join(SQUAD_DATA_PATH, "12048.eval.tf_record")
+ self.spiece_model_file = os.path.join(SQUAD_DATA_PATH, "spiece.cased.model")
+ self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
+
+ super(XLNetSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
+
+ @benchmark_wrappers.enable_runtime_flags
+ def _run_and_report_benchmark(self,
+ training_summary_path,
+ min_accuracy=87.0,
+ max_accuracy=89.0):
+ """Starts XLNet accuracy benchmark test."""
+
+ start_time_sec = time.time()
+ self._run_xlnet_squad()
+ wall_time_sec = time.time() - start_time_sec
+
+ with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
+ summary = json.loads(reader.read().decode('utf-8'))
+
+ super(XLNetSquadAccuracy, self)._report_benchmark(
+ stats=summary,
+ wall_time_sec=wall_time_sec,
+ min_accuracy=min_accuracy,
+ max_accuracy=max_accuracy)
+
+ def _setup(self):
+ super(XLNetSquadAccuracy, self)._setup()
+ FLAGS.train_batch_size = 16
+ FLAGS.seq_len = 512
+ FLAGS.mem_len = 0
+ FLAGS.n_layer = 24
+ FLAGS.d_model = 1024
+ FLAGS.d_embed = 1024
+ FLAGS.n_head = 16
+ FLAGS.d_head = 64
+ FLAGS.d_inner = 4096
+ FLAGS.untie_r = True
+ FLAGS.ff_activation = 'gelu'
+ FLAGS.strategy_type = 'mirror'
+ FLAGS.learning_rate = 3e-5
+ FLAGS.train_steps = 8000
+ FLAGS.warmup_steps = 1000
+ FLAGS.iterations = 1000
+ FLAGS.bi_data = False
+ FLAGS.init_checkpoint = self.pretrained_checkpoint_path
+ FLAGS.train_tfrecord_path = self.train_data_path
+ FLAGS.test_tfrecord_path = self.test_data_path
+ FLAGS.spiece_model_file = self.spiece_model_file
+ FLAGS.predict_file = self.predict_file
+ FLAGS.adam_epsilon = 1e-6
+ FLAGS.lr_layer_decay_rate = 0.75
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_8_gpu_squadv2(self):
+ """Run XLNet model squad v2 accuracy test with 8 GPUs."""
+ self._setup()
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squadv2')
+ FLAGS.predict_dir = FLAGS.model_dir
+ # Sets timer_callback to None as we do not use it now.
+ self.timer_callback = None
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+ @owner_utils.Owner('tf-model-garden')
+ def benchmark_2x2_tpu_squadv2(self):
+ """Run XLNet model squad v2 accuracy test on 2x2 tpu."""
+ self._setup()
+ FLAGS.strategy_type = 'tpu'
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_squadv2')
+ FLAGS.predict_dir = FLAGS.model_dir
+ # Sets timer_callback to None as we do not use it now.
+ self.timer_callback = None
+
+ summary_path = os.path.join(FLAGS.model_dir,
+ 'summaries/training_summary.txt')
+ self._run_and_report_benchmark(summary_path)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/colab/fine_tuning_bert.ipynb b/models/official/colab/fine_tuning_bert.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..443674b6b9f1292d25f26cc06e3359506763bfce
--- /dev/null
+++ b/models/official/colab/fine_tuning_bert.ipynb
@@ -0,0 +1,1830 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "vXLA5InzXydn"
+ },
+ "source": [
+ "##### Copyright 2019 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "cellView": "form",
+ "colab": {},
+ "colab_type": "code",
+ "id": "RuRlpLL-X0R_"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "1mLJmVotXs64"
+ },
+ "source": [
+ "# Fine-tuning a BERT model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "hYEwGTeCXnnX"
+ },
+ "source": [
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/tutorials/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/fine_tuning_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ "\u003c/table\u003e"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "YN2ACivEPxgD"
+ },
+ "source": [
+ "In this example, we will work through fine-tuning a BERT model using the tensorflow-models PIP package.\n",
+ "\n",
+ "The pretrained BERT model this tutorial is based on is also available on [TensorFlow Hub](https://tensorflow.org/hub), to see how to use it refer to the [Hub Appendix](#hub_bert)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "s2d9S2CSSO1z"
+ },
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "fsACVQpVSifi"
+ },
+ "source": [
+ "### Install the TensorFlow Model Garden pip package\n",
+ "\n",
+ "* `tf-models-nightly` is the nightly Model Garden package created daily automatically.\n",
+ "* pip will install all models and dependencies automatically."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "NvNr2svBM-p3"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q tf-nightly\n",
+ "!pip install -q tf-models-nightly"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "U-7qPCjWUAyy"
+ },
+ "source": [
+ "### Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "lXsXev5MNr20"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "\n",
+ "import tensorflow_hub as hub\n",
+ "import tensorflow_datasets as tfds\n",
+ "tfds.disable_progress_bar()\n",
+ "\n",
+ "from official.modeling import tf_utils\n",
+ "from official import nlp\n",
+ "from official.nlp import bert\n",
+ "\n",
+ "# Load the required submodules\n",
+ "import official.nlp.optimization\n",
+ "import official.nlp.bert.bert_models\n",
+ "import official.nlp.bert.configs\n",
+ "import official.nlp.bert.run_classifier\n",
+ "import official.nlp.bert.tokenization\n",
+ "import official.nlp.data.classifier_data_lib\n",
+ "import official.nlp.modeling.losses\n",
+ "import official.nlp.modeling.models\n",
+ "import official.nlp.modeling.networks"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "mbanlzTvJBsz"
+ },
+ "source": [
+ "### Resources"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "PpW0x8TpR8DT"
+ },
+ "source": [
+ "This directory contains the configuration, vocabulary, and a pre-trained checkpoint used in this tutorial:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "vzRHOLciR8eq"
+ },
+ "outputs": [],
+ "source": [
+ "gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12\"\n",
+ "tf.io.gfile.listdir(gs_folder_bert)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "9uFskufsR2LT"
+ },
+ "source": [
+ "You can get a pre-trained BERT encoder from TensorFlow Hub here:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "e0dAkUttJAzj"
+ },
+ "outputs": [],
+ "source": [
+ "hub_url_bert = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Qv6abtRvH4xO"
+ },
+ "source": [
+ "## The data\n",
+ "For this example we used the [GLUE MRPC dataset from TFDS](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc).\n",
+ "\n",
+ "This dataset is not set up so that it can be directly fed into the BERT model, so this section also handles the necessary preprocessing."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "28DvUhC1YUiB"
+ },
+ "source": [
+ "### Get the dataset from TensorFlow Datasets\n",
+ "\n",
+ "The Microsoft Research Paraphrase Corpus (Dolan \u0026 Brockett, 2005) is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent.\n",
+ "\n",
+ "* Number of labels: 2.\n",
+ "* Size of training dataset: 3668.\n",
+ "* Size of evaluation dataset: 408.\n",
+ "* Maximum sequence length of training and evaluation dataset: 128.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Ijikx5OsH9AT"
+ },
+ "outputs": [],
+ "source": [
+ "glue, info = tfds.load('glue/mrpc', with_info=True,\n",
+ " # It's small, load the whole dataset\n",
+ " batch_size=-1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "xf9zz4vLYXjr"
+ },
+ "outputs": [],
+ "source": [
+ "list(glue.keys())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "ZgBg2r2nYT-K"
+ },
+ "source": [
+ "The `info` object describes the dataset and it's features:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "IQrHxv7W7jH5"
+ },
+ "outputs": [],
+ "source": [
+ "info.features"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "vhsVWYNxazz5"
+ },
+ "source": [
+ "The two classes are:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "n0gfc_VTayfQ"
+ },
+ "outputs": [],
+ "source": [
+ "info.features['label'].names"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "38zJcap6xkbC"
+ },
+ "source": [
+ "Here is one example from the training set:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "xON_i6SkwApW"
+ },
+ "outputs": [],
+ "source": [
+ "glue_train = glue['train']\n",
+ "\n",
+ "for key, value in glue_train.items():\n",
+ " print(f\"{key:9s}: {value[0].numpy()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "9fbTyfJpNr7x"
+ },
+ "source": [
+ "### The BERT tokenizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "wqeN54S61ZKQ"
+ },
+ "source": [
+ "To fine tune a pre-trained model you need to be sure that you're using exactly the same tokenization, vocabulary, and index mapping as you used during training.\n",
+ "\n",
+ "The BERT tokenizer used in this tutorial is written in pure Python (It's not built out of TensorFlow ops). So you can't just plug it into your model as a `keras.layer` like you can with `preprocessing.TextVectorization`.\n",
+ "\n",
+ "The following code rebuilds the tokenizer that was used by the base model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "idxyhmrCQcw5"
+ },
+ "outputs": [],
+ "source": [
+ "# Set up tokenizer to generate Tensorflow dataset\n",
+ "tokenizer = bert.tokenization.FullTokenizer(\n",
+ " vocab_file=os.path.join(gs_folder_bert, \"vocab.txt\"),\n",
+ " do_lower_case=True)\n",
+ "\n",
+ "print(\"Vocab size:\", len(tokenizer.vocab))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "zYHDSquU2lDU"
+ },
+ "source": [
+ "Tokenize a sentence:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "L_OfOYPg853R"
+ },
+ "outputs": [],
+ "source": [
+ "tokens = tokenizer.tokenize(\"Hello TensorFlow!\")\n",
+ "print(tokens)\n",
+ "ids = tokenizer.convert_tokens_to_ids(tokens)\n",
+ "print(ids)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "kkAXLtuyWWDI"
+ },
+ "source": [
+ "### Preprocess the data\n",
+ "\n",
+ "The section manually preprocessed the dataset into the format expected by the model.\n",
+ "\n",
+ "This dataset is small, so preprocessing can be done quickly and easily in memory. For larger datasets the `tf_models` library includes some tools for preprocessing and re-serializing a dataset. See [Appendix: Re-encoding a large dataset](#re_encoding_tools) for details."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "62UTWLQd9-LB"
+ },
+ "source": [
+ "#### Encode the sentences\n",
+ "\n",
+ "The model expects its two inputs sentences to be concatenated together. This input is expected to start with a `[CLS]` \"This is a classification problem\" token, and each sentence should end with a `[SEP]` \"Separator\" token:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "bdL-dRNRBRJT"
+ },
+ "outputs": [],
+ "source": [
+ "tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "UrPktnqpwqie"
+ },
+ "source": [
+ "Start by encoding all the sentences while appending a `[SEP]` token, and packing them into ragged-tensors:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "BR7BmtU498Bh"
+ },
+ "outputs": [],
+ "source": [
+ "def encode_sentence(s):\n",
+ " tokens = list(tokenizer.tokenize(s.numpy()))\n",
+ " tokens.append('[SEP]')\n",
+ " return tokenizer.convert_tokens_to_ids(tokens)\n",
+ "\n",
+ "sentence1 = tf.ragged.constant([\n",
+ " encode_sentence(s) for s in glue_train[\"sentence1\"]])\n",
+ "sentence2 = tf.ragged.constant([\n",
+ " encode_sentence(s) for s in glue_train[\"sentence2\"]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "has42aUdfky-"
+ },
+ "outputs": [],
+ "source": [
+ "print(\"Sentence1 shape:\", sentence1.shape.as_list())\n",
+ "print(\"Sentence2 shape:\", sentence2.shape.as_list())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "MU9lTWy_xXbb"
+ },
+ "source": [
+ "Now prepend a `[CLS]` token, and concatenate the ragged tensors to form a single `input_word_ids` tensor for each example. `RaggedTensor.to_tensor()` zero pads to the longest sequence."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "USD8uihw-g4J"
+ },
+ "outputs": [],
+ "source": [
+ "cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*sentence1.shape[0]\n",
+ "input_word_ids = tf.concat([cls, sentence1, sentence2], axis=-1)\n",
+ "_ = plt.pcolormesh(input_word_ids.to_tensor())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xmNv4l4k-dBZ"
+ },
+ "source": [
+ "#### Mask and input type"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "DIWjNIKq-ldh"
+ },
+ "source": [
+ "The model expects two additional inputs:\n",
+ "\n",
+ "* The input mask\n",
+ "* The input type"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "ulNZ4U96-8JZ"
+ },
+ "source": [
+ "The mask allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the `input_word_ids`, and contains a `1` anywhere the `input_word_ids` is not padding."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "EezOO9qj91kP"
+ },
+ "outputs": [],
+ "source": [
+ "input_mask = tf.ones_like(input_word_ids).to_tensor()\n",
+ "\n",
+ "plt.pcolormesh(input_mask)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "rxLenwAvCkBf"
+ },
+ "source": [
+ "The \"input type\" also has the same shape, but inside the non-padded region, contains a `0` or a `1` indicating which sentence the token is a part of. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "2CetH_5C9P2m"
+ },
+ "outputs": [],
+ "source": [
+ "type_cls = tf.zeros_like(cls)\n",
+ "type_s1 = tf.zeros_like(sentence1)\n",
+ "type_s2 = tf.ones_like(sentence2)\n",
+ "input_type_ids = tf.concat([type_cls, type_s1, type_s2], axis=-1).to_tensor()\n",
+ "\n",
+ "plt.pcolormesh(input_type_ids)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "P5UBnCn8Ii6s"
+ },
+ "source": [
+ "#### Put it all together\n",
+ "\n",
+ "Collect the above text parsing code into a single function, and apply it to each split of the `glue/mrpc` dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "sDGiWYPLEd5a"
+ },
+ "outputs": [],
+ "source": [
+ "def encode_sentence(s, tokenizer):\n",
+ " tokens = list(tokenizer.tokenize(s))\n",
+ " tokens.append('[SEP]')\n",
+ " return tokenizer.convert_tokens_to_ids(tokens)\n",
+ "\n",
+ "def bert_encode(glue_dict, tokenizer):\n",
+ " num_examples = len(glue_dict[\"sentence1\"])\n",
+ " \n",
+ " sentence1 = tf.ragged.constant([\n",
+ " encode_sentence(s, tokenizer)\n",
+ " for s in np.array(glue_dict[\"sentence1\"])])\n",
+ " sentence2 = tf.ragged.constant([\n",
+ " encode_sentence(s, tokenizer)\n",
+ " for s in np.array(glue_dict[\"sentence2\"])])\n",
+ "\n",
+ " cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*sentence1.shape[0]\n",
+ " input_word_ids = tf.concat([cls, sentence1, sentence2], axis=-1)\n",
+ "\n",
+ " input_mask = tf.ones_like(input_word_ids).to_tensor()\n",
+ "\n",
+ " type_cls = tf.zeros_like(cls)\n",
+ " type_s1 = tf.zeros_like(sentence1)\n",
+ " type_s2 = tf.ones_like(sentence2)\n",
+ " input_type_ids = tf.concat(\n",
+ " [type_cls, type_s1, type_s2], axis=-1).to_tensor()\n",
+ "\n",
+ " inputs = {\n",
+ " 'input_word_ids': input_word_ids.to_tensor(),\n",
+ " 'input_mask': input_mask,\n",
+ " 'input_type_ids': input_type_ids}\n",
+ "\n",
+ " return inputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "yuLKxf6zHxw-"
+ },
+ "outputs": [],
+ "source": [
+ "glue_train = bert_encode(glue['train'], tokenizer)\n",
+ "glue_train_labels = glue['train']['label']\n",
+ "\n",
+ "glue_validation = bert_encode(glue['validation'], tokenizer)\n",
+ "glue_validation_labels = glue['validation']['label']\n",
+ "\n",
+ "glue_test = bert_encode(glue['test'], tokenizer)\n",
+ "glue_test_labels = glue['test']['label']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "7FC5aLVxKVKK"
+ },
+ "source": [
+ "Each subset of the data has been converted to a dictionary of features, and a set of labels. Each feature in the input dictionary has the same shape, and the number of labels should match:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "jyjTdGpFhO_1"
+ },
+ "outputs": [],
+ "source": [
+ "for key, value in glue_train.items():\n",
+ " print(f'{key:15s} shape: {value.shape}')\n",
+ "\n",
+ "print(f'glue_train_labels shape: {glue_train_labels.shape}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "FSwymsbkbLDA"
+ },
+ "source": [
+ "## The model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Efrj3Cn1kLAp"
+ },
+ "source": [
+ "### Build the model\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xxpOY5r2Ayq6"
+ },
+ "source": [
+ "The first step is to download the configuration for the pre-trained model.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "ujapVfZ_AKW7"
+ },
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "bert_config_file = os.path.join(gs_folder_bert, \"bert_config.json\")\n",
+ "config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
+ "\n",
+ "bert_config = bert.configs.BertConfig.from_dict(config_dict)\n",
+ "\n",
+ "config_dict"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "96ldxDSwkVkj"
+ },
+ "source": [
+ "The `config` defines the core BERT Model, which is a Keras model to predict the outputs of `num_classes` from the inputs with maximum sequence length `max_seq_length`.\n",
+ "\n",
+ "This function returns both the encoder and the classifier."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "cH682__U0FBv"
+ },
+ "outputs": [],
+ "source": [
+ "bert_classifier, bert_encoder = bert.bert_models.classifier_model(\n",
+ " bert_config, num_labels=2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "XqKp3-5GIZlw"
+ },
+ "source": [
+ "The classifier has three inputs and one output:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "bAQblMIjwkvx"
+ },
+ "outputs": [],
+ "source": [
+ "tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "sFmVG4SKZAw8"
+ },
+ "source": [
+ "Run it on a test batch of data 10 examples from the training set. The output is the logits for the two classes:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "VTjgPbp4ZDKo"
+ },
+ "outputs": [],
+ "source": [
+ "glue_batch = {key: val[:10] for key, val in glue_train.items()}\n",
+ "\n",
+ "bert_classifier(\n",
+ " glue_batch, training=True\n",
+ ").numpy()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Q0NTdwZsQK8n"
+ },
+ "source": [
+ "The `TransformerEncoder` in the center of the classifier above **is** the `bert_encoder`.\n",
+ "\n",
+ "Inspecting the encoder, we see its stack of `Transformer` layers connected to those same three inputs:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "8L__-erBwLIQ"
+ },
+ "outputs": [],
+ "source": [
+ "tf.keras.utils.plot_model(bert_encoder, show_shapes=True, dpi=48)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "mKAvkQc3heSy"
+ },
+ "source": [
+ "### Restore the encoder weights\n",
+ "\n",
+ "When built the encoder is randomly initialized. Restore the encoder's weights from the checkpoint:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "97Ll2Gichd_Y"
+ },
+ "outputs": [],
+ "source": [
+ "checkpoint = tf.train.Checkpoint(model=bert_encoder)\n",
+ "checkpoint.restore(\n",
+ " os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2oHOql35k3Dd"
+ },
+ "source": [
+ "Note: The pretrained `TransformerEncoder` is also available on [TensorFlow Hub](https://tensorflow.org/hub). See the [Hub appendix](#hub_bert) for details. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "115caFLMk-_l"
+ },
+ "source": [
+ "### Set up the optimizer\n",
+ "\n",
+ "BERT adopts the Adam optimizer with weight decay (aka \"[AdamW](https://arxiv.org/abs/1711.05101)\").\n",
+ "It also employs a learning rate schedule that firstly warms up from 0 and then decays to 0."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "w8qXKRZuCwW4"
+ },
+ "outputs": [],
+ "source": [
+ "# Set up epochs and steps\n",
+ "epochs = 3\n",
+ "batch_size = 32\n",
+ "eval_batch_size = 32\n",
+ "\n",
+ "train_data_size = len(glue_train_labels)\n",
+ "steps_per_epoch = int(train_data_size / batch_size)\n",
+ "num_train_steps = steps_per_epoch * epochs\n",
+ "warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)\n",
+ "\n",
+ "# creates an optimizer with learning rate schedule\n",
+ "optimizer = nlp.optimization.create_optimizer(\n",
+ " 2e-5, num_train_steps=num_train_steps, num_warmup_steps=warmup_steps)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "pXRGxiRNEHS2"
+ },
+ "source": [
+ "This returns an `AdamWeightDecay` optimizer with the learning rate schedule set:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "eQNA16bhDpky"
+ },
+ "outputs": [],
+ "source": [
+ "type(optimizer)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xqu_K71fJQB8"
+ },
+ "source": [
+ "To see an example of how to customize the optimizer and it's schedule, see the [Optimizer schedule appendix](#optiizer_schedule)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "78FEUOOEkoP0"
+ },
+ "source": [
+ "### Train the model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "OTNcA0O0nSq9"
+ },
+ "source": [
+ "The metric is accuracy and we use sparse categorical cross-entropy as loss."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "nzi8hjeTQTRs"
+ },
+ "outputs": [],
+ "source": [
+ "metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]\n",
+ "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
+ "\n",
+ "bert_classifier.compile(\n",
+ " optimizer=optimizer,\n",
+ " loss=loss,\n",
+ " metrics=metrics)\n",
+ "\n",
+ "bert_classifier.fit(\n",
+ " glue_train, glue_train_labels,\n",
+ " validation_data=(glue_validation, glue_validation_labels),\n",
+ " batch_size=32,\n",
+ " epochs=epochs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "IFtKFWbNKb0u"
+ },
+ "source": [
+ "Now run the fine-tuned model on a custom example to see that it works.\n",
+ "\n",
+ "Start by encoding some sentence pairs:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "9ZoUgDUNJPz3"
+ },
+ "outputs": [],
+ "source": [
+ "my_examples = bert_encode(\n",
+ " glue_dict = {\n",
+ " 'sentence1':[\n",
+ " 'The rain in Spain falls mainly on the plain.',\n",
+ " 'Look I fine tuned BERT.'],\n",
+ " 'sentence2':[\n",
+ " 'It mostly rains on the flat lands of Spain.',\n",
+ " 'Is it working? This does not match.']\n",
+ " },\n",
+ " tokenizer=tokenizer)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "7ynJibkBRTJF"
+ },
+ "source": [
+ "The model should report class `1` \"match\" for the first example and class `0` \"no-match\" for the second:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "umo0ttrgRYIM"
+ },
+ "outputs": [],
+ "source": [
+ "result = bert_classifier(my_examples, training=False)\n",
+ "\n",
+ "result = tf.argmax(result).numpy()\n",
+ "result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "utGl0M3aZCE4"
+ },
+ "outputs": [],
+ "source": [
+ "np.array(info.features['label'].names)[result]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "fVo_AnT0l26j"
+ },
+ "source": [
+ "### Save the model\n",
+ "\n",
+ "Often the goal of training a model is to _use_ it for something, so export the model and then restore it to be sure that it works."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Nl5x6nElZqkP"
+ },
+ "outputs": [],
+ "source": [
+ "export_dir='./saved_model'\n",
+ "tf.saved_model.save(bert_classifier, export_dir=export_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "y_ACvKPsVUXC"
+ },
+ "outputs": [],
+ "source": [
+ "reloaded = tf.saved_model.load(export_dir)\n",
+ "reloaded_result = reloaded([my_examples['input_word_ids'],\n",
+ " my_examples['input_mask'],\n",
+ " my_examples['input_type_ids']], training=False)\n",
+ "\n",
+ "original_result = bert_classifier(my_examples, training=False)\n",
+ "\n",
+ "# The results are (nearly) identical:\n",
+ "print(original_result.numpy())\n",
+ "print()\n",
+ "print(reloaded_result.numpy())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "eQceYqRFT_Eg"
+ },
+ "source": [
+ "## Appendix"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "SaC1RlFawUpc"
+ },
+ "source": [
+ "\u003ca id=re_encoding_tools\u003e\u003c/a\u003e\n",
+ "### Re-encoding a large dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "CwUdjFBkzUgh"
+ },
+ "source": [
+ "This tutorial you re-encoded the dataset in memory, for clarity.\n",
+ "\n",
+ "This was only possible because `glue/mrpc` is a very small dataset. To deal with larger datasets `tf_models` library includes some tools for processing and re-encoding a dataset for efficient training."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2UTQrkyOT5wD"
+ },
+ "source": [
+ "The first step is to describe which features of the dataset should be transformed:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "XQeDFOzYR9Z9"
+ },
+ "outputs": [],
+ "source": [
+ "processor = nlp.data.classifier_data_lib.TfdsProcessor(\n",
+ " tfds_params=\"dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2\",\n",
+ " process_text_fn=bert.tokenization.convert_to_unicode)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "XrFQbfErUWxa"
+ },
+ "source": [
+ "Then apply the transformation to generate new TFRecord files."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "ymw7GOHpSHKU"
+ },
+ "outputs": [],
+ "source": [
+ "# Set up output of training and evaluation Tensorflow dataset\n",
+ "train_data_output_path=\"./mrpc_train.tf_record\"\n",
+ "eval_data_output_path=\"./mrpc_eval.tf_record\"\n",
+ "\n",
+ "max_seq_length = 128\n",
+ "batch_size = 32\n",
+ "eval_batch_size = 32\n",
+ "\n",
+ "# Generate and save training data into a tf record file\n",
+ "input_meta_data = (\n",
+ " nlp.data.classifier_data_lib.generate_tf_record_from_data_file(\n",
+ " processor=processor,\n",
+ " data_dir=None, # It is `None` because data is from tfds, not local dir.\n",
+ " tokenizer=tokenizer,\n",
+ " train_data_output_path=train_data_output_path,\n",
+ " eval_data_output_path=eval_data_output_path,\n",
+ " max_seq_length=max_seq_length))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "uX_Sp-wTUoRm"
+ },
+ "source": [
+ "Finally create `tf.data` input pipelines from those TFRecord files:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "rkHxIK57SQ_r"
+ },
+ "outputs": [],
+ "source": [
+ "training_dataset = bert.run_classifier.get_dataset_fn(\n",
+ " train_data_output_path,\n",
+ " max_seq_length,\n",
+ " batch_size,\n",
+ " is_training=True)()\n",
+ "\n",
+ "evaluation_dataset = bert.run_classifier.get_dataset_fn(\n",
+ " eval_data_output_path,\n",
+ " max_seq_length,\n",
+ " eval_batch_size,\n",
+ " is_training=False)()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "stbaVouogvzS"
+ },
+ "source": [
+ "The resulting `tf.data.Datasets` return `(features, labels)` pairs, as expected by `keras.Model.fit`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "gwhrlQl4gxVF"
+ },
+ "outputs": [],
+ "source": [
+ "training_dataset.element_spec"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "dbJ76vSJj77j"
+ },
+ "source": [
+ "#### Create tf.data.Dataset for training and evaluation\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "9J95LFRohiYw"
+ },
+ "source": [
+ "If you need to modify the data loading here is some code to get you started:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "gCvaLLAxPuMc"
+ },
+ "outputs": [],
+ "source": [
+ "def create_classifier_dataset(file_path, seq_length, batch_size, is_training):\n",
+ " \"\"\"Creates input dataset from (tf)records files for train/eval.\"\"\"\n",
+ " dataset = tf.data.TFRecordDataset(file_path)\n",
+ " if is_training:\n",
+ " dataset = dataset.shuffle(100)\n",
+ " dataset = dataset.repeat()\n",
+ "\n",
+ " def decode_record(record):\n",
+ " name_to_features = {\n",
+ " 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
+ " 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
+ " 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
+ " 'label_ids': tf.io.FixedLenFeature([], tf.int64),\n",
+ " }\n",
+ " return tf.io.parse_single_example(record, name_to_features)\n",
+ "\n",
+ " def _select_data_from_record(record):\n",
+ " x = {\n",
+ " 'input_word_ids': record['input_ids'],\n",
+ " 'input_mask': record['input_mask'],\n",
+ " 'input_type_ids': record['segment_ids']\n",
+ " }\n",
+ " y = record['label_ids']\n",
+ " return (x, y)\n",
+ "\n",
+ " dataset = dataset.map(decode_record,\n",
+ " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
+ " dataset = dataset.map(\n",
+ " _select_data_from_record,\n",
+ " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
+ " dataset = dataset.batch(batch_size, drop_remainder=is_training)\n",
+ " dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n",
+ " return dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "rutkBadrhzdR"
+ },
+ "outputs": [],
+ "source": [
+ "# Set up batch sizes\n",
+ "batch_size = 32\n",
+ "eval_batch_size = 32\n",
+ "\n",
+ "# Return Tensorflow dataset\n",
+ "training_dataset = create_classifier_dataset(\n",
+ " train_data_output_path,\n",
+ " input_meta_data['max_seq_length'],\n",
+ " batch_size,\n",
+ " is_training=True)\n",
+ "\n",
+ "evaluation_dataset = create_classifier_dataset(\n",
+ " eval_data_output_path,\n",
+ " input_meta_data['max_seq_length'],\n",
+ " eval_batch_size,\n",
+ " is_training=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "59TVgt4Z7fuU"
+ },
+ "outputs": [],
+ "source": [
+ "training_dataset.element_spec"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "QbklKt-w_CiI"
+ },
+ "source": [
+ "\u003ca id=\"hub_bert\"\u003e\u003c/a\u003e\n",
+ "\n",
+ "### TFModels BERT on TFHub\n",
+ "\n",
+ "You can get [the BERT model](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2) off the shelf from [TFHub](https://tensorflow.org/hub). It would not be hard to add a classification head on top of this `hub.KerasLayer`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "lo6479At4sP1"
+ },
+ "outputs": [],
+ "source": [
+ "# Note: 350MB download.\n",
+ "import tensorflow_hub as hub\n",
+ "hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n",
+ "\n",
+ "print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "iTzF574wivQv"
+ },
+ "source": [
+ "Test run it on a batch of data:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "XEcYrCR45Uwo"
+ },
+ "outputs": [],
+ "source": [
+ "result = hub_encoder(\n",
+ " inputs=[glue_train['input_word_ids'][:10],\n",
+ " glue_train['input_mask'][:10],\n",
+ " glue_train['input_type_ids'][:10],],\n",
+ " training=False,\n",
+ ")\n",
+ "\n",
+ "print(\"Pooled output shape:\", result[0].shape)\n",
+ "print(\"Sequence output shape:\", result[1].shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "cjojn8SmLSRI"
+ },
+ "source": [
+ "At this point it would be simple to add a classification head yourself.\n",
+ "\n",
+ "The `bert_models.classifier_model` function can also build a classifier onto the encoder from TensorFlow Hub:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "9nTDaApyLR70"
+ },
+ "outputs": [],
+ "source": [
+ "hub_classifier, hub_encoder = bert.bert_models.classifier_model(\n",
+ " # Caution: Most of `bert_config` is ignored if you pass a hub url.\n",
+ " bert_config=bert_config, hub_module_url=hub_url_bert, num_labels=2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xMJX3wV0_v7I"
+ },
+ "source": [
+ "The one downside to loading this model from TFHub is that the structure of internal keras layers is not restored. So it's more difficult to inspect or modify the model. The `TransformerEncoder` model is now a single layer:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "pD71dnvhM2QS"
+ },
+ "outputs": [],
+ "source": [
+ "tf.keras.utils.plot_model(hub_classifier, show_shapes=True, dpi=64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "nLZD-isBzNKi"
+ },
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " tf.keras.utils.plot_model(hub_encoder, show_shapes=True, dpi=64)\n",
+ " assert False\n",
+ "except Exception as e:\n",
+ " print(f\"{type(e).__name__}: {e}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "ZxSqH0dNAgXV"
+ },
+ "source": [
+ "\u003ca id=\"model_builder_functions\"\u003e\u003c/a\u003e\n",
+ "\n",
+ "### Low level model building\n",
+ "\n",
+ "If you need a more control over the construction of the model it's worth noting that the `classifier_model` function used earlier is really just a thin wrapper over the `nlp.modeling.networks.TransformerEncoder` and `nlp.modeling.models.BertClassifier` classes. Just remember that if you start modifying the architecture it may not be correct or possible to reload the pre-trained checkpoint so you'll need to retrain from scratch."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "0cgABEwDj06P"
+ },
+ "source": [
+ "Build the encoder:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "5r_yqhBFSVEM"
+ },
+ "outputs": [],
+ "source": [
+ "transformer_config = config_dict.copy()\n",
+ "\n",
+ "# You need to rename a few fields to make this work:\n",
+ "transformer_config['attention_dropout_rate'] = transformer_config.pop('attention_probs_dropout_prob')\n",
+ "transformer_config['activation'] = tf_utils.get_activation(transformer_config.pop('hidden_act'))\n",
+ "transformer_config['dropout_rate'] = transformer_config.pop('hidden_dropout_prob')\n",
+ "transformer_config['initializer'] = tf.keras.initializers.TruncatedNormal(\n",
+ " stddev=transformer_config.pop('initializer_range'))\n",
+ "transformer_config['max_sequence_length'] = transformer_config.pop('max_position_embeddings')\n",
+ "transformer_config['num_layers'] = transformer_config.pop('num_hidden_layers')\n",
+ "\n",
+ "transformer_config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "rIO8MI7LLijh"
+ },
+ "outputs": [],
+ "source": [
+ "manual_encoder = nlp.modeling.networks.TransformerEncoder(**transformer_config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "4a4tFSg9krRi"
+ },
+ "source": [
+ "Restore the weights:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "X6N9NEqfXJCx"
+ },
+ "outputs": [],
+ "source": [
+ "checkpoint = tf.train.Checkpoint(model=manual_encoder)\n",
+ "checkpoint.restore(\n",
+ " os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "1BPiPO4ykuwM"
+ },
+ "source": [
+ "Test run it:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "hlVdgJKmj389"
+ },
+ "outputs": [],
+ "source": [
+ "result = manual_encoder(my_examples, training=True)\n",
+ "\n",
+ "print(\"Sequence output shape:\", result[0].shape)\n",
+ "print(\"Pooled output shape:\", result[1].shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "nJMXvVgJkyBv"
+ },
+ "source": [
+ "Wrap it in a classifier:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "tQX57GJ6wkAb"
+ },
+ "outputs": [],
+ "source": [
+ "manual_classifier = nlp.modeling.models.BertClassifier(\n",
+ " bert_encoder,\n",
+ " num_classes=2,\n",
+ " dropout_rate=transformer_config['dropout_rate'],\n",
+ " initializer=tf.keras.initializers.TruncatedNormal(\n",
+ " stddev=bert_config.initializer_range))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "kB-nBWhQk0dS"
+ },
+ "outputs": [],
+ "source": [
+ "manual_classifier(my_examples, training=True).numpy()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "E6AJlOSyIO1L"
+ },
+ "source": [
+ "\u003ca id=\"optiizer_schedule\"\u003e\u003c/a\u003e\n",
+ "\n",
+ "### Optimizers and schedules\n",
+ "\n",
+ "The optimizer used to train the model was created using the `nlp.optimization.create_optimizer` function:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "28Dv3BPRlFTD"
+ },
+ "outputs": [],
+ "source": [
+ "optimizer = nlp.optimization.create_optimizer(\n",
+ " 2e-5, num_train_steps=num_train_steps, num_warmup_steps=warmup_steps)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "LRjcHr0UlT8c"
+ },
+ "source": [
+ "That high level wrapper sets up the learning rate schedules and the optimizer.\n",
+ "\n",
+ "The base learning rate schedule used here is a linear decay to zero over the training run:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "MHY8K6kDngQn"
+ },
+ "outputs": [],
+ "source": [
+ "epochs = 3\n",
+ "batch_size = 32\n",
+ "eval_batch_size = 32\n",
+ "\n",
+ "train_data_size = len(glue_train_labels)\n",
+ "steps_per_epoch = int(train_data_size / batch_size)\n",
+ "num_train_steps = steps_per_epoch * epochs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "wKIcSprulu3P"
+ },
+ "outputs": [],
+ "source": [
+ "decay_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n",
+ " initial_learning_rate=2e-5,\n",
+ " decay_steps=num_train_steps,\n",
+ " end_learning_rate=0)\n",
+ "\n",
+ "plt.plot([decay_schedule(n) for n in range(num_train_steps)])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "IMTC_gfAl_PZ"
+ },
+ "source": [
+ "This, in turn is wrapped in a `WarmUp` schedule that linearly increases the learning rate to the target value over the first 10% of training:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "YRt3VTmBmCBY"
+ },
+ "outputs": [],
+ "source": [
+ "warmup_steps = num_train_steps * 0.1\n",
+ "\n",
+ "warmup_schedule = nlp.optimization.WarmUp(\n",
+ " initial_learning_rate=2e-5,\n",
+ " decay_schedule_fn=decay_schedule,\n",
+ " warmup_steps=warmup_steps)\n",
+ "\n",
+ "# The warmup overshoots, because it warms up to the `initial_learning_rate`\n",
+ "# following the original implementation. You can set\n",
+ "# `initial_learning_rate=decay_schedule(warmup_steps)` if you don't like the\n",
+ "# overshoot.\n",
+ "plt.plot([warmup_schedule(n) for n in range(num_train_steps)])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "l8D9Lv3Bn740"
+ },
+ "source": [
+ "Then create the `nlp.optimization.AdamWeightDecay` using that schedule, configured for the BERT model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "2Hf2rpRXk89N"
+ },
+ "outputs": [],
+ "source": [
+ "optimizer = nlp.optimization.AdamWeightDecay(\n",
+ " learning_rate=warmup_schedule,\n",
+ " weight_decay_rate=0.01,\n",
+ " epsilon=1e-6,\n",
+ " exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "fine_tuning_bert.ipynb",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/models/official/core/__init__.py b/models/official/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/core/base_task.py b/models/official/core/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..31811cbe6606fac61b664973717f4c75b6b4b37b
--- /dev/null
+++ b/models/official/core/base_task.py
@@ -0,0 +1,303 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines the base task abstraction."""
+import abc
+import functools
+from typing import Any, Callable, Optional
+
+import six
+import tensorflow as tf
+
+from official.modeling.hyperparams import config_definitions as cfg
+from official.utils import registry
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Task(tf.Module):
+ """A single-replica view of training procedure.
+
+ Tasks provide artifacts for training/evalution procedures, including
+ loading/iterating over Datasets, initializing the model, calculating the loss
+ and customized metrics with reduction.
+ """
+
+ # Special keys in train/validate step returned logs.
+ loss = "loss"
+
+ def __init__(self, params: cfg.TaskConfig):
+ self._task_config = params
+
+ @property
+ def task_config(self) -> cfg.TaskConfig:
+ return self._task_config
+
+ def initialize(self, model: tf.keras.Model):
+ """A callback function used as CheckpointManager's init_fn.
+
+ This function will be called when no checkpoint found for the model.
+ If there is a checkpoint, the checkpoint will be loaded and this function
+ will not be called. You can use this callback function to load a pretrained
+ checkpoint, saved under a directory other than the model_dir.
+
+ Args:
+ model: The keras.Model built or used by this task.
+ """
+ pass
+
+ @abc.abstractmethod
+ def build_model(self) -> tf.keras.Model:
+ """Creates the model architecture.
+
+ Returns:
+ A model instance.
+ """
+
+ def compile_model(self,
+ model: tf.keras.Model,
+ optimizer: tf.keras.optimizers.Optimizer,
+ loss=None,
+ train_step: Optional[Callable[..., Any]] = None,
+ validation_step: Optional[Callable[..., Any]] = None,
+ **kwargs) -> tf.keras.Model:
+ """Compiles the model with objects created by the task.
+
+ The method should not be used in any customized training implementation.
+
+ Args:
+ model: a keras.Model.
+ optimizer: the keras optimizer.
+ loss: a callable/list of losses.
+ train_step: optional train step function defined by the task.
+ validation_step: optional validation_step step function defined by the
+ task.
+ **kwargs: other kwargs consumed by keras.Model compile().
+
+ Returns:
+ a compiled keras.Model.
+ """
+ if bool(loss is None) == bool(train_step is None):
+ raise ValueError("`loss` and `train_step` should be exclusive to "
+ "each other.")
+ model.compile(optimizer=optimizer, loss=loss, **kwargs)
+
+ if train_step:
+ model.train_step = functools.partial(
+ train_step, model=model, optimizer=model.optimizer)
+ if validation_step:
+ model.test_step = functools.partial(validation_step, model=model)
+ return model
+
+ @abc.abstractmethod
+ def build_inputs(self,
+ params: cfg.DataConfig,
+ input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a dataset or a nested structure of dataset functions.
+
+ Dataset functions define per-host datasets with the per-replica batch size.
+
+ Args:
+ params: hyperparams to create input pipelines.
+ input_context: optional distribution input pipeline context.
+
+ Returns:
+ A nested structure of per-replica input functions.
+ """
+
+ def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
+ """Standard interface to compute losses.
+
+ Args:
+ labels: optional label tensors.
+ model_outputs: a nested structure of output tensors.
+ aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
+
+ Returns:
+ The total loss tensor.
+ """
+ del model_outputs, labels
+
+ if aux_losses is None:
+ losses = [tf.constant(0.0, dtype=tf.float32)]
+ else:
+ losses = aux_losses
+ total_loss = tf.add_n(losses)
+ return total_loss
+
+ def build_metrics(self, training: bool = True):
+ """Gets streaming metrics for training/validation."""
+ del training
+ return []
+
+ def process_metrics(self, metrics, labels, model_outputs):
+ """Process and update metrics. Called when using custom training loop API.
+
+ Args:
+ metrics: a nested structure of metrics objects.
+ The return of function self.build_metrics.
+ labels: a tensor or a nested structure of tensors.
+ model_outputs: a tensor or a nested structure of tensors.
+ For example, output of the keras model built by self.build_model.
+ """
+ for metric in metrics:
+ metric.update_state(labels, model_outputs)
+
+ def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
+ """Process and update compiled_metrics. call when using compile/fit API.
+
+ Args:
+ compiled_metrics: the compiled metrics (model.compiled_metrics).
+ labels: a tensor or a nested structure of tensors.
+ model_outputs: a tensor or a nested structure of tensors.
+ For example, output of the keras model built by self.build_model.
+ """
+ compiled_metrics.update_state(labels, model_outputs)
+
+ def train_step(self,
+ inputs,
+ model: tf.keras.Model,
+ optimizer: tf.keras.optimizers.Optimizer,
+ metrics=None):
+ """Does forward and backward.
+
+ Args:
+ inputs: a dictionary of input tensors.
+ model: the model, forward pass definition.
+ optimizer: the optimizer for this training step.
+ metrics: a nested structure of metrics objects.
+
+ Returns:
+ A dictionary of logs.
+ """
+ if isinstance(inputs, tuple) and len(inputs) == 2:
+ features, labels = inputs
+ else:
+ features, labels = inputs, inputs
+ with tf.GradientTape() as tape:
+ outputs = model(features, training=True)
+ # Computes per-replica loss.
+ loss = self.build_losses(
+ labels=labels, model_outputs=outputs, aux_losses=model.losses)
+ # Scales loss as the default gradients allreduce performs sum inside the
+ # optimizer.
+ scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
+
+ # For mixed precision, when a LossScaleOptimizer is used, the loss is
+ # scaled to avoid numeric underflow.
+ if isinstance(optimizer,
+ tf.keras.mixed_precision.experimental.LossScaleOptimizer):
+ scaled_loss = optimizer.get_scaled_loss(scaled_loss)
+
+ tvars = model.trainable_variables
+ grads = tape.gradient(scaled_loss, tvars)
+
+ if isinstance(optimizer,
+ tf.keras.mixed_precision.experimental.LossScaleOptimizer):
+ grads = optimizer.get_unscaled_gradients(grads)
+ optimizer.apply_gradients(list(zip(grads, tvars)))
+ logs = {self.loss: loss}
+ if metrics:
+ self.process_metrics(metrics, labels, outputs)
+ logs.update({m.name: m.result() for m in metrics})
+ elif model.compiled_metrics:
+ self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
+ logs.update({m.name: m.result() for m in model.metrics})
+ return logs
+
+ def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
+ """Validatation step.
+
+ Args:
+ inputs: a dictionary of input tensors.
+ model: the keras.Model.
+ metrics: a nested structure of metrics objects.
+
+ Returns:
+ A dictionary of logs.
+ """
+ if isinstance(inputs, tuple) and len(inputs) == 2:
+ features, labels = inputs
+ else:
+ features, labels = inputs, inputs
+ outputs = self.inference_step(features, model)
+ loss = self.build_losses(
+ labels=labels, model_outputs=outputs, aux_losses=model.losses)
+ logs = {self.loss: loss}
+ if metrics:
+ self.process_metrics(metrics, labels, outputs)
+ logs.update({m.name: m.result() for m in metrics})
+ elif model.compiled_metrics:
+ self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
+ logs.update({m.name: m.result() for m in model.metrics})
+ return logs
+
+ def inference_step(self, inputs, model: tf.keras.Model):
+ """Performs the forward step."""
+ return model(inputs, training=False)
+
+ def aggregate_logs(self, state, step_logs):
+ """Optional aggregation over logs returned from a validation step."""
+ pass
+
+ def reduce_aggregated_logs(self, aggregated_logs):
+ """Optional reduce of aggregated logs over validation steps."""
+ return {}
+
+
+_REGISTERED_TASK_CLS = {}
+
+
+# TODO(b/158268740): Move these outside the base class file.
+# TODO(b/158741360): Add type annotations once pytype checks across modules.
+def register_task_cls(task_config_cls):
+ """Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
+
+ This decorator supports registration of tasks as follows:
+
+ ```
+ @dataclasses.dataclass
+ class MyTaskConfig(TaskConfig):
+ # Add fields here.
+ pass
+
+ @register_task_cls(MyTaskConfig)
+ class MyTask(Task):
+ # Inherits def __init__(self, task_config).
+ pass
+
+ my_task_config = MyTaskConfig()
+ my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
+ ```
+
+ Besisdes a class itself, other callables that create a Task from a TaskConfig
+ can be decorated by the result of this function, as long as there is at most
+ one registration for each config class.
+
+ Args:
+ task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
+ Each task_config_cls can only be used for a single registration.
+
+ Returns:
+ A callable for use as class decorator that registers the decorated class
+ for creation from an instance of task_config_cls.
+ """
+ return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
+
+
+# The user-visible get_task() is defined after classes have been registered.
+# TODO(b/158741360): Add type annotations once pytype checks across modules.
+def get_task_cls(task_config_cls):
+ task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
+ return task_cls
diff --git a/models/official/core/input_reader.py b/models/official/core/input_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..52f6e84e4bd02d4178586556ca191912de18fc18
--- /dev/null
+++ b/models/official/core/input_reader.py
@@ -0,0 +1,223 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A common dataset reader."""
+
+from typing import Any, Callable, List, Optional
+
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+from official.modeling.hyperparams import config_definitions as cfg
+
+
+class InputReader:
+ """Input reader that returns a tf.data.Dataset instance."""
+
+ def __init__(self,
+ params: cfg.DataConfig,
+ shards: Optional[List[str]] = None,
+ dataset_fn=tf.data.TFRecordDataset,
+ decoder_fn: Optional[Callable[..., Any]] = None,
+ parser_fn: Optional[Callable[..., Any]] = None,
+ dataset_transform_fn: Optional[Callable[[tf.data.Dataset],
+ tf.data.Dataset]] = None,
+ postprocess_fn: Optional[Callable[..., Any]] = None):
+ """Initializes an InputReader instance.
+
+ Args:
+ params: A config_definitions.DataConfig object.
+ shards: A list of files to be read. If given, read from these files.
+ Otherwise, read from params.input_path.
+ dataset_fn: A `tf.data.Dataset` that consumes the input files. For
+ example, it can be `tf.data.TFRecordDataset`.
+ decoder_fn: An optional `callable` that takes the serialized data string
+ and decodes them into the raw tensor dictionary.
+ parser_fn: An optional `callable` that takes the decoded raw tensors dict
+ and parse them into a dictionary of tensors that can be consumed by the
+ model. It will be executed after decoder_fn.
+ dataset_transform_fn: An optional `callable` that takes a
+ `tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be
+ executed after parser_fn.
+ postprocess_fn: A optional `callable` that processes batched tensors. It
+ will be executed after batching.
+ """
+ if params.input_path and params.tfds_name:
+ raise ValueError('At most one of `input_path` and `tfds_name` can be '
+ 'specified, but got %s and %s.' % (
+ params.input_path, params.tfds_name))
+ self._shards = shards
+ self._tfds_builder = None
+ if self._shards:
+ self._num_files = len(self._shards)
+ elif not params.tfds_name:
+ self._input_patterns = params.input_path.strip().split(',')
+ self._num_files = 0
+ for input_pattern in self._input_patterns:
+ input_pattern = input_pattern.strip()
+ if not input_pattern:
+ continue
+ matched_files = tf.io.gfile.glob(input_pattern)
+ if not matched_files:
+ raise ValueError('%s does not match any files.' % input_pattern)
+ else:
+ self._num_files += len(matched_files)
+ if self._num_files == 0:
+ raise ValueError('%s does not match any files.' % params.input_path)
+ else:
+ if not params.tfds_split:
+ raise ValueError(
+ '`tfds_name` is %s, but `tfds_split` is not specified.' %
+ params.tfds_name)
+ self._tfds_builder = tfds.builder(
+ params.tfds_name, data_dir=params.tfds_data_dir)
+
+ self._global_batch_size = params.global_batch_size
+ self._is_training = params.is_training
+ self._drop_remainder = params.drop_remainder
+ self._shuffle_buffer_size = params.shuffle_buffer_size
+ self._cache = params.cache
+ self._cycle_length = params.cycle_length
+ self._block_length = params.block_length
+ self._sharding = params.sharding
+ self._examples_consume = params.examples_consume
+ self._tfds_split = params.tfds_split
+ self._tfds_download = params.tfds_download
+ self._tfds_as_supervised = params.tfds_as_supervised
+ self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
+
+ self._dataset_fn = dataset_fn
+ self._decoder_fn = decoder_fn
+ self._parser_fn = parser_fn
+ self._dataset_transform_fn = dataset_transform_fn
+ self._postprocess_fn = postprocess_fn
+
+ def _read_sharded_files(
+ self,
+ input_context: Optional[tf.distribute.InputContext] = None):
+ """Reads a dataset from sharded files."""
+ # Read from `self._shards` if it is provided.
+ if self._shards:
+ dataset = tf.data.Dataset.from_tensor_slices(self._shards)
+ else:
+ dataset = tf.data.Dataset.list_files(
+ self._input_patterns, shuffle=self._is_training)
+ if self._sharding and input_context and (
+ input_context.num_input_pipelines > 1):
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+ if self._is_training:
+ dataset = dataset.repeat()
+
+ dataset = dataset.interleave(
+ map_func=self._dataset_fn,
+ cycle_length=self._cycle_length,
+ block_length=self._block_length,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ def _read_single_file(
+ self,
+ input_context: Optional[tf.distribute.InputContext] = None):
+ """Reads a dataset from a single file."""
+ # Read from `self._shards` if it is provided.
+ dataset = self._dataset_fn(self._shards or self._input_patterns)
+
+ # When `input_file` is a path to a single file, disable auto sharding
+ # so that same input file is sent to all workers.
+ options = tf.data.Options()
+ options.experimental_distribute.auto_shard_policy = (
+ tf.data.experimental.AutoShardPolicy.OFF)
+ dataset = dataset.with_options(options)
+ if self._sharding and input_context and (
+ input_context.num_input_pipelines > 1):
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+ if self._is_training:
+ dataset = dataset.repeat()
+ return dataset
+
+ def _read_tfds(
+ self,
+ input_context: Optional[tf.distribute.InputContext] = None
+ ) -> tf.data.Dataset:
+ """Reads a dataset from tfds."""
+ if self._tfds_download:
+ self._tfds_builder.download_and_prepare()
+
+ read_config = tfds.ReadConfig(
+ interleave_cycle_length=self._cycle_length,
+ interleave_block_length=self._block_length,
+ input_context=input_context)
+ decoders = {}
+ if self._tfds_skip_decoding_feature:
+ for skip_feature in self._tfds_skip_decoding_feature.split(','):
+ decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
+ dataset = self._tfds_builder.as_dataset(
+ split=self._tfds_split,
+ shuffle_files=self._is_training,
+ as_supervised=self._tfds_as_supervised,
+ decoders=decoders,
+ read_config=read_config)
+ return dataset
+
+ @property
+ def tfds_info(self) -> tfds.core.DatasetInfo:
+ """Returns TFDS dataset info, if available."""
+ if self._tfds_builder:
+ return self._tfds_builder.info
+ else:
+ raise ValueError('tfds_info is not available, because the dataset '
+ 'is not loaded from tfds.')
+
+ def read(
+ self,
+ input_context: Optional[tf.distribute.InputContext] = None
+ ) -> tf.data.Dataset:
+ """Generates a tf.data.Dataset object."""
+ if self._tfds_builder:
+ dataset = self._read_tfds(input_context)
+ elif self._num_files > 1:
+ dataset = self._read_sharded_files(input_context)
+ else:
+ assert self._num_files == 1
+ dataset = self._read_single_file(input_context)
+
+ if self._cache:
+ dataset = dataset.cache()
+
+ if self._is_training:
+ dataset = dataset.shuffle(self._shuffle_buffer_size)
+
+ if self._examples_consume > 0:
+ dataset = dataset.take(self._examples_consume)
+
+ def maybe_map_fn(dataset, fn):
+ return dataset if fn is None else dataset.map(
+ fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ dataset = maybe_map_fn(dataset, self._decoder_fn)
+ dataset = maybe_map_fn(dataset, self._parser_fn)
+
+ if self._dataset_transform_fn is not None:
+ dataset = self._dataset_transform_fn(dataset)
+
+ per_replica_batch_size = input_context.get_per_replica_batch_size(
+ self._global_batch_size) if input_context else self._global_batch_size
+
+ dataset = dataset.batch(
+ per_replica_batch_size, drop_remainder=self._drop_remainder)
+ dataset = maybe_map_fn(dataset, self._postprocess_fn)
+ return dataset.prefetch(tf.data.experimental.AUTOTUNE)
diff --git a/models/official/modeling/__init__.py b/models/official/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/modeling/activations/__init__.py b/models/official/modeling/activations/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b558fef3cb276c61e58d93c219db6a899c107ef
--- /dev/null
+++ b/models/official/modeling/activations/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Activations package definition."""
+from official.modeling.activations.gelu import gelu
+from official.modeling.activations.swish import hard_swish
+from official.modeling.activations.swish import identity
+from official.modeling.activations.swish import simple_swish
diff --git a/models/official/modeling/activations/gelu.py b/models/official/modeling/activations/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c045bffa95b29e069831b548701b76d1b8e76c0d
--- /dev/null
+++ b/models/official/modeling/activations/gelu.py
@@ -0,0 +1,40 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gaussian error linear unit."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+def gelu(x):
+ """Gaussian Error Linear Unit.
+
+ This is a smoother version of the RELU.
+ Original paper: https://arxiv.org/abs/1606.08415
+ Args:
+ x: float Tensor to perform activation.
+
+ Returns:
+ `x` with the GELU activation applied.
+ """
+ cdf = 0.5 * (1.0 + tf.tanh(
+ (math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
+ return x * cdf
diff --git a/models/official/modeling/activations/gelu_test.py b/models/official/modeling/activations/gelu_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc3b95ca8be16c058c592247684e45d419b50cc5
--- /dev/null
+++ b/models/official/modeling/activations/gelu_test.py
@@ -0,0 +1,38 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the Gaussian error linear unit."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.modeling import activations
+
+
+@keras_parameterized.run_all_keras_modes
+class GeluTest(keras_parameterized.TestCase):
+
+ def test_gelu(self):
+ expected_data = [[0.14967535, 0., -0.10032465],
+ [-0.15880796, -0.04540223, 2.9963627]]
+ gelu_data = activations.gelu([[.25, 0, -.25], [-1, -2, 3]])
+ self.assertAllClose(expected_data, gelu_data)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/modeling/activations/swish.py b/models/official/modeling/activations/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d799613095efe1a16dade9673adddee05f2679d
--- /dev/null
+++ b/models/official/modeling/activations/swish.py
@@ -0,0 +1,75 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Customized Swish activation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+def simple_swish(features):
+ """Computes the Swish activation function.
+
+ The tf.nn.swish operation uses a custom gradient to reduce memory usage.
+ Since saving custom gradients in SavedModel is currently not supported, and
+ one would not be able to use an exported TF-Hub module for fine-tuning, we
+ provide this wrapper that can allow to select whether to use the native
+ TensorFlow swish operation, or whether to use a customized operation that
+ has uses default TensorFlow gradient computation.
+
+ Args:
+ features: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ features = tf.convert_to_tensor(features)
+ return features * tf.nn.sigmoid(features)
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+def hard_swish(features):
+ """Computes a hard version of the swish function.
+
+ This operation can be used to reduce computational cost and improve
+ quantization for edge devices.
+
+ Args:
+ features: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ features = tf.convert_to_tensor(features)
+ return features * tf.nn.relu6(features + tf.constant(3.)) * (1. / 6.)
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+def identity(features):
+ """Computes the identity function.
+
+ Useful for helping in quantization.
+
+ Args:
+ features: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ features = tf.convert_to_tensor(features)
+ return tf.identity(features)
diff --git a/models/official/modeling/activations/swish_test.py b/models/official/modeling/activations/swish_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..22042e9a290a420805fc75bbfca6ded6e917d9eb
--- /dev/null
+++ b/models/official/modeling/activations/swish_test.py
@@ -0,0 +1,49 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the customized Swish activation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.modeling import activations
+
+
+@keras_parameterized.run_all_keras_modes
+class CustomizedSwishTest(keras_parameterized.TestCase):
+
+ def _hard_swish_np(self, x):
+ x = np.float32(x)
+ return x * np.clip(x + 3, 0, 6) / 6
+
+ def test_simple_swish(self):
+ features = [[.25, 0, -.25], [-1, -2, 3]]
+ customized_swish_data = activations.simple_swish(features)
+ swish_data = tf.nn.swish(features)
+ self.assertAllClose(customized_swish_data, swish_data)
+
+ def test_hard_swish(self):
+ features = [[.25, 0, -.25], [-1, -2, 3]]
+ customized_swish_data = activations.hard_swish(features)
+ swish_data = self._hard_swish_np(features)
+ self.assertAllClose(customized_swish_data, swish_data)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/modeling/hyperparams/__init__.py b/models/official/modeling/hyperparams/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..87c00e7f2a1934800cac21405aa924f2ddc1f241
--- /dev/null
+++ b/models/official/modeling/hyperparams/__init__.py
@@ -0,0 +1,21 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Hyperparams package definition."""
+# pylint: disable=g-multiple-import
+from official.modeling.hyperparams.base_config import *
+from official.modeling.hyperparams.oneof import *
+from official.modeling.hyperparams.params_dict import *
+
diff --git a/models/official/modeling/hyperparams/base_config.py b/models/official/modeling/hyperparams/base_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce5ce2d55016dce0c985a0e6f9fe3893a25f644
--- /dev/null
+++ b/models/official/modeling/hyperparams/base_config.py
@@ -0,0 +1,248 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base configurations to standardize experiments."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import copy
+import functools
+from typing import Any, List, Mapping, Optional, Type
+
+import dataclasses
+import tensorflow as tf
+import yaml
+
+from official.modeling.hyperparams import params_dict
+
+
+@dataclasses.dataclass
+class Config(params_dict.ParamsDict):
+ """The base configuration class that supports YAML/JSON based overrides.
+
+ * It recursively enforces a whitelist of basic types and container types, so
+ it avoids surprises with copy and reuse caused by unanticipated types.
+ * It converts dict to Config even within sequences,
+ e.g. for config = Config({'key': [([{'a': 42}],)]),
+ type(config.key[0][0][0]) is Config rather than dict.
+ """
+
+ # It's safe to add bytes and other immutable types here.
+ IMMUTABLE_TYPES = (str, int, float, bool, type(None))
+ # It's safe to add set, frozenset and other collections here.
+ SEQUENCE_TYPES = (list, tuple)
+
+ default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
+ restrictions: dataclasses.InitVar[Optional[List[str]]] = None
+
+ @classmethod
+ def _isvalidsequence(cls, v):
+ """Check if the input values are valid sequences.
+
+ Args:
+ v: Input sequence.
+
+ Returns:
+ True if the sequence is valid. Valid sequence includes the sequence
+ type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
+ is dict or ParamsDict.
+ """
+ if not isinstance(v, cls.SEQUENCE_TYPES):
+ return False
+ return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
+ all(isinstance(e, dict) for e in v) or
+ all(isinstance(e, params_dict.ParamsDict) for e in v))
+
+ @classmethod
+ def _import_config(cls, v, subconfig_type):
+ """Returns v with dicts converted to Configs, recursively."""
+ if not issubclass(subconfig_type, params_dict.ParamsDict):
+ raise TypeError(
+ 'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
+ subconfig_type))
+ if isinstance(v, cls.IMMUTABLE_TYPES):
+ return v
+ elif isinstance(v, cls.SEQUENCE_TYPES):
+ # Only support one layer of sequence.
+ if not cls._isvalidsequence(v):
+ raise TypeError(
+ 'Invalid sequence: only supports single level {!r} of {!r} or '
+ 'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
+ cls.IMMUTABLE_TYPES, v))
+ import_fn = functools.partial(
+ cls._import_config, subconfig_type=subconfig_type)
+ return type(v)(map(import_fn, v))
+ elif isinstance(v, params_dict.ParamsDict):
+ # Deepcopy here is a temporary solution for preserving type in nested
+ # Config object.
+ return copy.deepcopy(v)
+ elif isinstance(v, dict):
+ return subconfig_type(v)
+ else:
+ raise TypeError('Unknown type: {!r}'.format(type(v)))
+
+ @classmethod
+ def _export_config(cls, v):
+ """Returns v with Configs converted to dicts, recursively."""
+ if isinstance(v, cls.IMMUTABLE_TYPES):
+ return v
+ elif isinstance(v, cls.SEQUENCE_TYPES):
+ return type(v)(map(cls._export_config, v))
+ elif isinstance(v, params_dict.ParamsDict):
+ return v.as_dict()
+ elif isinstance(v, dict):
+ raise TypeError('dict value not supported in converting.')
+ else:
+ raise TypeError('Unknown type: {!r}'.format(type(v)))
+
+ @classmethod
+ def _get_subconfig_type(cls, k) -> Type[params_dict.ParamsDict]:
+ """Get element type by the field name.
+
+ Args:
+ k: the key/name of the field.
+
+ Returns:
+ Config as default. If a type annotation is found for `k`,
+ 1) returns the type of the annotation if it is subtype of ParamsDict;
+ 2) returns the element type if the annotation of `k` is List[SubType]
+ or Tuple[SubType].
+ """
+ subconfig_type = Config
+ if k in cls.__annotations__:
+ # Directly Config subtype.
+ type_annotation = cls.__annotations__[k]
+ if (isinstance(type_annotation, type) and
+ issubclass(type_annotation, Config)):
+ subconfig_type = cls.__annotations__[k]
+ else:
+ # Check if the field is a sequence of subtypes.
+ field_type = getattr(type_annotation, '__origin__', type(None))
+ if (isinstance(field_type, type) and
+ issubclass(field_type, cls.SEQUENCE_TYPES)):
+ element_type = getattr(type_annotation, '__args__', [type(None)])[0]
+ subconfig_type = (
+ element_type if issubclass(element_type, params_dict.ParamsDict)
+ else subconfig_type)
+ return subconfig_type
+
+ def __post_init__(self, default_params, restrictions, *args, **kwargs):
+ super().__init__(default_params=default_params,
+ restrictions=restrictions,
+ *args,
+ **kwargs)
+
+ def _set(self, k, v):
+ """Overrides same method in ParamsDict.
+
+ Also called by ParamsDict methods.
+
+ Args:
+ k: key to set.
+ v: value.
+
+ Raises:
+ RuntimeError
+ """
+ subconfig_type = self._get_subconfig_type(k)
+ if isinstance(v, dict):
+ if k not in self.__dict__ or not self.__dict__[k]:
+ # If the key not exist or the value is None, a new Config-family object
+ # sould be created for the key.
+ self.__dict__[k] = subconfig_type(v)
+ else:
+ self.__dict__[k].override(v)
+ else:
+ self.__dict__[k] = self._import_config(v, subconfig_type)
+
+ def __setattr__(self, k, v):
+ if k not in self.RESERVED_ATTR:
+ if getattr(self, '_locked', False):
+ raise ValueError('The Config has been locked. ' 'No change is allowed.')
+ self._set(k, v)
+
+ def _override(self, override_dict, is_strict=True):
+ """Overrides same method in ParamsDict.
+
+ Also called by ParamsDict methods.
+
+ Args:
+ override_dict: dictionary to write to .
+ is_strict: If True, not allows to add new keys.
+
+ Raises:
+ KeyError: overriding reserved keys or keys not exist (is_strict=True).
+ """
+ for k, v in sorted(override_dict.items()):
+ if k in self.RESERVED_ATTR:
+ raise KeyError('The key {!r} is internally reserved. '
+ 'Can not be overridden.'.format(k))
+ if k not in self.__dict__:
+ if is_strict:
+ raise KeyError('The key {!r} does not exist in {!r}. '
+ 'To extend the existing keys, use '
+ '`override` with `is_strict` = False.'.format(
+ k, type(self)))
+ else:
+ self._set(k, v)
+ else:
+ if isinstance(v, dict) and self.__dict__[k]:
+ self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
+ elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
+ self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
+ else:
+ self._set(k, v)
+
+ def as_dict(self):
+ """Returns a dict representation of params_dict.ParamsDict.
+
+ For the nested params_dict.ParamsDict, a nested dict will be returned.
+ """
+ return {
+ k: self._export_config(v)
+ for k, v in self.__dict__.items()
+ if k not in self.RESERVED_ATTR
+ }
+
+ def replace(self, **kwargs):
+ """Like `override`, but returns a copy with the current config unchanged."""
+ params = self.__class__(self)
+ params.override(kwargs, is_strict=True)
+ return params
+
+ @classmethod
+ def from_yaml(cls, file_path: str):
+ # Note: This only works if the Config has all default values.
+ with tf.io.gfile.GFile(file_path, 'r') as f:
+ loaded = yaml.load(f)
+ config = cls()
+ config.override(loaded)
+ return config
+
+ @classmethod
+ def from_json(cls, file_path: str):
+ """Wrapper for `from_yaml`."""
+ return cls.from_yaml(file_path)
+
+ @classmethod
+ def from_args(cls, *args, **kwargs):
+ """Builds a config from the given list of arguments."""
+ attributes = list(cls.__annotations__.keys())
+ default_params = {a: p for a, p in zip(attributes, args)}
+ default_params.update(kwargs)
+ return cls(default_params)
diff --git a/models/official/modeling/hyperparams/base_config_test.py b/models/official/modeling/hyperparams/base_config_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..501f95899f526c8eab7cbfaaafb65433389ce0d8
--- /dev/null
+++ b/models/official/modeling/hyperparams/base_config_test.py
@@ -0,0 +1,299 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import pprint
+from typing import List, Tuple
+
+from absl.testing import parameterized
+import dataclasses
+import tensorflow as tf
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class DumpConfig1(base_config.Config):
+ a: int = 1
+ b: str = 'text'
+
+
+@dataclasses.dataclass
+class DumpConfig2(base_config.Config):
+ c: int = 2
+ d: str = 'text'
+ e: DumpConfig1 = DumpConfig1()
+
+
+@dataclasses.dataclass
+class DumpConfig3(DumpConfig2):
+ f: int = 2
+ g: str = 'text'
+ h: List[DumpConfig1] = dataclasses.field(
+ default_factory=lambda: [DumpConfig1(), DumpConfig1()])
+ g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
+
+
+class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
+
+ def assertHasSameTypes(self, c, d, msg=''):
+ """Checks if a Config has the same structure as a given dict.
+
+ Args:
+ c: the Config object to be check.
+ d: the reference dict object.
+ msg: The error message to show when type mismatched.
+ """
+ # Make sure d is not a Config. Assume d is either
+ # dictionary or primitive type and c is the Config or primitive types.
+ self.assertNotIsInstance(d, base_config.Config)
+ if isinstance(d, base_config.Config.IMMUTABLE_TYPES):
+ self.assertEqual(pprint.pformat(c), pprint.pformat(d), msg=msg)
+ elif isinstance(d, base_config.Config.SEQUENCE_TYPES):
+ self.assertEqual(type(c), type(d), msg=msg)
+ for i, v in enumerate(d):
+ self.assertHasSameTypes(c[i], v, msg='{}[{!r}]'.format(msg, i))
+ elif isinstance(d, dict):
+ self.assertIsInstance(c, base_config.Config, msg=msg)
+ for k, v in sorted(d.items()):
+ self.assertHasSameTypes(getattr(c, k), v, msg='{}[{!r}]'.format(msg, k))
+ else:
+ raise TypeError('Unknown type: %r' % type(d))
+
+ def assertImportExport(self, v):
+ config = base_config.Config({'key': v})
+ back = config.as_dict()['key']
+ self.assertEqual(pprint.pformat(back), pprint.pformat(v))
+ self.assertHasSameTypes(config.key, v, msg='=%s v' % pprint.pformat(v))
+
+ def test_invalid_keys(self):
+ params = base_config.Config()
+ with self.assertRaises(AttributeError):
+ _ = params.a
+
+ def test_nested_config_types(self):
+ config = DumpConfig3()
+ self.assertIsInstance(config.e, DumpConfig1)
+ self.assertIsInstance(config.h[0], DumpConfig1)
+ self.assertIsInstance(config.h[1], DumpConfig1)
+ self.assertIsInstance(config.g[0], DumpConfig1)
+
+ config.override({'e': {'a': 2, 'b': 'new text'}})
+ self.assertIsInstance(config.e, DumpConfig1)
+ self.assertEqual(config.e.a, 2)
+ self.assertEqual(config.e.b, 'new text')
+
+ config.override({'h': [{'a': 3, 'b': 'new text 2'}]})
+ self.assertIsInstance(config.h[0], DumpConfig1)
+ self.assertLen(config.h, 1)
+ self.assertEqual(config.h[0].a, 3)
+ self.assertEqual(config.h[0].b, 'new text 2')
+
+ config.override({'g': [{'a': 4, 'b': 'new text 3'}]})
+ self.assertIsInstance(config.g[0], DumpConfig1)
+ self.assertLen(config.g, 1)
+ self.assertEqual(config.g[0].a, 4)
+ self.assertEqual(config.g[0].b, 'new text 3')
+
+ @parameterized.parameters(
+ ('_locked', "The key '_locked' is internally reserved."),
+ ('_restrictions', "The key '_restrictions' is internally reserved."),
+ ('aa', "The key 'aa' does not exist."),
+ )
+ def test_key_error(self, key, msg):
+ params = base_config.Config()
+ with self.assertRaisesRegex(KeyError, msg):
+ params.override({key: True})
+
+ @parameterized.parameters(
+ ('str data',),
+ (123,),
+ (1.23,),
+ (None,),
+ (['str', 1, 2.3, None],),
+ (('str', 1, 2.3, None),),
+ )
+ def test_import_export_immutable_types(self, v):
+ self.assertImportExport(v)
+ out = base_config.Config({'key': v})
+ self.assertEqual(pprint.pformat(v), pprint.pformat(out.key))
+
+ def test_override_is_strict_true(self):
+ params = base_config.Config({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': {
+ 'c1': 'cc',
+ 'c2': 20
+ }
+ })
+ params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
+ self.assertEqual(params.a, 2)
+ self.assertEqual(params.c.c1, 'ccc')
+ with self.assertRaises(KeyError):
+ params.override({'d': 'ddd'}, is_strict=True)
+ with self.assertRaises(KeyError):
+ params.override({'c': {'c3': 30}}, is_strict=True)
+
+ config = base_config.Config({'key': [{'a': 42}]})
+ config.override({'key': [{'b': 43}]})
+ self.assertEqual(config.key[0].b, 43)
+ with self.assertRaisesRegex(AttributeError, 'The key `a` does not exist'):
+ _ = config.key[0].a
+
+ @parameterized.parameters(
+ (lambda x: x, 'Unknown type'),
+ (object(), 'Unknown type'),
+ (set(), 'Unknown type'),
+ (frozenset(), 'Unknown type'),
+ )
+ def test_import_unsupport_types(self, v, msg):
+ with self.assertRaisesRegex(TypeError, msg):
+ _ = base_config.Config({'key': v})
+
+ @parameterized.parameters(
+ ({
+ 'a': [{
+ 'b': 2,
+ }, {
+ 'c': 3,
+ }]
+ },),
+ ({
+ 'c': [{
+ 'f': 1.1,
+ }, {
+ 'h': [1, 2],
+ }]
+ },),
+ (({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': {
+ 'c1': 10,
+ 'c2': 20,
+ }
+ },),),
+ )
+ def test_import_export_nested_structure(self, d):
+ self.assertImportExport(d)
+
+ @parameterized.parameters(
+ ([{
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ }],),
+ (({
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ },),),
+ )
+ def test_import_export_nested_sequences(self, v):
+ self.assertImportExport(v)
+
+ @parameterized.parameters(
+ ([([{}],)],),
+ ([['str', 1, 2.3, None]],),
+ ((('str', 1, 2.3, None),),),
+ ([
+ ('str', 1, 2.3, None),
+ ],),
+ ([
+ ('str', 1, 2.3, None),
+ ],),
+ ([[{
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ }]],),
+ ([[[{
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ }]]],),
+ ((({
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ },),),),
+ (((({
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ },),),),),
+ ([({
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ },)],),
+ (([{
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ }],),),
+ )
+ def test_import_export_unsupport_sequence(self, v):
+ with self.assertRaisesRegex(TypeError,
+ 'Invalid sequence: only supports single level'):
+ _ = base_config.Config({'key': v})
+
+ def test_construct_subtype(self):
+ pass
+
+ def test_import_config(self):
+ params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
+ self.assertLen(params.a, 2)
+ self.assertEqual(params.a[0].b, 2)
+ self.assertEqual(type(params.a[0]), base_config.Config)
+ self.assertEqual(pprint.pformat(params.a[0].b), '2')
+ self.assertEqual(type(params.a[1]), base_config.Config)
+ self.assertEqual(type(params.a[1].c), base_config.Config)
+ self.assertEqual(pprint.pformat(params.a[1].c.d), '3')
+
+ def test_override(self):
+ params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
+ params.override({'a': [{'b': 4}, {'c': {'d': 5}}]}, is_strict=False)
+ self.assertEqual(type(params.a), list)
+ self.assertEqual(type(params.a[0]), base_config.Config)
+ self.assertEqual(pprint.pformat(params.a[0].b), '4')
+ self.assertEqual(type(params.a[1]), base_config.Config)
+ self.assertEqual(type(params.a[1].c), base_config.Config)
+ self.assertEqual(pprint.pformat(params.a[1].c.d), '5')
+
+ @parameterized.parameters(
+ ([{}],),
+ (({},),),
+ )
+ def test_config_vs_params_dict(self, v):
+ d = {'key': v}
+ self.assertEqual(type(base_config.Config(d).key[0]), base_config.Config)
+ self.assertEqual(type(base_config.params_dict.ParamsDict(d).key[0]), dict)
+
+ def test_ppformat(self):
+ self.assertEqual(
+ pprint.pformat([
+ 's', 1, 1.0, True, None, {}, [], (), {
+ (2,): (3, [4], {
+ 6: 7,
+ }),
+ 8: 9,
+ }
+ ]),
+ "['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/modeling/hyperparams/config_definitions.py b/models/official/modeling/hyperparams/config_definitions.py
new file mode 100644
index 0000000000000000000000000000000000000000..78180cd8a01a09a9d646c04eb05742bafce5bf42
--- /dev/null
+++ b/models/official/modeling/hyperparams/config_definitions.py
@@ -0,0 +1,220 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common configuration settings."""
+from typing import Optional, Union
+
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+from official.modeling.optimization.configs import optimization_config
+from official.utils import registry
+
+OptimizationConfig = optimization_config.OptimizationConfig
+
+
+@dataclasses.dataclass
+class DataConfig(base_config.Config):
+ """The base configuration for building datasets.
+
+ Attributes:
+ input_path: The path to the input. It can be either (1) a file pattern, or
+ (2) multiple file patterns separated by comma. It should not be specified
+ when the following `tfds_name` is specified.
+ tfds_name: The name of the tensorflow dataset (TFDS). It should not be
+ specified when the above `input_path` is specified.
+ tfds_split: A str indicating which split of the data to load from TFDS. It
+ is required when above `tfds_name` is specified.
+ global_batch_size: The global batch size across all replicas.
+ is_training: Whether this data is used for training or not.
+ drop_remainder: Whether the last batch should be dropped in the case it has
+ fewer than `global_batch_size` elements.
+ shuffle_buffer_size: The buffer size used for shuffling training data.
+ cache: Whether to cache dataset examples. Can be used to avoid re-reading
+ from disk on the second epoch. Requires significant memory overhead.
+ cycle_length: The number of files that will be processed concurrently when
+ interleaving files.
+ block_length: The number of consecutive elements to produce from each input
+ element before cycling to another input element when interleaving files.
+ sharding: Whether sharding is used in the input pipeline.
+ examples_consume: An `integer` specifying the number of examples it will
+ produce. If positive, it only takes this number of examples and raises
+ tf.error.OutOfRangeError after that. Default is -1, meaning it will
+ exhaust all the examples in the dataset.
+ tfds_data_dir: A str specifying the directory to read/write TFDS data.
+ tfds_download: A bool to indicate whether to download data using TFDS.
+ tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
+ the returned tf.data.Dataset will have a 2-tuple structure (input, label)
+ according to builder.info.supervised_keys; if False, the default,
+ the returned tf.data.Dataset will have a dictionary with all the features.
+ tfds_skip_decoding_feature: A str to indicate which features are skipped
+ for decoding when loading dataset from TFDS. Use comma to separate
+ multiple features. The main use case is to skip the image/video decoding
+ for better performance.
+ """
+ input_path: str = ""
+ tfds_name: str = ""
+ tfds_split: str = ""
+ global_batch_size: int = 0
+ is_training: bool = None
+ drop_remainder: bool = True
+ shuffle_buffer_size: int = 100
+ cache: bool = False
+ cycle_length: int = 8
+ block_length: int = 1
+ sharding: bool = True
+ examples_consume: int = -1
+ tfds_data_dir: str = ""
+ tfds_download: bool = False
+ tfds_as_supervised: bool = False
+ tfds_skip_decoding_feature: str = ""
+
+
+@dataclasses.dataclass
+class RuntimeConfig(base_config.Config):
+ """High-level configurations for Runtime.
+
+ These include parameters that are not directly related to the experiment,
+ e.g. directories, accelerator type, etc.
+
+ Attributes:
+ distribution_strategy: e.g. 'mirrored', 'tpu', etc.
+ enable_xla: Whether or not to enable XLA.
+ per_gpu_thread_count: thread count per GPU.
+ gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
+ dataset_num_private_threads: Number of threads for a private threadpool
+ created for all datasets computation.
+ tpu: The address of the TPU to use, if any.
+ num_gpus: The number of GPUs to use, if any.
+ worker_hosts: comma-separated list of worker ip:port pairs for running
+ multi-worker models with DistributionStrategy.
+ task_index: If multi-worker training, the task index of this worker.
+ all_reduce_alg: Defines the algorithm for performing all-reduce.
+ num_packs: Sets `num_packs` in the cross device ops used in
+ MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
+ mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
+ 'float16', or 'bfloat16'.
+ loss_scale: The type of loss scale, or 'float' value. This is used when
+ setting the mixed precision policy.
+ run_eagerly: Whether or not to run the experiment eagerly.
+ batchnorm_spatial_persistent: Whether or not to enable the spatial
+ persistent mode for CuDNN batch norm kernel for improved GPU performance.
+ """
+ distribution_strategy: str = "mirrored"
+ enable_xla: bool = False
+ gpu_thread_mode: Optional[str] = None
+ dataset_num_private_threads: Optional[int] = None
+ per_gpu_thread_count: int = 0
+ tpu: Optional[str] = None
+ num_gpus: int = 0
+ worker_hosts: Optional[str] = None
+ task_index: int = -1
+ all_reduce_alg: Optional[str] = None
+ num_packs: int = 1
+ loss_scale: Optional[Union[str, float]] = None
+ mixed_precision_dtype: Optional[str] = None
+ run_eagerly: bool = False
+ batchnorm_spatial_persistent: bool = False
+
+
+@dataclasses.dataclass
+class TensorboardConfig(base_config.Config):
+ """Configuration for Tensorboard.
+
+ Attributes:
+ track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
+ to True.
+ write_model_weights: Whether or not to write the model weights as images in
+ Tensorboard. Defaults to False.
+ """
+ track_lr: bool = True
+ write_model_weights: bool = False
+
+
+@dataclasses.dataclass
+class CallbacksConfig(base_config.Config):
+ """Configuration for Callbacks.
+
+ Attributes:
+ enable_checkpoint_and_export: Whether or not to enable checkpoints as a
+ Callback. Defaults to True.
+ enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
+ Defaults to True.
+ enable_time_history: Whether or not to enable TimeHistory Callbacks.
+ Defaults to True.
+ """
+ enable_checkpoint_and_export: bool = True
+ enable_tensorboard: bool = True
+ enable_time_history: bool = True
+
+
+@dataclasses.dataclass
+class TrainerConfig(base_config.Config):
+ """Configuration for trainer.
+
+ Attributes:
+ optimizer_config: optimizer config, it includes optimizer, learning rate,
+ and warmup schedule configs.
+ train_tf_while_loop: whether or not to use tf while loop.
+ train_tf_function: whether or not to use tf_function for training loop.
+ eval_tf_function: whether or not to use tf_function for eval.
+ steps_per_loop: number of steps per loop.
+ summary_interval: number of steps between each summary.
+ checkpoint_intervals: number of steps between checkpoints.
+ max_to_keep: max checkpoints to keep.
+ continuous_eval_timeout: maximum number of seconds to wait between
+ checkpoints, if set to None, continuous eval will wait indefinetely.
+ """
+ optimizer_config: OptimizationConfig = OptimizationConfig()
+ train_tf_while_loop: bool = True
+ train_tf_function: bool = True
+ eval_tf_function: bool = True
+ steps_per_loop: int = 1000
+ summary_interval: int = 1000
+ checkpoint_interval: int = 1000
+ max_to_keep: int = 5
+ continuous_eval_timeout: Optional[int] = None
+
+
+@dataclasses.dataclass
+class TaskConfig(base_config.Config):
+ network: base_config.Config = None
+ train_data: DataConfig = DataConfig()
+ validation_data: DataConfig = DataConfig()
+
+
+@dataclasses.dataclass
+class ExperimentConfig(base_config.Config):
+ """Top-level configuration."""
+ task: TaskConfig = TaskConfig()
+ trainer: TrainerConfig = TrainerConfig()
+ runtime: RuntimeConfig = RuntimeConfig()
+ train_steps: int = 0
+ validation_steps: Optional[int] = None
+ validation_interval: int = 100
+
+
+_REGISTERED_CONFIGS = {}
+
+
+def register_config_factory(name):
+ """Register ExperimentConfig factory method."""
+ return registry.register(_REGISTERED_CONFIGS, name)
+
+
+def get_exp_config_creater(exp_name: str):
+ """Looks up ExperimentConfig factory methods."""
+ exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
+ return exp_creater
diff --git a/models/official/modeling/hyperparams/oneof.py b/models/official/modeling/hyperparams/oneof.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd49218c137d17e7917e16f8f2eb0b73a8a6a392
--- /dev/null
+++ b/models/official/modeling/hyperparams/oneof.py
@@ -0,0 +1,62 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Config class that supports oneof functionality."""
+
+from typing import Optional
+
+import dataclasses
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class OneOfConfig(base_config.Config):
+ """Configuration for configs with one of feature.
+
+ Attributes:
+ type: 'str', name of the field to select.
+ """
+ type: Optional[str] = None
+
+ def as_dict(self):
+ """Returns a dict representation of OneOfConfig.
+
+ For the nested base_config.Config, a nested dict will be returned.
+ """
+ if self.type is None:
+ return {'type': None}
+ elif self.__dict__['type'] not in self.__dict__:
+ raise ValueError(
+ 'type: {!r} is not a valid key!'.format(self.__dict__['type']))
+ else:
+ chosen_type = self.type
+ chosen_value = self.__dict__[chosen_type]
+ return {
+ 'type': self.type,
+ chosen_type: self._export_config(chosen_value)
+ }
+
+ def get(self):
+ """Returns selected config based on the value of type.
+
+ If type is not set (None), None is returned.
+ """
+ chosen_type = self.type
+ if chosen_type is None:
+ return None
+ if chosen_type not in self.__dict__:
+ raise ValueError(
+ 'type: {!r} is not a valid key!'.format(self.type))
+ return self.__dict__[chosen_type]
diff --git a/models/official/modeling/hyperparams/oneof_test.py b/models/official/modeling/hyperparams/oneof_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd6564c17efbad9987eea6a0e91261afc34f3f3
--- /dev/null
+++ b/models/official/modeling/hyperparams/oneof_test.py
@@ -0,0 +1,67 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import dataclasses
+import tensorflow as tf
+from official.modeling.hyperparams import base_config
+from official.modeling.hyperparams import oneof
+
+
+@dataclasses.dataclass
+class ResNet(base_config.Config):
+ model_depth: int = 50
+
+
+@dataclasses.dataclass
+class Backbone(oneof.OneOfConfig):
+ type: str = 'resnet'
+ resnet: ResNet = ResNet()
+ not_resnet: int = 2
+
+
+@dataclasses.dataclass
+class OutputLayer(oneof.OneOfConfig):
+ type: str = 'single'
+ single: int = 1
+ multi_head: int = 2
+
+
+@dataclasses.dataclass
+class Network(base_config.Config):
+ backbone: Backbone = Backbone()
+ output_layer: OutputLayer = OutputLayer()
+
+
+class OneOfTest(tf.test.TestCase):
+
+ def test_to_dict(self):
+ network_params = {'backbone': {'type': 'resnet',
+ 'resnet': {'model_depth': 50}
+ },
+ 'output_layer': {'type': 'single',
+ 'single': 1000}
+ }
+ network_config = Network(network_params)
+ self.assertEqual(network_config.as_dict(), network_params)
+
+ def test_get_oneof(self):
+ backbone = Backbone()
+ self.assertIsInstance(backbone.get(), ResNet)
+ self.assertEqual(backbone.get().as_dict(), {'model_depth': 50})
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/modeling/hyperparams/params_dict.py b/models/official/modeling/hyperparams/params_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..88510e770a2021f0e7f65bfbf2ae6a2d3480de17
--- /dev/null
+++ b/models/official/modeling/hyperparams/params_dict.py
@@ -0,0 +1,439 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A parameter dictionary class which supports the nest structure."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import copy
+import re
+
+import six
+import tensorflow as tf
+import yaml
+
+# regex pattern that matches on key-value pairs in a comma-separated
+# key-value pair string. It splits each k-v pair on the = sign, and
+# matches on values that are within single quotes, double quotes, single
+# values (e.g. floats, ints, etc.), and a lists within brackets.
+_PARAM_RE = re.compile(r"""
+ (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x"
+ \s*=\s*
+ ((?P\'(.*?)\' # single quote
+ |
+ \"(.*?)\" # double quote
+ |
+ [^,\[]* # single value
+ |
+ \[[^\]]*\])) # list of values
+ ($|,\s*)""", re.VERBOSE)
+
+_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
+
+
+class ParamsDict(object):
+ """A hyperparameter container class."""
+
+ RESERVED_ATTR = ['_locked', '_restrictions']
+
+ def __init__(self, default_params=None, restrictions=None):
+ """Instantiate a ParamsDict.
+
+ Instantiate a ParamsDict given a set of default parameters and a list of
+ restrictions. Upon initialization, it validates itself by checking all the
+ defined restrictions, and raise error if it finds inconsistency.
+
+ Args:
+ default_params: a Python dict or another ParamsDict object including the
+ default parameters to initialize.
+ restrictions: a list of strings, which define a list of restrictions to
+ ensure the consistency of different parameters internally. Each
+ restriction string is defined as a binary relation with a set of
+ operators, including {'==', '!=', '<', '<=', '>', '>='}.
+ """
+ self._locked = False
+ self._restrictions = []
+ if restrictions:
+ self._restrictions = restrictions
+ if default_params is None:
+ default_params = {}
+ self.override(default_params, is_strict=False)
+ self.validate()
+
+ def _set(self, k, v):
+ if isinstance(v, dict):
+ self.__dict__[k] = ParamsDict(v)
+ else:
+ self.__dict__[k] = copy.deepcopy(v)
+
+ def __setattr__(self, k, v):
+ """Sets the value of the existing key.
+
+ Note that this does not allow directly defining a new key. Use the
+ `override` method with `is_strict=False` instead.
+
+ Args:
+ k: the key string.
+ v: the value to be used to set the key `k`.
+
+ Raises:
+ KeyError: if k is not defined in the ParamsDict.
+ """
+ if k not in ParamsDict.RESERVED_ATTR:
+ if k not in self.__dict__.keys():
+ raise KeyError('The key `%{}` does not exist. '
+ 'To extend the existing keys, use '
+ '`override` with `is_strict` = True.'.format(k))
+ if self._locked:
+ raise ValueError('The ParamsDict has been locked. '
+ 'No change is allowed.')
+ self._set(k, v)
+
+ def __getattr__(self, k):
+ """Gets the value of the existing key.
+
+ Args:
+ k: the key string.
+
+ Returns:
+ the value of the key.
+
+ Raises:
+ AttributeError: if k is not defined in the ParamsDict.
+ """
+ if k not in self.__dict__.keys():
+ raise AttributeError('The key `{}` does not exist. '.format(k))
+ return self.__dict__[k]
+
+ def __contains__(self, key):
+ """Implements the membership test operator."""
+ return key in self.__dict__
+
+ def get(self, key, value=None):
+ """Accesses through built-in dictionary get method."""
+ return self.__dict__.get(key, value)
+
+ def __delattr__(self, k):
+ """Deletes the key and removes its values.
+
+ Args:
+ k: the key string.
+
+ Raises:
+ AttributeError: if k is reserverd or not defined in the ParamsDict.
+ ValueError: if the ParamsDict instance has been locked.
+ """
+ if k in ParamsDict.RESERVED_ATTR:
+ raise AttributeError('The key `{}` is reserved. No change is allowes. '
+ .format(k))
+ if k not in self.__dict__.keys():
+ raise AttributeError('The key `{}` does not exist. '.format(k))
+ if self._locked:
+ raise ValueError('The ParamsDict has been locked. No change is allowed.')
+ del self.__dict__[k]
+
+ def override(self, override_params, is_strict=True):
+ """Override the ParamsDict with a set of given params.
+
+ Args:
+ override_params: a dict or a ParamsDict specifying the parameters to
+ be overridden.
+ is_strict: a boolean specifying whether override is strict or not. If
+ True, keys in `override_params` must be present in the ParamsDict.
+ If False, keys in `override_params` can be different from what is
+ currently defined in the ParamsDict. In this case, the ParamsDict will
+ be extended to include the new keys.
+ """
+ if self._locked:
+ raise ValueError('The ParamsDict has been locked. No change is allowed.')
+ if isinstance(override_params, ParamsDict):
+ override_params = override_params.as_dict()
+ self._override(override_params, is_strict) # pylint: disable=protected-access
+
+ def _override(self, override_dict, is_strict=True):
+ """The implementation of `override`."""
+ for k, v in six.iteritems(override_dict):
+ if k in ParamsDict.RESERVED_ATTR:
+ raise KeyError('The key `%{}` is internally reserved. '
+ 'Can not be overridden.')
+ if k not in self.__dict__.keys():
+ if is_strict:
+ raise KeyError('The key `{}` does not exist. '
+ 'To extend the existing keys, use '
+ '`override` with `is_strict` = False.'.format(k))
+ else:
+ self._set(k, v)
+ else:
+ if isinstance(v, dict):
+ self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
+ elif isinstance(v, ParamsDict):
+ self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
+ else:
+ self.__dict__[k] = copy.deepcopy(v)
+
+ def lock(self):
+ """Makes the ParamsDict immutable."""
+ self._locked = True
+
+ def as_dict(self):
+ """Returns a dict representation of ParamsDict.
+
+ For the nested ParamsDict, a nested dict will be returned.
+ """
+ params_dict = {}
+ for k, v in six.iteritems(self.__dict__):
+ if k not in ParamsDict.RESERVED_ATTR:
+ if isinstance(v, ParamsDict):
+ params_dict[k] = v.as_dict()
+ else:
+ params_dict[k] = copy.deepcopy(v)
+ return params_dict
+
+ def validate(self):
+ """Validate the parameters consistency based on the restrictions.
+
+ This method validates the internal consistency using the pre-defined list of
+ restrictions. A restriction is defined as a string which specfiies a binary
+ operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
+ '>='}. Note that the meaning of these operators are consistent with the
+ underlying Python immplementation. Users should make sure the define
+ restrictions on their type make sense.
+
+ For example, for a ParamsDict like the following
+ ```
+ a:
+ a1: 1
+ a2: 2
+ b:
+ bb:
+ bb1: 10
+ bb2: 20
+ ccc:
+ a1: 1
+ a3: 3
+ ```
+ one can define two restrictions like this
+ ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
+
+ What it enforces are:
+ - a.a1 = 1 == b.ccc.a1 = 2
+ - a.a2 = 2 <= b.bb.bb2 = 20
+
+ Raises:
+ KeyError: if any of the following happens
+ (1) any of parameters in any of restrictions is not defined in
+ ParamsDict,
+ (2) any inconsistency violating the restriction is found.
+ ValueError: if the restriction defined in the string is not supported.
+ """
+ def _get_kv(dotted_string, params_dict):
+ """Get keys and values indicated by dotted_string."""
+ if _CONST_VALUE_RE.match(dotted_string) is not None:
+ const_str = dotted_string
+ if const_str == 'None':
+ constant = None
+ else:
+ constant = float(const_str)
+ return None, constant
+ else:
+ tokenized_params = dotted_string.split('.')
+ v = params_dict
+ for t in tokenized_params:
+ v = v[t]
+ return tokenized_params[-1], v
+
+ def _get_kvs(tokens, params_dict):
+ if len(tokens) != 2:
+ raise ValueError('Only support binary relation in restriction.')
+ stripped_tokens = [t.strip() for t in tokens]
+ left_k, left_v = _get_kv(stripped_tokens[0], params_dict)
+ right_k, right_v = _get_kv(stripped_tokens[1], params_dict)
+ return left_k, left_v, right_k, right_v
+
+ params_dict = self.as_dict()
+ for restriction in self._restrictions:
+ if '==' in restriction:
+ tokens = restriction.split('==')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v != right_v:
+ raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
+ .format(tokens[0], tokens[1]))
+ elif '!=' in restriction:
+ tokens = restriction.split('!=')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v == right_v:
+ raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
+ .format(tokens[0], tokens[1]))
+ elif '<' in restriction:
+ tokens = restriction.split('<')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v >= right_v:
+ raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
+ .format(tokens[0], tokens[1]))
+ elif '<=' in restriction:
+ tokens = restriction.split('<=')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v > right_v:
+ raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
+ .format(tokens[0], tokens[1]))
+ elif '>' in restriction:
+ tokens = restriction.split('>')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v <= right_v:
+ raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
+ .format(tokens[0], tokens[1]))
+ elif '>=' in restriction:
+ tokens = restriction.split('>=')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v < right_v:
+ raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
+ .format(tokens[0], tokens[1]))
+ else:
+ raise ValueError('Unsupported relation in restriction.')
+
+
+def read_yaml_to_params_dict(file_path):
+ """Reads a YAML file to a ParamsDict."""
+ with tf.io.gfile.GFile(file_path, 'r') as f:
+ params_dict = yaml.load(f)
+ return ParamsDict(params_dict)
+
+
+def save_params_dict_to_yaml(params, file_path):
+ """Saves the input ParamsDict to a YAML file."""
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ def _my_list_rep(dumper, data):
+ # u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
+ return dumper.represent_sequence(
+ u'tag:yaml.org,2002:seq', data, flow_style=True)
+ yaml.add_representer(list, _my_list_rep)
+ yaml.dump(params.as_dict(), f, default_flow_style=False)
+
+
+def nested_csv_str_to_json_str(csv_str):
+ """Converts a nested (using '.') comma-separated k=v string to a JSON string.
+
+ Converts a comma-separated string of key/value pairs that supports
+ nesting of keys to a JSON string. Nesting is implemented using
+ '.' between levels for a given key.
+
+ Spacing between commas and = is supported (e.g. there is no difference between
+ "a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
+ keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
+
+ Note that this will only support values supported by CSV, meaning
+ values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
+ supported. Strings are supported as well, e.g. "a='hello'".
+
+ An example conversion would be:
+
+ "a=1, b=2, c.a=2, c.b=3, d.a.a=5"
+
+ to
+
+ "{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
+
+ Args:
+ csv_str: the comma separated string.
+
+ Returns:
+ the converted JSON string.
+
+ Raises:
+ ValueError: If csv_str is not in a comma separated string or
+ if the string is formatted incorrectly.
+ """
+ if not csv_str:
+ return ''
+
+ formatted_entries = []
+ nested_map = collections.defaultdict(list)
+ pos = 0
+ while pos < len(csv_str):
+ m = _PARAM_RE.match(csv_str, pos)
+ if not m:
+ raise ValueError('Malformed hyperparameter value while parsing '
+ 'CSV string: %s' % csv_str[pos:])
+ pos = m.end()
+ # Parse the values.
+ m_dict = m.groupdict()
+ name = m_dict['name']
+ v = m_dict['val']
+
+ # If a GCS path (e.g. gs://...) is provided, wrap this in quotes
+ # as yaml.load would otherwise throw an exception
+ if re.match(r'(?=[^\"\'])(?=[gs://])', v):
+ v = '\'{}\''.format(v)
+
+ name_nested = name.split('.')
+ if len(name_nested) > 1:
+ grouping = name_nested[0]
+ value = '.'.join(name_nested[1:]) + '=' + v
+ nested_map[grouping].append(value)
+ else:
+ formatted_entries.append('%s : %s' % (name, v))
+
+ for grouping, value in nested_map.items():
+ value = ','.join(value)
+ value = nested_csv_str_to_json_str(value)
+ formatted_entries.append('%s : %s' % (grouping, value))
+ return '{' + ', '.join(formatted_entries) + '}'
+
+
+def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
+ """Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
+
+ The logic of the function is outlined below:
+ 1. Test that the input is a dict. If not, proceed to 2.
+ 2. Tests that the input is a string. If not, raise unknown ValueError
+ 2.1. Test if the string is in a CSV format. If so, parse.
+ If not, proceed to 2.2.
+ 2.2. Try loading the string as a YAML/JSON. If successful, parse to
+ dict and use it to override. If not, proceed to 2.3.
+ 2.3. Try using the string as a file path and load the YAML file.
+
+ Args:
+ params: a ParamsDict object to be overridden.
+ dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
+ path to a YAML file specifying the parameters to be overridden.
+ is_strict: a boolean specifying whether override is strict or not.
+
+ Returns:
+ params: the overridden ParamsDict object.
+
+ Raises:
+ ValueError: if failed to override the parameters.
+ """
+ if not dict_or_string_or_yaml_file:
+ return params
+ if isinstance(dict_or_string_or_yaml_file, dict):
+ params.override(dict_or_string_or_yaml_file, is_strict)
+ elif isinstance(dict_or_string_or_yaml_file, six.string_types):
+ try:
+ dict_or_string_or_yaml_file = (
+ nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
+ except ValueError:
+ pass
+ params_dict = yaml.load(dict_or_string_or_yaml_file)
+ if isinstance(params_dict, dict):
+ params.override(params_dict, is_strict)
+ else:
+ with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
+ params.override(yaml.load(f), is_strict)
+ else:
+ raise ValueError('Unknown input type to parse.')
+ return params
diff --git a/models/official/modeling/hyperparams/params_dict_test.py b/models/official/modeling/hyperparams/params_dict_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..169ffa47ceff5717c2ae375f7e1114c5b05f3ea1
--- /dev/null
+++ b/models/official/modeling/hyperparams/params_dict_test.py
@@ -0,0 +1,346 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for params_dict.py."""
+
+import os
+
+import tensorflow as tf
+import yaml
+
+from official.modeling.hyperparams import params_dict
+
+
+class ParamsDictTest(tf.test.TestCase):
+
+ def test_init_from_an_empty_dict(self):
+ params = params_dict.ParamsDict()
+ with self.assertRaises(AttributeError):
+ _ = params.a
+
+ with self.assertRaises(KeyError):
+ params.a = 'aa'
+
+ def test_init_from_a_dict(self):
+ params = params_dict.ParamsDict({'a': 'aa', 'b': 2})
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+
+ def test_init_from_a_param_dict(self):
+ params_init = params_dict.ParamsDict({'a': 'aa', 'b': 2})
+ params = params_dict.ParamsDict(params_init)
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+
+ def test_lock(self):
+ params = params_dict.ParamsDict({'a': 1, 'b': 2, 'c': 3})
+ params.lock()
+ with self.assertRaises(ValueError):
+ params.a = 10
+ with self.assertRaises(ValueError):
+ params.override({'b': 20})
+ with self.assertRaises(ValueError):
+ del params.c
+
+ def test_setattr(self):
+ params = params_dict.ParamsDict()
+ params.override(
+ {'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
+ params.c = 'ccc'
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+ self.assertEqual(params.c, 'ccc')
+
+ def test_getattr(self):
+ params = params_dict.ParamsDict()
+ params.override(
+ {'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+ self.assertEqual(params.c, None)
+
+ def test_delattr(self):
+ params = params_dict.ParamsDict()
+ params.override(
+ {'a': 'aa', 'b': 2, 'c': None, 'd': {'d1': 1, 'd2': 10}},
+ is_strict=False)
+ del params.c
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+ with self.assertRaises(AttributeError):
+ _ = params.c
+ del params.d
+ with self.assertRaises(AttributeError):
+ _ = params.d.d1
+
+ def test_contains(self):
+ params = params_dict.ParamsDict()
+ params.override(
+ {'a': 'aa'}, is_strict=False)
+ self.assertIn('a', params)
+ self.assertNotIn('b', params)
+
+ def test_get(self):
+ params = params_dict.ParamsDict()
+ params.override(
+ {'a': 'aa'}, is_strict=False)
+ self.assertEqual(params.get('a'), 'aa')
+ self.assertEqual(params.get('b', 2), 2)
+ self.assertEqual(params.get('b'), None)
+
+ def test_override_is_strict_true(self):
+ params = params_dict.ParamsDict(
+ {'a': 'aa', 'b': 2, 'c': {'c1': 'cc', 'c2': 20}})
+ params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
+ self.assertEqual(params.a, 2)
+ self.assertEqual(params.c.c1, 'ccc')
+ with self.assertRaises(KeyError):
+ params.override({'d': 'ddd'}, is_strict=True)
+ with self.assertRaises(KeyError):
+ params.override({'c': {'c3': 30}}, is_strict=True)
+
+ def test_override_is_strict_false(self):
+ params = params_dict.ParamsDict(
+ {'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
+ params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
+ self.assertEqual(params.a, 2)
+ self.assertEqual(params.c.c3, 3000)
+ params.override({'d': 'ddd'}, is_strict=False)
+ self.assertEqual(params.d, 'ddd')
+ params.override({'c': {'c4': 4444}}, is_strict=False)
+ self.assertEqual(params.c.c4, 4444)
+
+ def test_as_dict(self):
+ params = params_dict.ParamsDict(
+ {'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
+ params_d = params.as_dict()
+ self.assertEqual(params_d['a'], 'aa')
+ self.assertEqual(params_d['b'], 2)
+ self.assertEqual(params_d['c']['c1'], 10)
+ self.assertEqual(params_d['c']['c2'], 20)
+
+ def test_validate(self):
+ # Raise error due to the unknown parameter.
+ with self.assertRaises(KeyError):
+ params = params_dict.ParamsDict(
+ {'a': 1, 'b': {'a': 11}}, ['a == c'])
+
+ # OK to check equality of two nested dicts.
+ params = params_dict.ParamsDict(
+ {'a': 1, 'b': {'a': 10}, 'c': {'a': 10}}, ['b == c'])
+
+ # Raise error due to inconsistency
+ with self.assertRaises(KeyError):
+ params = params_dict.ParamsDict(
+ {'a': 1, 'c': {'a': 10}}, ['a == c.a'])
+
+ # Valid rule.
+ params = params_dict.ParamsDict(
+ {'a': 1, 'c': {'a': 1}}, ['a == c.a'])
+
+ # Overridding violates the existing rule, raise error upon validate.
+ params.override({'a': 11})
+ with self.assertRaises(KeyError):
+ params.validate()
+
+ # Valid restrictions with constant.
+ params = params_dict.ParamsDict(
+ {'a': None, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
+ params.validate()
+ with self.assertRaises(KeyError):
+ params = params_dict.ParamsDict(
+ {'a': 4, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
+
+
+class ParamsDictIOTest(tf.test.TestCase):
+
+ def write_temp_file(self, filename, text):
+ temp_file = os.path.join(self.get_temp_dir(), filename)
+ with tf.io.gfile.GFile(temp_file, 'w') as writer:
+ writer.write(text)
+ return temp_file
+
+ def test_save_params_dict_to_yaml(self):
+ params = params_dict.ParamsDict(
+ {'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
+ output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
+ params_dict.save_params_dict_to_yaml(params, output_yaml_file)
+
+ with tf.io.gfile.GFile(output_yaml_file, 'r') as f:
+ params_d = yaml.load(f)
+ self.assertEqual(params.a, params_d['a'])
+ self.assertEqual(params.b, params_d['b'])
+ self.assertEqual(params.c.c1, params_d['c']['c1'])
+ self.assertEqual(params.c.c2, params_d['c']['c2'])
+
+ def test_read_yaml_to_params_dict(self):
+ input_yaml_file = self.write_temp_file(
+ 'params.yaml', r"""
+ a: 'aa'
+ b: 2
+ c:
+ c1: 10
+ c2: 20
+ """)
+ params = params_dict.read_yaml_to_params_dict(input_yaml_file)
+
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+ self.assertEqual(params.c.c1, 10)
+ self.assertEqual(params.c.c2, 20)
+
+ def test_override_params_dict_using_dict(self):
+ params = params_dict.ParamsDict({
+ 'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
+ override_dict = {'b': 5.2, 'c': [30, 40]}
+ params = params_dict.override_params_dict(
+ params, override_dict, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(5.2, params.b)
+ self.assertEqual([30, 40], params.c)
+ self.assertEqual('hello', params.d)
+ self.assertEqual(False, params.e)
+
+ def test_override_params_dict_using_yaml_string(self):
+ params = params_dict.ParamsDict({
+ 'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
+ override_yaml_string = "'b': 5.2\n'c': [30, 40]"
+ params = params_dict.override_params_dict(
+ params, override_yaml_string, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(5.2, params.b)
+ self.assertEqual([30, 40], params.c)
+ self.assertEqual('hello', params.d)
+ self.assertEqual(False, params.e)
+
+ def test_override_params_dict_using_json_string(self):
+ params = params_dict.ParamsDict({
+ 'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
+ 'd': {'d1': {'d2': 'hello'}}, 'e': False})
+ override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
+ params = params_dict.override_params_dict(
+ params, override_json_string, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(2, params.b.b1)
+ self.assertEqual([3, 4], params.b.b2)
+ self.assertEqual('hi', params.d.d1.d2)
+ self.assertEqual(False, params.e)
+
+ def test_override_params_dict_using_csv_string(self):
+ params = params_dict.ParamsDict({
+ 'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
+ 'd': {'d1': {'d2': 'hello'}}, 'e': False})
+ override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
+ params = params_dict.override_params_dict(
+ params, override_csv_string, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(2, params.b.b1)
+ self.assertEqual([3, 4], params.b.b2)
+ self.assertEqual('hi, world', params.d.d1.d2)
+ self.assertEqual('gs://test', params.e)
+
+ def test_override_params_dict_using_yaml_file(self):
+ params = params_dict.ParamsDict({
+ 'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
+ override_yaml_file = self.write_temp_file(
+ 'params.yaml', r"""
+ b: 5.2
+ c: [30, 40]
+ """)
+ params = params_dict.override_params_dict(
+ params, override_yaml_file, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(5.2, params.b)
+ self.assertEqual([30, 40], params.c)
+ self.assertEqual('hello', params.d)
+ self.assertEqual(False, params.e)
+
+
+class IOTest(tf.test.TestCase):
+
+ def test_basic_csv_str_to_json_str(self):
+ csv_str = 'a=1,b=2,c=3'
+ json_str = '{a : 1, b : 2, c : 3}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, json_str)
+
+ def test_basic_csv_str_load(self):
+ csv_str = 'a=1,b=2,c=3'
+ expected_output = {'a': 1, 'b': 2, 'c': 3}
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ converted_dict = yaml.load(converted_csv_str)
+ self.assertDictEqual(converted_dict, expected_output)
+
+ def test_basic_nested_csv_str_to_json_str(self):
+ csv_str = 'a=1,b.b1=2'
+ json_str = '{a : 1, b : {b1 : 2}}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, json_str)
+
+ def test_basic_nested_csv_str_load(self):
+ csv_str = 'a=1,b.b1=2,c.c1=3'
+ expected_output = {'a': 1, 'b': {'b1': 2}, 'c': {'c1': 3}}
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ converted_dict = yaml.load(converted_csv_str)
+ self.assertDictEqual(converted_dict, expected_output)
+
+ def test_complex_nested_csv_str_to_json_str(self):
+ csv_str = 'a.aa.aaa.aaaaa.a=1'
+ json_str = '{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, json_str)
+
+ def test_complex_nested_csv_str_load(self):
+ csv_str = 'a.aa.aaa.aaaaa.a=1,a.a=2'
+ expected_output = {'a': {'aa': {'aaa': {'aaaaa': {'a': 1}}}, 'a': 2}}
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ converted_dict = yaml.load(converted_csv_str)
+ self.assertDictEqual(converted_dict, expected_output)
+
+ def test_csv_str_load_supported_datatypes(self):
+ csv_str = 'a=1,b=2.,c=[1,2,3],d=\'hello, there\',e=\"Hi.\"'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ converted_dict = yaml.load(converted_csv_str)
+ self.assertEqual(converted_dict['a'], 1)
+ self.assertEqual(converted_dict['b'], 2.)
+ self.assertEqual(converted_dict['c'], [1, 2, 3])
+ self.assertEqual(converted_dict['d'], 'hello, there')
+ self.assertEqual(converted_dict['e'], 'Hi.')
+
+ def test_csv_str_load_unsupported_datatypes(self):
+ csv_str = 'a=[[1,2,3],[4,5,6]]'
+ self.assertRaises(ValueError,
+ params_dict.nested_csv_str_to_json_str,
+ csv_str)
+
+ def test_csv_str_to_json_str_spacing(self):
+ csv_str1 = 'a=1,b=2,c=3'
+ csv_str2 = 'a = 1, b = 2, c = 3'
+ json_str = '{a : 1, b : 2, c : 3}'
+ converted_csv_str1 = params_dict.nested_csv_str_to_json_str(csv_str1)
+ converted_csv_str2 = params_dict.nested_csv_str_to_json_str(csv_str2)
+ self.assertEqual(converted_csv_str1, converted_csv_str2)
+ self.assertEqual(converted_csv_str1, json_str)
+ self.assertEqual(converted_csv_str2, json_str)
+
+ def test_gcs_added_quotes(self):
+ csv_str = 'a=gs://abc, b=gs://def'
+ expected_output = '{a : \'gs://abc\', b : \'gs://def\'}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, expected_output)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/modeling/optimization/__init__.py b/models/official/modeling/optimization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5c6292b64d922144b7ced18c8c9460617e05492
--- /dev/null
+++ b/models/official/modeling/optimization/__init__.py
@@ -0,0 +1,7 @@
+"""Optimization package definition."""
+
+# pylint: disable=wildcard-import
+from official.modeling.optimization.configs.learning_rate_config import *
+from official.modeling.optimization.configs.optimization_config import *
+from official.modeling.optimization.configs.optimizer_config import *
+from official.modeling.optimization.optimizer_factory import OptimizerFactory
diff --git a/models/official/modeling/optimization/configs/__init__.py b/models/official/modeling/optimization/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/modeling/optimization/configs/learning_rate_config.py b/models/official/modeling/optimization/configs/learning_rate_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b55c713f1905cf9aaa52f87a6663d3385628d5a5
--- /dev/null
+++ b/models/official/modeling/optimization/configs/learning_rate_config.py
@@ -0,0 +1,162 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Dataclasses for learning rate schedule config."""
+from typing import List, Optional
+
+import dataclasses
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class StepwiseLrConfig(base_config.Config):
+ """Configuration for stepwise learning rate decay.
+
+ This class is a container for the piecewise constant learning rate scheduling
+ configs. It will configure an instance of PiecewiseConstantDecay keras
+ learning rate schedule.
+
+ An example (from keras docs): use a learning rate that's 1.0 for the first
+ 100001 steps, 0.5 for the next 10000 steps, and 0.1 for any additional steps.
+ ```python
+ boundaries: [100000, 110000]
+ values: [1.0, 0.5, 0.1]
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
+ boundaries: A list of ints of strictly increasing entries.
+ Defaults to None.
+ values: A list of floats that specifies the values for the intervals defined
+ by `boundaries`. It should have one more element than `boundaries`.
+ The learning rate is computed as follows:
+ [0, boundaries[0]] -> values[0]
+ [boundaries[0], boundaries[1]] -> values[1]
+ [boundaries[n-1], boundaries[n]] -> values[n]
+ [boundaries[n], end] -> values[n+1]
+ Defaults to None.
+ """
+ name: str = 'PiecewiseConstantDecay'
+ boundaries: Optional[List[int]] = None
+ values: Optional[List[float]] = None
+
+
+@dataclasses.dataclass
+class ExponentialLrConfig(base_config.Config):
+ """Configuration for exponential learning rate decay.
+
+ This class is a containers for the exponential learning rate decay configs.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to ExponentialDecay.
+ initial_learning_rate: A float. The initial learning rate. Defaults to
+ None.
+ decay_steps: A positive integer that is used for decay computation.
+ Defaults to None.
+ decay_rate: A float. Defaults to None.
+ staircase: A boolean, if true, learning rate is decreased at discreate
+ intervals. Defaults to False.
+ """
+ name: str = 'ExponentialDecay'
+ initial_learning_rate: Optional[float] = None
+ decay_steps: Optional[int] = None
+ decay_rate: Optional[float] = None
+ staircase: Optional[bool] = None
+
+
+@dataclasses.dataclass
+class PolynomialLrConfig(base_config.Config):
+ """Configuration for polynomial learning rate decay.
+
+ This class is a containers for the polynomial learning rate decay configs.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to PolynomialDecay.
+ initial_learning_rate: A float. The initial learning rate. Defaults to
+ None.
+ decay_steps: A positive integer that is used for decay computation.
+ Defaults to None.
+ end_learning_rate: A float. The minimal end learning rate.
+ power: A float. The power of the polynomial. Defaults to linear, 1.0.
+ cycle: A boolean, whether or not it should cycle beyond decay_steps.
+ Defaults to False.
+ """
+ name: str = 'PolynomialDecay'
+ initial_learning_rate: Optional[float] = None
+ decay_steps: Optional[int] = None
+ end_learning_rate: float = 0.0001
+ power: float = 1.0
+ cycle: bool = False
+
+
+@dataclasses.dataclass
+class CosineLrConfig(base_config.Config):
+ """Configuration for Cosine learning rate decay.
+
+ This class is a containers for the cosine learning rate decay configs,
+ tf.keras.experimental.CosineDecay.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to CosineDecay.
+ initial_learning_rate: A float. The initial learning rate. Defaults to
+ None.
+ decay_steps: A positive integer that is used for decay computation.
+ Defaults to None.
+ alpha: A float. Minimum learning rate value as a fraction of
+ initial_learning_rate.
+ """
+ name: str = 'CosineDecay'
+ initial_learning_rate: Optional[float] = None
+ decay_steps: Optional[int] = None
+ alpha: float = 0.0
+
+
+@dataclasses.dataclass
+class LinearWarmupConfig(base_config.Config):
+ """Configuration for linear warmup schedule config.
+
+ This class is a container for the linear warmup schedule configs.
+ Warmup_learning_rate is the initial learning rate, the final learning rate of
+ the warmup period is the learning_rate of the optimizer in use. The learning
+ rate at each step linearly increased according to the following formula:
+ warmup_learning_rate = warmup_learning_rate +
+ step / warmup_steps * (final_learning_rate - warmup_learning_rate).
+ Using warmup overrides the learning rate schedule by the number of warmup
+ steps.
+
+ Attributes:
+ name: The name of warmup schedule. Defaults to linear.
+ warmup_learning_rate: Initial learning rate for the warmup. Defaults to 0.
+ warmup_steps: Warmup steps. Defaults to None.
+ """
+ name: str = 'linear'
+ warmup_learning_rate: float = 0
+ warmup_steps: Optional[int] = None
+
+
+@dataclasses.dataclass
+class PolynomialWarmupConfig(base_config.Config):
+ """Configuration for linear warmup schedule config.
+
+ This class is a container for the polynomial warmup schedule configs.
+
+ Attributes:
+ name: The name of warmup schedule. Defaults to Polynomial.
+ power: Polynomial power. Defaults to 1.
+ warmup_steps: Warmup steps. Defaults to None.
+ """
+ name: str = 'polynomial'
+ power: float = 1
+ warmup_steps: Optional[int] = None
+
diff --git a/models/official/modeling/optimization/configs/optimization_config.py b/models/official/modeling/optimization/configs/optimization_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cf3616c75bec20c2560747561530f332cd2466c
--- /dev/null
+++ b/models/official/modeling/optimization/configs/optimization_config.py
@@ -0,0 +1,95 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Dataclasses for optimization configs.
+
+This file define the dataclass for optimization configs (OptimizationConfig).
+It also has two helper functions get_optimizer_config, and get_lr_config from
+an OptimizationConfig class.
+"""
+from typing import Optional
+
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+from official.modeling.hyperparams import oneof
+from official.modeling.optimization.configs import learning_rate_config as lr_cfg
+from official.modeling.optimization.configs import optimizer_config as opt_cfg
+
+
+@dataclasses.dataclass
+class OptimizerConfig(oneof.OneOfConfig):
+ """Configuration for optimizer.
+
+ Attributes:
+ type: 'str', type of optimizer to be used, on the of fields below.
+ sgd: sgd optimizer config.
+ adam: adam optimizer config.
+ adamw: adam with weight decay.
+ lamb: lamb optimizer.
+ rmsprop: rmsprop optimizer.
+ """
+ type: Optional[str] = None
+ sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
+ adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
+ adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
+ lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
+ rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
+
+
+@dataclasses.dataclass
+class LrConfig(oneof.OneOfConfig):
+ """Configuration for lr schedule.
+
+ Attributes:
+ type: 'str', type of lr schedule to be used, on the of fields below.
+ stepwise: stepwise learning rate config.
+ exponential: exponential learning rate config.
+ polynomial: polynomial learning rate config.
+ cosine: cosine learning rate config.
+ """
+ type: Optional[str] = None
+ stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig()
+ exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
+ polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
+ cosine: lr_cfg.CosineLrConfig = lr_cfg.CosineLrConfig()
+
+
+@dataclasses.dataclass
+class WarmupConfig(oneof.OneOfConfig):
+ """Configuration for lr schedule.
+
+ Attributes:
+ type: 'str', type of warmup schedule to be used, on the of fields below.
+ linear: linear warmup config.
+ polynomial: polynomial warmup config.
+ """
+ type: Optional[str] = None
+ linear: lr_cfg.LinearWarmupConfig = lr_cfg.LinearWarmupConfig()
+ polynomial: lr_cfg.PolynomialWarmupConfig = lr_cfg.PolynomialWarmupConfig()
+
+
+@dataclasses.dataclass
+class OptimizationConfig(base_config.Config):
+ """Configuration for optimizer and learning rate schedule.
+
+ Attributes:
+ optimizer: optimizer oneof config.
+ learning_rate: learning rate oneof config.
+ warmup: warmup oneof config.
+ """
+ optimizer: OptimizerConfig = OptimizerConfig()
+ learning_rate: LrConfig = LrConfig()
+ warmup: WarmupConfig = WarmupConfig()
diff --git a/models/official/modeling/optimization/configs/optimization_config_test.py b/models/official/modeling/optimization/configs/optimization_config_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dcd55e0e2071a23cae1494ae29c5efa282d052a
--- /dev/null
+++ b/models/official/modeling/optimization/configs/optimization_config_test.py
@@ -0,0 +1,61 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for optimization_config.py."""
+
+import tensorflow as tf
+
+from official.modeling.optimization.configs import learning_rate_config as lr_cfg
+from official.modeling.optimization.configs import optimization_config
+from official.modeling.optimization.configs import optimizer_config as opt_cfg
+
+
+class OptimizerConfigTest(tf.test.TestCase):
+
+ def test_no_optimizer(self):
+ optimizer = optimization_config.OptimizationConfig({}).optimizer.get()
+ self.assertEqual(optimizer, None)
+
+ def test_no_lr_schedule(self):
+ lr = optimization_config.OptimizationConfig({}).learning_rate.get()
+ self.assertEqual(lr, None)
+
+ def test_no_warmup_schedule(self):
+ warmup = optimization_config.OptimizationConfig({}).warmup.get()
+ self.assertEqual(warmup, None)
+
+ def test_config(self):
+ opt_config = optimization_config.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {} # default config
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {}
+ },
+ 'warmup': {
+ 'type': 'linear'
+ }
+ })
+ self.assertEqual(opt_config.optimizer.get(),
+ opt_cfg.SGDConfig())
+ self.assertEqual(opt_config.learning_rate.get(),
+ lr_cfg.PolynomialLrConfig())
+ self.assertEqual(opt_config.warmup.get(),
+ lr_cfg.LinearWarmupConfig())
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/modeling/optimization/configs/optimizer_config.py b/models/official/modeling/optimization/configs/optimizer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e295777481957c9e965fbe7408dbb55ba063fc9
--- /dev/null
+++ b/models/official/modeling/optimization/configs/optimizer_config.py
@@ -0,0 +1,148 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Dataclasses for optimizer configs."""
+from typing import List, Optional
+
+import dataclasses
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class SGDConfig(base_config.Config):
+ """Configuration for SGD optimizer.
+
+ The attributes for this class matches the arguments of tf.keras.optimizer.SGD.
+
+ Attributes:
+ name: name of the optimizer.
+ learning_rate: learning_rate for SGD optimizer.
+ decay: decay rate for SGD optimizer.
+ nesterov: nesterov for SGD optimizer.
+ momentum: momentum for SGD optimizer.
+ """
+ name: str = "SGD"
+ learning_rate: float = 0.01
+ decay: float = 0.0
+ nesterov: bool = False
+ momentum: float = 0.0
+
+
+@dataclasses.dataclass
+class RMSPropConfig(base_config.Config):
+ """Configuration for RMSProp optimizer.
+
+ The attributes for this class matches the arguments of
+ tf.keras.optimizers.RMSprop.
+
+ Attributes:
+ name: name of the optimizer.
+ learning_rate: learning_rate for RMSprop optimizer.
+ rho: discounting factor for RMSprop optimizer.
+ momentum: momentum for RMSprop optimizer.
+ epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
+ centered: Whether to normalize gradients or not.
+ """
+ name: str = "RMSprop"
+ learning_rate: float = 0.001
+ rho: float = 0.9
+ momentum: float = 0.0
+ epsilon: float = 1e-7
+ centered: bool = False
+
+
+@dataclasses.dataclass
+class AdamConfig(base_config.Config):
+ """Configuration for Adam optimizer.
+
+ The attributes for this class matches the arguments of
+ tf.keras.optimizer.Adam.
+
+ Attributes:
+ name: name of the optimizer.
+ learning_rate: learning_rate for Adam optimizer.
+ beta_1: decay rate for 1st order moments.
+ beta_2: decay rate for 2st order moments.
+ epsilon: epsilon value used for numerical stability in Adam optimizer.
+ amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
+ the paper "On the Convergence of Adam and beyond".
+ """
+ name: str = "Adam"
+ learning_rate: float = 0.001
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-07
+ amsgrad: bool = False
+
+
+@dataclasses.dataclass
+class AdamWeightDecayConfig(base_config.Config):
+ """Configuration for Adam optimizer with weight decay.
+
+ Attributes:
+ name: name of the optimizer.
+ learning_rate: learning_rate for the optimizer.
+ beta_1: decay rate for 1st order moments.
+ beta_2: decay rate for 2st order moments.
+ epsilon: epsilon value used for numerical stability in the optimizer.
+ amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
+ the paper "On the Convergence of Adam and beyond".
+ weight_decay_rate: float. Weight decay rate. Default to 0.
+ include_in_weight_decay: list[str], or None. List of weight names to include
+ in weight decay.
+ include_in_weight_decay: list[str], or None. List of weight names to not
+ include in weight decay.
+ """
+ name: str = "AdamWeightDecay"
+ learning_rate: float = 0.001
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-07
+ amsgrad: bool = False
+ weight_decay_rate: float = 0.0
+ include_in_weight_decay: Optional[List[str]] = None
+ exclude_from_weight_decay: Optional[List[str]] = None
+
+
+@dataclasses.dataclass
+class LAMBConfig(base_config.Config):
+ """Configuration for LAMB optimizer.
+
+ The attributes for this class matches the arguments of
+ tensorflow_addons.optimizers.LAMB.
+
+ Attributes:
+ name: name of the optimizer.
+ learning_rate: learning_rate for Adam optimizer.
+ beta_1: decay rate for 1st order moments.
+ beta_2: decay rate for 2st order moments.
+ epsilon: epsilon value used for numerical stability in LAMB optimizer.
+ weight_decay_rate: float. Weight decay rate. Default to 0.
+ exclude_from_weight_decay: List of regex patterns of variables excluded from
+ weight decay. Variables whose name contain a
+ substring matching the pattern will be excluded.
+ exclude_from_layer_adaptation: List of regex patterns of variables excluded
+ from layer adaptation. Variables whose name
+ contain a substring matching the pattern will
+ be excluded.
+ """
+ name: str = "LAMB"
+ learning_rate: float = 0.001
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-6
+ weight_decay_rate: float = 0.0
+ exclude_from_weight_decay: Optional[List[str]] = None
+ exclude_from_layer_adaptation: Optional[List[str]] = None
diff --git a/models/official/modeling/optimization/lr_schedule.py b/models/official/modeling/optimization/lr_schedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5dd6fb6fb1478297e579a4be5b87ab5ae25f40e
--- /dev/null
+++ b/models/official/modeling/optimization/lr_schedule.py
@@ -0,0 +1,155 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Learning rate schedule classes."""
+
+from typing import Mapping, Any, Union, Optional
+
+import tensorflow as tf
+
+
+class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Linear warmup schedule."""
+
+ def __init__(self, after_warmup_lr_sched: Union[
+ tf.keras.optimizers.schedules.LearningRateSchedule, float],
+ warmup_steps: int, warmup_learning_rate: float,
+ name: Optional[str] = None):
+ """Add linear warmup schedule to a learning rate schedule.
+
+ warmup_lr is the initial learning rate, the final learning rate of the
+ init_warmup period is the initial learning rate of lr_schedule in use.
+ The learning rate at each step linearly increased according to the following
+ formula:
+ learning_rate = warmup_lr + step / warmup_steps
+ * (final_warmup_lr - warmup_lr).
+ Using warmup overrides the learning rate schedule by the number of warmup
+ steps.
+
+ Args:
+ after_warmup_lr_sched: tf.keras.optimizers.schedules
+ .LearningRateSchedule or a constant.
+ warmup_steps: int. number of the warmup steps.
+ warmup_learning_rate: floating point number. Initial learning rate for the
+ warmup.
+ name: Optional, name of warmup schedule.
+ """
+ super(LinearWarmup, self).__init__()
+ self._name = name
+ self._after_warmup_lr_sched = after_warmup_lr_sched
+ self._warmup_steps = warmup_steps
+ self._init_warmup_lr = warmup_learning_rate
+ if isinstance(after_warmup_lr_sched,
+ tf.keras.optimizers.schedules.LearningRateSchedule):
+ self._final_warmup_lr = after_warmup_lr_sched(warmup_steps)
+ else:
+ self._final_warmup_lr = tf.cast(
+ after_warmup_lr_sched, dtype=tf.float32)
+
+ def __call__(self, step: int):
+
+ global_step = tf.cast(step, dtype=tf.float32)
+
+ linear_warmup_lr = (
+ self._init_warmup_lr + global_step / self._warmup_steps *
+ (self._final_warmup_lr - self._init_warmup_lr))
+
+ if isinstance(self._after_warmup_lr_sched,
+ tf.keras.optimizers.schedules.LearningRateSchedule):
+ after_warmup_lr = self._after_warmup_lr_sched(step)
+ else:
+ after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)
+
+ lr = tf.cond(global_step < self._warmup_steps,
+ lambda: linear_warmup_lr,
+ lambda: after_warmup_lr)
+ return lr
+
+ def get_config(self) -> Mapping[str, Any]:
+ if isinstance(self._after_warmup_lr_sched,
+ tf.keras.optimizers.schedules.LearningRateSchedule):
+ config = {
+ "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error
+ else:
+ config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error
+
+ config.update({
+ "warmup_steps": self._warmup_steps,
+ "warmup_learning_rate": self._init_warmup_lr,
+ "name": self._name
+ })
+ return config
+
+
+class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Applies polynomial warmup schedule on a given learning rate decay schedule.
+ """
+
+ def __init__(self,
+ after_warmup_lr_sched: Union[
+ tf.keras.optimizers.schedules.LearningRateSchedule, float],
+ warmup_steps: int,
+ power: float = 1.0,
+ name: str = "PolynomialWarmup"):
+ super(PolynomialWarmUp, self).__init__()
+ if isinstance(after_warmup_lr_sched,
+ tf.keras.optimizers.schedules.LearningRateSchedule):
+ self._initial_learning_rate = after_warmup_lr_sched(warmup_steps)
+ else:
+ self._initial_learning_rate = tf.cast(
+ after_warmup_lr_sched, dtype=tf.float32)
+
+ self._warmup_steps = warmup_steps
+ self._power = power
+ self._after_warmup_lr_sched = after_warmup_lr_sched
+ self._name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self._name or "PolynomialWarmUp") as name:
+ # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
+ global_step_float = tf.cast(step, tf.float32)
+ warmup_steps_float = tf.cast(self._warmup_steps, tf.float32)
+ warmup_percent_done = global_step_float / warmup_steps_float
+ warmup_learning_rate = (
+ self._initial_learning_rate *
+ tf.math.pow(warmup_percent_done, self._power))
+
+ if isinstance(self._after_warmup_lr_sched,
+ tf.keras.optimizers.schedules.LearningRateSchedule):
+ after_warmup_lr = self._after_warmup_lr_sched(step)
+ else:
+ after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)
+
+ return tf.cond(
+ global_step_float < warmup_steps_float,
+ lambda: warmup_learning_rate,
+ lambda: after_warmup_lr,
+ name=name)
+
+ def get_config(self) -> Mapping[str, Any]:
+ if isinstance(self._after_warmup_lr_sched,
+ tf.keras.optimizers.schedules.LearningRateSchedule):
+ config = {
+ "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error
+ else:
+ config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error
+
+ config.update({
+ "warmup_steps": self._warmup_setps,
+ "power": self._power,
+ "name": self._name
+ })
+ return config
diff --git a/models/official/modeling/optimization/optimizer_factory.py b/models/official/modeling/optimization/optimizer_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccb03d50ee8a5b74cda84cbe261cfdbecce60d23
--- /dev/null
+++ b/models/official/modeling/optimization/optimizer_factory.py
@@ -0,0 +1,145 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Optimizer factory class."""
+from typing import Union
+
+import tensorflow as tf
+
+import tensorflow_addons.optimizers as tfa_optimizers
+
+from official.modeling.optimization import lr_schedule
+from official.modeling.optimization.configs import optimization_config as opt_cfg
+from official.nlp import optimization as nlp_optimization
+
+OPTIMIZERS_CLS = {
+ 'sgd': tf.keras.optimizers.SGD,
+ 'adam': tf.keras.optimizers.Adam,
+ 'adamw': nlp_optimization.AdamWeightDecay,
+ 'lamb': tfa_optimizers.LAMB,
+ 'rmsprop': tf.keras.optimizers.RMSprop
+}
+
+LR_CLS = {
+ 'stepwise': tf.keras.optimizers.schedules.PiecewiseConstantDecay,
+ 'polynomial': tf.keras.optimizers.schedules.PolynomialDecay,
+ 'exponential': tf.keras.optimizers.schedules.ExponentialDecay,
+ 'cosine': tf.keras.experimental.CosineDecay
+}
+
+WARMUP_CLS = {
+ 'linear': lr_schedule.LinearWarmup,
+ 'polynomial': lr_schedule.PolynomialWarmUp
+}
+
+
+class OptimizerFactory(object):
+ """Optimizer factory class.
+
+ This class builds learning rate and optimizer based on an optimization config.
+ To use this class, you need to do the following:
+ (1) Define optimization config, this includes optimizer, and learning rate
+ schedule.
+ (2) Initialize the class using the optimization config.
+ (3) Build learning rate.
+ (4) Build optimizer.
+
+ This is a typical example for using this class:
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
+ },
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]}
+ },
+ 'warmup': {
+ 'type': 'linear',
+ 'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
+ }
+ }
+ opt_config = OptimizationConfig(params)
+ opt_factory = OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+ optimizer = opt_factory.build_optimizer(lr)
+ """
+
+ def __init__(self, config: opt_cfg.OptimizationConfig):
+ """Initializing OptimizerFactory.
+
+ Args:
+ config: OptimizationConfig instance contain optimization config.
+ """
+ self._config = config
+ self._optimizer_config = config.optimizer.get()
+ self._optimizer_type = config.optimizer.type
+
+ if self._optimizer_config is None:
+ raise ValueError('Optimizer type must be specified')
+
+ self._lr_config = config.learning_rate.get()
+ self._lr_type = config.learning_rate.type
+
+ self._warmup_config = config.warmup.get()
+ self._warmup_type = config.warmup.type
+
+ def build_learning_rate(self):
+ """Build learning rate.
+
+ Builds learning rate from config. Learning rate schedule is built according
+ to the learning rate config. If there is no learning rate config, optimizer
+ learning rate is returned.
+
+ Returns:
+ tf.keras.optimizers.schedules.LearningRateSchedule instance. If no
+ learning rate schedule defined, optimizer_config.learning_rate is
+ returned.
+ """
+
+ # TODO(arashwan): Explore if we want to only allow explicit const lr sched.
+ if not self._lr_config:
+ lr = self._optimizer_config.learning_rate
+ else:
+ lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
+
+ if self._warmup_config:
+ lr = WARMUP_CLS[self._warmup_type](lr, **self._warmup_config.as_dict())
+
+ return lr
+
+ def build_optimizer(
+ self, lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule,
+ float]):
+ """Build optimizer.
+
+ Builds optimizer from config. It takes learning rate as input, and builds
+ the optimizer according to the optimizer config. Typically, the learning
+ rate built using self.build_lr() is passed as an argument to this method.
+
+ Args:
+ lr: A floating point value, or
+ a tf.keras.optimizers.schedules.LearningRateSchedule instance.
+ Returns:
+ tf.keras.optimizers.Optimizer instance.
+ """
+
+ optimizer_dict = self._optimizer_config.as_dict()
+ optimizer_dict['learning_rate'] = lr
+
+ optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
+ return optimizer
+
diff --git a/models/official/modeling/optimization/optimizer_factory_test.py b/models/official/modeling/optimization/optimizer_factory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da76fec93303813df4e44a0a5dddff4073db3fa
--- /dev/null
+++ b/models/official/modeling/optimization/optimizer_factory_test.py
@@ -0,0 +1,249 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for optimizer_factory.py."""
+
+from absl.testing import parameterized
+
+import tensorflow as tf
+
+from official.modeling.optimization import optimizer_factory
+from official.modeling.optimization.configs import optimization_config
+
+
+class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(
+ ('sgd'),
+ ('rmsprop'),
+ ('adam'),
+ ('adamw'),
+ ('lamb'))
+ def test_optimizers(self, optimizer_type):
+ params = {
+ 'optimizer': {
+ 'type': optimizer_type
+ }
+ }
+ optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
+ expected_optimizer_config = optimizer_cls().get_config()
+
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+ optimizer = opt_factory.build_optimizer(lr)
+
+ self.assertIsInstance(optimizer, optimizer_cls)
+ self.assertEqual(expected_optimizer_config, optimizer.get_config())
+
+ def test_stepwise_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
+ },
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]}
+ }
+ }
+ expected_lr_step_values = [
+ [0, 0.1],
+ [5000, 0.1],
+ [10000, 0.1],
+ [10001, 0.01],
+ [20000, 0.01],
+ [20001, 0.001]
+ ]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_stepwise_lr_with_warmup_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
+ },
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]}
+ },
+ 'warmup': {
+ 'type': 'linear',
+ 'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
+ }
+ }
+ expected_lr_step_values = [
+ [0, 0.01],
+ [250, 0.055],
+ [500, 0.1],
+ [5500, 0.1],
+ [10000, 0.1],
+ [10001, 0.01],
+ [20000, 0.01],
+ [20001, 0.001]
+ ]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_exponential_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
+ },
+ 'learning_rate': {
+ 'type': 'exponential',
+ 'exponential': {
+ 'initial_learning_rate': 0.1,
+ 'decay_steps': 1000,
+ 'decay_rate': 0.96,
+ 'staircase': True
+ }
+ }
+ }
+ expected_lr_step_values = [
+ [0, 0.1],
+ [999, 0.1],
+ [1000, 0.096],
+ [1999, 0.096],
+ [2000, 0.09216],
+ ]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_polynomial_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'initial_learning_rate': 0.1,
+ 'decay_steps': 1000,
+ 'end_learning_rate': 0.001
+ }
+ }
+ }
+
+ expected_lr_step_values = [[0, 0.1], [500, 0.0505], [1000, 0.001]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_cosine_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
+ },
+ 'learning_rate': {
+ 'type': 'cosine',
+ 'cosine': {
+ 'initial_learning_rate': 0.1,
+ 'decay_steps': 1000
+ }
+ }
+ }
+ expected_lr_step_values = [[0, 0.1],
+ [250, 0.08535534],
+ [500, 0.04999999],
+ [750, 0.01464466],
+ [1000, 0]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_constant_lr_with_warmup_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
+ },
+ 'warmup': {
+ 'type': 'linear',
+ 'linear': {
+ 'warmup_steps': 500,
+ 'warmup_learning_rate': 0.01
+ }
+ }
+ }
+
+ expected_lr_step_values = [[0, 0.01], [250, 0.055], [500, 0.1], [5000, 0.1],
+ [10000, 0.1], [20000, 0.1]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_stepwise_lr_with_polynomial_warmup_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
+ },
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]}
+ },
+ 'warmup': {
+ 'type': 'polynomial',
+ 'polynomial': {'warmup_steps': 500, 'power': 2.}
+ }
+ }
+ expected_lr_step_values = [
+ [0, 0.0],
+ [250, 0.025],
+ [500, 0.1],
+ [5500, 0.1],
+ [10000, 0.1],
+ [10001, 0.01],
+ [20000, 0.01],
+ [20001, 0.001]
+ ]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/modeling/performance.py b/models/official/modeling/performance.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b264f53256db66326ee4e51c5a29676e273eca9
--- /dev/null
+++ b/models/official/modeling/performance.py
@@ -0,0 +1,56 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions and classes related to training performance."""
+
+import tensorflow as tf
+
+
+def configure_optimizer(optimizer,
+ use_float16=False,
+ use_graph_rewrite=False,
+ loss_scale="dynamic"):
+ """Configures optimizer object with performance options."""
+ if use_float16:
+ # Wraps optimizer with a LossScaleOptimizer. This is done automatically
+ # in compile() with the "mixed_float16" policy, but since we do not call
+ # compile(), we must wrap the optimizer manually.
+ optimizer = (
+ tf.keras.mixed_precision.experimental.LossScaleOptimizer(
+ optimizer, loss_scale=loss_scale))
+ if use_graph_rewrite:
+ # Note: the model dtype must be 'float32', which will ensure
+ # tf.ckeras.mixed_precision and
+ # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
+ # up.
+ optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
+ optimizer)
+ return optimizer
+
+
+def set_mixed_precision_policy(dtype, loss_scale=None):
+ """Sets mix precision policy."""
+ if dtype == tf.float16:
+ policy = tf.keras.mixed_precision.experimental.Policy(
+ 'mixed_float16', loss_scale=loss_scale)
+ tf.keras.mixed_precision.experimental.set_policy(policy)
+ elif dtype == tf.bfloat16:
+ policy = tf.keras.mixed_precision.experimental.Policy(
+ 'mixed_bfloat16')
+ tf.keras.mixed_precision.experimental.set_policy(policy)
+ elif dtype == tf.float32:
+ tf.keras.mixed_precision.experimental.set_policy('float32')
+ else:
+ raise ValueError("Unexpected dtype: %s" % dtype)
diff --git a/models/official/modeling/tf_utils.py b/models/official/modeling/tf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..279208239349e143ed59d9c6d5dbc418d25fe0fa
--- /dev/null
+++ b/models/official/modeling/tf_utils.py
@@ -0,0 +1,190 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common TF utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+import tensorflow as tf
+
+from tensorflow.python.util import deprecation
+from official.modeling import activations
+
+
+@deprecation.deprecated(
+ None,
+ "tf.keras.layers.Layer supports multiple positional args and kwargs as "
+ "input tensors. pack/unpack inputs to override __call__ is no longer "
+ "needed."
+)
+def pack_inputs(inputs):
+ """Pack a list of `inputs` tensors to a tuple.
+
+ Args:
+ inputs: a list of tensors.
+
+ Returns:
+ a tuple of tensors. if any input is None, replace it with a special constant
+ tensor.
+ """
+ inputs = tf.nest.flatten(inputs)
+ outputs = []
+ for x in inputs:
+ if x is None:
+ outputs.append(tf.constant(0, shape=[], dtype=tf.int32))
+ else:
+ outputs.append(x)
+ return tuple(outputs)
+
+
+@deprecation.deprecated(
+ None,
+ "tf.keras.layers.Layer supports multiple positional args and kwargs as "
+ "input tensors. pack/unpack inputs to override __call__ is no longer "
+ "needed."
+)
+def unpack_inputs(inputs):
+ """unpack a tuple of `inputs` tensors to a tuple.
+
+ Args:
+ inputs: a list of tensors.
+
+ Returns:
+ a tuple of tensors. if any input is a special constant tensor, replace it
+ with None.
+ """
+ inputs = tf.nest.flatten(inputs)
+ outputs = []
+ for x in inputs:
+ if is_special_none_tensor(x):
+ outputs.append(None)
+ else:
+ outputs.append(x)
+ x = tuple(outputs)
+
+ # To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
+ # from triggering.
+ if len(x) == 1:
+ return x[0]
+ return tuple(outputs)
+
+
+def is_special_none_tensor(tensor):
+ """Checks if a tensor is a special None Tensor."""
+ return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
+
+
+# TODO(hongkuny): consider moving custom string-map lookup to keras api.
+def get_activation(identifier):
+ """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
+
+ It checks string first and if it is one of customized activation not in TF,
+ the corresponding activation will be returned. For non-customized activation
+ names and callable identifiers, always fallback to tf.keras.activations.get.
+
+ Args:
+ identifier: String name of the activation function or callable.
+
+ Returns:
+ A Python function corresponding to the activation function.
+ """
+ if isinstance(identifier, six.string_types):
+ name_to_fn = {
+ "gelu": activations.gelu,
+ "simple_swish": activations.simple_swish,
+ "hard_swish": activations.hard_swish,
+ "identity": activations.identity,
+ }
+ identifier = str(identifier).lower()
+ if identifier in name_to_fn:
+ return tf.keras.activations.get(name_to_fn[identifier])
+ return tf.keras.activations.get(identifier)
+
+
+def get_shape_list(tensor, expected_rank=None, name=None):
+ """Returns a list of the shape of tensor, preferring static dimensions.
+
+ Args:
+ tensor: A tf.Tensor object to find the shape of.
+ expected_rank: (optional) int. The expected rank of `tensor`. If this is
+ specified and the `tensor` has a different rank, and exception will be
+ thrown.
+ name: Optional name of the tensor for the error message.
+
+ Returns:
+ A list of dimensions of the shape of tensor. All static dimensions will
+ be returned as python integers, and dynamic dimensions will be returned
+ as tf.Tensor scalars.
+ """
+ if expected_rank is not None:
+ assert_rank(tensor, expected_rank, name)
+
+ shape = tensor.shape.as_list()
+
+ non_static_indexes = []
+ for (index, dim) in enumerate(shape):
+ if dim is None:
+ non_static_indexes.append(index)
+
+ if not non_static_indexes:
+ return shape
+
+ dyn_shape = tf.shape(tensor)
+ for index in non_static_indexes:
+ shape[index] = dyn_shape[index]
+ return shape
+
+
+def assert_rank(tensor, expected_rank, name=None):
+ """Raises an exception if the tensor rank is not of the expected rank.
+
+ Args:
+ tensor: A tf.Tensor to check the rank of.
+ expected_rank: Python integer or list of integers, expected rank.
+ name: Optional name of the tensor for the error message.
+
+ Raises:
+ ValueError: If the expected shape doesn't match the actual shape.
+ """
+ expected_rank_dict = {}
+ if isinstance(expected_rank, six.integer_types):
+ expected_rank_dict[expected_rank] = True
+ else:
+ for x in expected_rank:
+ expected_rank_dict[x] = True
+
+ actual_rank = tensor.shape.ndims
+ if actual_rank not in expected_rank_dict:
+ raise ValueError(
+ "For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
+ "equal to the expected tensor rank `%s`" %
+ (name, actual_rank, str(tensor.shape), str(expected_rank)))
+
+
+def safe_mean(losses):
+ """Computes a safe mean of the losses.
+
+ Args:
+ losses: `Tensor` whose elements contain individual loss measurements.
+
+ Returns:
+ A scalar representing the mean of `losses`. If `num_present` is zero,
+ then zero is returned.
+ """
+ total = tf.reduce_sum(losses)
+ num_elements = tf.cast(tf.size(losses), dtype=losses.dtype)
+ return tf.math.divide_no_nan(total, num_elements)
diff --git a/models/official/modeling/training/__init__.py b/models/official/modeling/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/modeling/training/distributed_executor.py b/models/official/modeling/training/distributed_executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..11451260cdca52a9c9f4019010123c4d2b40e99e
--- /dev/null
+++ b/models/official/modeling/training/distributed_executor.py
@@ -0,0 +1,815 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Custom training loop for running TensorFlow 2.0 models."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+
+from absl import flags
+from absl import logging
+
+import numpy as np
+import tensorflow as tf
+
+# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
+from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
+from official.modeling.hyperparams import params_dict
+from official.utils import hyperparams_flags
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+
+strategy_flags_dict = hyperparams_flags.strategy_flags_dict
+hparam_flags_dict = hyperparams_flags.hparam_flags_dict
+
+
+def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
+ """Saves model to model_dir with provided checkpoint prefix."""
+
+ checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
+ saved_path = checkpoint.save(checkpoint_path)
+ logging.info('Saving model as TF checkpoint: %s', saved_path)
+
+
+def _steps_to_run(current_step, total_steps, steps_per_loop):
+ """Calculates steps to run on device."""
+ if steps_per_loop <= 0:
+ raise ValueError('steps_per_loop should be positive integer.')
+ return min(total_steps - current_step, steps_per_loop)
+
+
+def _no_metric():
+ return None
+
+
+def metrics_as_dict(metric):
+ """Puts input metric(s) into a list.
+
+ Args:
+ metric: metric(s) to be put into the list. `metric` could be a object, a
+ list or a dict of tf.keras.metrics.Metric or has the `required_method`.
+
+ Returns:
+ A dictionary of valid metrics.
+ """
+ if isinstance(metric, tf.keras.metrics.Metric):
+ metrics = {metric.name: metric}
+ elif isinstance(metric, list):
+ metrics = {m.name: m for m in metric}
+ elif isinstance(metric, dict):
+ metrics = metric
+ elif not metric:
+ return {}
+ else:
+ metrics = {'metric': metric}
+ return metrics
+
+
+def metric_results(metric):
+ """Collects results from the given metric(s)."""
+ metrics = metrics_as_dict(metric)
+ metric_result = {
+ name: m.result().numpy().astype(float) for name, m in metrics.items()
+ }
+ return metric_result
+
+
+def reset_states(metric):
+ """Resets states of the given metric(s)."""
+ metrics = metrics_as_dict(metric)
+ for m in metrics.values():
+ m.reset_states()
+
+
+class SummaryWriter(object):
+ """Simple SummaryWriter for writing dictionary of metrics.
+
+ Attributes:
+ writer: The tf.SummaryWriter.
+ """
+
+ def __init__(self, model_dir: Text, name: Text):
+ """Inits SummaryWriter with paths.
+
+ Arguments:
+ model_dir: the model folder path.
+ name: the summary subfolder name.
+ """
+ self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
+
+ def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
+ """Write metrics to summary with the given writer.
+
+ Args:
+ metrics: a dictionary of metrics values. Prefer dictionary.
+ step: integer. The training step.
+ """
+ if not isinstance(metrics, dict):
+ # Support scalar metric without name.
+ logging.warning('Warning: summary writer prefer metrics as dictionary.')
+ metrics = {'metric': metrics}
+
+ with self.writer.as_default():
+ for k, v in metrics.items():
+ tf.summary.scalar(k, v, step=step)
+ self.writer.flush()
+
+
+class DistributedExecutor(object):
+ """Interface to train and eval models with tf.distribute.Strategy.
+ """
+
+ def __init__(self,
+ strategy,
+ params,
+ model_fn,
+ loss_fn,
+ is_multi_host=False):
+ """Constructor.
+
+ Args:
+ strategy: an instance of tf.distribute.Strategy.
+ params: Model configuration needed to run distribution strategy.
+ model_fn: Keras model function. Signature:
+ (params: ParamsDict) -> tf.keras.models.Model.
+ loss_fn: loss function. Signature:
+ (y_true: Tensor, y_pred: Tensor) -> Tensor
+ is_multi_host: Set to True when using multi hosts for training, like multi
+ worker GPU or TPU pod (slice). Otherwise, False.
+ """
+
+ self._params = params
+ self._model_fn = model_fn
+ self._loss_fn = loss_fn
+ self._strategy = strategy
+ self._checkpoint_name = 'ctl_step_{step}.ckpt'
+ self._is_multi_host = is_multi_host
+ self.train_summary_writer = None
+ self.eval_summary_writer = None
+ self.global_train_step = None
+
+ @property
+ def checkpoint_name(self):
+ """Returns default checkpoint name."""
+ return self._checkpoint_name
+
+ @checkpoint_name.setter
+ def checkpoint_name(self, name):
+ """Sets default summary writer for the current thread."""
+ self._checkpoint_name = name
+
+ def loss_fn(self):
+ return self._loss_fn()
+
+ def model_fn(self, params):
+ return self._model_fn(params)
+
+ def _save_config(self, model_dir):
+ """Save parameters to config files if model_dir is defined."""
+
+ logging.info('Save config to model_dir %s.', model_dir)
+ if model_dir:
+ if not tf.io.gfile.exists(model_dir):
+ tf.io.gfile.makedirs(model_dir)
+ self._params.lock()
+ params_dict.save_params_dict_to_yaml(self._params,
+ model_dir + '/params.yaml')
+ else:
+ logging.warning('model_dir is empty, so skip the save config.')
+
+ def _get_input_iterator(
+ self, input_fn: Callable[..., tf.data.Dataset],
+ strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
+ """Returns distributed dataset iterator.
+
+ Args:
+ input_fn: (params: dict) -> tf.data.Dataset.
+ strategy: an instance of tf.distribute.Strategy.
+
+ Returns:
+ An iterator that yields input tensors.
+ """
+
+ if input_fn is None:
+ return None
+ # When training with multiple TPU workers, datasets needs to be cloned
+ # across workers. Since Dataset instance cannot be cloned in eager mode,
+ # we instead pass callable that returns a dataset.
+ if self._is_multi_host:
+ return iter(
+ strategy.experimental_distribute_datasets_from_function(input_fn))
+ else:
+ input_data = input_fn()
+ return iter(strategy.experimental_distribute_dataset(input_data))
+
+ def _create_replicated_step(self,
+ strategy,
+ model,
+ loss_fn,
+ optimizer,
+ metric=None):
+ """Creates a single training step.
+
+ Args:
+ strategy: an instance of tf.distribute.Strategy.
+ model: (Tensor, bool) -> Tensor. model function.
+ loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
+ optimizer: tf.keras.optimizers.Optimizer.
+ metric: tf.keras.metrics.Metric subclass.
+
+ Returns:
+ The training step callable.
+ """
+ metrics = metrics_as_dict(metric)
+
+ def _replicated_step(inputs):
+ """Replicated training step."""
+ inputs, labels = inputs
+
+ with tf.GradientTape() as tape:
+ outputs = model(inputs, training=True)
+ prediction_loss = loss_fn(labels, outputs)
+ loss = tf.reduce_mean(prediction_loss)
+ loss = loss / strategy.num_replicas_in_sync
+ for m in metrics.values():
+ m.update_state(labels, outputs)
+
+ grads = tape.gradient(loss, model.trainable_variables)
+ optimizer.apply_gradients(zip(grads, model.trainable_variables))
+ return loss
+
+ return _replicated_step
+
+ def _create_train_step(self,
+ strategy,
+ model,
+ loss_fn,
+ optimizer,
+ metric=None):
+ """Creates a distributed training step.
+
+ Args:
+ strategy: an instance of tf.distribute.Strategy.
+ model: (Tensor, bool) -> Tensor. model function.
+ loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
+ optimizer: tf.keras.optimizers.Optimizer.
+ metric: tf.keras.metrics.Metric subclass.
+
+ Returns:
+ The training step callable.
+ """
+ replicated_step = self._create_replicated_step(strategy, model, loss_fn,
+ optimizer, metric)
+
+ @tf.function
+ def train_step(iterator, num_steps):
+ """Performs a distributed training step.
+
+ Args:
+ iterator: an iterator that yields input tensors.
+ num_steps: the number of steps in the loop.
+
+ Returns:
+ The loss tensor.
+ """
+ if not isinstance(num_steps, tf.Tensor):
+ raise ValueError('steps should be an Tensor. Python object may cause '
+ 'retracing.')
+
+ per_replica_losses = strategy.run(
+ replicated_step, args=(next(iterator),))
+ for _ in tf.range(num_steps - 1):
+ per_replica_losses = strategy.run(
+ replicated_step, args=(next(iterator),))
+
+ # For reporting, we returns the mean of losses.
+ losses = tf.nest.map_structure(
+ lambda x: strategy.reduce(tf.distribute.ReduceOp.MEAN, x, axis=None),
+ per_replica_losses)
+ return losses
+
+ return train_step
+
+ def _create_test_step(self, strategy, model, metric):
+ """Creates a distributed test step."""
+ metrics = metrics_as_dict(metric)
+
+ @tf.function
+ def test_step(iterator):
+ """Calculates evaluation metrics on distributed devices."""
+ if not metric:
+ logging.info('Skip test_step because metric is None (%s)', metric)
+ return None, None
+
+ def _test_step_fn(inputs):
+ """Replicated accuracy calculation."""
+ inputs, labels = inputs
+ model_outputs = model(inputs, training=False)
+ for m in metrics.values():
+ m.update_state(labels, model_outputs)
+ return labels, model_outputs
+
+ return strategy.run(_test_step_fn, args=(next(iterator),))
+
+ return test_step
+
+ def train(self,
+ train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
+ eval_input_fn: Callable[[params_dict.ParamsDict],
+ tf.data.Dataset] = None,
+ model_dir: Text = None,
+ total_steps: int = 1,
+ iterations_per_loop: int = 1,
+ train_metric_fn: Callable[[], Any] = None,
+ eval_metric_fn: Callable[[], Any] = None,
+ summary_writer_fn: Callable[[Text, Text],
+ SummaryWriter] = SummaryWriter,
+ init_checkpoint: Callable[[tf.keras.Model], Any] = None,
+ custom_callbacks: List[tf.keras.callbacks.Callback] = None,
+ continuous_eval: bool = False,
+ save_config: bool = True):
+ """Runs distributed training.
+
+ Args:
+ train_input_fn: (params: dict) -> tf.data.Dataset training data input
+ function.
+ eval_input_fn: (Optional) same type as train_input_fn. If not None, will
+ trigger evaluting metric on eval data. If None, will not run eval step.
+ model_dir: the folder path for model checkpoints.
+ total_steps: total training steps.
+ iterations_per_loop: train steps per loop. After each loop, this job will
+ update metrics like loss and save checkpoint.
+ train_metric_fn: metric_fn for evaluation in train_step.
+ eval_metric_fn: metric_fn for evaluation in test_step.
+ summary_writer_fn: function to create summary writer.
+ init_checkpoint: function to load checkpoint.
+ custom_callbacks: A list of Keras Callbacks objects to run during
+ training. More specifically, `on_batch_begin()`, `on_batch_end()`,
+ methods are invoked during training.
+ continuous_eval: If `True`, will continously run evaluation on every
+ available checkpoints. If `False`, will do the evaluation once after the
+ final step.
+ save_config: bool. Whether to save params to model_dir.
+ Returns:
+ The training loss and eval metrics.
+ """
+ assert train_input_fn is not None
+ if train_metric_fn and not callable(train_metric_fn):
+ raise ValueError('if `train_metric_fn` is specified, '
+ 'train_metric_fn must be a callable.')
+ if eval_metric_fn and not callable(eval_metric_fn):
+ raise ValueError('if `eval_metric_fn` is specified, '
+ 'eval_metric_fn must be a callable.')
+ train_metric_fn = train_metric_fn or _no_metric
+ eval_metric_fn = eval_metric_fn or _no_metric
+
+ if custom_callbacks and iterations_per_loop != 1:
+ logging.warning(
+ 'It is sematically wrong to run callbacks when '
+ 'iterations_per_loop is not one (%s)', iterations_per_loop)
+
+ custom_callbacks = custom_callbacks or []
+
+ def _run_callbacks_on_batch_begin(batch):
+ """Runs custom callbacks at the start of every step."""
+ if not custom_callbacks:
+ return
+ for callback in custom_callbacks:
+ if callback:
+ callback.on_batch_begin(batch)
+
+ def _run_callbacks_on_batch_end(batch):
+ """Runs custom callbacks at the end of every step."""
+ if not custom_callbacks:
+ return
+ for callback in custom_callbacks:
+ if callback:
+ callback.on_batch_end(batch)
+
+ if save_config:
+ self._save_config(model_dir)
+
+ if FLAGS.save_checkpoint_freq:
+ save_freq = FLAGS.save_checkpoint_freq
+ else:
+ save_freq = iterations_per_loop
+
+ params = self._params
+ strategy = self._strategy
+ # To reduce unnecessary send/receive input pipeline operation, we place
+ # input pipeline ops in worker task.
+ train_iterator = self._get_input_iterator(train_input_fn, strategy)
+ train_loss = None
+ train_metric_result = None
+ eval_metric_result = None
+ tf.keras.backend.set_learning_phase(1)
+ with strategy.scope():
+ # To correctly place the model weights on accelerators,
+ # model and optimizer should be created in scope.
+ model = self.model_fn(params.as_dict())
+ if not hasattr(model, 'optimizer'):
+ raise ValueError('User should set optimizer attribute to model '
+ 'inside `model_fn`.')
+ optimizer = model.optimizer
+
+ # Training loop starts here.
+ checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
+ latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
+ initial_step = 0
+ if latest_checkpoint_file:
+ logging.info(
+ 'Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint_file)
+ checkpoint.restore(latest_checkpoint_file)
+ initial_step = optimizer.iterations.numpy()
+ logging.info('Loading from checkpoint file completed. Init step %d',
+ initial_step)
+ elif init_checkpoint:
+ logging.info('Restoring from init checkpoint function')
+ init_checkpoint(model)
+ logging.info('Loading from init checkpoint file completed')
+
+ current_step = optimizer.iterations.numpy()
+ checkpoint_name = self.checkpoint_name
+
+ eval_metric = eval_metric_fn()
+ train_metric = train_metric_fn()
+ train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
+ self.train_summary_writer = train_summary_writer.writer
+
+ test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
+ self.eval_summary_writer = test_summary_writer.writer
+
+ # Use training summary writer in TimeHistory if it's in use
+ for cb in custom_callbacks:
+ if isinstance(cb, keras_utils.TimeHistory):
+ cb.summary_writer = self.train_summary_writer
+
+ # Continue training loop.
+ train_step = self._create_train_step(
+ strategy=strategy,
+ model=model,
+ loss_fn=self.loss_fn(),
+ optimizer=optimizer,
+ metric=train_metric)
+ test_step = None
+ if eval_input_fn and eval_metric:
+ self.global_train_step = model.optimizer.iterations
+ test_step = self._create_test_step(strategy, model, metric=eval_metric)
+
+ # Step-0 operations
+ if current_step == 0 and not latest_checkpoint_file:
+ _save_checkpoint(
+ checkpoint, model_dir, checkpoint_name.format(step=current_step))
+ if test_step:
+ eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
+ eval_metric_result = self._run_evaluation(
+ test_step, current_step, eval_metric, eval_iterator)
+ logging.info(
+ 'Step: %s evalation metric = %s.', current_step, eval_metric_result)
+ test_summary_writer(
+ metrics=eval_metric_result, step=optimizer.iterations)
+ reset_states(eval_metric)
+
+ logging.info('Training started')
+ last_save_checkpoint_step = current_step
+ while current_step < total_steps:
+
+ num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
+ _run_callbacks_on_batch_begin(current_step)
+ train_loss = train_step(train_iterator,
+ tf.convert_to_tensor(num_steps, dtype=tf.int32))
+ current_step += num_steps
+
+ train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
+ train_loss)
+
+ _run_callbacks_on_batch_end(current_step - 1)
+ if not isinstance(train_loss, dict):
+ train_loss = {'total_loss': train_loss}
+ if np.isnan(train_loss['total_loss']):
+ raise ValueError('total loss is NaN.')
+
+ if train_metric:
+ train_metric_result = metric_results(train_metric)
+ train_metric_result.update(train_loss)
+ else:
+ train_metric_result = train_loss
+ if callable(optimizer.lr):
+ train_metric_result.update(
+ {'learning_rate': optimizer.lr(current_step).numpy()})
+ else:
+ train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
+ logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
+ current_step, total_steps, train_loss,
+ train_metric_result)
+
+ train_summary_writer(
+ metrics=train_metric_result, step=optimizer.iterations)
+
+ # Saves model checkpoints and run validation steps at every
+ # iterations_per_loop steps.
+ # To avoid repeated model saving, we do not save after the last
+ # step of training.
+ if save_freq > 0 and current_step < total_steps and (
+ current_step - last_save_checkpoint_step) >= save_freq:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ last_save_checkpoint_step = current_step
+
+ if continuous_eval and current_step < total_steps and test_step:
+ eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
+ eval_metric_result = self._run_evaluation(test_step, current_step,
+ eval_metric, eval_iterator)
+ logging.info('Step: %s evalation metric = %s.', current_step,
+ eval_metric_result)
+ test_summary_writer(
+ metrics=eval_metric_result, step=optimizer.iterations)
+
+ # Re-initialize evaluation metric, except the last step.
+ if eval_metric and current_step < total_steps:
+ reset_states(eval_metric)
+ if train_metric and current_step < total_steps:
+ reset_states(train_metric)
+
+ # Reaches the end of training and saves the last checkpoint.
+ if last_save_checkpoint_step < total_steps:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+
+ if test_step:
+ logging.info('Running final evaluation after training is complete.')
+ eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
+ eval_metric_result = self._run_evaluation(test_step, current_step,
+ eval_metric, eval_iterator)
+ logging.info('Final evaluation metric = %s.', eval_metric_result)
+ test_summary_writer(
+ metrics=eval_metric_result, step=optimizer.iterations)
+
+ self.train_summary_writer.close()
+ self.eval_summary_writer.close()
+
+ return train_metric_result, eval_metric_result
+
+ def _run_evaluation(self, test_step, current_training_step, metric,
+ test_iterator):
+ """Runs validation steps and aggregate metrics."""
+ if not test_iterator or not metric:
+ logging.warning(
+ 'Both test_iterator (%s) and metrics (%s) must not be None.',
+ test_iterator, metric)
+ return None
+ logging.info('Running evaluation after step: %s.', current_training_step)
+ eval_step = 0
+ while True:
+ try:
+ with tf.experimental.async_scope():
+ test_step(test_iterator)
+ eval_step += 1
+ except (StopIteration, tf.errors.OutOfRangeError):
+ tf.experimental.async_clear_error()
+ break
+
+ metric_result = metric_results(metric)
+ logging.info('Total eval steps: [%d]', eval_step)
+ logging.info('At training step: [%r] Validation metric = %r',
+ current_training_step, metric_result)
+ return metric_result
+
+ def evaluate_from_model_dir(
+ self,
+ model_dir: Text,
+ eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
+ eval_metric_fn: Callable[[], Any],
+ total_steps: int = -1,
+ eval_timeout: int = None,
+ min_eval_interval: int = 180,
+ summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
+ """Runs distributed evaluation on model folder.
+
+ Args:
+ model_dir: the folder for storing model checkpoints.
+ eval_input_fn: (Optional) same type as train_input_fn. If not None, will
+ trigger evaluting metric on eval data. If None, will not run eval step.
+ eval_metric_fn: metric_fn for evaluation in test_step.
+ total_steps: total training steps. If the current step reaches the
+ total_steps, the evaluation loop will stop.
+ eval_timeout: The maximum number of seconds to wait between checkpoints.
+ If left as None, then the process will wait indefinitely. Used by
+ tf.train.checkpoints_iterator.
+ min_eval_interval: The minimum number of seconds between yielding
+ checkpoints. Used by tf.train.checkpoints_iterator.
+ summary_writer_fn: function to create summary writer.
+
+ Returns:
+ Eval metrics dictionary of the last checkpoint.
+ """
+
+ if not model_dir:
+ raise ValueError('model_dir must be set.')
+
+ def terminate_eval():
+ tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
+ eval_timeout)
+ return True
+
+ summary_writer = summary_writer_fn(model_dir, 'eval')
+ self.eval_summary_writer = summary_writer.writer
+
+ # Read checkpoints from the given model directory
+ # until `eval_timeout` seconds elapses.
+ for checkpoint_path in tf.train.checkpoints_iterator(
+ model_dir,
+ min_interval_secs=min_eval_interval,
+ timeout=eval_timeout,
+ timeout_fn=terminate_eval):
+ eval_metric_result, current_step = self.evaluate_checkpoint(
+ checkpoint_path=checkpoint_path,
+ eval_input_fn=eval_input_fn,
+ eval_metric_fn=eval_metric_fn,
+ summary_writer=summary_writer)
+ if total_steps > 0 and current_step >= total_steps:
+ logging.info('Evaluation finished after training step %d', current_step)
+ break
+ return eval_metric_result
+
+ def evaluate_checkpoint(self,
+ checkpoint_path: Text,
+ eval_input_fn: Callable[[params_dict.ParamsDict],
+ tf.data.Dataset],
+ eval_metric_fn: Callable[[], Any],
+ summary_writer: SummaryWriter = None):
+ """Runs distributed evaluation on the one checkpoint.
+
+ Args:
+ checkpoint_path: the checkpoint to evaluate.
+ eval_input_fn: (Optional) same type as train_input_fn. If not None, will
+ trigger evaluting metric on eval data. If None, will not run eval step.
+ eval_metric_fn: metric_fn for evaluation in test_step.
+ summary_writer: function to create summary writer.
+
+ Returns:
+ Eval metrics dictionary of the last checkpoint.
+ """
+ if not callable(eval_metric_fn):
+ raise ValueError('if `eval_metric_fn` is specified, '
+ 'eval_metric_fn must be a callable.')
+
+ old_phrase = tf.keras.backend.learning_phase()
+ tf.keras.backend.set_learning_phase(0)
+ params = self._params
+ strategy = self._strategy
+ # To reduce unnecessary send/receive input pipeline operation, we place
+ # input pipeline ops in worker task.
+ with strategy.scope():
+
+ # To correctly place the model weights on accelerators,
+ # model and optimizer should be created in scope.
+ model = self.model_fn(params.as_dict())
+ checkpoint = tf.train.Checkpoint(model=model)
+
+ eval_metric = eval_metric_fn()
+ assert eval_metric, 'eval_metric does not exist'
+ test_step = self._create_test_step(strategy, model, metric=eval_metric)
+
+ logging.info('Starting to evaluate.')
+ if not checkpoint_path:
+ raise ValueError('checkpoint path is empty')
+ reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
+ current_step = reader.get_tensor(
+ 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
+ logging.info(
+ 'Checkpoint file %s found and restoring from '
+ 'checkpoint', checkpoint_path)
+ checkpoint.restore(checkpoint_path)
+
+ self.global_train_step = model.optimizer.iterations
+ eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
+ eval_metric_result = self._run_evaluation(test_step, current_step,
+ eval_metric, eval_iterator)
+ logging.info('Step: %s evalation metric = %s.', current_step,
+ eval_metric_result)
+ summary_writer(metrics=eval_metric_result, step=current_step)
+ reset_states(eval_metric)
+
+ tf.keras.backend.set_learning_phase(old_phrase)
+ return eval_metric_result, current_step
+
+ def predict(self):
+ return NotImplementedError('Unimplmented function.')
+
+
+class ExecutorBuilder(object):
+ """Builder of DistributedExecutor.
+
+ Example 1: Builds an executor with supported Strategy.
+ builder = ExecutorBuilder(
+ strategy_type='tpu',
+ strategy_config={'tpu': '/bns/xxx'})
+ dist_executor = builder.build_executor(
+ params=params,
+ model_fn=my_model_fn,
+ loss_fn=my_loss_fn,
+ metric_fn=my_metric_fn)
+
+ Example 2: Builds an executor with customized Strategy.
+ builder = ExecutorBuilder()
+ builder.strategy =
+ dist_executor = builder.build_executor(
+ params=params,
+ model_fn=my_model_fn,
+ loss_fn=my_loss_fn,
+ metric_fn=my_metric_fn)
+
+ Example 3: Builds a customized executor with customized Strategy.
+ class MyDistributedExecutor(DistributedExecutor):
+ # implementation ...
+
+ builder = ExecutorBuilder()
+ builder.strategy =
+ dist_executor = builder.build_executor(
+ class_ctor=MyDistributedExecutor,
+ params=params,
+ model_fn=my_model_fn,
+ loss_fn=my_loss_fn,
+ metric_fn=my_metric_fn)
+ """
+
+ def __init__(self, strategy_type=None, strategy_config=None):
+ _ = distribution_utils.configure_cluster(
+ strategy_config.worker_hosts, strategy_config.task_index)
+ """Constructor.
+
+ Args:
+ strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
+ If None. User is responsible to set the strategy before calling
+ build_executor(...).
+ strategy_config: necessary config for constructing the proper Strategy.
+ Check strategy_flags_dict() for examples of the structure.
+ """
+ self._strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=strategy_type,
+ num_gpus=strategy_config.num_gpus,
+ all_reduce_alg=strategy_config.all_reduce_alg,
+ num_packs=strategy_config.num_packs,
+ tpu_address=strategy_config.tpu)
+
+ @property
+ def strategy(self):
+ """Returns default checkpoint name."""
+ return self._strategy
+
+ @strategy.setter
+ def strategy(self, new_strategy):
+ """Sets default summary writer for the current thread."""
+ self._strategy = new_strategy
+
+ def build_executor(self,
+ class_ctor=DistributedExecutor,
+ params=None,
+ model_fn=None,
+ loss_fn=None,
+ **kwargs):
+ """Creates an executor according to strategy type.
+
+ See doc string of the DistributedExecutor.__init__ for more information of
+ the
+ input arguments.
+
+ Args:
+ class_ctor: A constructor of executor (default: DistributedExecutor).
+ params: ParamsDict, all the model parameters and runtime parameters.
+ model_fn: Keras model function.
+ loss_fn: loss function.
+ **kwargs: other arguments to the executor constructor.
+
+ Returns:
+ An instance of DistributedExecutor or its subclass.
+ """
+ if self._strategy is None:
+ raise ValueError('`strategy` should not be None. You need to specify '
+ '`strategy_type` in the builder contructor or directly '
+ 'set the `strategy` property of the builder.')
+ return class_ctor(
+ strategy=self._strategy,
+ params=params,
+ model_fn=model_fn,
+ loss_fn=loss_fn,
+ **kwargs)
diff --git a/models/official/nlp/README.md b/models/official/nlp/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..156f5c42858be92f20ed6bc157ddd8593cbc4329
--- /dev/null
+++ b/models/official/nlp/README.md
@@ -0,0 +1,37 @@
+# TensorFlow NLP Modelling Toolkit
+
+This codebase provides a Natrual Language Processing modeling toolkit written in
+[TF2](https://www.tensorflow.org/guide/effective_tf2). It allows researchers and
+developers to reproduce state-of-the-art model results and train custom models
+to experiment new research ideas.
+
+## Features
+
+* Reusable and modularized modeling building blocks
+* State-of-the-art reproducible
+* Easy to customize and extend
+* End-to-end training
+* Distributed trainable on both GPUs and TPUs
+
+## Major components
+
+### Libraries
+
+We provide modeling library to allow users to train custom models for new
+research ideas. Detailed intructions can be found in READMEs in each folder.
+
+* [modeling/](modeling): modeling library that provides building blocks (e.g., Layers, Networks, and Models) that can be assembled into transformer-based achitectures .
+* [data/](data): binaries and utils for input preprocessing, tokenization, etc.
+
+### State-of-the-Art models and examples
+
+We provide SoTA model implementations, pre-trained models, training and
+evaluation examples, and command lines. Detail instructions can be found in the
+READMEs for specific papers.
+
+1. [BERT](bert): [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Devlin et al., 2018
+2. [ALBERT](albert): [A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) by Lan et al., 2019
+3. [XLNet](xlnet): [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Yang et al., 2019
+4. [Transformer for translation](transformer): [Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et al., 2017
+5. [NHNet](nhnet): [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) by Gu et al, 2020
+
diff --git a/models/official/nlp/__init__.py b/models/official/nlp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/nlp/__pycache__/__init__.cpython-310.pyc b/models/official/nlp/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12138212531ff436903dd5fa8e73277b898b610c
Binary files /dev/null and b/models/official/nlp/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/official/nlp/__pycache__/__init__.cpython-38.pyc b/models/official/nlp/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..189ad70a8ac6b38cec6a6639e0790bdecf6b6ec5
Binary files /dev/null and b/models/official/nlp/__pycache__/__init__.cpython-38.pyc differ
diff --git a/models/official/nlp/__pycache__/__init__.cpython-39.pyc b/models/official/nlp/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cea2816038525fbca46e8d0575126c9fa59b6a2
Binary files /dev/null and b/models/official/nlp/__pycache__/__init__.cpython-39.pyc differ
diff --git a/models/official/nlp/__pycache__/optimization.cpython-38.pyc b/models/official/nlp/__pycache__/optimization.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..737c69c61222d76eb2a4c8aa7b070f8f89a4df69
Binary files /dev/null and b/models/official/nlp/__pycache__/optimization.cpython-38.pyc differ
diff --git a/models/official/nlp/__pycache__/optimization.cpython-39.pyc b/models/official/nlp/__pycache__/optimization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8ca526bdcc669126e2467acd6ee6f3cfd417674
Binary files /dev/null and b/models/official/nlp/__pycache__/optimization.cpython-39.pyc differ
diff --git a/models/official/nlp/albert/README.md b/models/official/nlp/albert/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cfb726c90ef9a638d5fd0485e341c232a86bdac2
--- /dev/null
+++ b/models/official/nlp/albert/README.md
@@ -0,0 +1,332 @@
+# ALBERT (ALBERT: A Lite BERT for Self-supervised Learning of Language Representations)
+
+The academic paper which describes ALBERT in detail and provides full results on
+a number of tasks can be found here: https://arxiv.org/abs/1909.11942.
+
+This repository contains TensorFlow 2.x implementation for ALBERT.
+
+## Contents
+ * [Contents](#contents)
+ * [Pre-trained Models](#pre-trained-models)
+ * [Restoring from Checkpoints](#restoring-from-checkpoints)
+ * [Set Up](#set-up)
+ * [Process Datasets](#process-datasets)
+ * [Fine-tuning with BERT](#fine-tuning-with-bert)
+ * [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
+ * [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
+ * [SQuAD 1.1](#squad-1.1)
+
+
+## Pre-trained Models
+
+We released both checkpoints and tf.hub modules as the pretrained models for
+fine-tuning. They are TF 2.x compatible and are converted from the ALBERT v2
+checkpoints released in TF 1.x official ALBERT repository
+[google-research/albert](https://github.com/google-research/albert)
+in order to keep consistent with ALBERT paper.
+
+Our current released checkpoints are exactly the same as TF 1.x official ALBERT
+repository.
+
+### Access to Pretrained Checkpoints
+
+Pretrained checkpoints can be found in the following links:
+
+**Note: We implemented ALBERT using Keras functional-style networks in [nlp/modeling](../modeling).
+ALBERT V2 models compatible with TF 2.x checkpoints are:**
+
+* **[`ALBERT V2 Base`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base.tar.gz)**:
+ 12-layer, 768-hidden, 12-heads, 12M parameters
+* **[`ALBERT V2 Large`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_large.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 18M parameters
+* **[`ALBERT V2 XLarge`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_xlarge.tar.gz)**:
+ 24-layer, 2048-hidden, 32-heads, 60M parameters
+* **[`ALBERT V2 XXLarge`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_xxlarge.tar.gz)**:
+ 12-layer, 4096-hidden, 64-heads, 235M parameters
+
+We recommend to host checkpoints on Google Cloud storage buckets when you use
+Cloud GPU/TPU.
+
+### Restoring from Checkpoints
+
+`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
+weights from provided pre-trained checkpoints, you can use the following code:
+
+```python
+init_checkpoint='the pretrained model checkpoint path.'
+model=tf.keras.Model() # Bert pre-trained model as feature extractor.
+checkpoint = tf.train.Checkpoint(model=model)
+checkpoint.restore(init_checkpoint)
+```
+
+Checkpoints featuring native serialized Keras models
+(i.e. model.load()/load_weights()) will be available soon.
+
+### Access to Pretrained hub modules.
+
+Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
+following links:
+
+* **[`ALBERT V2 Base`](https://tfhub.dev/tensorflow/albert_en_base/1)**:
+ 12-layer, 768-hidden, 12-heads, 12M parameters
+* **[`ALBERT V2 Large`](https://tfhub.dev/tensorflow/albert_en_large/1)**:
+ 24-layer, 1024-hidden, 16-heads, 18M parameters
+* **[`ALBERT V2 XLarge`](https://tfhub.dev/tensorflow/albert_en_xlarge/1)**:
+ 24-layer, 2048-hidden, 32-heads, 60M parameters
+* **[`ALBERT V2 XXLarge`](https://tfhub.dev/tensorflow/albert_en_xxlarge/1)**:
+ 12-layer, 4096-hidden, 64-heads, 235M parameters
+
+## Set Up
+
+```shell
+export PYTHONPATH="$PYTHONPATH:/path/to/models"
+```
+
+Install `tf-nightly` to get latest updates:
+
+```shell
+pip install tf-nightly-gpu
+```
+
+With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
+TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
+
+```shell
+ctpu up -name --tf-version=”nightly”
+```
+
+Second, you need to install TF 2 `tf-nightly` on your VM:
+
+```shell
+pip install tf-nightly
+```
+
+Warning: More details TPU-specific set-up instructions and tutorial should come
+along with official TF 2.x release for TPU. Note that this repo is not
+officially supported by Google Cloud TPU team yet until TF 2.1 released.
+
+## Process Datasets
+
+### Pre-training
+
+Pre-train ALBERT using TF2.x will come soon.
+For now, please use [ALBERT research repo](https://github.com/google-research/ALBERT)
+to pretrain the model and convert the checkpoint to TF2.x compatible ones using
+[tf2_albert_encoder_checkpoint_converter.py](tf2_albert_encoder_checkpoint_converter.py).
+
+
+
+### Fine-tuning
+
+To prepare the fine-tuning data for final model training, use the
+[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
+Note that different from BERT models that use word piece tokenzer,
+ALBERT models employ sentence piece tokenizer. So the FLAG tokenizer_impl has
+to be set to 'sentence_piece'.
+Resulting datasets in `tf_record` format and training meta data should be later
+passed to training or evaluation scripts. The task-specific arguments are
+described in following sections:
+
+* GLUE
+
+Users can download the
+[GLUE data](https://gluebenchmark.com/tasks) by running
+[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
+and unpack it to some directory `$GLUE_DIR`.
+
+```shell
+export GLUE_DIR=~/glue
+export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
+
+export TASK_NAME=MNLI
+export OUTPUT_DIR=gs://some_bucket/datasets
+python ../data/create_finetuning_data.py \
+ --input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
+ --sp_model_file=${ALBERT_DIR}/30k-clean.model \
+ --train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
+ --eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
+ --meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
+ --fine_tuning_task_type=classification --max_seq_length=128 \
+ --classification_task_name=${TASK_NAME} \
+ --tokenizer_impl=sentence_piece
+```
+
+* SQUAD
+
+The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
+detailed information about the SQuAD datasets and evaluation.
+
+The necessary files can be found here:
+
+* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
+* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
+* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
+* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
+* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
+* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
+
+```shell
+export SQUAD_DIR=~/squad
+export SQUAD_VERSION=v1.1
+export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
+export OUTPUT_DIR=gs://some_bucket/datasets
+
+python ../data/create_finetuning_data.py \
+ --squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
+ --sp_model_file=${ALBERT_DIR}/30k-clean.model \
+ --train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --fine_tuning_task_type=squad --max_seq_length=384 \
+ --tokenizer_impl=sentence_piece
+```
+
+## Fine-tuning with ALBERT
+
+### Cloud GPUs and TPUs
+
+* Cloud Storage
+
+The unzipped pre-trained model files can also be found in the Google Cloud
+Storage folder `gs://cloud-tpu-checkpoints/albert/checkpoints`. For example:
+
+```shell
+export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
+export MODEL_DIR=gs://some_bucket/my_output_dir
+```
+
+Currently, users are able to access to `tf-nightly` TPUs and the following TPU
+script should run with `tf-nightly`.
+
+* GPU -> TPU
+
+Just add the following flags to `run_classifier.py` or `run_squad.py`:
+
+```shell
+ --distribution_strategy=tpu
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+### Sentence and Sentence-pair Classification Tasks
+
+This example code fine-tunes `albert_v2_base` on the Microsoft Research
+Paraphrase Corpus (MRPC) corpus, which only contains 3,600 examples and can
+fine-tune in a few minutes on most GPUs.
+
+We use the `albert_v2_base` as an example throughout the
+workflow.
+
+
+```shell
+export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export GLUE_DIR=gs://some_bucket/datasets
+export TASK=MRPC
+
+python run_classifier.py \
+ --mode='train_and_eval' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=${ALBERT_DIR}/albert_config.json \
+ --init_checkpoint=${ALBERT_DIR}/bert_model.ckpt \
+ --train_batch_size=4 \
+ --eval_batch_size=4 \
+ --steps_per_loop=1 \
+ --learning_rate=2e-5 \
+ --num_train_epochs=3 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=mirrored
+```
+
+Alternatively, instead of specifying `init_checkpoint`, you can specify
+`hub_module_url` to employ a pretraind BERT hub module, e.g.,
+` --hub_module_url=https://tfhub.dev/tensorflow/albert_en_base/1`.
+
+To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
+information and use remote storage for model checkpoints.
+
+```shell
+export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
+export TPU_IP_ADDRESS='???'
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export GLUE_DIR=gs://some_bucket/datasets
+
+python run_classifier.py \
+ --mode='train_and_eval' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=$ALBERT_DIR/albert_config.json \
+ --init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
+ --train_batch_size=32 \
+ --eval_batch_size=32 \
+ --learning_rate=2e-5 \
+ --num_train_epochs=3 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=tpu \
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+### SQuAD 1.1
+
+The Stanford Question Answering Dataset (SQuAD) is a popular question answering
+benchmark dataset. See more in [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
+
+We use the `albert_v2_base` as an example throughout the
+workflow.
+
+```shell
+export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
+export SQUAD_DIR=gs://some_bucket/datasets
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export SQUAD_VERSION=v1.1
+
+python run_squad.py \
+ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
+ --sp_model_file=${ALBERT_DIR}/30k-clean.model \
+ --bert_config_file=$ALBERT_DIR/albert_config.json \
+ --init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
+ --train_batch_size=4 \
+ --predict_batch_size=4 \
+ --learning_rate=8e-5 \
+ --num_train_epochs=2 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=mirrored
+```
+
+Similarily, you can replace `init_checkpoint` FLAGS with `hub_module_url` to
+specify a hub module path.
+
+To use TPU, you need switch distribution strategy type to `tpu` with TPU
+information.
+
+```shell
+export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
+export TPU_IP_ADDRESS='???'
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export SQUAD_DIR=gs://some_bucket/datasets
+export SQUAD_VERSION=v1.1
+
+python run_squad.py \
+ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
+ --sp_model_file=${ALBERT_DIR}/30k-clean.model \
+ --bert_config_file=$ALBERT_DIR/albert_config.json \
+ --init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
+ --train_batch_size=32 \
+ --learning_rate=8e-5 \
+ --num_train_epochs=2 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=tpu \
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+The dev set predictions will be saved into a file called predictions.json in the
+model_dir:
+
+```shell
+python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
+```
diff --git a/models/official/nlp/albert/__init__.py b/models/official/nlp/albert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/nlp/albert/configs.py b/models/official/nlp/albert/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..10fbb79bd50cc224f4192819bfb428cde357ef3c
--- /dev/null
+++ b/models/official/nlp/albert/configs.py
@@ -0,0 +1,58 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The ALBERT configurations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from official.nlp.bert import configs
+
+
+class AlbertConfig(configs.BertConfig):
+ """Configuration for `ALBERT`."""
+
+ def __init__(self,
+ num_hidden_groups=1,
+ inner_group_num=1,
+ **kwargs):
+ """Constructs AlbertConfig.
+
+ Args:
+ num_hidden_groups: Number of group for the hidden layers, parameters in
+ the same group are shared. Note that this value and also the following
+ 'inner_group_num' has to be 1 for now, because all released ALBERT
+ models set them to 1. We may support arbitary valid values in future.
+ inner_group_num: Number of inner repetition of attention and ffn.
+ **kwargs: The remaining arguments are the same as above 'BertConfig'.
+ """
+ super(AlbertConfig, self).__init__(**kwargs)
+
+ # TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
+ # in the released ALBERT. Support other values in AlbertTransformerEncoder
+ # if needed.
+ if inner_group_num != 1 or num_hidden_groups != 1:
+ raise ValueError("We only support 'inner_group_num' and "
+ "'num_hidden_groups' as 1.")
+
+ @classmethod
+ def from_dict(cls, json_object):
+ """Constructs a `AlbertConfig` from a Python dictionary of parameters."""
+ config = AlbertConfig(vocab_size=None)
+ for (key, value) in six.iteritems(json_object):
+ config.__dict__[key] = value
+ return config
diff --git a/models/official/nlp/albert/export_albert_tfhub.py b/models/official/nlp/albert/export_albert_tfhub.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a1af1a17735c5f0b995bb5e431fe143ffffa1d1
--- /dev/null
+++ b/models/official/nlp/albert/export_albert_tfhub.py
@@ -0,0 +1,88 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A script to export the ALBERT core model as a TF-Hub SavedModel."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import app
+from absl import flags
+import tensorflow as tf
+from typing import Text
+
+from official.nlp.albert import configs
+from official.nlp.bert import bert_models
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("albert_config_file", None,
+ "Albert configuration file to define core albert layers.")
+flags.DEFINE_string("model_checkpoint_path", None,
+ "File path to TF model checkpoint.")
+flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
+flags.DEFINE_string(
+ "sp_model_file", None,
+ "The sentence piece model file that the ALBERT model was trained on.")
+
+
+def create_albert_model(
+ albert_config: configs.AlbertConfig) -> tf.keras.Model:
+ """Creates an ALBERT keras core model from ALBERT configuration.
+
+ Args:
+ albert_config: An `AlbertConfig` to create the core model.
+
+ Returns:
+ A keras model.
+ """
+ # Adds input layers just as placeholders.
+ input_word_ids = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_word_ids")
+ input_mask = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_mask")
+ input_type_ids = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_type_ids")
+ transformer_encoder = bert_models.get_transformer_encoder(
+ albert_config, sequence_length=None)
+ sequence_output, pooled_output = transformer_encoder(
+ [input_word_ids, input_mask, input_type_ids])
+ # To keep consistent with legacy hub modules, the outputs are
+ # "pooled_output" and "sequence_output".
+ return tf.keras.Model(
+ inputs=[input_word_ids, input_mask, input_type_ids],
+ outputs=[pooled_output, sequence_output]), transformer_encoder
+
+
+def export_albert_tfhub(albert_config: configs.AlbertConfig,
+ model_checkpoint_path: Text, hub_destination: Text,
+ sp_model_file: Text):
+ """Restores a tf.keras.Model and saves for TF-Hub."""
+ core_model, encoder = create_albert_model(albert_config)
+ checkpoint = tf.train.Checkpoint(model=encoder)
+ checkpoint.restore(model_checkpoint_path).assert_consumed()
+ core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
+ core_model.save(hub_destination, include_optimizer=False, save_format="tf")
+
+
+def main(_):
+ albert_config = configs.AlbertConfig.from_json_file(
+ FLAGS.albert_config_file)
+ export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
+ FLAGS.export_path, FLAGS.sp_model_file)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/albert/export_albert_tfhub_test.py b/models/official/nlp/albert/export_albert_tfhub_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4973090365b7ce6527ef1e4458e3f334ea1a5d1b
--- /dev/null
+++ b/models/official/nlp/albert/export_albert_tfhub_test.py
@@ -0,0 +1,89 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests official.nlp.albert.export_albert_tfhub."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+import tensorflow as tf
+import tensorflow_hub as hub
+
+from official.nlp.albert import configs
+from official.nlp.albert import export_albert_tfhub
+
+
+class ExportAlbertTfhubTest(tf.test.TestCase):
+
+ def test_export_albert_tfhub(self):
+ # Exports a savedmodel for TF-Hub
+ albert_config = configs.AlbertConfig(
+ vocab_size=100,
+ embedding_size=8,
+ hidden_size=16,
+ intermediate_size=32,
+ max_position_embeddings=128,
+ num_attention_heads=2,
+ num_hidden_layers=1)
+ bert_model, encoder = export_albert_tfhub.create_albert_model(albert_config)
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
+ checkpoint = tf.train.Checkpoint(model=encoder)
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
+
+ sp_model_file = os.path.join(self.get_temp_dir(), "sp_tokenizer.model")
+ with tf.io.gfile.GFile(sp_model_file, "w") as f:
+ f.write("dummy content")
+
+ hub_destination = os.path.join(self.get_temp_dir(), "hub")
+ export_albert_tfhub.export_albert_tfhub(
+ albert_config,
+ model_checkpoint_path,
+ hub_destination,
+ sp_model_file=sp_model_file)
+
+ # Restores a hub KerasLayer.
+ hub_layer = hub.KerasLayer(hub_destination, trainable=True)
+
+ if hasattr(hub_layer, "resolved_object"):
+ with tf.io.gfile.GFile(
+ hub_layer.resolved_object.sp_model_file.asset_path.numpy()) as f:
+ self.assertEqual("dummy content", f.read())
+ # Checks the hub KerasLayer.
+ for source_weight, hub_weight in zip(bert_model.trainable_weights,
+ hub_layer.trainable_weights):
+ self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
+
+ dummy_ids = np.zeros((2, 10), dtype=np.int32)
+ hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
+ source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
+
+ # The outputs of hub module are "pooled_output" and "sequence_output",
+ # while the outputs of encoder is in reversed order, i.e.,
+ # "sequence_output" and "pooled_output".
+ encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
+ self.assertEqual(hub_outputs[0].shape, (2, 16))
+ self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
+ for source_output, hub_output, encoder_output in zip(
+ source_outputs, hub_outputs, encoder_outputs):
+ self.assertAllClose(source_output.numpy(), hub_output.numpy())
+ self.assertAllClose(source_output.numpy(), encoder_output.numpy())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/albert/run_classifier.py b/models/official/nlp/albert/run_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe72ff880f61c99e304bf089ef4ed0d75bfc349b
--- /dev/null
+++ b/models/official/nlp/albert/run_classifier.py
@@ -0,0 +1,67 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ALBERT classification finetuning runner in tf2.x."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from absl import app
+from absl import flags
+import tensorflow as tf
+
+from official.nlp.albert import configs as albert_configs
+from official.nlp.bert import run_classifier as run_classifier_bert
+from official.utils.misc import distribution_utils
+
+FLAGS = flags.FLAGS
+
+
+def main(_):
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
+ input_meta_data = json.loads(reader.read().decode('utf-8'))
+
+ if not FLAGS.model_dir:
+ FLAGS.model_dir = '/tmp/bert20/'
+
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ tpu_address=FLAGS.tpu)
+ max_seq_length = input_meta_data['max_seq_length']
+ train_input_fn = run_classifier_bert.get_dataset_fn(
+ FLAGS.train_data_path,
+ max_seq_length,
+ FLAGS.train_batch_size,
+ is_training=True)
+ eval_input_fn = run_classifier_bert.get_dataset_fn(
+ FLAGS.eval_data_path,
+ max_seq_length,
+ FLAGS.eval_batch_size,
+ is_training=False)
+
+ albert_config = albert_configs.AlbertConfig.from_json_file(
+ FLAGS.bert_config_file)
+ run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
+ train_input_fn, eval_input_fn)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('bert_config_file')
+ flags.mark_flag_as_required('input_meta_data_path')
+ flags.mark_flag_as_required('model_dir')
+ app.run(main)
diff --git a/models/official/nlp/albert/run_squad.py b/models/official/nlp/albert/run_squad.py
new file mode 100644
index 0000000000000000000000000000000000000000..28a171a3f4a377ab174418c3b466b22680ad5734
--- /dev/null
+++ b/models/official/nlp/albert/run_squad.py
@@ -0,0 +1,137 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import time
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+
+from official.nlp.albert import configs as albert_configs
+from official.nlp.bert import run_squad_helper
+from official.nlp.bert import tokenization
+from official.nlp.data import squad_lib_sp
+from official.utils.misc import distribution_utils
+
+flags.DEFINE_string(
+ 'sp_model_file', None,
+ 'The path to the sentence piece model. Used by sentence piece tokenizer '
+ 'employed by ALBERT.')
+
+# More flags can be found in run_squad_helper.
+run_squad_helper.define_common_squad_flags()
+
+FLAGS = flags.FLAGS
+
+
+def train_squad(strategy,
+ input_meta_data,
+ custom_callbacks=None,
+ run_eagerly=False):
+ """Runs bert squad training."""
+ bert_config = albert_configs.AlbertConfig.from_json_file(
+ FLAGS.bert_config_file)
+ run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
+ custom_callbacks, run_eagerly)
+
+
+def predict_squad(strategy, input_meta_data):
+ """Makes predictions for the squad dataset."""
+ bert_config = albert_configs.AlbertConfig.from_json_file(
+ FLAGS.bert_config_file)
+ tokenizer = tokenization.FullSentencePieceTokenizer(
+ sp_model_file=FLAGS.sp_model_file)
+
+ run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
+ bert_config, squad_lib_sp)
+
+
+def eval_squad(strategy, input_meta_data):
+ """Evaluate on the squad dataset."""
+ bert_config = albert_configs.AlbertConfig.from_json_file(
+ FLAGS.bert_config_file)
+ tokenizer = tokenization.FullSentencePieceTokenizer(
+ sp_model_file=FLAGS.sp_model_file)
+
+ eval_metrics = run_squad_helper.eval_squad(
+ strategy, input_meta_data, tokenizer, bert_config, squad_lib_sp)
+ return eval_metrics
+
+
+def export_squad(model_export_path, input_meta_data):
+ """Exports a trained model as a `SavedModel` for inference.
+
+ Args:
+ model_export_path: a string specifying the path to the SavedModel directory.
+ input_meta_data: dictionary containing meta data about input and model.
+
+ Raises:
+ Export path is not specified, got an empty string or None.
+ """
+ bert_config = albert_configs.AlbertConfig.from_json_file(
+ FLAGS.bert_config_file)
+ run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
+
+
+def main(_):
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
+ input_meta_data = json.loads(reader.read().decode('utf-8'))
+
+ if FLAGS.mode == 'export_only':
+ export_squad(FLAGS.model_export_path, input_meta_data)
+ return
+
+ # Configures cluster spec for multi-worker distribution strategy.
+ if FLAGS.num_gpus > 0:
+ _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
+ FLAGS.task_index)
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ all_reduce_alg=FLAGS.all_reduce_alg,
+ tpu_address=FLAGS.tpu)
+
+ if 'train' in FLAGS.mode:
+ train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
+ if 'predict' in FLAGS.mode:
+ predict_squad(strategy, input_meta_data)
+ if 'eval' in FLAGS.mode:
+ eval_metrics = eval_squad(strategy, input_meta_data)
+ f1_score = eval_metrics['final_f1']
+ logging.info('SQuAD eval F1-score: %f', f1_score)
+ summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
+ summary_writer = tf.summary.create_file_writer(summary_dir)
+ with summary_writer.as_default():
+ # TODO(lehou): write to the correct step number.
+ tf.summary.scalar('F1-score', f1_score, step=0)
+ summary_writer.flush()
+ # Also write eval_metrics to json file.
+ squad_lib_sp.write_to_json_files(
+ eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
+ time.sleep(60)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('bert_config_file')
+ flags.mark_flag_as_required('model_dir')
+ app.run(main)
diff --git a/models/official/nlp/albert/tf2_albert_encoder_checkpoint_converter.py b/models/official/nlp/albert/tf2_albert_encoder_checkpoint_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..402bc1445bed575362598d09212d14d03b629179
--- /dev/null
+++ b/models/official/nlp/albert/tf2_albert_encoder_checkpoint_converter.py
@@ -0,0 +1,132 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A converter from a tf1 ALBERT encoder checkpoint to a tf2 encoder checkpoint.
+
+The conversion will yield an object-oriented checkpoint that can be used
+to restore a AlbertTransformerEncoder object.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+
+import tensorflow as tf
+from official.modeling import activations
+from official.nlp.albert import configs
+from official.nlp.bert import tf1_checkpoint_converter_lib
+from official.nlp.modeling import networks
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("albert_config_file", None,
+ "Albert configuration file to define core bert layers.")
+flags.DEFINE_string(
+ "checkpoint_to_convert", None,
+ "Initial checkpoint from a pretrained BERT model core (that is, only the "
+ "BertModel, with no task heads.)")
+flags.DEFINE_string("converted_checkpoint_path", None,
+ "Name for the created object-based V2 checkpoint.")
+
+
+ALBERT_NAME_REPLACEMENTS = (
+ ("bert/encoder/", ""),
+ ("bert/", ""),
+ ("embeddings/word_embeddings", "word_embeddings/embeddings"),
+ ("embeddings/position_embeddings", "position_embedding/embeddings"),
+ ("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
+ ("embeddings/LayerNorm", "embeddings/layer_norm"),
+ ("embedding_hidden_mapping_in", "embedding_projection"),
+ ("group_0/inner_group_0/", ""),
+ ("attention_1/self", "self_attention"),
+ ("attention_1/output/dense", "self_attention/attention_output"),
+ ("LayerNorm/", "self_attention_layer_norm/"),
+ ("ffn_1/intermediate/dense", "intermediate"),
+ ("ffn_1/intermediate/output/dense", "output"),
+ ("LayerNorm_1/", "output_layer_norm/"),
+ ("pooler/dense", "pooler_transform"),
+ ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
+ ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
+ ("cls/seq_relationship/output_weights",
+ "predictions/transform/logits/kernel"),
+)
+
+
+def _create_albert_model(cfg):
+ """Creates a BERT keras core model from BERT configuration.
+
+ Args:
+ cfg: A `BertConfig` to create the core model.
+
+ Returns:
+ A keras model.
+ """
+ albert_encoder = networks.AlbertTransformerEncoder(
+ vocab_size=cfg.vocab_size,
+ hidden_size=cfg.hidden_size,
+ embedding_width=cfg.embedding_size,
+ num_layers=cfg.num_hidden_layers,
+ num_attention_heads=cfg.num_attention_heads,
+ intermediate_size=cfg.intermediate_size,
+ activation=activations.gelu,
+ dropout_rate=cfg.hidden_dropout_prob,
+ attention_dropout_rate=cfg.attention_probs_dropout_prob,
+ sequence_length=cfg.max_position_embeddings,
+ type_vocab_size=cfg.type_vocab_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=cfg.initializer_range))
+ return albert_encoder
+
+
+def convert_checkpoint(bert_config, output_path, v1_checkpoint):
+ """Converts a V1 checkpoint into an OO V2 checkpoint."""
+ output_dir, _ = os.path.split(output_path)
+
+ # Create a temporary V1 name-converted checkpoint in the output directory.
+ temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
+ temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
+ tf1_checkpoint_converter_lib.convert(
+ checkpoint_from_path=v1_checkpoint,
+ checkpoint_to_path=temporary_checkpoint,
+ num_heads=bert_config.num_attention_heads,
+ name_replacements=ALBERT_NAME_REPLACEMENTS,
+ permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
+ exclude_patterns=["adam", "Adam"])
+
+ # Create a V2 checkpoint from the temporary checkpoint.
+ model = _create_albert_model(bert_config)
+ tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
+ output_path)
+
+ # Clean up the temporary checkpoint, if it exists.
+ try:
+ tf.io.gfile.rmtree(temporary_checkpoint_dir)
+ except tf.errors.OpError:
+ # If it doesn't exist, we don't need to clean it up; continue.
+ pass
+
+
+def main(_):
+ output_path = FLAGS.converted_checkpoint_path
+ v1_checkpoint = FLAGS.checkpoint_to_convert
+ albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
+ convert_checkpoint(albert_config, output_path, v1_checkpoint)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/bert/README.md b/models/official/nlp/bert/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c26a87df520b9d9bb9cccefd515abc0bf4a399c7
--- /dev/null
+++ b/models/official/nlp/bert/README.md
@@ -0,0 +1,368 @@
+# BERT (Bidirectional Encoder Representations from Transformers)
+
+The academic paper which describes BERT in detail and provides full results on a
+number of tasks can be found here: https://arxiv.org/abs/1810.04805.
+
+This repository contains TensorFlow 2.x implementation for BERT.
+
+## Contents
+ * [Contents](#contents)
+ * [Pre-trained Models](#pre-trained-models)
+ * [Restoring from Checkpoints](#restoring-from-checkpoints)
+ * [Set Up](#set-up)
+ * [Process Datasets](#process-datasets)
+ * [Fine-tuning with BERT](#fine-tuning-with-bert)
+ * [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
+ * [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
+ * [SQuAD 1.1](#squad-1.1)
+
+
+## Pre-trained Models
+
+We released both checkpoints and tf.hub modules as the pretrained models for
+fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
+released in TF 1.x official BERT repository
+[google-research/bert](https://github.com/google-research/bert)
+in order to keep consistent with BERT paper.
+
+
+### Access to Pretrained Checkpoints
+
+Pretrained checkpoints can be found in the following links:
+
+**Note: We have switched BERT implementation
+to use Keras functional-style networks in [nlp/modeling](../modeling).
+The new checkpoints are:**
+
+* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
+ 12-layer, 768-hidden, 12-heads, 110M parameters
+* **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
+ 12-layer, 768-hidden, 12-heads , 110M parameters
+* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+
+We recommend to host checkpoints on Google Cloud storage buckets when you use
+Cloud GPU/TPU.
+
+### Restoring from Checkpoints
+
+`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
+weights from provided pre-trained checkpoints, you can use the following code:
+
+```python
+init_checkpoint='the pretrained model checkpoint path.'
+model=tf.keras.Model() # Bert pre-trained model as feature extractor.
+checkpoint = tf.train.Checkpoint(model=model)
+checkpoint.restore(init_checkpoint)
+```
+
+Checkpoints featuring native serialized Keras models
+(i.e. model.load()/load_weights()) will be available soon.
+
+### Access to Pretrained hub modules.
+
+Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
+following links:
+
+* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/1)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/1)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1)**:
+ 12-layer, 768-hidden, 12-heads, 110M parameters
+* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/1)**:
+ 12-layer, 768-hidden, 12-heads , 110M parameters
+* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/1)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/1)**:
+ 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
+* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/1)**:
+ Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
+ 110M parameters
+
+## Set Up
+
+```shell
+export PYTHONPATH="$PYTHONPATH:/path/to/models"
+```
+
+Install `tf-nightly` to get latest updates:
+
+```shell
+pip install tf-nightly-gpu
+```
+
+With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
+TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
+
+```shell
+ctpu up -name --tf-version=”nightly”
+```
+
+Second, you need to install TF 2 `tf-nightly` on your VM:
+
+```shell
+pip install tf-nightly
+```
+
+## Process Datasets
+
+### Pre-training
+
+There is no change to generate pre-training data. Please use the script
+[`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
+which is essentially branched from [BERT research repo](https://github.com/google-research/bert)
+to get processed pre-training data and it adapts to TF2 symbols and python3
+compatibility.
+
+
+### Fine-tuning
+
+To prepare the fine-tuning data for final model training, use the
+[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
+Resulting datasets in `tf_record` format and training meta data should be later
+passed to training or evaluation scripts. The task-specific arguments are
+described in following sections:
+
+* GLUE
+
+Users can download the
+[GLUE data](https://gluebenchmark.com/tasks) by running
+[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
+and unpack it to some directory `$GLUE_DIR`.
+Also, users can download [Pretrained Checkpoint](#access-to-pretrained-checkpoints) and locate on some directory `$BERT_DIR` instead of using checkpoints on Google Cloud Storage.
+
+```shell
+export GLUE_DIR=~/glue
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+
+export TASK_NAME=MNLI
+export OUTPUT_DIR=gs://some_bucket/datasets
+python ../data/create_finetuning_data.py \
+ --input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
+ --vocab_file=${BERT_DIR}/vocab.txt \
+ --train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
+ --eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
+ --meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
+ --fine_tuning_task_type=classification --max_seq_length=128 \
+ --classification_task_name=${TASK_NAME}
+```
+
+* SQUAD
+
+The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
+detailed information about the SQuAD datasets and evaluation.
+
+The necessary files can be found here:
+
+* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
+* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
+* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
+* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
+* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
+* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
+
+```shell
+export SQUAD_DIR=~/squad
+export SQUAD_VERSION=v1.1
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export OUTPUT_DIR=gs://some_bucket/datasets
+
+python ../data/create_finetuning_data.py \
+ --squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
+ --vocab_file=${BERT_DIR}/vocab.txt \
+ --train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --fine_tuning_task_type=squad --max_seq_length=384
+```
+
+## Fine-tuning with BERT
+
+### Cloud GPUs and TPUs
+
+* Cloud Storage
+
+The unzipped pre-trained model files can also be found in the Google Cloud
+Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export MODEL_DIR=gs://some_bucket/my_output_dir
+```
+
+Currently, users are able to access to `tf-nightly` TPUs and the following TPU
+script should run with `tf-nightly`.
+
+* GPU -> TPU
+
+Just add the following flags to `run_classifier.py` or `run_squad.py`:
+
+```shell
+ --distribution_strategy=tpu
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+### Sentence and Sentence-pair Classification Tasks
+
+This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
+Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
+few minutes on most GPUs.
+
+We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
+workflow.
+For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
+(uncased_L-12_H-768_A-12).
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export GLUE_DIR=gs://some_bucket/datasets
+export TASK=MRPC
+
+python run_classifier.py \
+ --mode='train_and_eval' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
+ --train_batch_size=4 \
+ --eval_batch_size=4 \
+ --steps_per_loop=1 \
+ --learning_rate=2e-5 \
+ --num_train_epochs=3 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=mirrored
+```
+
+Alternatively, instead of specifying `init_checkpoint`, you can specify
+`hub_module_url` to employ a pretraind BERT hub module, e.g.,
+` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
+
+After training a model, to get predictions from the classifier, you can set the
+`--mode=predict` and offer the test set tfrecords to `--eval_data_path`.
+Output will be created in file called test_results.tsv in the output folder.
+Each line will contain output for each sample, columns are the class
+probabilities.
+
+```shell
+python run_classifier.py \
+ --mode='predict' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --eval_batch_size=4 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=mirrored
+```
+
+To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
+information and use remote storage for model checkpoints.
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export TPU_IP_ADDRESS='???'
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export GLUE_DIR=gs://some_bucket/datasets
+export TASK=MRPC
+
+python run_classifier.py \
+ --mode='train_and_eval' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
+ --train_batch_size=32 \
+ --eval_batch_size=32 \
+ --steps_per_loop=1000 \
+ --learning_rate=2e-5 \
+ --num_train_epochs=3 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=tpu \
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
+training steps inside a `tf.function` can significantly increase TPU utilization
+and callbacks will not be called inside the loop.
+
+### SQuAD 1.1
+
+The Stanford Question Answering Dataset (SQuAD) is a popular question answering
+benchmark dataset. See more in [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
+
+We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
+workflow.
+For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
+(uncased_L-12_H-768_A-12).
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export SQUAD_DIR=gs://some_bucket/datasets
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export SQUAD_VERSION=v1.1
+
+python run_squad.py \
+ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
+ --vocab_file=${BERT_DIR}/vocab.txt \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
+ --train_batch_size=4 \
+ --predict_batch_size=4 \
+ --learning_rate=8e-5 \
+ --num_train_epochs=2 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=mirrored
+```
+
+Similarily, you can replace `init_checkpoint` FLAG with `hub_module_url` to
+specify a hub module path.
+
+`run_squad.py` writes the prediction for `--predict_file` by default. If you set
+the `--model=predict` and offer the SQuAD test data, the scripts will generate
+the prediction json file.
+
+To use TPU, you need switch distribution strategy type to `tpu` with TPU
+information.
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export TPU_IP_ADDRESS='???'
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export SQUAD_DIR=gs://some_bucket/datasets
+export SQUAD_VERSION=v1.1
+
+python run_squad.py \
+ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
+ --vocab_file=${BERT_DIR}/vocab.txt \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
+ --train_batch_size=32 \
+ --learning_rate=8e-5 \
+ --num_train_epochs=2 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=tpu \
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+The dev set predictions will be saved into a file called predictions.json in the
+model_dir:
+
+```shell
+python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
+```
+
+
diff --git a/models/official/nlp/bert/__init__.py b/models/official/nlp/bert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/official/nlp/bert/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/official/nlp/bert/__pycache__/__init__.cpython-38.pyc b/models/official/nlp/bert/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6db7d3b5135d245ab0dcc000b240d54ffdf3a938
Binary files /dev/null and b/models/official/nlp/bert/__pycache__/__init__.cpython-38.pyc differ
diff --git a/models/official/nlp/bert/__pycache__/__init__.cpython-39.pyc b/models/official/nlp/bert/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8be831873bc9909e371fdc816ebe9760a643e71e
Binary files /dev/null and b/models/official/nlp/bert/__pycache__/__init__.cpython-39.pyc differ
diff --git a/models/official/nlp/bert/__pycache__/tokenization.cpython-38.pyc b/models/official/nlp/bert/__pycache__/tokenization.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c334a3013725666cc3cc808cb6ccbbc4003582a
Binary files /dev/null and b/models/official/nlp/bert/__pycache__/tokenization.cpython-38.pyc differ
diff --git a/models/official/nlp/bert/__pycache__/tokenization.cpython-39.pyc b/models/official/nlp/bert/__pycache__/tokenization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d4a853456ee0666c019aeb5fdd3ab2b21f929e5
Binary files /dev/null and b/models/official/nlp/bert/__pycache__/tokenization.cpython-39.pyc differ
diff --git a/models/official/nlp/bert/bert_cloud_tpu.md b/models/official/nlp/bert/bert_cloud_tpu.md
new file mode 100644
index 0000000000000000000000000000000000000000..e5e6758a8bdc216744b7770d7eb8b5ff47408493
--- /dev/null
+++ b/models/official/nlp/bert/bert_cloud_tpu.md
@@ -0,0 +1,110 @@
+# BERT FineTuning with Cloud TPU: Sentence and Sentence-Pair Classification Tasks (TF 2.1)
+This tutorial shows you how to train the Bidirectional Encoder Representations from Transformers (BERT) model on Cloud TPU.
+
+
+## Set up Cloud Storage and Compute Engine VM
+1. [Open a cloud shell window](https://console.cloud.google.com/?cloudshell=true&_ga=2.11844148.-1612541229.1552429951)
+2. Create a variable for the project's name:
+```
+export PROJECT_NAME=your-project_name
+```
+3. Configure `gcloud` command-line tool to use the project where you want to create Cloud TPU.
+```
+gcloud config set project ${PROJECT_NAME}
+```
+4. Create a Cloud Storage bucket using the following command:
+```
+gsutil mb -p ${PROJECT_NAME} -c standard -l europe-west4 -b on gs://your-bucket-name
+```
+This Cloud Storage bucket stores the data you use to train your model and the training results.
+5. Launch a Compute Engine VM and Cloud TPU using the ctpu up command.
+```
+ctpu up --tpu-size=v3-8 \
+ --machine-type=n1-standard-8 \
+ --zone=europe-west4-a \
+ --tf-version=2.1 [optional flags: --project, --name]
+```
+6. The configuration you specified appears. Enter y to approve or n to cancel.
+7. When the ctpu up command has finished executing, verify that your shell prompt has changed from username@project to username@tpuname. This change shows that you are now logged into your Compute Engine VM.
+```
+gcloud compute ssh vm-name --zone=europe-west4-a
+(vm)$ export TPU_NAME=vm-name
+```
+As you continue these instructions, run each command that begins with `(vm)$` in your VM session window.
+
+## Prepare the Dataset
+1. From your Compute Engine virtual machine (VM), install requirements.txt.
+```
+(vm)$ cd /usr/share/models
+(vm)$ sudo pip3 install -r official/requirements.txt
+```
+2. Optional: download download_glue_data.py
+
+This tutorial uses the General Language Understanding Evaluation (GLUE) benchmark to evaluate and analyze the performance of the model. The GLUE data is provided for this tutorial at gs://cloud-tpu-checkpoints/bert/classification.
+
+## Define parameter values
+Next, define several parameter values that are required when you train and evaluate your model:
+
+```
+(vm)$ export PYTHONPATH="$PYTHONPATH:/usr/share/tpu/models"
+(vm)$ export STORAGE_BUCKET=gs://your-bucket-name
+(vm)$ export BERT_BASE_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+(vm)$ export MODEL_DIR=${STORAGE_BUCKET}/bert-output
+(vm)$ export GLUE_DIR=gs://cloud-tpu-checkpoints/bert/classification
+(vm)$ export TASK=mnli
+```
+
+## Train the model
+From your Compute Engine VM, run the following command.
+
+```
+(vm)$ python3 official/nlp/bert/run_classifier.py \
+ --mode='train_and_eval' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=$BERT_BASE_DIR/bert_config.json \
+ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
+ --train_batch_size=32 \
+ --eval_batch_size=32 \
+ --learning_rate=2e-5 \
+ --num_train_epochs=3 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=tpu \
+ --tpu=${TPU_NAME}
+```
+
+## Verify your results
+The training takes approximately 1 hour on a v3-8 TPU. When script completes, you should see results similar to the following:
+```
+Training Summary:
+{'train_loss': 0.28142181038856506,
+'last_train_metrics': 0.9467429518699646,
+'eval_metrics': 0.8599063158035278,
+'total_training_steps': 36813}
+```
+
+## Clean up
+To avoid incurring charges to your GCP account for the resources used in this topic:
+1. Disconnect from the Compute Engine VM:
+```
+(vm)$ exit
+```
+2. In your Cloud Shell, run ctpu delete with the --zone flag you used when you set up the Cloud TPU to delete your Compute Engine VM and your Cloud TPU:
+```
+$ ctpu delete --zone=your-zone
+```
+3. Run ctpu status specifying your zone to make sure you have no instances allocated to avoid unnecessary charges for TPU usage. The deletion might take several minutes. A response like the one below indicates there are no more allocated instances:
+```
+$ ctpu status --zone=your-zone
+```
+4. Run gsutil as shown, replacing your-bucket with the name of the Cloud Storage bucket you created for this tutorial:
+```
+$ gsutil rm -r gs://your-bucket
+```
+
+
+
+
+
+
diff --git a/models/official/nlp/bert/bert_models.py b/models/official/nlp/bert/bert_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d16150d0c353e6626b911b32c9961c4712c8aed
--- /dev/null
+++ b/models/official/nlp/bert/bert_models.py
@@ -0,0 +1,371 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BERT models that are compatible with TF 2.0."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gin
+import tensorflow as tf
+import tensorflow_hub as hub
+
+from official.modeling import tf_utils
+from official.nlp.albert import configs as albert_configs
+from official.nlp.bert import configs
+from official.nlp.modeling import models
+from official.nlp.modeling import networks
+
+
+class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
+ """Returns layer that computes custom loss and metrics for pretraining."""
+
+ def __init__(self, vocab_size, **kwargs):
+ super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
+ self._vocab_size = vocab_size
+ self.config = {
+ 'vocab_size': vocab_size,
+ }
+
+ def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
+ lm_example_loss, sentence_output, sentence_labels,
+ next_sentence_loss):
+ """Adds metrics."""
+ masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
+ lm_labels, lm_output)
+ numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
+ denominator = tf.reduce_sum(lm_label_weights) + 1e-5
+ masked_lm_accuracy = numerator / denominator
+ self.add_metric(
+ masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
+
+ self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
+
+ if sentence_labels is not None:
+ next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
+ sentence_labels, sentence_output)
+ self.add_metric(
+ next_sentence_accuracy,
+ name='next_sentence_accuracy',
+ aggregation='mean')
+
+ if next_sentence_loss is not None:
+ self.add_metric(
+ next_sentence_loss, name='next_sentence_loss', aggregation='mean')
+
+ def call(self,
+ lm_output_logits,
+ sentence_output_logits,
+ lm_label_ids,
+ lm_label_weights,
+ sentence_labels=None):
+ """Implements call() for the layer."""
+ lm_label_weights = tf.cast(lm_label_weights, tf.float32)
+ lm_output_logits = tf.cast(lm_output_logits, tf.float32)
+
+ lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
+ lm_label_ids, lm_output_logits, from_logits=True)
+ lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
+ lm_denominator_loss = tf.reduce_sum(lm_label_weights)
+ mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
+ lm_denominator_loss)
+
+ if sentence_labels is not None:
+ sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
+ sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
+ sentence_labels, sentence_output_logits, from_logits=True)
+ sentence_loss = tf.reduce_mean(sentence_loss)
+ loss = mask_label_loss + sentence_loss
+ else:
+ sentence_loss = None
+ loss = mask_label_loss
+
+ batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
+ # TODO(hongkuny): Avoids the hack and switches add_loss.
+ final_loss = tf.fill(batch_shape, loss)
+
+ self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
+ mask_label_loss, sentence_output_logits, sentence_labels,
+ sentence_loss)
+ return final_loss
+
+
+@gin.configurable
+def get_transformer_encoder(bert_config,
+ sequence_length,
+ transformer_encoder_cls=None,
+ output_range=None):
+ """Gets a 'TransformerEncoder' object.
+
+ Args:
+ bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
+ sequence_length: Maximum sequence length of the training data.
+ transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
+ default BERT encoder implementation.
+ output_range: the sequence output range, [0, output_range). Default setting
+ is to return the entire sequence output.
+
+ Returns:
+ A networks.TransformerEncoder object.
+ """
+ if transformer_encoder_cls is not None:
+ # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
+ embedding_cfg = dict(
+ vocab_size=bert_config.vocab_size,
+ type_vocab_size=bert_config.type_vocab_size,
+ hidden_size=bert_config.hidden_size,
+ seq_length=sequence_length,
+ max_seq_length=bert_config.max_position_embeddings,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range),
+ dropout_rate=bert_config.hidden_dropout_prob,
+ )
+ hidden_cfg = dict(
+ num_attention_heads=bert_config.num_attention_heads,
+ intermediate_size=bert_config.intermediate_size,
+ intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
+ dropout_rate=bert_config.hidden_dropout_prob,
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
+ kernel_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range),
+ )
+ kwargs = dict(
+ embedding_cfg=embedding_cfg,
+ hidden_cfg=hidden_cfg,
+ num_hidden_instances=bert_config.num_hidden_layers,
+ pooled_output_dim=bert_config.hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range))
+
+ # Relies on gin configuration to define the Transformer encoder arguments.
+ return transformer_encoder_cls(**kwargs)
+
+ kwargs = dict(
+ vocab_size=bert_config.vocab_size,
+ hidden_size=bert_config.hidden_size,
+ num_layers=bert_config.num_hidden_layers,
+ num_attention_heads=bert_config.num_attention_heads,
+ intermediate_size=bert_config.intermediate_size,
+ activation=tf_utils.get_activation(bert_config.hidden_act),
+ dropout_rate=bert_config.hidden_dropout_prob,
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
+ sequence_length=sequence_length,
+ max_sequence_length=bert_config.max_position_embeddings,
+ type_vocab_size=bert_config.type_vocab_size,
+ embedding_width=bert_config.embedding_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range))
+ if isinstance(bert_config, albert_configs.AlbertConfig):
+ return networks.AlbertTransformerEncoder(**kwargs)
+ else:
+ assert isinstance(bert_config, configs.BertConfig)
+ kwargs['output_range'] = output_range
+ return networks.TransformerEncoder(**kwargs)
+
+
+def pretrain_model(bert_config,
+ seq_length,
+ max_predictions_per_seq,
+ initializer=None,
+ use_next_sentence_label=True,
+ return_core_pretrainer_model=False):
+ """Returns model to be used for pre-training.
+
+ Args:
+ bert_config: Configuration that defines the core BERT model.
+ seq_length: Maximum sequence length of the training data.
+ max_predictions_per_seq: Maximum number of tokens in sequence to mask out
+ and use for pretraining.
+ initializer: Initializer for weights in BertPretrainer.
+ use_next_sentence_label: Whether to use the next sentence label.
+ return_core_pretrainer_model: Whether to also return the `BertPretrainer`
+ object.
+
+ Returns:
+ A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
+ save weights after pretraining, and (3) optional core `BertPretrainer`
+ object if argument `return_core_pretrainer_model` is True.
+ """
+ input_word_ids = tf.keras.layers.Input(
+ shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
+ input_mask = tf.keras.layers.Input(
+ shape=(seq_length,), name='input_mask', dtype=tf.int32)
+ input_type_ids = tf.keras.layers.Input(
+ shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
+ masked_lm_positions = tf.keras.layers.Input(
+ shape=(max_predictions_per_seq,),
+ name='masked_lm_positions',
+ dtype=tf.int32)
+ masked_lm_ids = tf.keras.layers.Input(
+ shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
+ masked_lm_weights = tf.keras.layers.Input(
+ shape=(max_predictions_per_seq,),
+ name='masked_lm_weights',
+ dtype=tf.int32)
+
+ if use_next_sentence_label:
+ next_sentence_labels = tf.keras.layers.Input(
+ shape=(1,), name='next_sentence_labels', dtype=tf.int32)
+ else:
+ next_sentence_labels = None
+
+ transformer_encoder = get_transformer_encoder(bert_config, seq_length)
+ if initializer is None:
+ initializer = tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range)
+ pretrainer_model = models.BertPretrainer(
+ network=transformer_encoder,
+ embedding_table=transformer_encoder.get_embedding_table(),
+ num_classes=2, # The next sentence prediction label has two classes.
+ activation=tf_utils.get_activation(bert_config.hidden_act),
+ num_token_predictions=max_predictions_per_seq,
+ initializer=initializer,
+ output='logits')
+
+ outputs = pretrainer_model(
+ [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
+ lm_output = outputs['masked_lm']
+ sentence_output = outputs['classification']
+ pretrain_loss_layer = BertPretrainLossAndMetricLayer(
+ vocab_size=bert_config.vocab_size)
+ output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
+ masked_lm_weights, next_sentence_labels)
+ inputs = {
+ 'input_word_ids': input_word_ids,
+ 'input_mask': input_mask,
+ 'input_type_ids': input_type_ids,
+ 'masked_lm_positions': masked_lm_positions,
+ 'masked_lm_ids': masked_lm_ids,
+ 'masked_lm_weights': masked_lm_weights,
+ }
+ if use_next_sentence_label:
+ inputs['next_sentence_labels'] = next_sentence_labels
+
+ keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
+ if return_core_pretrainer_model:
+ return keras_model, transformer_encoder, pretrainer_model
+ else:
+ return keras_model, transformer_encoder
+
+
+def squad_model(bert_config,
+ max_seq_length,
+ initializer=None,
+ hub_module_url=None,
+ hub_module_trainable=True):
+ """Returns BERT Squad model along with core BERT model to import weights.
+
+ Args:
+ bert_config: BertConfig, the config defines the core Bert model.
+ max_seq_length: integer, the maximum input sequence length.
+ initializer: Initializer for the final dense layer in the span labeler.
+ Defaulted to TruncatedNormal initializer.
+ hub_module_url: TF-Hub path/url to Bert module.
+ hub_module_trainable: True to finetune layers in the hub module.
+
+ Returns:
+ A tuple of (1) keras model that outputs start logits and end logits and
+ (2) the core BERT transformer encoder.
+ """
+ if initializer is None:
+ initializer = tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range)
+ if not hub_module_url:
+ bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
+ return models.BertSpanLabeler(
+ network=bert_encoder, initializer=initializer), bert_encoder
+
+ input_word_ids = tf.keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
+ input_mask = tf.keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
+ input_type_ids = tf.keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
+ core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
+ pooled_output, sequence_output = core_model(
+ [input_word_ids, input_mask, input_type_ids])
+ bert_encoder = tf.keras.Model(
+ inputs={
+ 'input_word_ids': input_word_ids,
+ 'input_mask': input_mask,
+ 'input_type_ids': input_type_ids,
+ },
+ outputs=[sequence_output, pooled_output],
+ name='core_model')
+ return models.BertSpanLabeler(
+ network=bert_encoder, initializer=initializer), bert_encoder
+
+
+def classifier_model(bert_config,
+ num_labels,
+ max_seq_length=None,
+ final_layer_initializer=None,
+ hub_module_url=None,
+ hub_module_trainable=True):
+ """BERT classifier model in functional API style.
+
+ Construct a Keras model for predicting `num_labels` outputs from an input with
+ maximum sequence length `max_seq_length`.
+
+ Args:
+ bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
+ ALBERT model.
+ num_labels: integer, the number of classes.
+ max_seq_length: integer, the maximum input sequence length.
+ final_layer_initializer: Initializer for final dense layer. Defaulted
+ TruncatedNormal initializer.
+ hub_module_url: TF-Hub path/url to Bert module.
+ hub_module_trainable: True to finetune layers in the hub module.
+
+ Returns:
+ Combined prediction model (words, mask, type) -> (one-hot labels)
+ BERT sub-model (words, mask, type) -> (bert_outputs)
+ """
+ if final_layer_initializer is not None:
+ initializer = final_layer_initializer
+ else:
+ initializer = tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range)
+
+ if not hub_module_url:
+ bert_encoder = get_transformer_encoder(
+ bert_config, max_seq_length, output_range=1)
+ return models.BertClassifier(
+ bert_encoder,
+ num_classes=num_labels,
+ dropout_rate=bert_config.hidden_dropout_prob,
+ initializer=initializer), bert_encoder
+
+ input_word_ids = tf.keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
+ input_mask = tf.keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
+ input_type_ids = tf.keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
+ bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
+ pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
+ output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
+ pooled_output)
+
+ output = tf.keras.layers.Dense(
+ num_labels, kernel_initializer=initializer, name='output')(
+ output)
+ return tf.keras.Model(
+ inputs={
+ 'input_word_ids': input_word_ids,
+ 'input_mask': input_mask,
+ 'input_type_ids': input_type_ids
+ },
+ outputs=output), bert_model
diff --git a/models/official/nlp/bert/bert_models_test.py b/models/official/nlp/bert/bert_models_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..93763b45bfc53c5d32de2df7f7f0f72894e9556f
--- /dev/null
+++ b/models/official/nlp/bert/bert_models_test.py
@@ -0,0 +1,114 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.nlp.bert import bert_models
+from official.nlp.bert import configs as bert_configs
+from official.nlp.modeling import networks
+
+
+class BertModelsTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(BertModelsTest, self).setUp()
+ self._bert_test_config = bert_configs.BertConfig(
+ attention_probs_dropout_prob=0.0,
+ hidden_act='gelu',
+ hidden_dropout_prob=0.0,
+ hidden_size=16,
+ initializer_range=0.02,
+ intermediate_size=32,
+ max_position_embeddings=128,
+ num_attention_heads=2,
+ num_hidden_layers=2,
+ type_vocab_size=2,
+ vocab_size=30522)
+
+ def test_pretrain_model(self):
+ model, encoder = bert_models.pretrain_model(
+ self._bert_test_config,
+ seq_length=5,
+ max_predictions_per_seq=2,
+ initializer=None,
+ use_next_sentence_label=True)
+ self.assertIsInstance(model, tf.keras.Model)
+ self.assertIsInstance(encoder, networks.TransformerEncoder)
+
+ # model has one scalar output: loss value.
+ self.assertEqual(model.output.shape.as_list(), [None,])
+
+ # Expect two output from encoder: sequence and classification output.
+ self.assertIsInstance(encoder.output, list)
+ self.assertLen(encoder.output, 2)
+ # shape should be [batch size, seq_length, hidden_size]
+ self.assertEqual(encoder.output[0].shape.as_list(), [None, 5, 16])
+ # shape should be [batch size, hidden_size]
+ self.assertEqual(encoder.output[1].shape.as_list(), [None, 16])
+
+ def test_squad_model(self):
+ model, core_model = bert_models.squad_model(
+ self._bert_test_config,
+ max_seq_length=5,
+ initializer=None,
+ hub_module_url=None,
+ hub_module_trainable=None)
+ self.assertIsInstance(model, tf.keras.Model)
+ self.assertIsInstance(core_model, tf.keras.Model)
+
+ # Expect two output from model: start positions and end positions
+ self.assertIsInstance(model.output, list)
+ self.assertLen(model.output, 2)
+ # shape should be [batch size, seq_length]
+ self.assertEqual(model.output[0].shape.as_list(), [None, 5])
+ # shape should be [batch size, seq_length]
+ self.assertEqual(model.output[1].shape.as_list(), [None, 5])
+
+ # Expect two output from core_model: sequence and classification output.
+ self.assertIsInstance(core_model.output, list)
+ self.assertLen(core_model.output, 2)
+ # shape should be [batch size, seq_length, hidden_size]
+ self.assertEqual(core_model.output[0].shape.as_list(), [None, 5, 16])
+ # shape should be [batch size, hidden_size]
+ self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
+
+ def test_classifier_model(self):
+ model, core_model = bert_models.classifier_model(
+ self._bert_test_config,
+ num_labels=3,
+ max_seq_length=5,
+ final_layer_initializer=None,
+ hub_module_url=None,
+ hub_module_trainable=None)
+ self.assertIsInstance(model, tf.keras.Model)
+ self.assertIsInstance(core_model, tf.keras.Model)
+
+ # model has one classification output with num_labels=3.
+ self.assertEqual(model.output.shape.as_list(), [None, 3])
+
+ # Expect two output from core_model: sequence and classification output.
+ self.assertIsInstance(core_model.output, list)
+ self.assertLen(core_model.output, 2)
+ # shape should be [batch size, 1, hidden_size]
+ self.assertEqual(core_model.output[0].shape.as_list(), [None, 1, 16])
+ # shape should be [batch size, hidden_size]
+ self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/bert/common_flags.py b/models/official/nlp/bert/common_flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..06a376d63de5447ddd67810f2cf6be3399f2a958
--- /dev/null
+++ b/models/official/nlp/bert/common_flags.py
@@ -0,0 +1,117 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defining common flags used across all BERT models/applications."""
+
+from absl import flags
+import tensorflow as tf
+
+from official.utils import hyperparams_flags
+from official.utils.flags import core as flags_core
+
+
+def define_common_bert_flags():
+ """Define common flags for BERT tasks."""
+ flags_core.define_base(
+ data_dir=False,
+ model_dir=True,
+ clean=False,
+ train_epochs=False,
+ epochs_between_evals=False,
+ stop_threshold=False,
+ batch_size=False,
+ num_gpu=True,
+ export_dir=False,
+ distribution_strategy=True,
+ run_eagerly=True)
+ flags_core.define_distribution()
+ flags.DEFINE_string('bert_config_file', None,
+ 'Bert configuration file to define core bert layers.')
+ flags.DEFINE_string(
+ 'model_export_path', None,
+ 'Path to the directory, where trainined model will be '
+ 'exported.')
+ flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
+ flags.DEFINE_string(
+ 'init_checkpoint', None,
+ 'Initial checkpoint (usually from a pre-trained BERT model).')
+ flags.DEFINE_integer('num_train_epochs', 3,
+ 'Total number of training epochs to perform.')
+ flags.DEFINE_integer(
+ 'steps_per_loop', None,
+ 'Number of steps per graph-mode loop. Only training step '
+ 'happens inside the loop. Callbacks will not be called '
+ 'inside. If not set the value will be configured depending on the '
+ 'devices available.')
+ flags.DEFINE_float('learning_rate', 5e-5,
+ 'The initial learning rate for Adam.')
+ flags.DEFINE_float('end_lr', 0.0,
+ 'The end learning rate for learning rate decay.')
+ flags.DEFINE_string('optimizer_type', 'adamw',
+ 'The type of optimizer to use for training (adamw|lamb)')
+ flags.DEFINE_boolean(
+ 'scale_loss', False,
+ 'Whether to divide the loss by number of replica inside the per-replica '
+ 'loss function.')
+ flags.DEFINE_boolean(
+ 'use_keras_compile_fit', False,
+ 'If True, uses Keras compile/fit() API for training logic. Otherwise '
+ 'use custom training loop.')
+ flags.DEFINE_string(
+ 'hub_module_url', None, 'TF-Hub path/url to Bert module. '
+ 'If specified, init_checkpoint flag should not be used.')
+ flags.DEFINE_bool('hub_module_trainable', True,
+ 'True to make keras layers in the hub module trainable.')
+ flags.DEFINE_string('sub_model_export_name', None,
+ 'If set, `sub_model` checkpoints are exported into '
+ 'FLAGS.model_dir/FLAGS.sub_model_export_name.')
+
+ flags_core.define_log_steps()
+
+ # Adds flags for mixed precision and multi-worker training.
+ flags_core.define_performance(
+ num_parallel_calls=False,
+ inter_op=False,
+ intra_op=False,
+ synthetic_data=False,
+ max_train_steps=False,
+ dtype=True,
+ dynamic_loss_scale=True,
+ loss_scale=True,
+ all_reduce_alg=True,
+ num_packs=False,
+ tf_gpu_thread_mode=True,
+ datasets_num_private_threads=True,
+ enable_xla=True,
+ fp16_implementation=True,
+ )
+
+ # Adds gin configuration flags.
+ hyperparams_flags.define_gin_flags()
+
+
+def dtype():
+ return flags_core.get_tf_dtype(flags.FLAGS)
+
+
+def use_float16():
+ return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
+
+
+def use_graph_rewrite():
+ return flags.FLAGS.fp16_implementation == 'graph_rewrite'
+
+
+def get_loss_scale():
+ return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
diff --git a/models/official/nlp/bert/configs.py b/models/official/nlp/bert/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3f9082655f490e010ff2a341c40d488eb1097c1
--- /dev/null
+++ b/models/official/nlp/bert/configs.py
@@ -0,0 +1,108 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The main BERT model and related functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import json
+import six
+import tensorflow as tf
+
+
+class BertConfig(object):
+ """Configuration for `BertModel`."""
+
+ def __init__(self,
+ vocab_size,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ initializer_range=0.02,
+ embedding_size=None,
+ backward_compatible=True):
+ """Constructs BertConfig.
+
+ Args:
+ vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
+ hidden_size: Size of the encoder layers and the pooler layer.
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
+ num_attention_heads: Number of attention heads for each attention layer in
+ the Transformer encoder.
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
+ layer in the Transformer encoder.
+ hidden_act: The non-linear activation function (function or string) in the
+ encoder and pooler.
+ hidden_dropout_prob: The dropout probability for all fully connected
+ layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob: The dropout ratio for the attention
+ probabilities.
+ max_position_embeddings: The maximum sequence length that this model might
+ ever be used with. Typically set this to something large just in case
+ (e.g., 512 or 1024 or 2048).
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
+ `BertModel`.
+ initializer_range: The stdev of the truncated_normal_initializer for
+ initializing all weight matrices.
+ embedding_size: (Optional) width of the factorized word embeddings.
+ backward_compatible: Boolean, whether the variables shape are compatible
+ with checkpoints converted from TF 1.x BERT.
+ """
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.embedding_size = embedding_size
+ self.backward_compatible = backward_compatible
+
+ @classmethod
+ def from_dict(cls, json_object):
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
+ config = BertConfig(vocab_size=None)
+ for (key, value) in six.iteritems(json_object):
+ config.__dict__[key] = value
+ return config
+
+ @classmethod
+ def from_json_file(cls, json_file):
+ """Constructs a `BertConfig` from a json file of parameters."""
+ with tf.io.gfile.GFile(json_file, "r") as reader:
+ text = reader.read()
+ return cls.from_dict(json.loads(text))
+
+ def to_dict(self):
+ """Serializes this instance to a Python dictionary."""
+ output = copy.deepcopy(self.__dict__)
+ return output
+
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
+
diff --git a/models/official/nlp/bert/export_tfhub.py b/models/official/nlp/bert/export_tfhub.py
new file mode 100644
index 0000000000000000000000000000000000000000..5923309d1fa36a16d4cccda11650d9c3d0fcc616
--- /dev/null
+++ b/models/official/nlp/bert/export_tfhub.py
@@ -0,0 +1,95 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A script to export the BERT core model as a TF-Hub SavedModel."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+from typing import Text
+from official.nlp.bert import bert_models
+from official.nlp.bert import configs
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("bert_config_file", None,
+ "Bert configuration file to define core bert layers.")
+flags.DEFINE_string("model_checkpoint_path", None,
+ "File path to TF model checkpoint.")
+flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
+flags.DEFINE_string("vocab_file", None,
+ "The vocabulary file that the BERT model was trained on.")
+flags.DEFINE_bool("do_lower_case", None, "Whether to lowercase. If None, "
+ "do_lower_case will be enabled if 'uncased' appears in the "
+ "name of --vocab_file")
+
+
+def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
+ """Creates a BERT keras core model from BERT configuration.
+
+ Args:
+ bert_config: A `BertConfig` to create the core model.
+
+ Returns:
+ A keras model.
+ """
+ # Adds input layers just as placeholders.
+ input_word_ids = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_word_ids")
+ input_mask = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_mask")
+ input_type_ids = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_type_ids")
+ transformer_encoder = bert_models.get_transformer_encoder(
+ bert_config, sequence_length=None)
+ sequence_output, pooled_output = transformer_encoder(
+ [input_word_ids, input_mask, input_type_ids])
+ # To keep consistent with legacy hub modules, the outputs are
+ # "pooled_output" and "sequence_output".
+ return tf.keras.Model(
+ inputs=[input_word_ids, input_mask, input_type_ids],
+ outputs=[pooled_output, sequence_output]), transformer_encoder
+
+
+def export_bert_tfhub(bert_config: configs.BertConfig,
+ model_checkpoint_path: Text, hub_destination: Text,
+ vocab_file: Text, do_lower_case: bool = None):
+ """Restores a tf.keras.Model and saves for TF-Hub."""
+ # If do_lower_case is not explicit, default to checking whether "uncased" is
+ # in the vocab file name
+ if do_lower_case is None:
+ do_lower_case = "uncased" in vocab_file
+ logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
+ do_lower_case, vocab_file)
+ core_model, encoder = create_bert_model(bert_config)
+ checkpoint = tf.train.Checkpoint(model=encoder)
+ checkpoint.restore(model_checkpoint_path).assert_consumed()
+ core_model.vocab_file = tf.saved_model.Asset(vocab_file)
+ core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
+ core_model.save(hub_destination, include_optimizer=False, save_format="tf")
+
+
+def main(_):
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
+ FLAGS.vocab_file, FLAGS.do_lower_case)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/bert/export_tfhub_test.py b/models/official/nlp/bert/export_tfhub_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b6fd40f5e1be5d5e8d4699d54c048add7435523
--- /dev/null
+++ b/models/official/nlp/bert/export_tfhub_test.py
@@ -0,0 +1,109 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests official.nlp.bert.export_tfhub."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+import tensorflow as tf
+import tensorflow_hub as hub
+from official.nlp.bert import configs
+from official.nlp.bert import export_tfhub
+
+
+class ExportTfhubTest(tf.test.TestCase):
+
+ def test_export_tfhub(self):
+ # Exports a savedmodel for TF-Hub
+ hidden_size = 16
+ bert_config = configs.BertConfig(
+ vocab_size=100,
+ hidden_size=hidden_size,
+ intermediate_size=32,
+ max_position_embeddings=128,
+ num_attention_heads=2,
+ num_hidden_layers=1)
+ bert_model, encoder = export_tfhub.create_bert_model(bert_config)
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
+ checkpoint = tf.train.Checkpoint(model=encoder)
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
+
+ vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
+ with tf.io.gfile.GFile(vocab_file, "w") as f:
+ f.write("dummy content")
+
+ hub_destination = os.path.join(self.get_temp_dir(), "hub")
+ export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
+ hub_destination, vocab_file)
+
+ # Restores a hub KerasLayer.
+ hub_layer = hub.KerasLayer(hub_destination, trainable=True)
+
+ if hasattr(hub_layer, "resolved_object"):
+ # Checks meta attributes.
+ self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
+ with tf.io.gfile.GFile(
+ hub_layer.resolved_object.vocab_file.asset_path.numpy()) as f:
+ self.assertEqual("dummy content", f.read())
+ # Checks the hub KerasLayer.
+ for source_weight, hub_weight in zip(bert_model.trainable_weights,
+ hub_layer.trainable_weights):
+ self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
+
+ seq_length = 10
+ dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
+ hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
+ source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
+
+ # The outputs of hub module are "pooled_output" and "sequence_output",
+ # while the outputs of encoder is in reversed order, i.e.,
+ # "sequence_output" and "pooled_output".
+ encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
+ self.assertEqual(hub_outputs[0].shape, (2, hidden_size))
+ self.assertEqual(hub_outputs[1].shape, (2, seq_length, hidden_size))
+ for source_output, hub_output, encoder_output in zip(
+ source_outputs, hub_outputs, encoder_outputs):
+ self.assertAllClose(source_output.numpy(), hub_output.numpy())
+ self.assertAllClose(source_output.numpy(), encoder_output.numpy())
+
+ # Test that training=True makes a difference (activates dropout).
+ def _dropout_mean_stddev(training, num_runs=20):
+ input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
+ inputs = [input_ids, np.ones_like(input_ids), np.zeros_like(input_ids)]
+ outputs = np.concatenate(
+ [hub_layer(inputs, training=training)[0] for _ in range(num_runs)])
+ return np.mean(np.std(outputs, axis=0))
+ self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
+ self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
+
+ # Test propagation of seq_length in shape inference.
+ input_word_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
+ input_mask = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
+ input_type_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
+ pooled_output, sequence_output = hub_layer(
+ [input_word_ids, input_mask, input_type_ids])
+ self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
+ self.assertEqual(sequence_output.shape.as_list(),
+ [None, seq_length, hidden_size])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/bert/input_pipeline.py b/models/official/nlp/bert/input_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed3fd173d4379a75ab1e2e5a9ba0bbdcbaa0be42
--- /dev/null
+++ b/models/official/nlp/bert/input_pipeline.py
@@ -0,0 +1,285 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BERT model input pipelines."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def decode_record(record, name_to_features):
+ """Decodes a record to a TensorFlow example."""
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+
+def single_file_dataset(input_file, name_to_features):
+ """Creates a single-file dataset to be passed for BERT custom training."""
+ # For training, we want a lot of parallel reading and shuffling.
+ # For eval, we want no shuffling and parallel reading doesn't matter.
+ d = tf.data.TFRecordDataset(input_file)
+ d = d.map(
+ lambda record: decode_record(record, name_to_features),
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ # When `input_file` is a path to a single file or a list
+ # containing a single path, disable auto sharding so that
+ # same input file is sent to all workers.
+ if isinstance(input_file, str) or len(input_file) == 1:
+ options = tf.data.Options()
+ options.experimental_distribute.auto_shard_policy = (
+ tf.data.experimental.AutoShardPolicy.OFF)
+ d = d.with_options(options)
+ return d
+
+
+def create_pretrain_dataset(input_patterns,
+ seq_length,
+ max_predictions_per_seq,
+ batch_size,
+ is_training=True,
+ input_pipeline_context=None,
+ use_next_sentence_label=True,
+ use_position_id=False,
+ output_fake_labels=True):
+ """Creates input dataset from (tf)records files for pretraining."""
+ name_to_features = {
+ 'input_ids':
+ tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'input_mask':
+ tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'segment_ids':
+ tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'masked_lm_positions':
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
+ 'masked_lm_ids':
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
+ 'masked_lm_weights':
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
+ }
+ if use_next_sentence_label:
+ name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
+ tf.int64)
+ if use_position_id:
+ name_to_features['position_ids'] = tf.io.FixedLenFeature([seq_length],
+ tf.int64)
+ for input_pattern in input_patterns:
+ if not tf.io.gfile.glob(input_pattern):
+ raise ValueError('%s does not match any files.' % input_pattern)
+
+ dataset = tf.data.Dataset.list_files(input_patterns, shuffle=is_training)
+
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+ if is_training:
+ dataset = dataset.repeat()
+
+ # We set shuffle buffer to exactly match total number of
+ # training files to ensure that training data is well shuffled.
+ input_files = []
+ for input_pattern in input_patterns:
+ input_files.extend(tf.io.gfile.glob(input_pattern))
+ dataset = dataset.shuffle(len(input_files))
+
+ # In parallel, create tf record dataset for each train files.
+ # cycle_length = 8 means that up to 8 files will be read and deserialized in
+ # parallel. You may want to increase this number if you have a large number of
+ # CPU cores.
+ dataset = dataset.interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=8,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if is_training:
+ dataset = dataset.shuffle(100)
+
+ decode_fn = lambda record: decode_record(record, name_to_features)
+ dataset = dataset.map(
+ decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ def _select_data_from_record(record):
+ """Filter out features to use for pretraining."""
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids'],
+ 'masked_lm_positions': record['masked_lm_positions'],
+ 'masked_lm_ids': record['masked_lm_ids'],
+ 'masked_lm_weights': record['masked_lm_weights'],
+ }
+ if use_next_sentence_label:
+ x['next_sentence_labels'] = record['next_sentence_labels']
+ if use_position_id:
+ x['position_ids'] = record['position_ids']
+
+ # TODO(hongkuny): Remove the fake labels after migrating bert pretraining.
+ if output_fake_labels:
+ return (x, record['masked_lm_weights'])
+ else:
+ return x
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=is_training)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def create_classifier_dataset(file_path,
+ seq_length,
+ batch_size,
+ is_training=True,
+ input_pipeline_context=None,
+ label_type=tf.int64,
+ include_sample_weights=False):
+ """Creates input dataset from (tf)records files for train/eval."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'label_ids': tf.io.FixedLenFeature([], label_type),
+ }
+ if include_sample_weights:
+ name_to_features['weight'] = tf.io.FixedLenFeature([], tf.float32)
+ dataset = single_file_dataset(file_path, name_to_features)
+
+ # The dataset is always sharded by number of hosts.
+ # num_input_pipelines is the number of hosts rather than number of cores.
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+
+ def _select_data_from_record(record):
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids']
+ }
+ y = record['label_ids']
+ if include_sample_weights:
+ w = record['weight']
+ return (x, y, w)
+ return (x, y)
+
+ if is_training:
+ dataset = dataset.shuffle(100)
+ dataset = dataset.repeat()
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=is_training)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def create_squad_dataset(file_path,
+ seq_length,
+ batch_size,
+ is_training=True,
+ input_pipeline_context=None):
+ """Creates input dataset from (tf)records files for train/eval."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ }
+ if is_training:
+ name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
+ name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
+ else:
+ name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
+
+ dataset = single_file_dataset(file_path, name_to_features)
+
+ # The dataset is always sharded by number of hosts.
+ # num_input_pipelines is the number of hosts rather than number of cores.
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+
+ def _select_data_from_record(record):
+ """Dispatches record to features and labels."""
+ x, y = {}, {}
+ for name, tensor in record.items():
+ if name in ('start_positions', 'end_positions'):
+ y[name] = tensor
+ elif name == 'input_ids':
+ x['input_word_ids'] = tensor
+ elif name == 'segment_ids':
+ x['input_type_ids'] = tensor
+ else:
+ x[name] = tensor
+ return (x, y)
+
+ if is_training:
+ dataset = dataset.shuffle(100)
+ dataset = dataset.repeat()
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def create_retrieval_dataset(file_path,
+ seq_length,
+ batch_size,
+ input_pipeline_context=None):
+ """Creates input dataset from (tf)records files for scoring."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'int_iden': tf.io.FixedLenFeature([1], tf.int64),
+ }
+ dataset = single_file_dataset(file_path, name_to_features)
+
+ # The dataset is always sharded by number of hosts.
+ # num_input_pipelines is the number of hosts rather than number of cores.
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+
+ def _select_data_from_record(record):
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids']
+ }
+ y = record['int_iden']
+ return (x, y)
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=False)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
diff --git a/models/official/nlp/bert/model_saving_utils.py b/models/official/nlp/bert/model_saving_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..13d2c9ed02f9a98d9dcbb2a60c46fa5cd13bb666
--- /dev/null
+++ b/models/official/nlp/bert/model_saving_utils.py
@@ -0,0 +1,77 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to save models."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+
+from absl import logging
+import tensorflow as tf
+import typing
+
+
+def export_bert_model(model_export_path: typing.Text,
+ model: tf.keras.Model,
+ checkpoint_dir: typing.Optional[typing.Text] = None,
+ restore_model_using_load_weights: bool = False) -> None:
+ """Export BERT model for serving which does not include the optimizer.
+
+ Arguments:
+ model_export_path: Path to which exported model will be saved.
+ model: Keras model object to export.
+ checkpoint_dir: Path from which model weights will be loaded, if
+ specified.
+ restore_model_using_load_weights: Whether to use checkpoint.restore() API
+ for custom checkpoint or to use model.load_weights() API.
+ There are 2 different ways to save checkpoints. One is using
+ tf.train.Checkpoint and another is using Keras model.save_weights().
+ Custom training loop implementation uses tf.train.Checkpoint API
+ and Keras ModelCheckpoint callback internally uses model.save_weights()
+ API. Since these two API's cannot be used toghether, model loading logic
+ must be take into account how model checkpoint was saved.
+
+ Raises:
+ ValueError when either model_export_path or model is not specified.
+ """
+ if not model_export_path:
+ raise ValueError('model_export_path must be specified.')
+ if not isinstance(model, tf.keras.Model):
+ raise ValueError('model must be a tf.keras.Model object.')
+
+ if checkpoint_dir:
+ # Keras compile/fit() was used to save checkpoint using
+ # model.save_weights().
+ if restore_model_using_load_weights:
+ model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
+ assert tf.io.gfile.exists(model_weight_path)
+ model.load_weights(model_weight_path)
+
+ # tf.train.Checkpoint API was used via custom training loop logic.
+ else:
+ checkpoint = tf.train.Checkpoint(model=model)
+
+ # Restores the model from latest checkpoint.
+ latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
+ assert latest_checkpoint_file
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint_file)
+ checkpoint.restore(
+ latest_checkpoint_file).assert_existing_objects_matched()
+
+ model.save(model_export_path, include_optimizer=False, save_format='tf')
diff --git a/models/official/nlp/bert/model_training_utils.py b/models/official/nlp/bert/model_training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0fe67615726906a6b1d3ef38a5ca9acfe8502de
--- /dev/null
+++ b/models/official/nlp/bert/model_training_utils.py
@@ -0,0 +1,572 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A light weight utilities to train NLP models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import tempfile
+
+from absl import logging
+import tensorflow as tf
+from tensorflow.python.util import deprecation
+from official.staging.training import grad_utils
+from official.utils.misc import distribution_utils
+
+_SUMMARY_TXT = 'training_summary.txt'
+_MIN_SUMMARY_STEPS = 10
+
+
+def _should_export_checkpoint(strategy):
+ return (not strategy) or strategy.extended.should_checkpoint
+
+
+def _should_export_summary(strategy):
+ return (not strategy) or strategy.extended.should_save_summary
+
+
+def _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_prefix):
+ """Saves model to with provided checkpoint prefix."""
+
+ if _should_export_checkpoint(strategy):
+ checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
+ saved_path = checkpoint.save(checkpoint_path)
+ logging.info('Saving model as TF checkpoint: %s', saved_path)
+ else:
+ # In multi worker training we need every worker to save checkpoint, because
+ # variables can trigger synchronization on read and synchronization needs
+ # all workers to participate. To avoid workers overriding each other we save
+ # to a temporary directory on non-chief workers.
+ tmp_dir = tempfile.mkdtemp()
+ checkpoint.save(os.path.join(tmp_dir, 'ckpt'))
+ tf.io.gfile.rmtree(tmp_dir)
+ return
+
+
+def _get_input_iterator(input_fn, strategy):
+ """Returns distributed dataset iterator."""
+ # When training with TPU pods, datasets needs to be cloned across
+ # workers. Since Dataset instance cannot be cloned in eager mode, we instead
+ # pass callable that returns a dataset.
+ if not callable(input_fn):
+ raise ValueError('`input_fn` should be a closure that returns a dataset.')
+ iterator = iter(
+ strategy.experimental_distribute_datasets_from_function(input_fn))
+ return iterator
+
+
+def _float_metric_value(metric):
+ """Gets the value of a float-value keras metric."""
+ return metric.result().numpy().astype(float)
+
+
+def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
+ """Calculates steps to run on device."""
+ if steps_per_loop <= 0:
+ raise ValueError('steps_per_loop should be positive integer.')
+ if steps_per_loop == 1:
+ return steps_per_loop
+ remainder_in_epoch = current_step % steps_per_epoch
+ if remainder_in_epoch != 0:
+ return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
+ else:
+ return steps_per_loop
+
+
+def write_txt_summary(training_summary, summary_dir):
+ """Writes a summary text file to record stats."""
+ if not tf.io.gfile.exists(summary_dir):
+ tf.io.gfile.mkdir(summary_dir)
+ summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
+ with tf.io.gfile.GFile(summary_path, 'wb') as f:
+ logging.info('Training Summary: \n%s', str(training_summary))
+ f.write(json.dumps(training_summary, indent=4))
+
+
+@deprecation.deprecated(
+ None, 'This function is deprecated. Please use Keras compile/fit instead.')
+def run_customized_training_loop(
+ # pylint: disable=invalid-name
+ _sentinel=None,
+ # pylint: enable=invalid-name
+ strategy=None,
+ model_fn=None,
+ loss_fn=None,
+ scale_loss=True,
+ model_dir=None,
+ train_input_fn=None,
+ steps_per_epoch=None,
+ num_eval_per_epoch=1,
+ steps_per_loop=None,
+ epochs=1,
+ eval_input_fn=None,
+ eval_steps=None,
+ metric_fn=None,
+ init_checkpoint=None,
+ custom_callbacks=None,
+ run_eagerly=False,
+ sub_model_export_name=None,
+ explicit_allreduce=False,
+ pre_allreduce_callbacks=None,
+ post_allreduce_callbacks=None,
+ train_summary_interval=0):
+ """Run BERT pretrain model training using low-level API.
+
+ Arguments:
+ _sentinel: Used to prevent positional parameters. Internal, do not use.
+ strategy: Distribution strategy on which to run low level training loop.
+ model_fn: Function that returns a tuple (model, sub_model). Caller of this
+ function should add optimizer to the `model` via calling
+ `model.compile()` API or manually setting `model.optimizer` attribute.
+ Second element of the returned tuple(sub_model) is an optional sub model
+ to be used for initial checkpoint -- if provided.
+ loss_fn: Function with signature func(labels, logits) and returns a loss
+ tensor.
+ scale_loss: Whether to divide the raw loss by number of replicas before
+ gradients calculation.
+ model_dir: Model directory used during training for restoring/saving model
+ weights.
+ train_input_fn: Function that returns a tf.data.Dataset used for training.
+ steps_per_epoch: Number of steps to run per epoch. At the end of each
+ epoch, model checkpoint will be saved and evaluation will be conducted
+ if evaluation dataset is provided.
+ num_eval_per_epoch: Number of evaluations per epoch.
+ steps_per_loop: Number of steps per graph-mode loop. In order to reduce
+ communication in eager context, training logs are printed every
+ steps_per_loop.
+ epochs: Number of epochs to train.
+ eval_input_fn: Function that returns evaluation dataset. If none,
+ evaluation is skipped.
+ eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
+ is not none.
+ metric_fn: A metrics function that returns a Keras Metric object to record
+ evaluation result using evaluation dataset or with training dataset
+ after every epoch.
+ init_checkpoint: Optional checkpoint to load to `sub_model` returned by
+ `model_fn`.
+ custom_callbacks: A list of Keras Callbacks objects to run during
+ training. More specifically, `on_train_begin(), on_train_end(),
+ on_batch_begin()`, `on_batch_end()`, `on_epoch_begin()`,
+ `on_epoch_end()` methods are invoked during training.
+ Note that some metrics may be missing from `logs`.
+ run_eagerly: Whether to run model training in pure eager execution. This
+ should be disable for TPUStrategy.
+ sub_model_export_name: If not None, will export `sub_model` returned by
+ `model_fn` into checkpoint files. The name of intermediate checkpoint
+ file is {sub_model_export_name}_step_{step}.ckpt and the last
+ checkpint's name is {sub_model_export_name}.ckpt; if None, `sub_model`
+ will not be exported as checkpoint.
+ explicit_allreduce: Whether to explicitly perform gradient allreduce,
+ instead of relying on implicit allreduce in optimizer.apply_gradients().
+ default is False. For now, if training using FP16 mixed precision,
+ explicit allreduce will aggregate gradients in FP16 format. For TPU and
+ GPU training using FP32, explicit allreduce will aggregate gradients in
+ FP32 format.
+ pre_allreduce_callbacks: A list of callback functions that takes gradients
+ and model variables pairs as input, manipulate them, and returns a new
+ gradients and model variables paris. The callback functions will be
+ invoked in the list order and before gradients are allreduced. With
+ mixed precision training, the pre_allreduce_allbacks will be applied on
+ scaled_gradients. Default is no callbacks. Only used when
+ explicit_allreduce=True.
+ post_allreduce_callbacks: A list of callback functions that takes
+ gradients and model variables pairs as input, manipulate them, and
+ returns a new gradients and model variables paris. The callback
+ functions will be invoked in the list order and right before gradients
+ are applied to variables for updates. Default is no callbacks. Only used
+ when explicit_allreduce=True.
+ train_summary_interval: Step interval for training summaries. If the value
+ is a negative number, then training summaries are not enabled.
+
+ Returns:
+ Trained model.
+
+ Raises:
+ ValueError: (1) When model returned by `model_fn` does not have optimizer
+ attribute or when required parameters are set to none. (2) eval args are
+ not specified correctly. (3) metric_fn must be a callable if specified.
+ (4) sub_model_checkpoint_name is specified, but `sub_model` returned
+ by `model_fn` is None.
+ """
+
+ if _sentinel is not None:
+ raise ValueError('only call `run_customized_training_loop()` '
+ 'with named arguments.')
+
+ required_arguments = [
+ strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
+ ]
+
+ steps_between_evals = int(steps_per_epoch / num_eval_per_epoch)
+ if [arg for arg in required_arguments if arg is None]:
+ raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
+ '`steps_per_epoch` and `train_input_fn` are required '
+ 'parameters.')
+ if not steps_per_loop:
+ if tf.config.list_logical_devices('TPU'):
+ # One can't fully utilize a TPU with steps_per_loop=1, so in this case
+ # default users to a more useful value.
+ steps_per_loop = min(1000, steps_between_evals)
+ else:
+ steps_per_loop = 1
+ logging.info('steps_per_loop not specified. Using steps_per_loop=%d',
+ steps_per_loop)
+ if steps_per_loop > steps_between_evals:
+ logging.warning(
+ 'steps_per_loop: %d is specified to be greater than '
+ ' steps_between_evals: %d, we will use steps_between_evals as'
+ ' steps_per_loop.', steps_per_loop, steps_between_evals)
+ steps_per_loop = steps_between_evals
+ assert tf.executing_eagerly()
+
+ if run_eagerly:
+ if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
+ raise ValueError(
+ 'TPUStrategy should not run eagerly as it heavily relies on graph'
+ ' optimization for the distributed system.')
+
+ if eval_input_fn and eval_steps is None:
+ raise ValueError(
+ '`eval_step` is required when `eval_input_fn ` is not none.')
+ if metric_fn and not callable(metric_fn):
+ raise ValueError(
+ 'if `metric_fn` is specified, metric_fn must be a callable.')
+
+ total_training_steps = steps_per_epoch * epochs
+ train_iterator = _get_input_iterator(train_input_fn, strategy)
+ eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
+
+ with distribution_utils.get_strategy_scope(strategy):
+ # To correctly place the model weights on accelerators,
+ # model and optimizer should be created in scope.
+ model, sub_model = model_fn()
+ if not hasattr(model, 'optimizer'):
+ raise ValueError('User should set optimizer attribute to model '
+ 'inside `model_fn`.')
+ if sub_model_export_name and sub_model is None:
+ raise ValueError('sub_model_export_name is specified as %s, but '
+ 'sub_model is None.' % sub_model_export_name)
+
+ callback_list = tf.keras.callbacks.CallbackList(
+ callbacks=custom_callbacks, model=model)
+
+ optimizer = model.optimizer
+
+ if init_checkpoint:
+ logging.info(
+ 'Checkpoint file %s found and restoring from '
+ 'initial checkpoint for core model.', init_checkpoint)
+ checkpoint = tf.train.Checkpoint(model=sub_model)
+ checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
+ logging.info('Loading from checkpoint file completed')
+
+ train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
+ eval_metrics = [metric_fn()] if metric_fn else []
+ # If evaluation is required, make a copy of metric as it will be used by
+ # both train and evaluation.
+ train_metrics = [
+ metric.__class__.from_config(metric.get_config())
+ for metric in eval_metrics
+ ]
+
+ # Create summary writers
+ if _should_export_summary(strategy):
+ summary_dir = os.path.join(model_dir, 'summaries')
+ else:
+ # In multi worker training we need every worker to write summary, because
+ # variables can trigger synchronization on read and synchronization needs
+ # all workers to participate.
+ summary_dir = tempfile.mkdtemp()
+ eval_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, 'eval'))
+ last_summary_step = 0
+ if steps_per_loop >= _MIN_SUMMARY_STEPS and train_summary_interval >= 0:
+ # Only writes summary when the stats are collected sufficiently over
+ # enough steps.
+ train_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, 'train'))
+ else:
+ train_summary_writer = tf.summary.create_noop_writer()
+
+ # Collects training variables.
+ training_vars = model.trainable_variables
+
+ def _replicated_step(inputs):
+ """Replicated training step."""
+
+ inputs, labels = inputs
+ with tf.GradientTape() as tape:
+ model_outputs = model(inputs, training=True)
+ loss = loss_fn(labels, model_outputs)
+ # Raw loss is used for reporting in metrics/logs.
+ raw_loss = loss
+ if scale_loss:
+ # Scales down the loss for gradients to be invariant from replicas.
+ loss = loss / strategy.num_replicas_in_sync
+
+ if explicit_allreduce:
+ grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
+ training_vars,
+ pre_allreduce_callbacks,
+ post_allreduce_callbacks)
+ else:
+ if isinstance(optimizer,
+ tf.keras.mixed_precision.experimental.LossScaleOptimizer):
+ with tape:
+ scaled_loss = optimizer.get_scaled_loss(loss)
+ scaled_grads = tape.gradient(scaled_loss, training_vars)
+ grads = optimizer.get_unscaled_gradients(scaled_grads)
+ else:
+ grads = tape.gradient(loss, training_vars)
+ optimizer.apply_gradients(zip(grads, training_vars))
+ # For reporting, the metric takes the mean of losses.
+ train_loss_metric.update_state(raw_loss)
+ for metric in train_metrics:
+ metric.update_state(labels, model_outputs)
+
+ @tf.function
+ def train_steps(iterator, steps):
+ """Performs distributed training steps in a loop.
+
+ Args:
+ iterator: the distributed iterator of training datasets.
+ steps: an tf.int32 integer tensor to specify number of steps to run
+ inside host training loop.
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ """
+ if not isinstance(steps, tf.Tensor):
+ raise ValueError('steps should be an Tensor. Python object may cause '
+ 'retracing.')
+
+ for _ in tf.range(steps):
+ strategy.run(_replicated_step, args=(next(iterator),))
+
+ def train_single_step(iterator):
+ """Performs a distributed training step.
+
+ Args:
+ iterator: the distributed iterator of training datasets.
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ """
+ strategy.run(_replicated_step, args=(next(iterator),))
+
+ def test_step(iterator):
+ """Calculates evaluation metrics on distributed devices."""
+
+ def _test_step_fn(inputs):
+ """Replicated accuracy calculation."""
+
+ inputs, labels = inputs
+ model_outputs = model(inputs, training=False)
+ for metric in eval_metrics:
+ metric.update_state(labels, model_outputs)
+ return model_outputs, labels
+
+ outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
+ outputs = tf.nest.map_structure(strategy.experimental_local_results,
+ outputs)
+ labels = tf.nest.map_structure(strategy.experimental_local_results,
+ labels)
+ return outputs, labels
+
+ if not run_eagerly:
+ train_single_step = tf.function(train_single_step)
+ test_step = tf.function(test_step)
+
+ def _run_evaluation(current_training_step, test_iterator):
+ """Runs validation steps and aggregate metrics.
+
+ Args:
+ current_training_step: tf.int32 tensor containing the current step.
+ test_iterator: distributed iterator of test datasets.
+
+ Returns:
+ A dict of metic names and values.
+ """
+ # The last batch of the evaluation is often smaller than previous ones.
+ # Moreover, in some distributed pieces it might even be empty. Therefore,
+ # different from the way training_loss is calculated, it is needed to
+ # gather all the logits and labels here to calculate the evaluation loss
+ # outside.
+ loss_list, loss_weights = list(), list()
+ for _ in range(eval_steps):
+ outputs, labels = test_step(test_iterator)
+ for cur_logits, cur_labels in zip(outputs, labels):
+ # This is to handle cases when cur_labels is not a single tensor,
+ # but a dict of tensors.
+ cur_weight = tf.shape(tf.nest.flatten(cur_labels)[0])[0]
+ if cur_weight != 0:
+ loss_list.append(loss_fn(cur_labels, cur_logits).numpy())
+ loss_weights.append(cur_weight)
+ # The sample_weights are the actual number of examples in each batch,
+ # a summation of numbers of examples in each replica if using
+ # distributed training.
+ eval_loss_metric.update_state(loss_list, sample_weight=loss_weights)
+
+ logs = {}
+ with eval_summary_writer.as_default():
+ for metric in [eval_loss_metric] + eval_metrics + model.metrics:
+ metric_value = _float_metric_value(metric)
+ logs[metric.name] = metric_value
+ logging.info('Step: [%d] Validation %s = %f', current_training_step,
+ metric.name, metric_value)
+ tf.summary.scalar(
+ metric.name, metric_value, step=current_training_step)
+ eval_summary_writer.flush()
+
+ return logs
+
+ # Training loop starts here.
+ checkpoint = tf.train.Checkpoint(
+ model=model, optimizer=optimizer, global_step=optimizer.iterations)
+ sub_model_checkpoint = tf.train.Checkpoint(
+ model=sub_model,
+ global_step=optimizer.iterations) if sub_model_export_name else None
+
+ latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
+ if latest_checkpoint_file:
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint_file)
+ checkpoint.restore(latest_checkpoint_file)
+ logging.info('Loading from checkpoint file completed')
+
+ current_step = optimizer.iterations.numpy()
+ checkpoint_name = 'ctl_step_{step}.ckpt'
+
+ logs = {}
+ callback_list.on_train_begin()
+ while current_step < total_training_steps and not model.stop_training:
+ if current_step % steps_per_epoch == 0:
+ callback_list.on_epoch_begin(
+ int(current_step / steps_per_epoch) + 1)
+
+ # Training loss/metric are taking average over steps inside micro
+ # training loop. We reset the their values before each round.
+ train_loss_metric.reset_states()
+ for metric in train_metrics + model.metrics:
+ metric.reset_states()
+
+ callback_list.on_batch_begin(current_step)
+ # Runs several steps in the host while loop.
+ steps = steps_to_run(current_step, steps_between_evals, steps_per_loop)
+
+ if tf.config.list_physical_devices('GPU'):
+ # TODO(zongweiz): merge with train_steps once tf.while_loop
+ # GPU performance bugs are fixed.
+ for _ in range(steps):
+ train_single_step(train_iterator)
+ else:
+ # Converts steps to a Tensor to avoid tf.function retracing.
+ train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
+ train_loss = _float_metric_value(train_loss_metric)
+ current_step += steps
+
+ # Updates training logging.
+ training_status = 'Train Step: %d/%d / loss = %s' % (
+ current_step, total_training_steps, train_loss)
+
+ if current_step >= last_summary_step + train_summary_interval:
+ summary_writer = train_summary_writer
+ last_summary_step = current_step
+ else:
+ summary_writer = tf.summary.create_noop_writer()
+
+ with summary_writer.as_default():
+ if callable(optimizer.learning_rate):
+ tf.summary.scalar(
+ 'learning_rate',
+ optimizer.learning_rate(current_step),
+ step=current_step)
+ tf.summary.scalar(train_loss_metric.name, train_loss, step=current_step)
+ for metric in train_metrics + model.metrics:
+ metric_value = _float_metric_value(metric)
+ training_status += ' %s = %f' % (metric.name, metric_value)
+ tf.summary.scalar(metric.name, metric_value, step=current_step)
+ summary_writer.flush()
+ logging.info(training_status)
+
+ # If no need for evaluation, we only call on_batch_end with train_loss,
+ # this is to ensure we get granular global_step/sec on Tensorboard.
+ if current_step % steps_between_evals:
+ callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
+ else:
+ # Save a submodel with the step in the file name after each epoch.
+ if sub_model_export_name:
+ _save_checkpoint(
+ strategy, sub_model_checkpoint, model_dir,
+ '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
+
+ # Save model checkpoints and run validation steps after each epoch
+ # (with the exception of the final epoch which is handled after the
+ # training loop).
+ if current_step < total_training_steps:
+ _save_checkpoint(strategy, checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ if eval_input_fn:
+ logging.info('Running evaluation after step: %s.', current_step)
+ logs = _run_evaluation(current_step,
+ _get_input_iterator(eval_input_fn, strategy))
+ # Re-initialize evaluation metric.
+ eval_loss_metric.reset_states()
+ for metric in eval_metrics + model.metrics:
+ metric.reset_states()
+ # We add train_loss here rather than call on_batch_end twice to make
+ # sure that no duplicated values are generated.
+ logs['loss'] = train_loss
+ callback_list.on_batch_end(current_step - 1, logs)
+
+ # Calls on_epoch_end after each real epoch ends to prevent mis-calculation
+ # of training steps.
+ if current_step % steps_per_epoch == 0:
+ callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
+
+ if sub_model_export_name:
+ _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
+ '%s.ckpt' % sub_model_export_name)
+
+ _save_checkpoint(strategy, checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ if eval_input_fn:
+ logging.info('Running final evaluation after training is complete.')
+ logs = _run_evaluation(current_step,
+ _get_input_iterator(eval_input_fn, strategy))
+ callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
+ training_summary = {
+ 'total_training_steps': total_training_steps,
+ 'train_loss': _float_metric_value(train_loss_metric),
+ }
+ for metric in model.metrics:
+ training_summary[metric.name] = _float_metric_value(metric)
+ if eval_metrics:
+ # TODO(hongkuny): Cleans up summary reporting in text.
+ training_summary['last_train_metrics'] = _float_metric_value(
+ train_metrics[0])
+ training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
+
+ write_txt_summary(training_summary, summary_dir)
+
+ if not _should_export_summary(strategy):
+ tf.io.gfile.rmtree(summary_dir)
+
+ callback_list.on_train_end()
+
+ return model
diff --git a/models/official/nlp/bert/model_training_utils_test.py b/models/official/nlp/bert/model_training_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c85a6c9b520a1b4e39e6abdfde503b35034d29e
--- /dev/null
+++ b/models/official/nlp/bert/model_training_utils_test.py
@@ -0,0 +1,308 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for official.modeling.training.model_training_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import logging
+from absl.testing import parameterized
+from absl.testing.absltest import mock
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.nlp.bert import model_training_utils
+
+
+def eager_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ ],
+ mode='eager',
+ )
+
+
+def eager_gpu_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ ],
+ mode='eager',
+ )
+
+
+def create_fake_data_input_fn(batch_size, features_shape, num_classes):
+ """Creates a dummy input function with the given feature and label shapes.
+
+ Args:
+ batch_size: integer.
+ features_shape: list[int]. Feature shape for an individual example.
+ num_classes: integer. Number of labels.
+
+ Returns:
+ An input function that is usable in the executor.
+ """
+
+ def _dataset_fn(input_context=None):
+ """An input function for generating fake data."""
+ local_batch_size = input_context.get_per_replica_batch_size(batch_size)
+ features = np.random.rand(64, *features_shape)
+ labels = np.random.randint(2, size=[64, num_classes])
+ # Convert the inputs to a Dataset.
+ dataset = tf.data.Dataset.from_tensor_slices((features, labels))
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+
+ def _assign_dtype(features, labels):
+ features = tf.cast(features, tf.float32)
+ labels = tf.cast(labels, tf.float32)
+ return features, labels
+
+ # Shuffle, repeat, and batch the examples.
+ dataset = dataset.map(_assign_dtype)
+ dataset = dataset.shuffle(64).repeat()
+ dataset = dataset.batch(local_batch_size, drop_remainder=True)
+ dataset = dataset.prefetch(buffer_size=64)
+ return dataset
+
+ return _dataset_fn
+
+
+def create_model_fn(input_shape, num_classes, use_float16=False):
+
+ def _model_fn():
+ """A one-layer softmax model suitable for testing."""
+ input_layer = tf.keras.layers.Input(shape=input_shape)
+ x = tf.keras.layers.Dense(num_classes, activation='relu')(input_layer)
+ output_layer = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
+ sub_model = tf.keras.models.Model(input_layer, x, name='sub_model')
+ model = tf.keras.models.Model(input_layer, output_layer, name='model')
+ model.add_metric(
+ tf.reduce_mean(input_layer), name='mean_input', aggregation='mean')
+ model.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
+ if use_float16:
+ model.optimizer = (
+ tf.keras.mixed_precision.experimental.LossScaleOptimizer(
+ model.optimizer, loss_scale='dynamic'))
+ return model, sub_model
+
+ return _model_fn
+
+
+def metric_fn():
+ """Gets a tf.keras metric object."""
+ return tf.keras.metrics.CategoricalAccuracy(name='accuracy', dtype=tf.float32)
+
+
+def summaries_with_matching_keyword(keyword, summary_dir):
+ """Yields summary protos matching given keyword from event file."""
+ event_paths = tf.io.gfile.glob(os.path.join(summary_dir, 'events*'))
+ for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
+ if event.summary is not None:
+ for value in event.summary.value:
+ if keyword in value.tag:
+ logging.error(event)
+ yield event.summary
+
+
+def check_eventfile_for_keyword(keyword, summary_dir):
+ """Checks event files for the keyword."""
+ return any(summaries_with_matching_keyword(keyword, summary_dir))
+
+
+class RecordingCallback(tf.keras.callbacks.Callback):
+
+ def __init__(self):
+ self.batch_begin = [] # (batch, logs)
+ self.batch_end = [] # (batch, logs)
+ self.epoch_begin = [] # (epoch, logs)
+ self.epoch_end = [] # (epoch, logs)
+
+ def on_batch_begin(self, batch, logs=None):
+ self.batch_begin.append((batch, logs))
+
+ def on_batch_end(self, batch, logs=None):
+ self.batch_end.append((batch, logs))
+
+ def on_epoch_begin(self, epoch, logs=None):
+ self.epoch_begin.append((epoch, logs))
+
+ def on_epoch_end(self, epoch, logs=None):
+ self.epoch_end.append((epoch, logs))
+
+
+class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(ModelTrainingUtilsTest, self).setUp()
+ self._model_fn = create_model_fn(input_shape=[128], num_classes=3)
+
+ def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
+ input_fn = create_fake_data_input_fn(
+ batch_size=8, features_shape=[128], num_classes=3)
+ model_training_utils.run_customized_training_loop(
+ strategy=strategy,
+ model_fn=self._model_fn,
+ loss_fn=tf.keras.losses.categorical_crossentropy,
+ model_dir=model_dir,
+ steps_per_epoch=20,
+ steps_per_loop=steps_per_loop,
+ epochs=2,
+ train_input_fn=input_fn,
+ eval_input_fn=input_fn,
+ eval_steps=10,
+ init_checkpoint=None,
+ sub_model_export_name='my_submodel_name',
+ metric_fn=metric_fn,
+ custom_callbacks=None,
+ run_eagerly=run_eagerly)
+
+ @combinations.generate(eager_strategy_combinations())
+ def test_train_eager_single_step(self, distribution):
+ model_dir = self.get_temp_dir()
+ if isinstance(distribution, tf.distribute.experimental.TPUStrategy):
+ with self.assertRaises(ValueError):
+ self.run_training(
+ distribution, model_dir, steps_per_loop=1, run_eagerly=True)
+ else:
+ self.run_training(
+ distribution, model_dir, steps_per_loop=1, run_eagerly=True)
+
+ @combinations.generate(eager_gpu_strategy_combinations())
+ def test_train_eager_mixed_precision(self, distribution):
+ model_dir = self.get_temp_dir()
+ policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
+ tf.keras.mixed_precision.experimental.set_policy(policy)
+ self._model_fn = create_model_fn(
+ input_shape=[128], num_classes=3, use_float16=True)
+ self.run_training(
+ distribution, model_dir, steps_per_loop=1, run_eagerly=True)
+
+ @combinations.generate(eager_strategy_combinations())
+ def test_train_check_artifacts(self, distribution):
+ model_dir = self.get_temp_dir()
+ self.run_training(
+ distribution, model_dir, steps_per_loop=10, run_eagerly=False)
+
+ # Two checkpoints should be saved after two epochs.
+ files = map(os.path.basename,
+ tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*index')))
+ self.assertCountEqual(['ctl_step_20.ckpt-1.index',
+ 'ctl_step_40.ckpt-2.index'], files)
+
+ # Three submodel checkpoints should be saved after two epochs (one after
+ # each epoch plus one final).
+ files = map(os.path.basename,
+ tf.io.gfile.glob(os.path.join(model_dir,
+ 'my_submodel_name*index')))
+ self.assertCountEqual(['my_submodel_name.ckpt-3.index',
+ 'my_submodel_name_step_20.ckpt-1.index',
+ 'my_submodel_name_step_40.ckpt-2.index'], files)
+
+ self.assertNotEmpty(
+ tf.io.gfile.glob(
+ os.path.join(model_dir, 'summaries/training_summary*')))
+
+ # Loss and accuracy values should be written into summaries.
+ self.assertTrue(
+ check_eventfile_for_keyword('loss',
+ os.path.join(model_dir, 'summaries/train')))
+ self.assertTrue(
+ check_eventfile_for_keyword('accuracy',
+ os.path.join(model_dir, 'summaries/train')))
+ self.assertTrue(
+ check_eventfile_for_keyword('mean_input',
+ os.path.join(model_dir, 'summaries/train')))
+ self.assertTrue(
+ check_eventfile_for_keyword('accuracy',
+ os.path.join(model_dir, 'summaries/eval')))
+ self.assertTrue(
+ check_eventfile_for_keyword('mean_input',
+ os.path.join(model_dir, 'summaries/eval')))
+
+ @combinations.generate(eager_strategy_combinations())
+ def test_train_check_callbacks(self, distribution):
+ model_dir = self.get_temp_dir()
+ callback = RecordingCallback()
+ callbacks = [callback]
+ input_fn = create_fake_data_input_fn(
+ batch_size=8, features_shape=[128], num_classes=3)
+ model_training_utils.run_customized_training_loop(
+ strategy=distribution,
+ model_fn=self._model_fn,
+ loss_fn=tf.keras.losses.categorical_crossentropy,
+ model_dir=model_dir,
+ steps_per_epoch=20,
+ num_eval_per_epoch=4,
+ steps_per_loop=10,
+ epochs=2,
+ train_input_fn=input_fn,
+ eval_input_fn=input_fn,
+ eval_steps=10,
+ init_checkpoint=None,
+ metric_fn=metric_fn,
+ custom_callbacks=callbacks,
+ run_eagerly=False)
+ self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})])
+ epoch_ends, epoch_end_infos = zip(*callback.epoch_end)
+ self.assertEqual(list(epoch_ends), [1, 2, 2])
+ for info in epoch_end_infos:
+ self.assertIn('accuracy', info)
+
+ self.assertEqual(callback.batch_begin, [(0, {}), (5, {}), (10, {}),
+ (15, {}), (20, {}), (25, {}),
+ (30, {}), (35, {})])
+ batch_ends, batch_end_infos = zip(*callback.batch_end)
+ self.assertEqual(list(batch_ends), [4, 9, 14, 19, 24, 29, 34, 39])
+ for info in batch_end_infos:
+ self.assertIn('loss', info)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ mode='eager',
+ ))
+ def test_train_check_artifacts_non_chief(self, distribution):
+ # We shouldn't export artifacts on non-chief workers. Since there's no easy
+ # way to test with real MultiWorkerMirroredStrategy, we patch the strategy
+ # to make it as if it's MultiWorkerMirroredStrategy on non-chief workers.
+ extended = distribution.extended
+ with mock.patch.object(extended.__class__, 'should_checkpoint',
+ new_callable=mock.PropertyMock, return_value=False), \
+ mock.patch.object(extended.__class__, 'should_save_summary',
+ new_callable=mock.PropertyMock, return_value=False):
+ model_dir = self.get_temp_dir()
+ self.run_training(
+ distribution, model_dir, steps_per_loop=10, run_eagerly=False)
+ self.assertEmpty(tf.io.gfile.listdir(model_dir))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/bert/run_classifier.py b/models/official/nlp/bert/run_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2eb525ae4335091c78eb4ead72494f8021a7f89
--- /dev/null
+++ b/models/official/nlp/bert/run_classifier.py
@@ -0,0 +1,497 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BERT classification or regression finetuning runner in TF 2.x."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import json
+import math
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import gin
+import tensorflow as tf
+from official.modeling import performance
+from official.nlp import optimization
+from official.nlp.bert import bert_models
+from official.nlp.bert import common_flags
+from official.nlp.bert import configs as bert_configs
+from official.nlp.bert import input_pipeline
+from official.nlp.bert import model_saving_utils
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+
+flags.DEFINE_enum(
+ 'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
+ 'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
+ 'trains the model and evaluates in the meantime. '
+ '`export_only`: will take the latest checkpoint inside '
+ 'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
+ 'restores the model to output predictions on the test set.')
+flags.DEFINE_string('train_data_path', None,
+ 'Path to training data for BERT classifier.')
+flags.DEFINE_string('eval_data_path', None,
+ 'Path to evaluation data for BERT classifier.')
+flags.DEFINE_string(
+ 'input_meta_data_path', None,
+ 'Path to file that contains meta data about input '
+ 'to be used for training and evaluation.')
+flags.DEFINE_string('predict_checkpoint_path', None,
+ 'Path to the checkpoint for predictions.')
+flags.DEFINE_integer(
+ 'num_eval_per_epoch', 1,
+ 'Number of evaluations per epoch. The purpose of this flag is to provide '
+ 'more granular evaluation scores and checkpoints. For example, if original '
+ 'data has N samples and num_eval_per_epoch is n, then each epoch will be '
+ 'evaluated every N/n samples.')
+flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
+flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
+
+common_flags.define_common_bert_flags()
+
+FLAGS = flags.FLAGS
+
+LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
+
+
+def get_loss_fn(num_classes):
+ """Gets the classification loss function."""
+
+ def classification_loss_fn(labels, logits):
+ """Classification loss."""
+ labels = tf.squeeze(labels)
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
+ one_hot_labels = tf.one_hot(
+ tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
+ per_example_loss = -tf.reduce_sum(
+ tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
+ return tf.reduce_mean(per_example_loss)
+
+ return classification_loss_fn
+
+
+def get_dataset_fn(input_file_pattern,
+ max_seq_length,
+ global_batch_size,
+ is_training,
+ label_type=tf.int64,
+ include_sample_weights=False):
+ """Gets a closure to create a dataset."""
+
+ def _dataset_fn(ctx=None):
+ """Returns tf.data.Dataset for distributed BERT pretraining."""
+ batch_size = ctx.get_per_replica_batch_size(
+ global_batch_size) if ctx else global_batch_size
+ dataset = input_pipeline.create_classifier_dataset(
+ tf.io.gfile.glob(input_file_pattern),
+ max_seq_length,
+ batch_size,
+ is_training=is_training,
+ input_pipeline_context=ctx,
+ label_type=label_type,
+ include_sample_weights=include_sample_weights)
+ return dataset
+
+ return _dataset_fn
+
+
+def run_bert_classifier(strategy,
+ bert_config,
+ input_meta_data,
+ model_dir,
+ epochs,
+ steps_per_epoch,
+ steps_per_loop,
+ eval_steps,
+ warmup_steps,
+ initial_lr,
+ init_checkpoint,
+ train_input_fn,
+ eval_input_fn,
+ training_callbacks=True,
+ custom_callbacks=None,
+ custom_metrics=None):
+ """Run BERT classifier training using low-level API."""
+ max_seq_length = input_meta_data['max_seq_length']
+ num_classes = input_meta_data.get('num_labels', 1)
+ is_regression = num_classes == 1
+
+ def _get_classifier_model():
+ """Gets a classifier model."""
+ classifier_model, core_model = (
+ bert_models.classifier_model(
+ bert_config,
+ num_classes,
+ max_seq_length,
+ hub_module_url=FLAGS.hub_module_url,
+ hub_module_trainable=FLAGS.hub_module_trainable))
+ optimizer = optimization.create_optimizer(initial_lr,
+ steps_per_epoch * epochs,
+ warmup_steps, FLAGS.end_lr,
+ FLAGS.optimizer_type)
+ classifier_model.optimizer = performance.configure_optimizer(
+ optimizer,
+ use_float16=common_flags.use_float16(),
+ use_graph_rewrite=common_flags.use_graph_rewrite())
+ return classifier_model, core_model
+
+ # tf.keras.losses objects accept optional sample_weight arguments (eg. coming
+ # from the dataset) to compute weighted loss, as used for the regression
+ # tasks. The classification tasks, using the custom get_loss_fn don't accept
+ # sample weights though.
+ loss_fn = (tf.keras.losses.MeanSquaredError() if is_regression
+ else get_loss_fn(num_classes))
+
+ # Defines evaluation metrics function, which will create metrics in the
+ # correct device and strategy scope.
+ if custom_metrics:
+ metric_fn = custom_metrics
+ elif is_regression:
+ metric_fn = functools.partial(
+ tf.keras.metrics.MeanSquaredError,
+ 'mean_squared_error',
+ dtype=tf.float32)
+ else:
+ metric_fn = functools.partial(
+ tf.keras.metrics.SparseCategoricalAccuracy,
+ 'accuracy',
+ dtype=tf.float32)
+
+ # Start training using Keras compile/fit API.
+ logging.info('Training using TF 2.x Keras compile/fit API with '
+ 'distribution strategy.')
+ return run_keras_compile_fit(
+ model_dir,
+ strategy,
+ _get_classifier_model,
+ train_input_fn,
+ eval_input_fn,
+ loss_fn,
+ metric_fn,
+ init_checkpoint,
+ epochs,
+ steps_per_epoch,
+ steps_per_loop,
+ eval_steps,
+ training_callbacks=training_callbacks,
+ custom_callbacks=custom_callbacks)
+
+
+def run_keras_compile_fit(model_dir,
+ strategy,
+ model_fn,
+ train_input_fn,
+ eval_input_fn,
+ loss_fn,
+ metric_fn,
+ init_checkpoint,
+ epochs,
+ steps_per_epoch,
+ steps_per_loop,
+ eval_steps,
+ training_callbacks=True,
+ custom_callbacks=None):
+ """Runs BERT classifier model using Keras compile/fit API."""
+
+ with strategy.scope():
+ training_dataset = train_input_fn()
+ evaluation_dataset = eval_input_fn() if eval_input_fn else None
+ bert_model, sub_model = model_fn()
+ optimizer = bert_model.optimizer
+
+ if init_checkpoint:
+ checkpoint = tf.train.Checkpoint(model=sub_model)
+ checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
+
+ if not isinstance(metric_fn, (list, tuple)):
+ metric_fn = [metric_fn]
+ bert_model.compile(
+ optimizer=optimizer,
+ loss=loss_fn,
+ metrics=[fn() for fn in metric_fn],
+ experimental_steps_per_execution=steps_per_loop)
+
+ summary_dir = os.path.join(model_dir, 'summaries')
+ summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
+ checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ directory=model_dir,
+ max_to_keep=None,
+ step_counter=optimizer.iterations,
+ checkpoint_interval=0)
+ checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
+
+ if training_callbacks:
+ if custom_callbacks is not None:
+ custom_callbacks += [summary_callback, checkpoint_callback]
+ else:
+ custom_callbacks = [summary_callback, checkpoint_callback]
+
+ history = bert_model.fit(
+ x=training_dataset,
+ validation_data=evaluation_dataset,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ validation_steps=eval_steps,
+ callbacks=custom_callbacks)
+ stats = {'total_training_steps': steps_per_epoch * epochs}
+ if 'loss' in history.history:
+ stats['train_loss'] = history.history['loss'][-1]
+ if 'val_accuracy' in history.history:
+ stats['eval_metrics'] = history.history['val_accuracy'][-1]
+ return bert_model, stats
+
+
+def get_predictions_and_labels(strategy,
+ trained_model,
+ eval_input_fn,
+ return_probs=False):
+ """Obtains predictions of trained model on evaluation data.
+
+ Note that list of labels is returned along with the predictions because the
+ order changes on distributing dataset over TPU pods.
+
+ Args:
+ strategy: Distribution strategy.
+ trained_model: Trained model with preloaded weights.
+ eval_input_fn: Input function for evaluation data.
+ return_probs: Whether to return probabilities of classes.
+
+ Returns:
+ predictions: List of predictions.
+ labels: List of gold labels corresponding to predictions.
+ """
+
+ @tf.function
+ def test_step(iterator):
+ """Computes predictions on distributed devices."""
+
+ def _test_step_fn(inputs):
+ """Replicated predictions."""
+ inputs, labels = inputs
+ logits = trained_model(inputs, training=False)
+ probabilities = tf.nn.softmax(logits)
+ return probabilities, labels
+
+ outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
+ # outputs: current batch logits as a tuple of shard logits
+ outputs = tf.nest.map_structure(strategy.experimental_local_results,
+ outputs)
+ labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
+ return outputs, labels
+
+ def _run_evaluation(test_iterator):
+ """Runs evaluation steps."""
+ preds, golds = list(), list()
+ try:
+ with tf.experimental.async_scope():
+ while True:
+ probabilities, labels = test_step(test_iterator)
+ for cur_probs, cur_labels in zip(probabilities, labels):
+ if return_probs:
+ preds.extend(cur_probs.numpy().tolist())
+ else:
+ preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
+ golds.extend(cur_labels.numpy().tolist())
+ except (StopIteration, tf.errors.OutOfRangeError):
+ tf.experimental.async_clear_error()
+ return preds, golds
+
+ test_iter = iter(
+ strategy.experimental_distribute_datasets_from_function(eval_input_fn))
+ predictions, labels = _run_evaluation(test_iter)
+
+ return predictions, labels
+
+
+def export_classifier(model_export_path, input_meta_data, bert_config,
+ model_dir):
+ """Exports a trained model as a `SavedModel` for inference.
+
+ Args:
+ model_export_path: a string specifying the path to the SavedModel directory.
+ input_meta_data: dictionary containing meta data about input and model.
+ bert_config: Bert configuration file to define core bert layers.
+ model_dir: The directory where the model weights and training/evaluation
+ summaries are stored.
+
+ Raises:
+ Export path is not specified, got an empty string or None.
+ """
+ if not model_export_path:
+ raise ValueError('Export path is not specified: %s' % model_export_path)
+ if not model_dir:
+ raise ValueError('Export path is not specified: %s' % model_dir)
+
+ # Export uses float32 for now, even if training uses mixed precision.
+ tf.keras.mixed_precision.experimental.set_policy('float32')
+ classifier_model = bert_models.classifier_model(
+ bert_config, input_meta_data.get('num_labels', 1))[0]
+
+ model_saving_utils.export_bert_model(
+ model_export_path, model=classifier_model, checkpoint_dir=model_dir)
+
+
+def run_bert(strategy,
+ input_meta_data,
+ model_config,
+ train_input_fn=None,
+ eval_input_fn=None,
+ init_checkpoint=None,
+ custom_callbacks=None,
+ custom_metrics=None):
+ """Run BERT training."""
+ # Enables XLA in Session Config. Should not be set for TPU.
+ keras_utils.set_session_config(FLAGS.enable_xla)
+ performance.set_mixed_precision_policy(common_flags.dtype())
+
+ epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
+ train_data_size = (
+ input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
+ steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
+ warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
+ eval_steps = int(
+ math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
+
+ if not strategy:
+ raise ValueError('Distribution strategy has not been specified.')
+
+ if not custom_callbacks:
+ custom_callbacks = []
+
+ if FLAGS.log_steps:
+ custom_callbacks.append(
+ keras_utils.TimeHistory(
+ batch_size=FLAGS.train_batch_size,
+ log_steps=FLAGS.log_steps,
+ logdir=FLAGS.model_dir))
+
+ trained_model, _ = run_bert_classifier(
+ strategy,
+ model_config,
+ input_meta_data,
+ FLAGS.model_dir,
+ epochs,
+ steps_per_epoch,
+ FLAGS.steps_per_loop,
+ eval_steps,
+ warmup_steps,
+ FLAGS.learning_rate,
+ init_checkpoint or FLAGS.init_checkpoint,
+ train_input_fn,
+ eval_input_fn,
+ custom_callbacks=custom_callbacks,
+ custom_metrics=custom_metrics)
+
+ if FLAGS.model_export_path:
+ model_saving_utils.export_bert_model(
+ FLAGS.model_export_path, model=trained_model)
+ return trained_model
+
+
+def custom_main(custom_callbacks=None, custom_metrics=None):
+ """Run classification or regression.
+
+ Args:
+ custom_callbacks: list of tf.keras.Callbacks passed to training loop.
+ custom_metrics: list of metrics passed to the training loop.
+ """
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
+
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
+ input_meta_data = json.loads(reader.read().decode('utf-8'))
+ label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
+ include_sample_weights = input_meta_data.get('has_sample_weights', False)
+
+ if not FLAGS.model_dir:
+ FLAGS.model_dir = '/tmp/bert20/'
+
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+
+ if FLAGS.mode == 'export_only':
+ export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
+ FLAGS.model_dir)
+ return
+
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ tpu_address=FLAGS.tpu)
+ eval_input_fn = get_dataset_fn(
+ FLAGS.eval_data_path,
+ input_meta_data['max_seq_length'],
+ FLAGS.eval_batch_size,
+ is_training=False,
+ label_type=label_type,
+ include_sample_weights=include_sample_weights)
+
+ if FLAGS.mode == 'predict':
+ with strategy.scope():
+ classifier_model = bert_models.classifier_model(
+ bert_config, input_meta_data['num_labels'])[0]
+ checkpoint = tf.train.Checkpoint(model=classifier_model)
+ latest_checkpoint_file = (
+ FLAGS.predict_checkpoint_path or
+ tf.train.latest_checkpoint(FLAGS.model_dir))
+ assert latest_checkpoint_file
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint_file)
+ checkpoint.restore(
+ latest_checkpoint_file).assert_existing_objects_matched()
+ preds, _ = get_predictions_and_labels(
+ strategy, classifier_model, eval_input_fn, return_probs=True)
+ output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
+ with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
+ logging.info('***** Predict results *****')
+ for probabilities in preds:
+ output_line = '\t'.join(
+ str(class_probability)
+ for class_probability in probabilities) + '\n'
+ writer.write(output_line)
+ return
+
+ if FLAGS.mode != 'train_and_eval':
+ raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
+ train_input_fn = get_dataset_fn(
+ FLAGS.train_data_path,
+ input_meta_data['max_seq_length'],
+ FLAGS.train_batch_size,
+ is_training=True,
+ label_type=label_type,
+ include_sample_weights=include_sample_weights)
+ run_bert(
+ strategy,
+ input_meta_data,
+ bert_config,
+ train_input_fn,
+ eval_input_fn,
+ custom_callbacks=custom_callbacks,
+ custom_metrics=custom_metrics)
+
+
+def main(_):
+ custom_main(custom_callbacks=None, custom_metrics=None)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('bert_config_file')
+ flags.mark_flag_as_required('input_meta_data_path')
+ flags.mark_flag_as_required('model_dir')
+ app.run(main)
diff --git a/models/official/nlp/bert/run_pretraining.py b/models/official/nlp/bert/run_pretraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..44a18fea0ce9d79bea61294e91f0ac00c2ea45e6
--- /dev/null
+++ b/models/official/nlp/bert/run_pretraining.py
@@ -0,0 +1,197 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Run masked LM/next sentence pre-training for BERT in TF 2.x."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import app
+from absl import flags
+from absl import logging
+import gin
+import tensorflow as tf
+from official.modeling import performance
+from official.nlp import optimization
+from official.nlp.bert import bert_models
+from official.nlp.bert import common_flags
+from official.nlp.bert import configs
+from official.nlp.bert import input_pipeline
+from official.nlp.bert import model_training_utils
+from official.utils.misc import distribution_utils
+
+
+flags.DEFINE_string('input_files', None,
+ 'File path to retrieve training data for pre-training.')
+# Model training specific flags.
+flags.DEFINE_integer(
+ 'max_seq_length', 128,
+ 'The maximum total input sequence length after WordPiece tokenization. '
+ 'Sequences longer than this will be truncated, and sequences shorter '
+ 'than this will be padded.')
+flags.DEFINE_integer('max_predictions_per_seq', 20,
+ 'Maximum predictions per sequence_output.')
+flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
+flags.DEFINE_integer('num_steps_per_epoch', 1000,
+ 'Total number of training steps to run per epoch.')
+flags.DEFINE_float('warmup_steps', 10000,
+ 'Warmup steps for Adam weight decay optimizer.')
+flags.DEFINE_bool('use_next_sentence_label', True,
+ 'Whether to use next sentence label to compute final loss.')
+flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training '
+ 'summaries. If the value is a negative number, '
+ 'then training summaries are not enabled.')
+
+common_flags.define_common_bert_flags()
+
+FLAGS = flags.FLAGS
+
+
+def get_pretrain_dataset_fn(input_file_pattern, seq_length,
+ max_predictions_per_seq, global_batch_size,
+ use_next_sentence_label=True):
+ """Returns input dataset from input file string."""
+ def _dataset_fn(ctx=None):
+ """Returns tf.data.Dataset for distributed BERT pretraining."""
+ input_patterns = input_file_pattern.split(',')
+ batch_size = ctx.get_per_replica_batch_size(global_batch_size)
+ train_dataset = input_pipeline.create_pretrain_dataset(
+ input_patterns,
+ seq_length,
+ max_predictions_per_seq,
+ batch_size,
+ is_training=True,
+ input_pipeline_context=ctx,
+ use_next_sentence_label=use_next_sentence_label)
+ return train_dataset
+
+ return _dataset_fn
+
+
+def get_loss_fn():
+ """Returns loss function for BERT pretraining."""
+
+ def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
+ return tf.reduce_mean(losses)
+
+ return _bert_pretrain_loss_fn
+
+
+def run_customized_training(strategy,
+ bert_config,
+ init_checkpoint,
+ max_seq_length,
+ max_predictions_per_seq,
+ model_dir,
+ steps_per_epoch,
+ steps_per_loop,
+ epochs,
+ initial_lr,
+ warmup_steps,
+ end_lr,
+ optimizer_type,
+ input_files,
+ train_batch_size,
+ use_next_sentence_label=True,
+ train_summary_interval=0,
+ custom_callbacks=None):
+ """Run BERT pretrain model training using low-level API."""
+
+ train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
+ max_predictions_per_seq,
+ train_batch_size,
+ use_next_sentence_label)
+
+ def _get_pretrain_model():
+ """Gets a pretraining model."""
+ pretrain_model, core_model = bert_models.pretrain_model(
+ bert_config, max_seq_length, max_predictions_per_seq,
+ use_next_sentence_label=use_next_sentence_label)
+ optimizer = optimization.create_optimizer(
+ initial_lr, steps_per_epoch * epochs, warmup_steps,
+ end_lr, optimizer_type)
+ pretrain_model.optimizer = performance.configure_optimizer(
+ optimizer,
+ use_float16=common_flags.use_float16(),
+ use_graph_rewrite=common_flags.use_graph_rewrite())
+ return pretrain_model, core_model
+
+ trained_model = model_training_utils.run_customized_training_loop(
+ strategy=strategy,
+ model_fn=_get_pretrain_model,
+ loss_fn=get_loss_fn(),
+ scale_loss=FLAGS.scale_loss,
+ model_dir=model_dir,
+ init_checkpoint=init_checkpoint,
+ train_input_fn=train_input_fn,
+ steps_per_epoch=steps_per_epoch,
+ steps_per_loop=steps_per_loop,
+ epochs=epochs,
+ sub_model_export_name='pretrained/bert_model',
+ train_summary_interval=train_summary_interval,
+ custom_callbacks=custom_callbacks)
+
+ return trained_model
+
+
+def run_bert_pretrain(strategy, custom_callbacks=None):
+ """Runs BERT pre-training."""
+
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ if not strategy:
+ raise ValueError('Distribution strategy is not specified.')
+
+ # Runs customized training loop.
+ logging.info('Training using customized training loop TF 2.0 with distributed'
+ 'strategy.')
+
+ performance.set_mixed_precision_policy(common_flags.dtype())
+
+ return run_customized_training(
+ strategy,
+ bert_config,
+ FLAGS.init_checkpoint, # Used to initialize only the BERT submodel.
+ FLAGS.max_seq_length,
+ FLAGS.max_predictions_per_seq,
+ FLAGS.model_dir,
+ FLAGS.num_steps_per_epoch,
+ FLAGS.steps_per_loop,
+ FLAGS.num_train_epochs,
+ FLAGS.learning_rate,
+ FLAGS.warmup_steps,
+ FLAGS.end_lr,
+ FLAGS.optimizer_type,
+ FLAGS.input_files,
+ FLAGS.train_batch_size,
+ FLAGS.use_next_sentence_label,
+ FLAGS.train_summary_interval,
+ custom_callbacks=custom_callbacks)
+
+
+def main(_):
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
+ if not FLAGS.model_dir:
+ FLAGS.model_dir = '/tmp/bert20/'
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ tpu_address=FLAGS.tpu)
+ if strategy:
+ print('***** Number of cores used : ', strategy.num_replicas_in_sync)
+
+ run_bert_pretrain(strategy)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/models/official/nlp/bert/run_squad.py b/models/official/nlp/bert/run_squad.py
new file mode 100644
index 0000000000000000000000000000000000000000..b12925cfaad2337c28483325c5f942df651add62
--- /dev/null
+++ b/models/official/nlp/bert/run_squad.py
@@ -0,0 +1,153 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import time
+
+from absl import app
+from absl import flags
+from absl import logging
+import gin
+import tensorflow as tf
+
+from official.nlp.bert import configs as bert_configs
+from official.nlp.bert import run_squad_helper
+from official.nlp.bert import tokenization
+from official.nlp.data import squad_lib as squad_lib_wp
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+
+
+flags.DEFINE_string('vocab_file', None,
+ 'The vocabulary file that the BERT model was trained on.')
+
+# More flags can be found in run_squad_helper.
+run_squad_helper.define_common_squad_flags()
+
+FLAGS = flags.FLAGS
+
+
+def train_squad(strategy,
+ input_meta_data,
+ custom_callbacks=None,
+ run_eagerly=False,
+ init_checkpoint=None,
+ sub_model_export_name=None):
+ """Run bert squad training."""
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
+ run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
+ custom_callbacks, run_eagerly, init_checkpoint,
+ sub_model_export_name=sub_model_export_name)
+
+
+def predict_squad(strategy, input_meta_data):
+ """Makes predictions for the squad dataset."""
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ run_squad_helper.predict_squad(
+ strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
+
+
+def eval_squad(strategy, input_meta_data):
+ """Evaluate on the squad dataset."""
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ eval_metrics = run_squad_helper.eval_squad(
+ strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
+ return eval_metrics
+
+
+def export_squad(model_export_path, input_meta_data):
+ """Exports a trained model as a `SavedModel` for inference.
+
+ Args:
+ model_export_path: a string specifying the path to the SavedModel directory.
+ input_meta_data: dictionary containing meta data about input and model.
+
+ Raises:
+ Export path is not specified, got an empty string or None.
+ """
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
+
+
+def main(_):
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
+
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
+ input_meta_data = json.loads(reader.read().decode('utf-8'))
+
+ if FLAGS.mode == 'export_only':
+ export_squad(FLAGS.model_export_path, input_meta_data)
+ return
+
+ # Configures cluster spec for multi-worker distribution strategy.
+ if FLAGS.num_gpus > 0:
+ _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
+ FLAGS.task_index)
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ all_reduce_alg=FLAGS.all_reduce_alg,
+ tpu_address=FLAGS.tpu)
+
+ if 'train' in FLAGS.mode:
+ if FLAGS.log_steps:
+ custom_callbacks = [keras_utils.TimeHistory(
+ batch_size=FLAGS.train_batch_size,
+ log_steps=FLAGS.log_steps,
+ logdir=FLAGS.model_dir,
+ )]
+ else:
+ custom_callbacks = None
+
+ train_squad(
+ strategy,
+ input_meta_data,
+ custom_callbacks=custom_callbacks,
+ run_eagerly=FLAGS.run_eagerly,
+ sub_model_export_name=FLAGS.sub_model_export_name,
+ )
+ if 'predict' in FLAGS.mode:
+ predict_squad(strategy, input_meta_data)
+ if 'eval' in FLAGS.mode:
+ eval_metrics = eval_squad(strategy, input_meta_data)
+ f1_score = eval_metrics['final_f1']
+ logging.info('SQuAD eval F1-score: %f', f1_score)
+ summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
+ summary_writer = tf.summary.create_file_writer(summary_dir)
+ with summary_writer.as_default():
+ # TODO(lehou): write to the correct step number.
+ tf.summary.scalar('F1-score', f1_score, step=0)
+ summary_writer.flush()
+ # Also write eval_metrics to json file.
+ squad_lib_wp.write_to_json_files(
+ eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
+ time.sleep(60)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('bert_config_file')
+ flags.mark_flag_as_required('model_dir')
+ app.run(main)
diff --git a/models/official/nlp/bert/run_squad_helper.py b/models/official/nlp/bert/run_squad_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b03e356d91bdf6a9edf9486f505526852c6c7ef6
--- /dev/null
+++ b/models/official/nlp/bert/run_squad_helper.py
@@ -0,0 +1,481 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import json
+import os
+from absl import flags
+from absl import logging
+import tensorflow as tf
+from official.modeling import performance
+from official.nlp import optimization
+from official.nlp.bert import bert_models
+from official.nlp.bert import common_flags
+from official.nlp.bert import input_pipeline
+from official.nlp.bert import model_saving_utils
+from official.nlp.bert import model_training_utils
+from official.nlp.bert import squad_evaluate_v1_1
+from official.nlp.bert import squad_evaluate_v2_0
+from official.nlp.data import squad_lib_sp
+from official.utils.misc import keras_utils
+
+
+def define_common_squad_flags():
+ """Defines common flags used by SQuAD tasks."""
+ flags.DEFINE_enum(
+ 'mode', 'train_and_eval',
+ ['train_and_eval', 'train_and_predict',
+ 'train', 'eval', 'predict', 'export_only'],
+ 'One of {"train_and_eval", "train_and_predict", '
+ '"train", "eval", "predict", "export_only"}. '
+ '`train_and_eval`: train & predict to json files & compute eval metrics. '
+ '`train_and_predict`: train & predict to json files. '
+ '`train`: only trains the model. '
+ '`eval`: predict answers from squad json file & compute eval metrics. '
+ '`predict`: predict answers from the squad json file. '
+ '`export_only`: will take the latest checkpoint inside '
+ 'model_dir and export a `SavedModel`.')
+ flags.DEFINE_string('train_data_path', '',
+ 'Training data path with train tfrecords.')
+ flags.DEFINE_string(
+ 'input_meta_data_path', None,
+ 'Path to file that contains meta data about input '
+ 'to be used for training and evaluation.')
+ # Model training specific flags.
+ flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
+ # Predict processing related.
+ flags.DEFINE_string('predict_file', None,
+ 'SQuAD prediction json file path. '
+ '`predict` mode supports multiple files: one can use '
+ 'wildcard to specify multiple files and it can also be '
+ 'multiple file patterns separated by comma. Note that '
+ '`eval` mode only supports a single predict file.')
+ flags.DEFINE_bool(
+ 'do_lower_case', True,
+ 'Whether to lower case the input text. Should be True for uncased '
+ 'models and False for cased models.')
+ flags.DEFINE_float(
+ 'null_score_diff_threshold', 0.0,
+ 'If null_score - best_non_null is greater than the threshold, '
+ 'predict null. This is only used for SQuAD v2.')
+ flags.DEFINE_bool(
+ 'verbose_logging', False,
+ 'If true, all of the warnings related to data processing will be '
+ 'printed. A number of warnings are expected for a normal SQuAD '
+ 'evaluation.')
+ flags.DEFINE_integer('predict_batch_size', 8,
+ 'Total batch size for prediction.')
+ flags.DEFINE_integer(
+ 'n_best_size', 20,
+ 'The total number of n-best predictions to generate in the '
+ 'nbest_predictions.json output file.')
+ flags.DEFINE_integer(
+ 'max_answer_length', 30,
+ 'The maximum length of an answer that can be generated. This is needed '
+ 'because the start and end predictions are not conditioned on one '
+ 'another.')
+
+ common_flags.define_common_bert_flags()
+
+
+FLAGS = flags.FLAGS
+
+
+def squad_loss_fn(start_positions,
+ end_positions,
+ start_logits,
+ end_logits):
+ """Returns sparse categorical crossentropy for start/end logits."""
+ start_loss = tf.keras.losses.sparse_categorical_crossentropy(
+ start_positions, start_logits, from_logits=True)
+ end_loss = tf.keras.losses.sparse_categorical_crossentropy(
+ end_positions, end_logits, from_logits=True)
+
+ total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
+ return total_loss
+
+
+def get_loss_fn():
+ """Gets a loss function for squad task."""
+
+ def _loss_fn(labels, model_outputs):
+ start_positions = labels['start_positions']
+ end_positions = labels['end_positions']
+ start_logits, end_logits = model_outputs
+ return squad_loss_fn(
+ start_positions,
+ end_positions,
+ start_logits,
+ end_logits)
+
+ return _loss_fn
+
+
+RawResult = collections.namedtuple('RawResult',
+ ['unique_id', 'start_logits', 'end_logits'])
+
+
+def get_raw_results(predictions):
+ """Converts multi-replica predictions to RawResult."""
+ for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
+ predictions['start_logits'],
+ predictions['end_logits']):
+ for values in zip(unique_ids.numpy(), start_logits.numpy(),
+ end_logits.numpy()):
+ yield RawResult(
+ unique_id=values[0],
+ start_logits=values[1].tolist(),
+ end_logits=values[2].tolist())
+
+
+def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
+ is_training):
+ """Gets a closure to create a dataset.."""
+
+ def _dataset_fn(ctx=None):
+ """Returns tf.data.Dataset for distributed BERT pretraining."""
+ batch_size = ctx.get_per_replica_batch_size(
+ global_batch_size) if ctx else global_batch_size
+ dataset = input_pipeline.create_squad_dataset(
+ input_file_pattern,
+ max_seq_length,
+ batch_size,
+ is_training=is_training,
+ input_pipeline_context=ctx)
+ return dataset
+
+ return _dataset_fn
+
+
+def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
+ input_meta_data):
+ """Gets a squad model to make predictions."""
+ with strategy.scope():
+ # Prediction always uses float32, even if training uses mixed precision.
+ tf.keras.mixed_precision.experimental.set_policy('float32')
+ squad_model, _ = bert_models.squad_model(
+ bert_config,
+ input_meta_data['max_seq_length'],
+ hub_module_url=FLAGS.hub_module_url)
+
+ if checkpoint_path is None:
+ checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
+ logging.info('Restoring checkpoints from %s', checkpoint_path)
+ checkpoint = tf.train.Checkpoint(model=squad_model)
+ checkpoint.restore(checkpoint_path).expect_partial()
+ return squad_model
+
+
+def predict_squad_customized(strategy,
+ input_meta_data,
+ predict_tfrecord_path,
+ num_steps,
+ squad_model):
+ """Make predictions using a Bert-based squad model."""
+ predict_dataset_fn = get_dataset_fn(
+ predict_tfrecord_path,
+ input_meta_data['max_seq_length'],
+ FLAGS.predict_batch_size,
+ is_training=False)
+ predict_iterator = iter(
+ strategy.experimental_distribute_datasets_from_function(
+ predict_dataset_fn))
+
+ @tf.function
+ def predict_step(iterator):
+ """Predicts on distributed devices."""
+
+ def _replicated_step(inputs):
+ """Replicated prediction calculation."""
+ x, _ = inputs
+ unique_ids = x.pop('unique_ids')
+ start_logits, end_logits = squad_model(x, training=False)
+ return dict(
+ unique_ids=unique_ids,
+ start_logits=start_logits,
+ end_logits=end_logits)
+
+ outputs = strategy.run(_replicated_step, args=(next(iterator),))
+ return tf.nest.map_structure(strategy.experimental_local_results, outputs)
+
+ all_results = []
+ for _ in range(num_steps):
+ predictions = predict_step(predict_iterator)
+ for result in get_raw_results(predictions):
+ all_results.append(result)
+ if len(all_results) % 100 == 0:
+ logging.info('Made predictions for %d records.', len(all_results))
+ return all_results
+
+
+def train_squad(strategy,
+ input_meta_data,
+ bert_config,
+ custom_callbacks=None,
+ run_eagerly=False,
+ init_checkpoint=None,
+ sub_model_export_name=None):
+ """Run bert squad training."""
+ if strategy:
+ logging.info('Training using customized training loop with distribution'
+ ' strategy.')
+ # Enables XLA in Session Config. Should not be set for TPU.
+ keras_utils.set_session_config(FLAGS.enable_xla)
+ performance.set_mixed_precision_policy(common_flags.dtype())
+
+ epochs = FLAGS.num_train_epochs
+ num_train_examples = input_meta_data['train_data_size']
+ max_seq_length = input_meta_data['max_seq_length']
+ steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
+ warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
+ train_input_fn = get_dataset_fn(
+ FLAGS.train_data_path,
+ max_seq_length,
+ FLAGS.train_batch_size,
+ is_training=True)
+
+ def _get_squad_model():
+ """Get Squad model and optimizer."""
+ squad_model, core_model = bert_models.squad_model(
+ bert_config,
+ max_seq_length,
+ hub_module_url=FLAGS.hub_module_url,
+ hub_module_trainable=FLAGS.hub_module_trainable)
+ optimizer = optimization.create_optimizer(FLAGS.learning_rate,
+ steps_per_epoch * epochs,
+ warmup_steps,
+ FLAGS.end_lr,
+ FLAGS.optimizer_type)
+
+ squad_model.optimizer = performance.configure_optimizer(
+ optimizer,
+ use_float16=common_flags.use_float16(),
+ use_graph_rewrite=common_flags.use_graph_rewrite())
+ return squad_model, core_model
+
+ # If explicit_allreduce = True, apply_gradients() no longer implicitly
+ # allreduce gradients, users manually allreduce gradient and pass the
+ # allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
+ # applied to allreduced gradients.
+ def clip_by_global_norm_callback(grads_and_vars):
+ grads, variables = zip(*grads_and_vars)
+ (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
+ return zip(clipped_grads, variables)
+
+ model_training_utils.run_customized_training_loop(
+ strategy=strategy,
+ model_fn=_get_squad_model,
+ loss_fn=get_loss_fn(),
+ model_dir=FLAGS.model_dir,
+ steps_per_epoch=steps_per_epoch,
+ steps_per_loop=FLAGS.steps_per_loop,
+ epochs=epochs,
+ train_input_fn=train_input_fn,
+ init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
+ sub_model_export_name=sub_model_export_name,
+ run_eagerly=run_eagerly,
+ custom_callbacks=custom_callbacks,
+ explicit_allreduce=False,
+ post_allreduce_callbacks=[clip_by_global_norm_callback])
+
+
+def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
+ predict_file, squad_model):
+ """Makes predictions for a squad dataset."""
+ doc_stride = input_meta_data['doc_stride']
+ max_query_length = input_meta_data['max_query_length']
+ # Whether data should be in Ver 2.0 format.
+ version_2_with_negative = input_meta_data.get('version_2_with_negative',
+ False)
+ eval_examples = squad_lib.read_squad_examples(
+ input_file=predict_file,
+ is_training=False,
+ version_2_with_negative=version_2_with_negative)
+
+ eval_writer = squad_lib.FeatureWriter(
+ filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
+ is_training=False)
+ eval_features = []
+
+ def _append_feature(feature, is_padding):
+ if not is_padding:
+ eval_features.append(feature)
+ eval_writer.process_feature(feature)
+
+ # TPU requires a fixed batch size for all batches, therefore the number
+ # of examples must be a multiple of the batch size, or else examples
+ # will get dropped. So we pad with fake examples which are ignored
+ # later on.
+ kwargs = dict(
+ examples=eval_examples,
+ tokenizer=tokenizer,
+ max_seq_length=input_meta_data['max_seq_length'],
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ is_training=False,
+ output_fn=_append_feature,
+ batch_size=FLAGS.predict_batch_size)
+
+ # squad_lib_sp requires one more argument 'do_lower_case'.
+ if squad_lib == squad_lib_sp:
+ kwargs['do_lower_case'] = FLAGS.do_lower_case
+ dataset_size = squad_lib.convert_examples_to_features(**kwargs)
+ eval_writer.close()
+
+ logging.info('***** Running predictions *****')
+ logging.info(' Num orig examples = %d', len(eval_examples))
+ logging.info(' Num split examples = %d', len(eval_features))
+ logging.info(' Batch size = %d', FLAGS.predict_batch_size)
+
+ num_steps = int(dataset_size / FLAGS.predict_batch_size)
+ all_results = predict_squad_customized(
+ strategy, input_meta_data, eval_writer.filename, num_steps, squad_model)
+
+ all_predictions, all_nbest_json, scores_diff_json = (
+ squad_lib.postprocess_output(
+ eval_examples,
+ eval_features,
+ all_results,
+ FLAGS.n_best_size,
+ FLAGS.max_answer_length,
+ FLAGS.do_lower_case,
+ version_2_with_negative=version_2_with_negative,
+ null_score_diff_threshold=FLAGS.null_score_diff_threshold,
+ verbose=FLAGS.verbose_logging))
+
+ return all_predictions, all_nbest_json, scores_diff_json
+
+
+def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
+ squad_lib, version_2_with_negative, file_prefix=''):
+ """Save output to json files."""
+ output_prediction_file = os.path.join(FLAGS.model_dir,
+ '%spredictions.json' % file_prefix)
+ output_nbest_file = os.path.join(FLAGS.model_dir,
+ '%snbest_predictions.json' % file_prefix)
+ output_null_log_odds_file = os.path.join(FLAGS.model_dir, file_prefix,
+ '%snull_odds.json' % file_prefix)
+ logging.info('Writing predictions to: %s', (output_prediction_file))
+ logging.info('Writing nbest to: %s', (output_nbest_file))
+
+ squad_lib.write_to_json_files(all_predictions, output_prediction_file)
+ squad_lib.write_to_json_files(all_nbest_json, output_nbest_file)
+ if version_2_with_negative:
+ squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
+
+
+def _get_matched_files(input_path):
+ """Returns all files that matches the input_path."""
+ input_patterns = input_path.strip().split(',')
+ all_matched_files = []
+ for input_pattern in input_patterns:
+ input_pattern = input_pattern.strip()
+ if not input_pattern:
+ continue
+ matched_files = tf.io.gfile.glob(input_pattern)
+ if not matched_files:
+ raise ValueError('%s does not match any files.' % input_pattern)
+ else:
+ all_matched_files.extend(matched_files)
+ return sorted(all_matched_files)
+
+
+def predict_squad(strategy,
+ input_meta_data,
+ tokenizer,
+ bert_config,
+ squad_lib,
+ init_checkpoint=None):
+ """Get prediction results and evaluate them to hard drive."""
+ if init_checkpoint is None:
+ init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
+
+ all_predict_files = _get_matched_files(FLAGS.predict_file)
+ squad_model = get_squad_model_to_predict(strategy, bert_config,
+ init_checkpoint, input_meta_data)
+ for idx, predict_file in enumerate(all_predict_files):
+ all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
+ strategy, input_meta_data, tokenizer, squad_lib, predict_file,
+ squad_model)
+ if len(all_predict_files) == 1:
+ file_prefix = ''
+ else:
+ # if predict_file is /path/xquad.ar.json, the `file_prefix` may be
+ # "xquad.ar-0-"
+ file_prefix = '%s-' % os.path.splitext(
+ os.path.basename(all_predict_files[idx]))[0]
+ dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
+ input_meta_data.get('version_2_with_negative', False),
+ file_prefix)
+
+
+def eval_squad(strategy,
+ input_meta_data,
+ tokenizer,
+ bert_config,
+ squad_lib,
+ init_checkpoint=None):
+ """Get prediction results and evaluate them against ground truth."""
+ if init_checkpoint is None:
+ init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
+
+ all_predict_files = _get_matched_files(FLAGS.predict_file)
+ if len(all_predict_files) != 1:
+ raise ValueError('`eval_squad` only supports one predict file, '
+ 'but got %s' % all_predict_files)
+
+ squad_model = get_squad_model_to_predict(strategy, bert_config,
+ init_checkpoint, input_meta_data)
+ all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
+ strategy, input_meta_data, tokenizer, squad_lib, all_predict_files[0],
+ squad_model)
+ dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
+ input_meta_data.get('version_2_with_negative', False))
+
+ with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
+ dataset_json = json.load(reader)
+ pred_dataset = dataset_json['data']
+ if input_meta_data.get('version_2_with_negative', False):
+ eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset,
+ all_predictions,
+ scores_diff_json)
+ else:
+ eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
+ return eval_metrics
+
+
+def export_squad(model_export_path, input_meta_data, bert_config):
+ """Exports a trained model as a `SavedModel` for inference.
+
+ Args:
+ model_export_path: a string specifying the path to the SavedModel directory.
+ input_meta_data: dictionary containing meta data about input and model.
+ bert_config: Bert configuration file to define core bert layers.
+
+ Raises:
+ Export path is not specified, got an empty string or None.
+ """
+ if not model_export_path:
+ raise ValueError('Export path is not specified: %s' % model_export_path)
+ # Export uses float32 for now, even if training uses mixed precision.
+ tf.keras.mixed_precision.experimental.set_policy('float32')
+ squad_model, _ = bert_models.squad_model(bert_config,
+ input_meta_data['max_seq_length'])
+ model_saving_utils.export_bert_model(
+ model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
diff --git a/models/official/nlp/bert/serving.py b/models/official/nlp/bert/serving.py
new file mode 100644
index 0000000000000000000000000000000000000000..895f61dc37adf40d93ea347817abbb18966e157e
--- /dev/null
+++ b/models/official/nlp/bert/serving.py
@@ -0,0 +1,134 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Examples of SavedModel export for tf-serving."""
+
+from absl import app
+from absl import flags
+import tensorflow as tf
+
+from official.nlp.bert import bert_models
+from official.nlp.bert import configs
+
+flags.DEFINE_integer("sequence_length", None,
+ "Sequence length to parse the tf.Example. If "
+ "sequence_length > 0, add a signature for serialized "
+ "tf.Example and define the parsing specification by the "
+ "sequence_length.")
+flags.DEFINE_string("bert_config_file", None,
+ "Bert configuration file to define core bert layers.")
+flags.DEFINE_string("model_checkpoint_path", None,
+ "File path to TF model checkpoint.")
+flags.DEFINE_string("export_path", None,
+ "Destination folder to export the serving SavedModel.")
+
+FLAGS = flags.FLAGS
+
+
+class BertServing(tf.keras.Model):
+ """Bert transformer encoder model for serving."""
+
+ def __init__(self, bert_config, name_to_features=None, name="serving_model"):
+ super(BertServing, self).__init__(name=name)
+ self.encoder = bert_models.get_transformer_encoder(
+ bert_config, sequence_length=None)
+ self.name_to_features = name_to_features
+
+ def call(self, inputs):
+ input_word_ids = inputs["input_ids"]
+ input_mask = inputs["input_mask"]
+ input_type_ids = inputs["segment_ids"]
+
+ encoder_outputs, _ = self.encoder(
+ [input_word_ids, input_mask, input_type_ids])
+ return encoder_outputs
+
+ def serve_body(self, input_ids, input_mask=None, segment_ids=None):
+ if segment_ids is None:
+ # Requires CLS token is the first token of inputs.
+ segment_ids = tf.zeros_like(input_ids)
+ if input_mask is None:
+ # The mask has 1 for real tokens and 0 for padding tokens.
+ input_mask = tf.where(
+ tf.equal(input_ids, 0), tf.zeros_like(input_ids),
+ tf.ones_like(input_ids))
+
+ inputs = dict(
+ input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids)
+ return self.call(inputs)
+
+ @tf.function
+ def serve(self, input_ids, input_mask=None, segment_ids=None):
+ outputs = self.serve_body(input_ids, input_mask, segment_ids)
+ # Returns a dictionary to control SignatureDef output signature.
+ return {"outputs": outputs[-1]}
+
+ @tf.function
+ def serve_examples(self, inputs):
+ features = tf.io.parse_example(inputs, self.name_to_features)
+ for key in list(features.keys()):
+ t = features[key]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ features[key] = t
+ return self.serve(
+ features["input_ids"],
+ input_mask=features["input_mask"] if "input_mask" in features else None,
+ segment_ids=features["segment_ids"]
+ if "segment_ids" in features else None)
+
+ @classmethod
+ def export(cls, model, export_dir):
+ if not isinstance(model, cls):
+ raise ValueError("Invalid model instance: %s, it should be a %s" %
+ (model, cls))
+
+ signatures = {
+ "serving_default":
+ model.serve.get_concrete_function(
+ input_ids=tf.TensorSpec(
+ shape=[None, None], dtype=tf.int32, name="inputs")),
+ }
+ if model.name_to_features:
+ signatures[
+ "serving_examples"] = model.serve_examples.get_concrete_function(
+ tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
+ tf.saved_model.save(model, export_dir=export_dir, signatures=signatures)
+
+
+def main(_):
+ sequence_length = FLAGS.sequence_length
+ if sequence_length is not None and sequence_length > 0:
+ name_to_features = {
+ "input_ids": tf.io.FixedLenFeature([sequence_length], tf.int64),
+ "input_mask": tf.io.FixedLenFeature([sequence_length], tf.int64),
+ "segment_ids": tf.io.FixedLenFeature([sequence_length], tf.int64),
+ }
+ else:
+ name_to_features = None
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ serving_model = BertServing(
+ bert_config=bert_config, name_to_features=name_to_features)
+ checkpoint = tf.train.Checkpoint(model=serving_model.encoder)
+ checkpoint.restore(FLAGS.model_checkpoint_path
+ ).assert_existing_objects_matched().run_restore_ops()
+ BertServing.export(serving_model, FLAGS.export_path)
+
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("bert_config_file")
+ flags.mark_flag_as_required("model_checkpoint_path")
+ flags.mark_flag_as_required("export_path")
+ app.run(main)
diff --git a/models/official/nlp/bert/squad_evaluate_v1_1.py b/models/official/nlp/bert/squad_evaluate_v1_1.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7f4f4de66813cb4fbdc59cc716911fac064f0c9
--- /dev/null
+++ b/models/official/nlp/bert/squad_evaluate_v1_1.py
@@ -0,0 +1,108 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Evaluation of SQuAD predictions (version 1.1).
+
+The functions are copied from
+https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
+
+The SQuAD dataset is described in this paper:
+SQuAD: 100,000+ Questions for Machine Comprehension of Text
+Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
+https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+import string
+
+# pylint: disable=g-bad-import-order
+from absl import logging
+# pylint: enable=g-bad-import-order
+
+
+def _normalize_answer(s):
+ """Lowers text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def _f1_score(prediction, ground_truth):
+ """Computes F1 score by comparing prediction to ground truth."""
+ prediction_tokens = _normalize_answer(prediction).split()
+ ground_truth_tokens = _normalize_answer(ground_truth).split()
+ prediction_counter = collections.Counter(prediction_tokens)
+ ground_truth_counter = collections.Counter(ground_truth_tokens)
+ common = prediction_counter & ground_truth_counter
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def _exact_match_score(prediction, ground_truth):
+ """Checks if predicted answer exactly matches ground truth answer."""
+ return _normalize_answer(prediction) == _normalize_answer(ground_truth)
+
+
+def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
+ """Computes the max over all metric scores."""
+ scores_for_ground_truths = []
+ for ground_truth in ground_truths:
+ score = metric_fn(prediction, ground_truth)
+ scores_for_ground_truths.append(score)
+ return max(scores_for_ground_truths)
+
+
+def evaluate(dataset, predictions):
+ """Evaluates predictions for a dataset."""
+ f1 = exact_match = total = 0
+ for article in dataset:
+ for paragraph in article["paragraphs"]:
+ for qa in paragraph["qas"]:
+ total += 1
+ if qa["id"] not in predictions:
+ message = "Unanswered question " + qa["id"] + " will receive score 0."
+ logging.error(message)
+ continue
+ ground_truths = [entry["text"] for entry in qa["answers"]]
+ prediction = predictions[qa["id"]]
+ exact_match += _metric_max_over_ground_truths(_exact_match_score,
+ prediction, ground_truths)
+ f1 += _metric_max_over_ground_truths(_f1_score, prediction,
+ ground_truths)
+
+ exact_match = exact_match / total
+ f1 = f1 / total
+
+ return {"exact_match": exact_match, "final_f1": f1}
diff --git a/models/official/nlp/bert/squad_evaluate_v2_0.py b/models/official/nlp/bert/squad_evaluate_v2_0.py
new file mode 100644
index 0000000000000000000000000000000000000000..54fb84e993c3459ffdd2b3d90f870e4d178ab54f
--- /dev/null
+++ b/models/official/nlp/bert/squad_evaluate_v2_0.py
@@ -0,0 +1,252 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Evaluation script for SQuAD version 2.0.
+
+The functions are copied and modified from
+https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
+
+In addition to basic functionality, we also compute additional statistics and
+plot precision-recall curves if an additional na_prob.json file is provided.
+This file is expected to map question ID's to the model's predicted probability
+that a question is unanswerable.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+import string
+
+from absl import logging
+
+
+def _make_qid_to_has_ans(dataset):
+ qid_to_has_ans = {}
+ for article in dataset:
+ for p in article['paragraphs']:
+ for qa in p['qas']:
+ qid_to_has_ans[qa['id']] = bool(qa['answers'])
+ return qid_to_has_ans
+
+
+def _normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+ def remove_articles(text):
+ regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
+ return re.sub(regex, ' ', text)
+ def white_space_fix(text):
+ return ' '.join(text.split())
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return ''.join(ch for ch in text if ch not in exclude)
+ def lower(text):
+ return text.lower()
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def _get_tokens(s):
+ if not s: return []
+ return _normalize_answer(s).split()
+
+
+def _compute_exact(a_gold, a_pred):
+ return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
+
+
+def _compute_f1(a_gold, a_pred):
+ """Compute F1-score."""
+ gold_toks = _get_tokens(a_gold)
+ pred_toks = _get_tokens(a_pred)
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
+ num_same = sum(common.values())
+ if not gold_toks or not pred_toks:
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
+ return int(gold_toks == pred_toks)
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(pred_toks)
+ recall = 1.0 * num_same / len(gold_toks)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def _get_raw_scores(dataset, predictions):
+ """Compute raw scores."""
+ exact_scores = {}
+ f1_scores = {}
+ for article in dataset:
+ for p in article['paragraphs']:
+ for qa in p['qas']:
+ qid = qa['id']
+ gold_answers = [a['text'] for a in qa['answers']
+ if _normalize_answer(a['text'])]
+ if not gold_answers:
+ # For unanswerable questions, only correct answer is empty string
+ gold_answers = ['']
+ if qid not in predictions:
+ logging.error('Missing prediction for %s', qid)
+ continue
+ a_pred = predictions[qid]
+ # Take max over all gold answers
+ exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
+ f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
+ return exact_scores, f1_scores
+
+
+def _apply_no_ans_threshold(
+ scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
+ new_scores = {}
+ for qid, s in scores.items():
+ pred_na = na_probs[qid] > na_prob_thresh
+ if pred_na:
+ new_scores[qid] = float(not qid_to_has_ans[qid])
+ else:
+ new_scores[qid] = s
+ return new_scores
+
+
+def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
+ """Make evaluation result dictionary."""
+ if not qid_list:
+ total = len(exact_scores)
+ return collections.OrderedDict([
+ ('exact', 100.0 * sum(exact_scores.values()) / total),
+ ('f1', 100.0 * sum(f1_scores.values()) / total),
+ ('total', total),
+ ])
+ else:
+ total = len(qid_list)
+ return collections.OrderedDict([
+ ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
+ ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
+ ('total', total),
+ ])
+
+
+def _merge_eval(main_eval, new_eval, prefix):
+ for k in new_eval:
+ main_eval['%s_%s' % (prefix, k)] = new_eval[k]
+
+
+def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
+ """Make evaluation dictionary containing average recision recall."""
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
+ true_pos = 0.0
+ cur_p = 1.0
+ cur_r = 0.0
+ precisions = [1.0]
+ recalls = [0.0]
+ avg_prec = 0.0
+ for i, qid in enumerate(qid_list):
+ if qid_to_has_ans[qid]:
+ true_pos += scores[qid]
+ cur_p = true_pos / float(i+1)
+ cur_r = true_pos / float(num_true_pos)
+ if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
+ # i.e., if we can put a threshold after this point
+ avg_prec += cur_p * (cur_r - recalls[-1])
+ precisions.append(cur_p)
+ recalls.append(cur_r)
+ return {'ap': 100.0 * avg_prec}
+
+
+def _run_precision_recall_analysis(
+ main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
+ """Run precision recall analysis and return result dictionary."""
+ num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
+ if num_true_pos == 0:
+ return
+ pr_exact = _make_precision_recall_eval(
+ exact_raw, na_probs, num_true_pos, qid_to_has_ans)
+ pr_f1 = _make_precision_recall_eval(
+ f1_raw, na_probs, num_true_pos, qid_to_has_ans)
+ oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
+ pr_oracle = _make_precision_recall_eval(
+ oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
+ _merge_eval(main_eval, pr_exact, 'pr_exact')
+ _merge_eval(main_eval, pr_f1, 'pr_f1')
+ _merge_eval(main_eval, pr_oracle, 'pr_oracle')
+
+
+def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
+ """Find the best threshold for no answer probability."""
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
+ cur_score = num_no_ans
+ best_score = cur_score
+ best_thresh = 0.0
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
+ for qid in qid_list:
+ if qid not in scores: continue
+ if qid_to_has_ans[qid]:
+ diff = scores[qid]
+ else:
+ if predictions[qid]:
+ diff = -1
+ else:
+ diff = 0
+ cur_score += diff
+ if cur_score > best_score:
+ best_score = cur_score
+ best_thresh = na_probs[qid]
+ return 100.0 * best_score / len(scores), best_thresh
+
+
+def _find_all_best_thresh(
+ main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
+ best_exact, exact_thresh = _find_best_thresh(
+ predictions, exact_raw, na_probs, qid_to_has_ans)
+ best_f1, f1_thresh = _find_best_thresh(
+ predictions, f1_raw, na_probs, qid_to_has_ans)
+ main_eval['final_exact'] = best_exact
+ main_eval['final_exact_thresh'] = exact_thresh
+ main_eval['final_f1'] = best_f1
+ main_eval['final_f1_thresh'] = f1_thresh
+
+
+def evaluate(dataset, predictions, na_probs=None):
+ """Evaluate prediction results."""
+ new_orig_data = []
+ for article in dataset:
+ for p in article['paragraphs']:
+ for qa in p['qas']:
+ if qa['id'] in predictions:
+ new_para = {'qas': [qa]}
+ new_article = {'paragraphs': [new_para]}
+ new_orig_data.append(new_article)
+ dataset = new_orig_data
+
+ if na_probs is None:
+ na_probs = {k: 0.0 for k in predictions}
+ qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
+ has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
+ no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
+ exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
+ exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
+ f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
+ out_eval = _make_eval_dict(exact_thresh, f1_thresh)
+ if has_ans_qids:
+ has_ans_eval = _make_eval_dict(
+ exact_thresh, f1_thresh, qid_list=has_ans_qids)
+ _merge_eval(out_eval, has_ans_eval, 'HasAns')
+ if no_ans_qids:
+ no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
+ _merge_eval(out_eval, no_ans_eval, 'NoAns')
+
+ _find_all_best_thresh(
+ out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
+ _run_precision_recall_analysis(
+ out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
+ return out_eval
diff --git a/models/official/nlp/bert/tf1_checkpoint_converter_lib.py b/models/official/nlp/bert/tf1_checkpoint_converter_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..122e455210ae70cd9af04912b95a600a3d23d09a
--- /dev/null
+++ b/models/official/nlp/bert/tf1_checkpoint_converter_lib.py
@@ -0,0 +1,195 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow.compat.v1 as tf # TF 1.x
+
+# Mapping between old <=> new names. The source pattern in original variable
+# name will be replaced by destination pattern.
+BERT_NAME_REPLACEMENTS = (
+ ("bert", "bert_model"),
+ ("embeddings/word_embeddings", "word_embeddings/embeddings"),
+ ("embeddings/token_type_embeddings",
+ "embedding_postprocessor/type_embeddings"),
+ ("embeddings/position_embeddings",
+ "embedding_postprocessor/position_embeddings"),
+ ("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
+ ("attention/self", "self_attention"),
+ ("attention/output/dense", "self_attention_output"),
+ ("attention/output/LayerNorm", "self_attention_layer_norm"),
+ ("intermediate/dense", "intermediate"),
+ ("output/dense", "output"),
+ ("output/LayerNorm", "output_layer_norm"),
+ ("pooler/dense", "pooler_transform"),
+)
+
+BERT_V2_NAME_REPLACEMENTS = (
+ ("bert/", ""),
+ ("encoder", "transformer"),
+ ("embeddings/word_embeddings", "word_embeddings/embeddings"),
+ ("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
+ ("embeddings/position_embeddings", "position_embedding/embeddings"),
+ ("embeddings/LayerNorm", "embeddings/layer_norm"),
+ ("attention/self", "self_attention"),
+ ("attention/output/dense", "self_attention/attention_output"),
+ ("attention/output/LayerNorm", "self_attention_layer_norm"),
+ ("intermediate/dense", "intermediate"),
+ ("output/dense", "output"),
+ ("output/LayerNorm", "output_layer_norm"),
+ ("pooler/dense", "pooler_transform"),
+ ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
+ ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
+ ("cls/seq_relationship/output_weights",
+ "predictions/transform/logits/kernel"),
+)
+
+BERT_PERMUTATIONS = ()
+
+BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
+
+
+def _bert_name_replacement(var_name, name_replacements):
+ """Gets the variable name replacement."""
+ for src_pattern, tgt_pattern in name_replacements:
+ if src_pattern in var_name:
+ old_var_name = var_name
+ var_name = var_name.replace(src_pattern, tgt_pattern)
+ tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
+ return var_name
+
+
+def _has_exclude_patterns(name, exclude_patterns):
+ """Checks if a string contains substrings that match patterns to exclude."""
+ for p in exclude_patterns:
+ if p in name:
+ return True
+ return False
+
+
+def _get_permutation(name, permutations):
+ """Checks whether a variable requires transposition by pattern matching."""
+ for src_pattern, permutation in permutations:
+ if src_pattern in name:
+ tf.logging.info("Permuted: %s --> %s", name, permutation)
+ return permutation
+
+ return None
+
+
+def _get_new_shape(name, shape, num_heads):
+ """Checks whether a variable requires reshape by pattern matching."""
+ if "self_attention/attention_output/kernel" in name:
+ return tuple([num_heads, shape[0] // num_heads, shape[1]])
+ if "self_attention/attention_output/bias" in name:
+ return shape
+
+ patterns = [
+ "self_attention/query", "self_attention/value", "self_attention/key"
+ ]
+ for pattern in patterns:
+ if pattern in name:
+ if "kernel" in name:
+ return tuple([shape[0], num_heads, shape[1] // num_heads])
+ if "bias" in name:
+ return tuple([num_heads, shape[0] // num_heads])
+ return None
+
+
+def create_v2_checkpoint(model, src_checkpoint, output_path):
+ """Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
+ # Uses streaming-restore in eager model to read V1 name-based checkpoints.
+ model.load_weights(src_checkpoint).assert_existing_objects_matched()
+ checkpoint = tf.train.Checkpoint(model=model)
+ checkpoint.save(output_path)
+
+
+def convert(checkpoint_from_path,
+ checkpoint_to_path,
+ num_heads,
+ name_replacements,
+ permutations,
+ exclude_patterns=None):
+ """Migrates the names of variables within a checkpoint.
+
+ Args:
+ checkpoint_from_path: Path to source checkpoint to be read in.
+ checkpoint_to_path: Path to checkpoint to be written out.
+ num_heads: The number of heads of the model.
+ name_replacements: A list of tuples of the form (match_str, replace_str)
+ describing variable names to adjust.
+ permutations: A list of tuples of the form (match_str, permutation)
+ describing permutations to apply to given variables. Note that match_str
+ should match the original variable name, not the replaced one.
+ exclude_patterns: A list of string patterns to exclude variables from
+ checkpoint conversion.
+
+ Returns:
+ A dictionary that maps the new variable names to the Variable objects.
+ A dictionary that maps the old variable names to the new variable names.
+ """
+ with tf.Graph().as_default():
+ tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
+ reader = tf.train.NewCheckpointReader(checkpoint_from_path)
+ name_shape_map = reader.get_variable_to_shape_map()
+ new_variable_map = {}
+ conversion_map = {}
+ for var_name in name_shape_map:
+ if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
+ continue
+ # Get the original tensor data.
+ tensor = reader.get_tensor(var_name)
+
+ # Look up the new variable name, if any.
+ new_var_name = _bert_name_replacement(var_name, name_replacements)
+
+ # See if we need to reshape the underlying tensor.
+ new_shape = None
+ if num_heads > 0:
+ new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
+ if new_shape:
+ tf.logging.info("Veriable %s has a shape change from %s to %s",
+
+ var_name, tensor.shape, new_shape)
+ tensor = np.reshape(tensor, new_shape)
+
+ # See if we need to permute the underlying tensor.
+ permutation = _get_permutation(var_name, permutations)
+ if permutation:
+ tensor = np.transpose(tensor, permutation)
+
+ # Create a new variable with the possibly-reshaped or transposed tensor.
+ var = tf.Variable(tensor, name=var_name)
+
+ # Save the variable into the new variable map.
+ new_variable_map[new_var_name] = var
+
+ # Keep a list of converter variables for sanity checking.
+ if new_var_name != var_name:
+ conversion_map[var_name] = new_var_name
+
+ saver = tf.train.Saver(new_variable_map)
+
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
+ saver.save(sess, checkpoint_to_path, write_meta_graph=False)
+
+ tf.logging.info("Summary:")
+ tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
+ tf.logging.info(" Converted: %s", str(conversion_map))
diff --git a/models/official/nlp/bert/tf2_encoder_checkpoint_converter.py b/models/official/nlp/bert/tf2_encoder_checkpoint_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..2faf6ea2cfb9f0d71d0a79dff101e0408fa41778
--- /dev/null
+++ b/models/official/nlp/bert/tf2_encoder_checkpoint_converter.py
@@ -0,0 +1,109 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
+
+The conversion will yield an object-oriented checkpoint that can be used
+to restore a TransformerEncoder object.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+
+import tensorflow as tf
+from official.modeling import activations
+from official.nlp.bert import configs
+from official.nlp.bert import tf1_checkpoint_converter_lib
+from official.nlp.modeling import networks
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("bert_config_file", None,
+ "Bert configuration file to define core bert layers.")
+flags.DEFINE_string(
+ "checkpoint_to_convert", None,
+ "Initial checkpoint from a pretrained BERT model core (that is, only the "
+ "BertModel, with no task heads.)")
+flags.DEFINE_string("converted_checkpoint_path", None,
+ "Name for the created object-based V2 checkpoint.")
+
+
+def _create_bert_model(cfg):
+ """Creates a BERT keras core model from BERT configuration.
+
+ Args:
+ cfg: A `BertConfig` to create the core model.
+ Returns:
+ A TransformerEncoder netowork.
+ """
+ bert_encoder = networks.TransformerEncoder(
+ vocab_size=cfg.vocab_size,
+ hidden_size=cfg.hidden_size,
+ num_layers=cfg.num_hidden_layers,
+ num_attention_heads=cfg.num_attention_heads,
+ intermediate_size=cfg.intermediate_size,
+ activation=activations.gelu,
+ dropout_rate=cfg.hidden_dropout_prob,
+ attention_dropout_rate=cfg.attention_probs_dropout_prob,
+ sequence_length=cfg.max_position_embeddings,
+ type_vocab_size=cfg.type_vocab_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=cfg.initializer_range),
+ embedding_width=cfg.embedding_size)
+
+ return bert_encoder
+
+
+def convert_checkpoint(bert_config, output_path, v1_checkpoint):
+ """Converts a V1 checkpoint into an OO V2 checkpoint."""
+ output_dir, _ = os.path.split(output_path)
+
+ # Create a temporary V1 name-converted checkpoint in the output directory.
+ temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
+ temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
+ tf1_checkpoint_converter_lib.convert(
+ checkpoint_from_path=v1_checkpoint,
+ checkpoint_to_path=temporary_checkpoint,
+ num_heads=bert_config.num_attention_heads,
+ name_replacements=tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS,
+ permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
+ exclude_patterns=["adam", "Adam"])
+
+ # Create a V2 checkpoint from the temporary checkpoint.
+ model = _create_bert_model(bert_config)
+ tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
+ output_path)
+
+ # Clean up the temporary checkpoint, if it exists.
+ try:
+ tf.io.gfile.rmtree(temporary_checkpoint_dir)
+ except tf.errors.OpError:
+ # If it doesn't exist, we don't need to clean it up; continue.
+ pass
+
+
+def main(_):
+ output_path = FLAGS.converted_checkpoint_path
+ v1_checkpoint = FLAGS.checkpoint_to_convert
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ convert_checkpoint(bert_config, output_path, v1_checkpoint)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/bert/tokenization.py b/models/official/nlp/bert/tokenization.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0f7e27e320c727c4eee511fc63ebb63929250c7
--- /dev/null
+++ b/models/official/nlp/bert/tokenization.py
@@ -0,0 +1,545 @@
+# coding=utf-8
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tokenization classes implementation.
+
+The file is forked from:
+https://github.com/google-research/bert/blob/master/tokenization.py.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+import unicodedata
+
+import six
+import tensorflow as tf
+
+import sentencepiece as spm
+
+SPIECE_UNDERLINE = "▁"
+
+
+def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
+ """Checks whether the casing config is consistent with the checkpoint name."""
+
+ # The casing has to be passed in by the user and there is no explicit check
+ # as to whether it matches the checkpoint. The casing information probably
+ # should have been stored in the bert_config.json file, but it's not, so
+ # we have to heuristically detect it to validate.
+
+ if not init_checkpoint:
+ return
+
+ m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
+ if m is None:
+ return
+
+ model_name = m.group(1)
+
+ lower_models = [
+ "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
+ "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
+ ]
+
+ cased_models = [
+ "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
+ "multi_cased_L-12_H-768_A-12"
+ ]
+
+ is_bad_config = False
+ if model_name in lower_models and not do_lower_case:
+ is_bad_config = True
+ actual_flag = "False"
+ case_name = "lowercased"
+ opposite_flag = "True"
+
+ if model_name in cased_models and do_lower_case:
+ is_bad_config = True
+ actual_flag = "True"
+ case_name = "cased"
+ opposite_flag = "False"
+
+ if is_bad_config:
+ raise ValueError(
+ "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
+ "However, `%s` seems to be a %s model, so you "
+ "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
+ "how the model was pre-training. If this error is wrong, please "
+ "just comment out this check." %
+ (actual_flag, init_checkpoint, model_name, case_name, opposite_flag))
+
+
+def convert_to_unicode(text):
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
+ if six.PY3:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, bytes):
+ return text.decode("utf-8", "ignore")
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ elif six.PY2:
+ if isinstance(text, str):
+ return text.decode("utf-8", "ignore")
+ elif isinstance(text, unicode):
+ return text
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ else:
+ raise ValueError("Not running on Python2 or Python 3?")
+
+
+def printable_text(text):
+ """Returns text encoded in a way suitable for print or `tf.logging`."""
+
+ # These functions want `str` for both Python2 and Python3, but in one case
+ # it's a Unicode string and in the other it's a byte string.
+ if six.PY3:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, bytes):
+ return text.decode("utf-8", "ignore")
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ elif six.PY2:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, unicode):
+ return text.encode("utf-8")
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ else:
+ raise ValueError("Not running on Python2 or Python 3?")
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ index = 0
+ with tf.io.gfile.GFile(vocab_file, "r") as reader:
+ while True:
+ token = convert_to_unicode(reader.readline())
+ if not token:
+ break
+ token = token.strip()
+ vocab[token] = index
+ index += 1
+ return vocab
+
+
+def convert_by_vocab(vocab, items):
+ """Converts a sequence of [tokens|ids] using the vocab."""
+ output = []
+ for item in items:
+ output.append(vocab[item])
+ return output
+
+
+def convert_tokens_to_ids(vocab, tokens):
+ return convert_by_vocab(vocab, tokens)
+
+
+def convert_ids_to_tokens(inv_vocab, ids):
+ return convert_by_vocab(inv_vocab, ids)
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class FullTokenizer(object):
+ """Runs end-to-end tokenziation."""
+
+ def __init__(self, vocab_file, do_lower_case=True, split_on_punc=True):
+ self.vocab = load_vocab(vocab_file)
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case, split_on_punc=split_on_punc)
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+
+ def tokenize(self, text):
+ split_tokens = []
+ for token in self.basic_tokenizer.tokenize(text):
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
+ split_tokens.append(sub_token)
+
+ return split_tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ return convert_by_vocab(self.vocab, tokens)
+
+ def convert_ids_to_tokens(self, ids):
+ return convert_by_vocab(self.inv_vocab, ids)
+
+
+class BasicTokenizer(object):
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+ def __init__(self, do_lower_case=True, split_on_punc=True):
+ """Constructs a BasicTokenizer.
+
+ Args:
+ do_lower_case: Whether to lower case the input.
+ split_on_punc: Whether to apply split on punctuations. By default BERT
+ starts a new token for punctuations. This makes detokenization difficult
+ for tasks like seq2seq decoding.
+ """
+ self.do_lower_case = do_lower_case
+ self.split_on_punc = split_on_punc
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text."""
+ text = convert_to_unicode(text)
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ text = self._tokenize_chinese_chars(text)
+
+ orig_tokens = whitespace_tokenize(text)
+ split_tokens = []
+ for token in orig_tokens:
+ if self.do_lower_case:
+ token = token.lower()
+ token = self._run_strip_accents(token)
+ if self.split_on_punc:
+ split_tokens.extend(self._run_split_on_punc(token))
+ else:
+ split_tokens.append(token)
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text):
+ """Splits punctuation on a piece of text."""
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xfffd or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+class WordpieceTokenizer(object):
+ """Runs WordPiece tokenziation."""
+
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=400):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text into its word pieces.
+
+ This uses a greedy longest-match-first algorithm to perform tokenization
+ using the given vocabulary.
+
+ For example:
+ input = "unaffable"
+ output = ["un", "##aff", "##able"]
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through `BasicTokenizer.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ text = convert_to_unicode(text)
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+def _is_whitespace(char):
+ """Checks whether `chars` is a whitespace character."""
+ # \t, \n, and \r are technically control characters but we treat them
+ # as whitespace since they are generally considered as such.
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+
+def _is_control(char):
+ """Checks whether `chars` is a control character."""
+ # These are technically control characters but we count them as whitespace
+ # characters.
+ if char == "\t" or char == "\n" or char == "\r":
+ return False
+ cat = unicodedata.category(char)
+ if cat in ("Cc", "Cf"):
+ return True
+ return False
+
+
+def _is_punctuation(char):
+ """Checks whether `chars` is a punctuation character."""
+ cp = ord(char)
+ # We treat all non-letter/number ASCII as punctuation.
+ # Characters such as "^", "$", and "`" are not in the Unicode
+ # Punctuation class but we treat them as punctuation anyways, for
+ # consistency.
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
+ return True
+ cat = unicodedata.category(char)
+ if cat.startswith("P"):
+ return True
+ return False
+
+
+def preprocess_text(inputs, remove_space=True, lower=False):
+ """Preprocesses data by removing extra space and normalize data.
+
+ This method is used together with sentence piece tokenizer and is forked from:
+ https://github.com/google-research/google-research/blob/master/albert/tokenization.py
+
+ Args:
+ inputs: The input text.
+ remove_space: Whether to remove the extra space.
+ lower: Whether to lowercase the text.
+
+ Returns:
+ The preprocessed text.
+
+ """
+ outputs = inputs
+ if remove_space:
+ outputs = " ".join(inputs.strip().split())
+
+ if six.PY2 and isinstance(outputs, str):
+ try:
+ outputs = six.ensure_text(outputs, "utf-8")
+ except UnicodeDecodeError:
+ outputs = six.ensure_text(outputs, "latin-1")
+
+ outputs = unicodedata.normalize("NFKD", outputs)
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
+ if lower:
+ outputs = outputs.lower()
+
+ return outputs
+
+
+def encode_pieces(sp_model, text, sample=False):
+ """Segements text into pieces.
+
+ This method is used together with sentence piece tokenizer and is forked from:
+ https://github.com/google-research/google-research/blob/master/albert/tokenization.py
+
+
+ Args:
+ sp_model: A spm.SentencePieceProcessor object.
+ text: The input text to be segemented.
+ sample: Whether to randomly sample a segmentation output or return a
+ deterministic one.
+
+ Returns:
+ A list of token pieces.
+ """
+ if six.PY2 and isinstance(text, six.text_type):
+ text = six.ensure_binary(text, "utf-8")
+
+ if not sample:
+ pieces = sp_model.EncodeAsPieces(text)
+ else:
+ pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
+ new_pieces = []
+ for piece in pieces:
+ piece = printable_text(piece)
+ if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
+ cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
+ SPIECE_UNDERLINE, ""))
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
+ if len(cur_pieces[0]) == 1:
+ cur_pieces = cur_pieces[1:]
+ else:
+ cur_pieces[0] = cur_pieces[0][1:]
+ cur_pieces.append(piece[-1])
+ new_pieces.extend(cur_pieces)
+ else:
+ new_pieces.append(piece)
+
+ return new_pieces
+
+
+def encode_ids(sp_model, text, sample=False):
+ """Segments text and return token ids.
+
+ This method is used together with sentence piece tokenizer and is forked from:
+ https://github.com/google-research/google-research/blob/master/albert/tokenization.py
+
+ Args:
+ sp_model: A spm.SentencePieceProcessor object.
+ text: The input text to be segemented.
+ sample: Whether to randomly sample a segmentation output or return a
+ deterministic one.
+
+ Returns:
+ A list of token ids.
+ """
+ pieces = encode_pieces(sp_model, text, sample=sample)
+ ids = [sp_model.PieceToId(piece) for piece in pieces]
+ return ids
+
+
+class FullSentencePieceTokenizer(object):
+ """Runs end-to-end sentence piece tokenization.
+
+ The interface of this class is intended to keep the same as above
+ `FullTokenizer` class for easier usage.
+ """
+
+ def __init__(self, sp_model_file):
+ """Inits FullSentencePieceTokenizer.
+
+ Args:
+ sp_model_file: The path to the sentence piece model file.
+ """
+ self.sp_model = spm.SentencePieceProcessor()
+ self.sp_model.Load(sp_model_file)
+ self.vocab = {
+ self.sp_model.IdToPiece(i): i
+ for i in six.moves.range(self.sp_model.GetPieceSize())
+ }
+
+ def tokenize(self, text):
+ """Tokenizes text into pieces."""
+ return encode_pieces(self.sp_model, text)
+
+ def convert_tokens_to_ids(self, tokens):
+ """Converts a list of tokens to a list of ids."""
+ return [self.sp_model.PieceToId(printable_text(token)) for token in tokens]
+
+ def convert_ids_to_tokens(self, ids):
+ """Converts a list of ids ot a list of tokens."""
+ return [self.sp_model.IdToPiece(id_) for id_ in ids]
diff --git a/models/official/nlp/bert/tokenization_test.py b/models/official/nlp/bert/tokenization_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a0503c3ed6999e3bd81aec4de8f7d64ec733bd9
--- /dev/null
+++ b/models/official/nlp/bert/tokenization_test.py
@@ -0,0 +1,160 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+import six
+import tensorflow as tf
+
+from official.nlp.bert import tokenization
+
+
+class TokenizationTest(tf.test.TestCase):
+ """Tokenization test.
+
+ The implementation is forked from
+ https://github.com/google-research/bert/blob/master/tokenization_test.py."
+ """
+
+ def test_full_tokenizer(self):
+ vocab_tokens = [
+ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
+ "##ing", ","
+ ]
+ with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
+ if six.PY2:
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
+ else:
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens
+ ]).encode("utf-8"))
+
+ vocab_file = vocab_writer.name
+
+ tokenizer = tokenization.FullTokenizer(vocab_file)
+ os.unlink(vocab_file)
+
+ tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
+ self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
+
+ self.assertAllEqual(
+ tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
+
+ def test_chinese(self):
+ tokenizer = tokenization.BasicTokenizer()
+
+ self.assertAllEqual(
+ tokenizer.tokenize(u"ah\u535A\u63A8zz"),
+ [u"ah", u"\u535A", u"\u63A8", u"zz"])
+
+ def test_basic_tokenizer_lower(self):
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
+
+ self.assertAllEqual(
+ tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
+ ["hello", "!", "how", "are", "you", "?"])
+ self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_no_lower(self):
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
+
+ self.assertAllEqual(
+ tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
+ ["HeLLo", "!", "how", "Are", "yoU", "?"])
+
+ def test_basic_tokenizer_no_split_on_punc(self):
+ tokenizer = tokenization.BasicTokenizer(
+ do_lower_case=True, split_on_punc=False)
+
+ self.assertAllEqual(
+ tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
+ ["hello!how", "are", "you?"])
+
+ def test_wordpiece_tokenizer(self):
+ vocab_tokens = [
+ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
+ "##ing", "##!", "!"
+ ]
+
+ vocab = {}
+ for (i, token) in enumerate(vocab_tokens):
+ vocab[token] = i
+ tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
+
+ self.assertAllEqual(tokenizer.tokenize(""), [])
+
+ self.assertAllEqual(
+ tokenizer.tokenize("unwanted running"),
+ ["un", "##want", "##ed", "runn", "##ing"])
+
+ self.assertAllEqual(
+ tokenizer.tokenize("unwanted running !"),
+ ["un", "##want", "##ed", "runn", "##ing", "!"])
+
+ self.assertAllEqual(
+ tokenizer.tokenize("unwanted running!"),
+ ["un", "##want", "##ed", "runn", "##ing", "##!"])
+
+ self.assertAllEqual(
+ tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
+
+ def test_convert_tokens_to_ids(self):
+ vocab_tokens = [
+ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
+ "##ing"
+ ]
+
+ vocab = {}
+ for (i, token) in enumerate(vocab_tokens):
+ vocab[token] = i
+
+ self.assertAllEqual(
+ tokenization.convert_tokens_to_ids(
+ vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
+
+ def test_is_whitespace(self):
+ self.assertTrue(tokenization._is_whitespace(u" "))
+ self.assertTrue(tokenization._is_whitespace(u"\t"))
+ self.assertTrue(tokenization._is_whitespace(u"\r"))
+ self.assertTrue(tokenization._is_whitespace(u"\n"))
+ self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
+
+ self.assertFalse(tokenization._is_whitespace(u"A"))
+ self.assertFalse(tokenization._is_whitespace(u"-"))
+
+ def test_is_control(self):
+ self.assertTrue(tokenization._is_control(u"\u0005"))
+
+ self.assertFalse(tokenization._is_control(u"A"))
+ self.assertFalse(tokenization._is_control(u" "))
+ self.assertFalse(tokenization._is_control(u"\t"))
+ self.assertFalse(tokenization._is_control(u"\r"))
+ self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
+
+ def test_is_punctuation(self):
+ self.assertTrue(tokenization._is_punctuation(u"-"))
+ self.assertTrue(tokenization._is_punctuation(u"$"))
+ self.assertTrue(tokenization._is_punctuation(u"`"))
+ self.assertTrue(tokenization._is_punctuation(u"."))
+
+ self.assertFalse(tokenization._is_punctuation(u"A"))
+ self.assertFalse(tokenization._is_punctuation(u" "))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/configs/__init__.py b/models/official/nlp/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/official/nlp/configs/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/official/nlp/configs/bert.py b/models/official/nlp/configs/bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..48b83107f20a2b3251624a24b580412b93ed1979
--- /dev/null
+++ b/models/official/nlp/configs/bert.py
@@ -0,0 +1,151 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Multi-head BERT encoder network with classification heads.
+
+Includes configurations and instantiation methods.
+"""
+from typing import List, Optional, Text
+
+import dataclasses
+import tensorflow as tf
+
+from official.modeling import tf_utils
+from official.modeling.hyperparams import base_config
+from official.modeling.hyperparams import config_definitions as cfg
+from official.nlp.configs import encoders
+from official.nlp.modeling import layers
+from official.nlp.modeling.models import bert_pretrainer
+
+
+@dataclasses.dataclass
+class ClsHeadConfig(base_config.Config):
+ inner_dim: int = 0
+ num_classes: int = 2
+ activation: Optional[Text] = "tanh"
+ dropout_rate: float = 0.0
+ cls_token_idx: int = 0
+ name: Optional[Text] = None
+
+
+@dataclasses.dataclass
+class BertPretrainerConfig(base_config.Config):
+ """BERT encoder configuration."""
+ num_masked_tokens: int = 76
+ encoder: encoders.TransformerEncoderConfig = (
+ encoders.TransformerEncoderConfig())
+ cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
+
+
+def instantiate_classification_heads_from_cfgs(
+ cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
+ return [
+ layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
+ ] if cls_head_configs else []
+
+
+def instantiate_bertpretrainer_from_cfg(
+ config: BertPretrainerConfig,
+ encoder_network: Optional[tf.keras.Model] = None
+ ) -> bert_pretrainer.BertPretrainerV2:
+ """Instantiates a BertPretrainer from the config."""
+ encoder_cfg = config.encoder
+ if encoder_network is None:
+ encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
+ return bert_pretrainer.BertPretrainerV2(
+ config.num_masked_tokens,
+ mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
+ mlm_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ encoder_network=encoder_network,
+ classification_heads=instantiate_classification_heads_from_cfgs(
+ config.cls_heads))
+
+
+@dataclasses.dataclass
+class BertPretrainDataConfig(cfg.DataConfig):
+ """Data config for BERT pretraining task (tasks/masked_lm)."""
+ input_path: str = ""
+ global_batch_size: int = 512
+ is_training: bool = True
+ seq_length: int = 512
+ max_predictions_per_seq: int = 76
+ use_next_sentence_label: bool = True
+ use_position_id: bool = False
+
+
+@dataclasses.dataclass
+class BertPretrainEvalDataConfig(BertPretrainDataConfig):
+ """Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
+ input_path: str = ""
+ global_batch_size: int = 512
+ is_training: bool = False
+
+
+@dataclasses.dataclass
+class SentencePredictionDataConfig(cfg.DataConfig):
+ """Data config for sentence prediction task (tasks/sentence_prediction)."""
+ input_path: str = ""
+ global_batch_size: int = 32
+ is_training: bool = True
+ seq_length: int = 128
+
+
+@dataclasses.dataclass
+class SentencePredictionDevDataConfig(cfg.DataConfig):
+ """Dev Data config for sentence prediction (tasks/sentence_prediction)."""
+ input_path: str = ""
+ global_batch_size: int = 32
+ is_training: bool = False
+ seq_length: int = 128
+ drop_remainder: bool = False
+
+
+@dataclasses.dataclass
+class QADataConfig(cfg.DataConfig):
+ """Data config for question answering task (tasks/question_answering)."""
+ input_path: str = ""
+ global_batch_size: int = 48
+ is_training: bool = True
+ seq_length: int = 384
+
+
+@dataclasses.dataclass
+class QADevDataConfig(cfg.DataConfig):
+ """Dev Data config for queston answering (tasks/question_answering)."""
+ input_path: str = ""
+ global_batch_size: int = 48
+ is_training: bool = False
+ seq_length: int = 384
+ drop_remainder: bool = False
+
+
+@dataclasses.dataclass
+class TaggingDataConfig(cfg.DataConfig):
+ """Data config for tagging (tasks/tagging)."""
+ input_path: str = ""
+ global_batch_size: int = 48
+ is_training: bool = True
+ seq_length: int = 384
+
+
+@dataclasses.dataclass
+class TaggingDevDataConfig(cfg.DataConfig):
+ """Dev Data config for tagging (tasks/tagging)."""
+ input_path: str = ""
+ global_batch_size: int = 48
+ is_training: bool = False
+ seq_length: int = 384
+ drop_remainder: bool = False
diff --git a/models/official/nlp/configs/bert_test.py b/models/official/nlp/configs/bert_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c734b190ea71697350cc0fb84cf50582afdb96b3
--- /dev/null
+++ b/models/official/nlp/configs/bert_test.py
@@ -0,0 +1,65 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BERT configurations and models instantiation."""
+
+import tensorflow as tf
+
+from official.nlp.configs import bert
+from official.nlp.configs import encoders
+
+
+class BertModelsTest(tf.test.TestCase):
+
+ def test_network_invocation(self):
+ config = bert.BertPretrainerConfig(
+ encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
+ _ = bert.instantiate_bertpretrainer_from_cfg(config)
+
+ # Invokes with classification heads.
+ config = bert.BertPretrainerConfig(
+ encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=10, num_classes=2, name="next_sentence")
+ ])
+ _ = bert.instantiate_bertpretrainer_from_cfg(config)
+
+ with self.assertRaises(ValueError):
+ config = bert.BertPretrainerConfig(
+ encoder=encoders.TransformerEncoderConfig(
+ vocab_size=10, num_layers=1),
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=10, num_classes=2, name="next_sentence"),
+ bert.ClsHeadConfig(
+ inner_dim=10, num_classes=2, name="next_sentence")
+ ])
+ _ = bert.instantiate_bertpretrainer_from_cfg(config)
+
+ def test_checkpoint_items(self):
+ config = bert.BertPretrainerConfig(
+ encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=10, num_classes=2, name="next_sentence")
+ ])
+ encoder = bert.instantiate_bertpretrainer_from_cfg(config)
+ self.assertSameElements(encoder.checkpoint_items.keys(),
+ ["encoder", "next_sentence.pooler_dense"])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/configs/encoders.py b/models/official/nlp/configs/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..0af5b733d9a7b60af21a8be9021fafdfa085e34a
--- /dev/null
+++ b/models/official/nlp/configs/encoders.py
@@ -0,0 +1,62 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Transformer Encoders.
+
+Includes configurations and instantiation methods.
+"""
+
+import dataclasses
+import tensorflow as tf
+
+from official.modeling import tf_utils
+from official.modeling.hyperparams import base_config
+from official.nlp.modeling import networks
+
+
+@dataclasses.dataclass
+class TransformerEncoderConfig(base_config.Config):
+ """BERT encoder configuration."""
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ hidden_activation: str = "gelu"
+ intermediate_size: int = 3072
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+
+
+def instantiate_encoder_from_cfg(
+ config: TransformerEncoderConfig) -> networks.TransformerEncoder:
+ """Instantiate a Transformer encoder network from TransformerEncoderConfig."""
+ encoder_network = networks.TransformerEncoder(
+ vocab_size=config.vocab_size,
+ hidden_size=config.hidden_size,
+ num_layers=config.num_layers,
+ num_attention_heads=config.num_attention_heads,
+ intermediate_size=config.intermediate_size,
+ activation=tf_utils.get_activation(config.hidden_activation),
+ dropout_rate=config.dropout_rate,
+ attention_dropout_rate=config.attention_dropout_rate,
+ sequence_length=None,
+ max_sequence_length=config.max_position_embeddings,
+ type_vocab_size=config.type_vocab_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=config.initializer_range))
+ return encoder_network
diff --git a/models/official/nlp/data/__init__.py b/models/official/nlp/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/nlp/data/__pycache__/__init__.cpython-310.pyc b/models/official/nlp/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3866844f42b63caffc58cebe79447266ef6bad5d
Binary files /dev/null and b/models/official/nlp/data/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/official/nlp/data/__pycache__/__init__.cpython-38.pyc b/models/official/nlp/data/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b43d026004ab9d6e35ecbbe5888b00e9cd8a4677
Binary files /dev/null and b/models/official/nlp/data/__pycache__/__init__.cpython-38.pyc differ
diff --git a/models/official/nlp/data/__pycache__/__init__.cpython-39.pyc b/models/official/nlp/data/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2ae6127a7ed41164727b124e616b4a4f69227a4
Binary files /dev/null and b/models/official/nlp/data/__pycache__/__init__.cpython-39.pyc differ
diff --git a/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-310.pyc b/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..303a264d77eb106f58c195f0d8d4511afee9d9d7
Binary files /dev/null and b/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-310.pyc differ
diff --git a/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-38.pyc b/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a935018c088877b49469e9a4df486eacb64837f0
Binary files /dev/null and b/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-38.pyc differ
diff --git a/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-39.pyc b/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19fe41408ff4122c5998a2d2f15b31cf63bba55a
Binary files /dev/null and b/models/official/nlp/data/__pycache__/classifier_data_lib.cpython-39.pyc differ
diff --git a/models/official/nlp/data/classifier_data_lib.py b/models/official/nlp/data/classifier_data_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..67c47d7874c4d3a40b23e5280e14ed8716a23176
--- /dev/null
+++ b/models/official/nlp/data/classifier_data_lib.py
@@ -0,0 +1,1088 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BERT library to process data for classification task."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import csv
+import importlib
+import os
+
+from absl import logging
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+from official.nlp.bert import tokenization
+
+
+class InputExample(object):
+ """A single training/test example for simple sequence classification."""
+
+ def __init__(self,
+ guid,
+ text_a,
+ text_b=None,
+ label=None,
+ weight=None,
+ int_iden=None):
+ """Constructs a InputExample.
+
+ Args:
+ guid: Unique id for the example.
+ text_a: string. The untokenized text of the first sequence. For single
+ sequence tasks, only this sequence must be specified.
+ text_b: (Optional) string. The untokenized text of the second sequence.
+ Only must be specified for sequence pair tasks.
+ label: (Optional) string. The label of the example. This should be
+ specified for train and dev examples, but not for test examples.
+ weight: (Optional) float. The weight of the example to be used during
+ training.
+ int_iden: (Optional) int. The int identification number of example in the
+ corpus.
+ """
+ self.guid = guid
+ self.text_a = text_a
+ self.text_b = text_b
+ self.label = label
+ self.weight = weight
+ self.int_iden = int_iden
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ input_ids,
+ input_mask,
+ segment_ids,
+ label_id,
+ is_real_example=True,
+ weight=None,
+ int_iden=None):
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.segment_ids = segment_ids
+ self.label_id = label_id
+ self.is_real_example = is_real_example
+ self.weight = weight
+ self.int_iden = int_iden
+
+
+class DataProcessor(object):
+ """Base class for data converters for sequence classification data sets."""
+
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
+ self.process_text_fn = process_text_fn
+
+ def get_train_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the train set."""
+ raise NotImplementedError()
+
+ def get_dev_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the dev set."""
+ raise NotImplementedError()
+
+ def get_test_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for prediction."""
+ raise NotImplementedError()
+
+ def get_labels(self):
+ """Gets the list of labels for this data set."""
+ raise NotImplementedError()
+
+ @staticmethod
+ def get_processor_name():
+ """Gets the string identifier of the processor."""
+ raise NotImplementedError()
+
+ @classmethod
+ def _read_tsv(cls, input_file, quotechar=None):
+ """Reads a tab separated value file."""
+ with tf.io.gfile.GFile(input_file, "r") as f:
+ reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
+ lines = []
+ for line in reader:
+ lines.append(line)
+ return lines
+
+
+class XnliProcessor(DataProcessor):
+ """Processor for the XNLI data set."""
+ supported_languages = [
+ "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
+ "ur", "vi", "zh"
+ ]
+
+ def __init__(self,
+ language="en",
+ process_text_fn=tokenization.convert_to_unicode):
+ super(XnliProcessor, self).__init__(process_text_fn)
+ if language == "all":
+ self.languages = XnliProcessor.supported_languages
+ elif language not in XnliProcessor.supported_languages:
+ raise ValueError("language %s is not supported for XNLI task." % language)
+ else:
+ self.languages = [language]
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lines = []
+ for language in self.languages:
+ # Skips the header.
+ lines.extend(
+ self._read_tsv(
+ os.path.join(data_dir, "multinli",
+ "multinli.train.%s.tsv" % language))[1:])
+
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "train-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ if label == self.process_text_fn("contradictory"):
+ label = self.process_text_fn("contradiction")
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "dev-%d" % i
+ text_a = self.process_text_fn(line[6])
+ text_b = self.process_text_fn(line[7])
+ label = self.process_text_fn(line[1])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
+ examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "test-%d" % i
+ language = self.process_text_fn(line[0])
+ text_a = self.process_text_fn(line[6])
+ text_b = self.process_text_fn(line[7])
+ label = self.process_text_fn(line[1])
+ examples_by_lang[language].append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples_by_lang
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "XNLI"
+
+
+class XtremeXnliProcessor(DataProcessor):
+ """Processor for the XTREME XNLI data set."""
+ supported_languages = [
+ "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
+ "ur", "vi", "zh"
+ ]
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
+
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "train-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "dev-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ examples_by_lang = {k: [] for k in self.supported_languages}
+ for lang in self.supported_languages:
+ lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
+ for (i, line) in enumerate(lines):
+ guid = f"test-{i}"
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = "contradiction"
+ examples_by_lang[lang].append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples_by_lang
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "XTREME-XNLI"
+
+
+class PawsxProcessor(DataProcessor):
+ """Processor for the PAWS-X data set."""
+ supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
+
+ def __init__(self,
+ language="en",
+ process_text_fn=tokenization.convert_to_unicode):
+ super(PawsxProcessor, self).__init__(process_text_fn)
+ if language == "all":
+ self.languages = PawsxProcessor.supported_languages
+ elif language not in PawsxProcessor.supported_languages:
+ raise ValueError("language %s is not supported for PAWS-X task." %
+ language)
+ else:
+ self.languages = [language]
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lines = []
+ for language in self.languages:
+ if language == "en":
+ train_tsv = "train.tsv"
+ else:
+ train_tsv = "translated_train.tsv"
+ # Skips the header.
+ lines.extend(
+ self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
+
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "train-%d" % i
+ text_a = self.process_text_fn(line[1])
+ text_b = self.process_text_fn(line[2])
+ label = self.process_text_fn(line[3])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ lines = []
+ for lang in PawsxProcessor.supported_languages:
+ lines.extend(self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv")))
+
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "dev-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ examples_by_lang = {k: [] for k in self.supported_languages}
+ for lang in self.supported_languages:
+ lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
+ for (i, line) in enumerate(lines):
+ guid = "test-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples_by_lang[lang].append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples_by_lang
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "XTREME-PAWS-X"
+
+
+class XtremePawsxProcessor(DataProcessor):
+ """Processor for the XTREME PAWS-X data set."""
+ supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "train-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
+
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "dev-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ examples_by_lang = {k: [] for k in self.supported_languages}
+ for lang in self.supported_languages:
+ lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
+ for (i, line) in enumerate(lines):
+ guid = "test-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = "0"
+ examples_by_lang[lang].append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples_by_lang
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "XTREME-PAWS-X"
+
+
+class MnliProcessor(DataProcessor):
+ """Processor for the MultiNLI data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
+ "dev_matched")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "MNLI"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
+ text_a = self.process_text_fn(line[8])
+ text_b = self.process_text_fn(line[9])
+ if set_type == "test":
+ label = "contradiction"
+ else:
+ label = self.process_text_fn(line[-1])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class MrpcProcessor(DataProcessor):
+ """Processor for the MRPC data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "MRPC"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, i)
+ text_a = self.process_text_fn(line[3])
+ text_b = self.process_text_fn(line[4])
+ if set_type == "test":
+ label = "0"
+ else:
+ label = self.process_text_fn(line[0])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class QqpProcessor(DataProcessor):
+ """Processor for the QQP data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "QQP"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, line[0])
+ try:
+ text_a = line[3]
+ text_b = line[4]
+ label = line[5]
+ except IndexError:
+ continue
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class ColaProcessor(DataProcessor):
+ """Processor for the CoLA data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "COLA"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ # Only the test set has a header
+ if set_type == "test" and i == 0:
+ continue
+ guid = "%s-%s" % (set_type, i)
+ if set_type == "test":
+ text_a = self.process_text_fn(line[1])
+ label = "0"
+ else:
+ text_a = self.process_text_fn(line[3])
+ label = self.process_text_fn(line[1])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
+ return examples
+
+
+class RteProcessor(DataProcessor):
+ """Processor for the RTE data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ # All datasets are converted to 2-class split, where for 3-class datasets we
+ # collapse neutral and contradiction into not_entailment.
+ return ["entailment", "not_entailment"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "RTE"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, i)
+ if set_type == "test":
+ text_a = tokenization.convert_to_unicode(line[1])
+ text_b = tokenization.convert_to_unicode(line[2])
+ label = "entailment"
+ else:
+ text_a = tokenization.convert_to_unicode(line[1])
+ text_b = tokenization.convert_to_unicode(line[2])
+ label = tokenization.convert_to_unicode(line[3])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class SstProcessor(DataProcessor):
+ """Processor for the SST-2 data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "SST-2"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, i)
+ if set_type == "test":
+ text_a = tokenization.convert_to_unicode(line[1])
+ label = "0"
+ else:
+ text_a = tokenization.convert_to_unicode(line[0])
+ label = tokenization.convert_to_unicode(line[1])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
+ return examples
+
+
+class QnliProcessor(DataProcessor):
+ """Processor for the QNLI data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["entailment", "not_entailment"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "QNLI"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, 1)
+ if set_type == "test":
+ text_a = tokenization.convert_to_unicode(line[1])
+ text_b = tokenization.convert_to_unicode(line[2])
+ label = "entailment"
+ else:
+ text_a = tokenization.convert_to_unicode(line[1])
+ text_b = tokenization.convert_to_unicode(line[2])
+ label = tokenization.convert_to_unicode(line[-1])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class TfdsProcessor(DataProcessor):
+ """Processor for generic text classification and regression TFDS data set.
+
+ The TFDS parameters are expected to be provided in the tfds_params string, in
+ a comma-separated list of parameter assignments.
+ Examples:
+ tfds_params="dataset=scicite,text_key=string"
+ tfds_params="dataset=imdb_reviews,test_split=,dev_split=test"
+ tfds_params="dataset=glue/cola,text_key=sentence"
+ tfds_params="dataset=glue/sst2,text_key=sentence"
+ tfds_params="dataset=glue/qnli,text_key=question,text_b_key=sentence"
+ tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
+ tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
+ "is_regression=true,label_type=float"
+ Possible parameters (please refer to the documentation of Tensorflow Datasets
+ (TFDS) for the meaning of individual parameters):
+ dataset: Required dataset name (potentially with subset and version number).
+ data_dir: Optional TFDS source root directory.
+ module_import: Optional Dataset module to import.
+ train_split: Name of the train split (defaults to `train`).
+ dev_split: Name of the dev split (defaults to `validation`).
+ test_split: Name of the test split (defaults to `test`).
+ text_key: Key of the text_a feature (defaults to `text`).
+ text_b_key: Key of the second text feature if available.
+ label_key: Key of the label feature (defaults to `label`).
+ test_text_key: Key of the text feature to use in test set.
+ test_text_b_key: Key of the second text feature to use in test set.
+ test_label: String to be used as the label for all test examples.
+ label_type: Type of the label key (defaults to `int`).
+ weight_key: Key of the float sample weight (is not used if not provided).
+ is_regression: Whether the task is a regression problem (defaults to False).
+ """
+
+ def __init__(self,
+ tfds_params,
+ process_text_fn=tokenization.convert_to_unicode):
+ super(TfdsProcessor, self).__init__(process_text_fn)
+ self._process_tfds_params_str(tfds_params)
+ if self.module_import:
+ importlib.import_module(self.module_import)
+
+ self.dataset, info = tfds.load(
+ self.dataset_name, data_dir=self.data_dir, with_info=True)
+ if self.is_regression:
+ self._labels = None
+ else:
+ self._labels = list(range(info.features[self.label_key].num_classes))
+
+ def _process_tfds_params_str(self, params_str):
+ """Extracts TFDS parameters from a comma-separated assignements string."""
+ dtype_map = {"int": int, "float": float}
+ cast_str_to_bool = lambda s: s.lower() not in ["false", "0"]
+
+ tuples = [x.split("=") for x in params_str.split(",")]
+ d = {k.strip(): v.strip() for k, v in tuples}
+ self.dataset_name = d["dataset"] # Required.
+ self.data_dir = d.get("data_dir", None)
+ self.module_import = d.get("module_import", None)
+ self.train_split = d.get("train_split", "train")
+ self.dev_split = d.get("dev_split", "validation")
+ self.test_split = d.get("test_split", "test")
+ self.text_key = d.get("text_key", "text")
+ self.text_b_key = d.get("text_b_key", None)
+ self.label_key = d.get("label_key", "label")
+ self.test_text_key = d.get("test_text_key", self.text_key)
+ self.test_text_b_key = d.get("test_text_b_key", self.text_b_key)
+ self.test_label = d.get("test_label", "test_example")
+ self.label_type = dtype_map[d.get("label_type", "int")]
+ self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
+ self.weight_key = d.get("weight_key", None)
+
+ def get_train_examples(self, data_dir):
+ assert data_dir is None
+ return self._create_examples(self.train_split, "train")
+
+ def get_dev_examples(self, data_dir):
+ assert data_dir is None
+ return self._create_examples(self.dev_split, "dev")
+
+ def get_test_examples(self, data_dir):
+ assert data_dir is None
+ return self._create_examples(self.test_split, "test")
+
+ def get_labels(self):
+ return self._labels
+
+ def get_processor_name(self):
+ return "TFDS_" + self.dataset_name
+
+ def _create_examples(self, split_name, set_type):
+ """Creates examples for the training and dev sets."""
+ if split_name not in self.dataset:
+ raise ValueError("Split {} not available.".format(split_name))
+ dataset = self.dataset[split_name].as_numpy_iterator()
+ examples = []
+ text_b, weight = None, None
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ if set_type == "test":
+ text_a = self.process_text_fn(example[self.test_text_key])
+ if self.test_text_b_key:
+ text_b = self.process_text_fn(example[self.test_text_b_key])
+ label = self.test_label
+ else:
+ text_a = self.process_text_fn(example[self.text_key])
+ if self.text_b_key:
+ text_b = self.process_text_fn(example[self.text_b_key])
+ label = self.label_type(example[self.label_key])
+ if self.weight_key:
+ weight = float(example[self.weight_key])
+ examples.append(
+ InputExample(
+ guid=guid,
+ text_a=text_a,
+ text_b=text_b,
+ label=label,
+ weight=weight))
+ return examples
+
+
+def convert_single_example(ex_index, example, label_list, max_seq_length,
+ tokenizer):
+ """Converts a single `InputExample` into a single `InputFeatures`."""
+ label_map = {}
+ if label_list:
+ for (i, label) in enumerate(label_list):
+ label_map[label] = i
+
+ tokens_a = tokenizer.tokenize(example.text_a)
+ tokens_b = None
+ if example.text_b:
+ tokens_b = tokenizer.tokenize(example.text_b)
+
+ if tokens_b:
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
+ # length is less than the specified length.
+ # Account for [CLS], [SEP], [SEP] with "- 3"
+ _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
+ else:
+ # Account for [CLS] and [SEP] with "- 2"
+ if len(tokens_a) > max_seq_length - 2:
+ tokens_a = tokens_a[0:(max_seq_length - 2)]
+
+ # The convention in BERT is:
+ # (a) For sequence pairs:
+ # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ # (b) For single sequences:
+ # tokens: [CLS] the dog is hairy . [SEP]
+ # type_ids: 0 0 0 0 0 0 0
+ #
+ # Where "type_ids" are used to indicate whether this is the first
+ # sequence or the second sequence. The embedding vectors for `type=0` and
+ # `type=1` were learned during pre-training and are added to the wordpiece
+ # embedding vector (and position vector). This is not *strictly* necessary
+ # since the [SEP] token unambiguously separates the sequences, but it makes
+ # it easier for the model to learn the concept of sequences.
+ #
+ # For classification tasks, the first vector (corresponding to [CLS]) is
+ # used as the "sentence vector". Note that this only makes sense because
+ # the entire model is fine-tuned.
+ tokens = []
+ segment_ids = []
+ tokens.append("[CLS]")
+ segment_ids.append(0)
+ for token in tokens_a:
+ tokens.append(token)
+ segment_ids.append(0)
+ tokens.append("[SEP]")
+ segment_ids.append(0)
+
+ if tokens_b:
+ for token in tokens_b:
+ tokens.append(token)
+ segment_ids.append(1)
+ tokens.append("[SEP]")
+ segment_ids.append(1)
+
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(0)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ label_id = label_map[example.label] if label_map else example.label
+ if ex_index < 5:
+ logging.info("*** Example ***")
+ logging.info("guid: %s", (example.guid))
+ logging.info("tokens: %s",
+ " ".join([tokenization.printable_text(x) for x in tokens]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ logging.info("label: %s (id = %s)", example.label, str(label_id))
+ logging.info("weight: %s", example.weight)
+ logging.info("int_iden: %s", str(example.int_iden))
+
+ feature = InputFeatures(
+ input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ label_id=label_id,
+ is_real_example=True,
+ weight=example.weight,
+ int_iden=example.int_iden)
+
+ return feature
+
+
+def file_based_convert_examples_to_features(examples,
+ label_list,
+ max_seq_length,
+ tokenizer,
+ output_file,
+ label_type=None):
+ """Convert a set of `InputExample`s to a TFRecord file."""
+
+ tf.io.gfile.makedirs(os.path.dirname(output_file))
+ writer = tf.io.TFRecordWriter(output_file)
+
+ for (ex_index, example) in enumerate(examples):
+ if ex_index % 10000 == 0:
+ logging.info("Writing example %d of %d", ex_index, len(examples))
+
+ feature = convert_single_example(ex_index, example, label_list,
+ max_seq_length, tokenizer)
+
+ def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ features = collections.OrderedDict()
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_int_feature(feature.input_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+ if label_type is not None and label_type == float:
+ features["label_ids"] = create_float_feature([feature.label_id])
+ elif feature.label_id is not None:
+ features["label_ids"] = create_int_feature([feature.label_id])
+ features["is_real_example"] = create_int_feature(
+ [int(feature.is_real_example)])
+ if feature.weight is not None:
+ features["weight"] = create_float_feature([feature.weight])
+ if feature.int_iden is not None:
+ features["int_iden"] = create_int_feature([feature.int_iden])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+def _truncate_seq_pair(tokens_a, tokens_b, max_length):
+ """Truncates a sequence pair in place to the maximum length."""
+
+ # This is a simple heuristic which will always truncate the longer sequence
+ # one token at a time. This makes more sense than truncating an equal percent
+ # of tokens from each, since if one sequence is very short then each token
+ # that's truncated likely contains more information than a longer sequence.
+ while True:
+ total_length = len(tokens_a) + len(tokens_b)
+ if total_length <= max_length:
+ break
+ if len(tokens_a) > len(tokens_b):
+ tokens_a.pop()
+ else:
+ tokens_b.pop()
+
+
+def generate_tf_record_from_data_file(processor,
+ data_dir,
+ tokenizer,
+ train_data_output_path=None,
+ eval_data_output_path=None,
+ test_data_output_path=None,
+ max_seq_length=128):
+ """Generates and saves training data into a tf record file.
+
+ Arguments:
+ processor: Input processor object to be used for generating data. Subclass
+ of `DataProcessor`.
+ data_dir: Directory that contains train/eval data to process. Data files
+ should be in from "dev.tsv", "test.tsv", or "train.tsv".
+ tokenizer: The tokenizer to be applied on the data.
+ train_data_output_path: Output to which processed tf record for training
+ will be saved.
+ eval_data_output_path: Output to which processed tf record for evaluation
+ will be saved.
+ test_data_output_path: Output to which processed tf record for testing
+ will be saved. Must be a pattern template with {} if processor has
+ language specific test data.
+ max_seq_length: Maximum sequence length of the to be generated
+ training/eval data.
+
+ Returns:
+ A dictionary containing input meta data.
+ """
+ assert train_data_output_path or eval_data_output_path
+
+ label_list = processor.get_labels()
+ label_type = getattr(processor, "label_type", None)
+ is_regression = getattr(processor, "is_regression", False)
+ has_sample_weights = getattr(processor, "weight_key", False)
+ assert train_data_output_path
+
+ train_input_data_examples = processor.get_train_examples(data_dir)
+ file_based_convert_examples_to_features(train_input_data_examples, label_list,
+ max_seq_length, tokenizer,
+ train_data_output_path, label_type)
+ num_training_data = len(train_input_data_examples)
+
+ if eval_data_output_path:
+ eval_input_data_examples = processor.get_dev_examples(data_dir)
+ file_based_convert_examples_to_features(eval_input_data_examples,
+ label_list, max_seq_length,
+ tokenizer, eval_data_output_path,
+ label_type)
+
+ if test_data_output_path:
+ test_input_data_examples = processor.get_test_examples(data_dir)
+ if isinstance(test_input_data_examples, dict):
+ for language, examples in test_input_data_examples.items():
+ file_based_convert_examples_to_features(
+ examples, label_list, max_seq_length, tokenizer,
+ test_data_output_path.format(language), label_type)
+ else:
+ file_based_convert_examples_to_features(test_input_data_examples,
+ label_list, max_seq_length,
+ tokenizer, test_data_output_path,
+ label_type)
+
+ meta_data = {
+ "processor_type": processor.get_processor_name(),
+ "train_data_size": num_training_data,
+ "max_seq_length": max_seq_length,
+ }
+ if is_regression:
+ meta_data["task_type"] = "bert_regression"
+ meta_data["label_type"] = {int: "int", float: "float"}[label_type]
+ else:
+ meta_data["task_type"] = "bert_classification"
+ meta_data["num_labels"] = len(processor.get_labels())
+ if has_sample_weights:
+ meta_data["has_sample_weights"] = True
+
+ if eval_data_output_path:
+ meta_data["eval_data_size"] = len(eval_input_data_examples)
+
+ if test_data_output_path:
+ test_input_data_examples = processor.get_test_examples(data_dir)
+ if isinstance(test_input_data_examples, dict):
+ for language, examples in test_input_data_examples.items():
+ meta_data["test_{}_data_size".format(language)] = len(examples)
+ else:
+ meta_data["test_data_size"] = len(test_input_data_examples)
+
+ return meta_data
diff --git a/models/official/nlp/data/create_finetuning_data.py b/models/official/nlp/data/create_finetuning_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fae97e127680d8828d23442ecd7592abb39b584
--- /dev/null
+++ b/models/official/nlp/data/create_finetuning_data.py
@@ -0,0 +1,316 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BERT finetuning task dataset generator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import json
+import os
+
+from absl import app
+from absl import flags
+import tensorflow as tf
+from official.nlp.bert import tokenization
+from official.nlp.data import classifier_data_lib
+from official.nlp.data import sentence_retrieval_lib
+# word-piece tokenizer based squad_lib
+from official.nlp.data import squad_lib as squad_lib_wp
+# sentence-piece tokenizer based squad_lib
+from official.nlp.data import squad_lib_sp
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_enum(
+ "fine_tuning_task_type", "classification",
+ ["classification", "regression", "squad", "retrieval"],
+ "The name of the BERT fine tuning task for which data "
+ "will be generated..")
+
+# BERT classification specific flags.
+flags.DEFINE_string(
+ "input_data_dir", None,
+ "The input data dir. Should contain the .tsv files (or other data files) "
+ "for the task.")
+
+flags.DEFINE_enum("classification_task_name", "MNLI",
+ ["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI",
+ "PAWS-X", "XTREME-XNLI", "XTREME-PAWS-X"],
+ "The name of the task to train BERT classifier. The "
+ "difference between XTREME-XNLI and XNLI is: 1. the format "
+ "of input tsv files; 2. the dev set for XTREME is english "
+ "only and for XNLI is all languages combined. Same for "
+ "PAWS-X.")
+
+flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
+ "The name of sentence retrieval task for scoring")
+
+# XNLI task specific flag.
+flags.DEFINE_string(
+ "xnli_language", "en",
+ "Language of training data for XNIL task. If the value is 'all', the data "
+ "of all languages will be used for training.")
+
+# PAWS-X task specific flag.
+flags.DEFINE_string(
+ "pawsx_language", "en",
+ "Language of trainig data for PAWS-X task. If the value is 'all', the data "
+ "of all languages will be used for training.")
+
+# BERT Squad task specific flags.
+flags.DEFINE_string(
+ "squad_data_file", None,
+ "The input data file in for generating training data for BERT squad task.")
+
+flags.DEFINE_integer(
+ "doc_stride", 128,
+ "When splitting up a long document into chunks, how much stride to "
+ "take between chunks.")
+
+flags.DEFINE_integer(
+ "max_query_length", 64,
+ "The maximum number of tokens for the question. Questions longer than "
+ "this will be truncated to this length.")
+
+flags.DEFINE_bool(
+ "version_2_with_negative", False,
+ "If true, the SQuAD examples contain some that do not have an answer.")
+
+# Shared flags across BERT fine-tuning tasks.
+flags.DEFINE_string("vocab_file", None,
+ "The vocabulary file that the BERT model was trained on.")
+
+flags.DEFINE_string(
+ "train_data_output_path", None,
+ "The path in which generated training input data will be written as tf"
+ " records.")
+
+flags.DEFINE_string(
+ "eval_data_output_path", None,
+ "The path in which generated evaluation input data will be written as tf"
+ " records.")
+
+flags.DEFINE_string(
+ "test_data_output_path", None,
+ "The path in which generated test input data will be written as tf"
+ " records. If None, do not generate test data. Must be a pattern template"
+ " as test_{}.tfrecords if processor has language specific test data.")
+
+flags.DEFINE_string("meta_data_file_path", None,
+ "The path in which input meta data will be written.")
+
+flags.DEFINE_bool(
+ "do_lower_case", True,
+ "Whether to lower case the input text. Should be True for uncased "
+ "models and False for cased models.")
+
+flags.DEFINE_integer(
+ "max_seq_length", 128,
+ "The maximum total input sequence length after WordPiece tokenization. "
+ "Sequences longer than this will be truncated, and sequences shorter "
+ "than this will be padded.")
+
+flags.DEFINE_string("sp_model_file", "",
+ "The path to the model used by sentence piece tokenizer.")
+
+flags.DEFINE_enum(
+ "tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"],
+ "Specifies the tokenizer implementation, i.e., whehter to use word_piece "
+ "or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
+ "while ALBERT uses sentence_piece tokenizer.")
+
+flags.DEFINE_string("tfds_params", "",
+ "Comma-separated list of TFDS parameter assigments for "
+ "generic classfication data import (for more details "
+ "see the TfdsProcessor class documentation).")
+
+
+def generate_classifier_dataset():
+ """Generates classifier dataset and returns input meta data."""
+ assert (FLAGS.input_data_dir and FLAGS.classification_task_name
+ or FLAGS.tfds_params)
+
+ if FLAGS.tokenizer_impl == "word_piece":
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ processor_text_fn = tokenization.convert_to_unicode
+ else:
+ assert FLAGS.tokenizer_impl == "sentence_piece"
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
+ processor_text_fn = functools.partial(
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
+
+ if FLAGS.tfds_params:
+ processor = classifier_data_lib.TfdsProcessor(
+ tfds_params=FLAGS.tfds_params,
+ process_text_fn=processor_text_fn)
+ return classifier_data_lib.generate_tf_record_from_data_file(
+ processor,
+ None,
+ tokenizer,
+ train_data_output_path=FLAGS.train_data_output_path,
+ eval_data_output_path=FLAGS.eval_data_output_path,
+ test_data_output_path=FLAGS.test_data_output_path,
+ max_seq_length=FLAGS.max_seq_length)
+ else:
+ processors = {
+ "cola":
+ classifier_data_lib.ColaProcessor,
+ "mnli":
+ classifier_data_lib.MnliProcessor,
+ "mrpc":
+ classifier_data_lib.MrpcProcessor,
+ "qnli":
+ classifier_data_lib.QnliProcessor,
+ "qqp": classifier_data_lib.QqpProcessor,
+ "rte": classifier_data_lib.RteProcessor,
+ "sst-2":
+ classifier_data_lib.SstProcessor,
+ "xnli":
+ functools.partial(classifier_data_lib.XnliProcessor,
+ language=FLAGS.xnli_language),
+ "paws-x":
+ functools.partial(classifier_data_lib.PawsxProcessor,
+ language=FLAGS.pawsx_language),
+ "xtreme-xnli":
+ functools.partial(classifier_data_lib.XtremeXnliProcessor),
+ "xtreme-paws-x":
+ functools.partial(classifier_data_lib.XtremePawsxProcessor)
+ }
+ task_name = FLAGS.classification_task_name.lower()
+ if task_name not in processors:
+ raise ValueError("Task not found: %s" % (task_name))
+
+ processor = processors[task_name](process_text_fn=processor_text_fn)
+ return classifier_data_lib.generate_tf_record_from_data_file(
+ processor,
+ FLAGS.input_data_dir,
+ tokenizer,
+ train_data_output_path=FLAGS.train_data_output_path,
+ eval_data_output_path=FLAGS.eval_data_output_path,
+ test_data_output_path=FLAGS.test_data_output_path,
+ max_seq_length=FLAGS.max_seq_length)
+
+
+def generate_regression_dataset():
+ """Generates regression dataset and returns input meta data."""
+ if FLAGS.tokenizer_impl == "word_piece":
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ processor_text_fn = tokenization.convert_to_unicode
+ else:
+ assert FLAGS.tokenizer_impl == "sentence_piece"
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
+ processor_text_fn = functools.partial(
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
+
+ if FLAGS.tfds_params:
+ processor = classifier_data_lib.TfdsProcessor(
+ tfds_params=FLAGS.tfds_params,
+ process_text_fn=processor_text_fn)
+ return classifier_data_lib.generate_tf_record_from_data_file(
+ processor,
+ None,
+ tokenizer,
+ train_data_output_path=FLAGS.train_data_output_path,
+ eval_data_output_path=FLAGS.eval_data_output_path,
+ test_data_output_path=FLAGS.test_data_output_path,
+ max_seq_length=FLAGS.max_seq_length)
+ else:
+ raise ValueError("No data processor found for the given regression task.")
+
+
+def generate_squad_dataset():
+ """Generates squad training dataset and returns input meta data."""
+ assert FLAGS.squad_data_file
+ if FLAGS.tokenizer_impl == "word_piece":
+ return squad_lib_wp.generate_tf_record_from_json_file(
+ FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
+ FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
+ FLAGS.doc_stride, FLAGS.version_2_with_negative)
+ else:
+ assert FLAGS.tokenizer_impl == "sentence_piece"
+ return squad_lib_sp.generate_tf_record_from_json_file(
+ FLAGS.squad_data_file, FLAGS.sp_model_file,
+ FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
+ FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
+
+
+def generate_retrieval_dataset():
+ """Generate retrieval test and dev dataset and returns input meta data."""
+ assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
+ if FLAGS.tokenizer_impl == "word_piece":
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ processor_text_fn = tokenization.convert_to_unicode
+ else:
+ assert FLAGS.tokenizer_impl == "sentence_piece"
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
+ processor_text_fn = functools.partial(
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
+
+ processors = {
+ "bucc": sentence_retrieval_lib.BuccProcessor,
+ "tatoeba": sentence_retrieval_lib.TatoebaProcessor,
+ }
+
+ task_name = FLAGS.retrieval_task_name.lower()
+ if task_name not in processors:
+ raise ValueError("Task not found: %s" % task_name)
+
+ processor = processors[task_name](process_text_fn=processor_text_fn)
+
+ return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
+ processor,
+ FLAGS.input_data_dir,
+ tokenizer,
+ FLAGS.eval_data_output_path,
+ FLAGS.test_data_output_path,
+ FLAGS.max_seq_length)
+
+
+def main(_):
+ if FLAGS.tokenizer_impl == "word_piece":
+ if not FLAGS.vocab_file:
+ raise ValueError(
+ "FLAG vocab_file for word-piece tokenizer is not specified.")
+ else:
+ assert FLAGS.tokenizer_impl == "sentence_piece"
+ if not FLAGS.sp_model_file:
+ raise ValueError(
+ "FLAG sp_model_file for sentence-piece tokenizer is not specified.")
+
+ if FLAGS.fine_tuning_task_type != "retrieval":
+ flags.mark_flag_as_required("train_data_output_path")
+
+ if FLAGS.fine_tuning_task_type == "classification":
+ input_meta_data = generate_classifier_dataset()
+ elif FLAGS.fine_tuning_task_type == "regression":
+ input_meta_data = generate_regression_dataset()
+ elif FLAGS.fine_tuning_task_type == "retrieval":
+ input_meta_data = generate_retrieval_dataset()
+ else:
+ input_meta_data = generate_squad_dataset()
+
+ tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
+ with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
+ writer.write(json.dumps(input_meta_data, indent=4) + "\n")
+
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("meta_data_file_path")
+ app.run(main)
diff --git a/models/official/nlp/data/create_pretraining_data.py b/models/official/nlp/data/create_pretraining_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..79dac57ac8775687673604af6fb2fb50c9f74244
--- /dev/null
+++ b/models/official/nlp/data/create_pretraining_data.py
@@ -0,0 +1,486 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Create masked LM/next sentence masked_lm TF examples for BERT."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import random
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+
+from official.nlp.bert import tokenization
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("input_file", None,
+ "Input raw text file (or comma-separated list of files).")
+
+flags.DEFINE_string(
+ "output_file", None,
+ "Output TF example file (or comma-separated list of files).")
+
+flags.DEFINE_string("vocab_file", None,
+ "The vocabulary file that the BERT model was trained on.")
+
+flags.DEFINE_bool(
+ "do_lower_case", True,
+ "Whether to lower case the input text. Should be True for uncased "
+ "models and False for cased models.")
+
+flags.DEFINE_bool(
+ "do_whole_word_mask", False,
+ "Whether to use whole word masking rather than per-WordPiece masking.")
+
+flags.DEFINE_bool(
+ "gzip_compress", False,
+ "Whether to use `GZIP` compress option to get compressed TFRecord files.")
+
+flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
+
+flags.DEFINE_integer("max_predictions_per_seq", 20,
+ "Maximum number of masked LM predictions per sequence.")
+
+flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
+
+flags.DEFINE_integer(
+ "dupe_factor", 10,
+ "Number of times to duplicate the input data (with different masks).")
+
+flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
+
+flags.DEFINE_float(
+ "short_seq_prob", 0.1,
+ "Probability of creating sequences which are shorter than the "
+ "maximum length.")
+
+
+class TrainingInstance(object):
+ """A single training instance (sentence pair)."""
+
+ def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
+ is_random_next):
+ self.tokens = tokens
+ self.segment_ids = segment_ids
+ self.is_random_next = is_random_next
+ self.masked_lm_positions = masked_lm_positions
+ self.masked_lm_labels = masked_lm_labels
+
+ def __str__(self):
+ s = ""
+ s += "tokens: %s\n" % (" ".join(
+ [tokenization.printable_text(x) for x in self.tokens]))
+ s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
+ s += "is_random_next: %s\n" % self.is_random_next
+ s += "masked_lm_positions: %s\n" % (" ".join(
+ [str(x) for x in self.masked_lm_positions]))
+ s += "masked_lm_labels: %s\n" % (" ".join(
+ [tokenization.printable_text(x) for x in self.masked_lm_labels]))
+ s += "\n"
+ return s
+
+ def __repr__(self):
+ return self.__str__()
+
+
+def write_instance_to_example_files(instances, tokenizer, max_seq_length,
+ max_predictions_per_seq, output_files,
+ gzip_compress):
+ """Create TF example files from `TrainingInstance`s."""
+ writers = []
+ for output_file in output_files:
+ writers.append(
+ tf.io.TFRecordWriter(
+ output_file, options="GZIP" if gzip_compress else ""))
+
+ writer_index = 0
+
+ total_written = 0
+ for (inst_index, instance) in enumerate(instances):
+ input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
+ input_mask = [1] * len(input_ids)
+ segment_ids = list(instance.segment_ids)
+ assert len(input_ids) <= max_seq_length
+
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(0)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ masked_lm_positions = list(instance.masked_lm_positions)
+ masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
+ masked_lm_weights = [1.0] * len(masked_lm_ids)
+
+ while len(masked_lm_positions) < max_predictions_per_seq:
+ masked_lm_positions.append(0)
+ masked_lm_ids.append(0)
+ masked_lm_weights.append(0.0)
+
+ next_sentence_label = 1 if instance.is_random_next else 0
+
+ features = collections.OrderedDict()
+ features["input_ids"] = create_int_feature(input_ids)
+ features["input_mask"] = create_int_feature(input_mask)
+ features["segment_ids"] = create_int_feature(segment_ids)
+ features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
+ features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
+ features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
+ features["next_sentence_labels"] = create_int_feature([next_sentence_label])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+
+ writers[writer_index].write(tf_example.SerializeToString())
+ writer_index = (writer_index + 1) % len(writers)
+
+ total_written += 1
+
+ if inst_index < 20:
+ logging.info("*** Example ***")
+ logging.info("tokens: %s", " ".join(
+ [tokenization.printable_text(x) for x in instance.tokens]))
+
+ for feature_name in features.keys():
+ feature = features[feature_name]
+ values = []
+ if feature.int64_list.value:
+ values = feature.int64_list.value
+ elif feature.float_list.value:
+ values = feature.float_list.value
+ logging.info("%s: %s", feature_name, " ".join([str(x) for x in values]))
+
+ for writer in writers:
+ writer.close()
+
+ logging.info("Wrote %d total instances", total_written)
+
+
+def create_int_feature(values):
+ feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return feature
+
+
+def create_float_feature(values):
+ feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return feature
+
+
+def create_training_instances(input_files,
+ tokenizer,
+ max_seq_length,
+ dupe_factor,
+ short_seq_prob,
+ masked_lm_prob,
+ max_predictions_per_seq,
+ rng,
+ do_whole_word_mask=False):
+ """Create `TrainingInstance`s from raw text."""
+ all_documents = [[]]
+
+ # Input file format:
+ # (1) One sentence per line. These should ideally be actual sentences, not
+ # entire paragraphs or arbitrary spans of text. (Because we use the
+ # sentence boundaries for the "next sentence prediction" task).
+ # (2) Blank lines between documents. Document boundaries are needed so
+ # that the "next sentence prediction" task doesn't span between documents.
+ for input_file in input_files:
+ with tf.io.gfile.GFile(input_file, "rb") as reader:
+ while True:
+ line = tokenization.convert_to_unicode(reader.readline())
+ if not line:
+ break
+ line = line.strip()
+
+ # Empty lines are used as document delimiters
+ if not line:
+ all_documents.append([])
+ tokens = tokenizer.tokenize(line)
+ if tokens:
+ all_documents[-1].append(tokens)
+
+ # Remove empty documents
+ all_documents = [x for x in all_documents if x]
+ rng.shuffle(all_documents)
+
+ vocab_words = list(tokenizer.vocab.keys())
+ instances = []
+ for _ in range(dupe_factor):
+ for document_index in range(len(all_documents)):
+ instances.extend(
+ create_instances_from_document(
+ all_documents, document_index, max_seq_length, short_seq_prob,
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
+ do_whole_word_mask))
+
+ rng.shuffle(instances)
+ return instances
+
+
+def create_instances_from_document(
+ all_documents, document_index, max_seq_length, short_seq_prob,
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
+ do_whole_word_mask=False):
+ """Creates `TrainingInstance`s for a single document."""
+ document = all_documents[document_index]
+
+ # Account for [CLS], [SEP], [SEP]
+ max_num_tokens = max_seq_length - 3
+
+ # We *usually* want to fill up the entire sequence since we are padding
+ # to `max_seq_length` anyways, so short sequences are generally wasted
+ # computation. However, we *sometimes*
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
+ # sequences to minimize the mismatch between pre-training and fine-tuning.
+ # The `target_seq_length` is just a rough target however, whereas
+ # `max_seq_length` is a hard limit.
+ target_seq_length = max_num_tokens
+ if rng.random() < short_seq_prob:
+ target_seq_length = rng.randint(2, max_num_tokens)
+
+ # We DON'T just concatenate all of the tokens from a document into a long
+ # sequence and choose an arbitrary split point because this would make the
+ # next sentence prediction task too easy. Instead, we split the input into
+ # segments "A" and "B" based on the actual "sentences" provided by the user
+ # input.
+ instances = []
+ current_chunk = []
+ current_length = 0
+ i = 0
+ while i < len(document):
+ segment = document[i]
+ current_chunk.append(segment)
+ current_length += len(segment)
+ if i == len(document) - 1 or current_length >= target_seq_length:
+ if current_chunk:
+ # `a_end` is how many segments from `current_chunk` go into the `A`
+ # (first) sentence.
+ a_end = 1
+ if len(current_chunk) >= 2:
+ a_end = rng.randint(1, len(current_chunk) - 1)
+
+ tokens_a = []
+ for j in range(a_end):
+ tokens_a.extend(current_chunk[j])
+
+ tokens_b = []
+ # Random next
+ is_random_next = False
+ if len(current_chunk) == 1 or rng.random() < 0.5:
+ is_random_next = True
+ target_b_length = target_seq_length - len(tokens_a)
+
+ # This should rarely go for more than one iteration for large
+ # corpora. However, just to be careful, we try to make sure that
+ # the random document is not the same as the document
+ # we're processing.
+ for _ in range(10):
+ random_document_index = rng.randint(0, len(all_documents) - 1)
+ if random_document_index != document_index:
+ break
+
+ random_document = all_documents[random_document_index]
+ random_start = rng.randint(0, len(random_document) - 1)
+ for j in range(random_start, len(random_document)):
+ tokens_b.extend(random_document[j])
+ if len(tokens_b) >= target_b_length:
+ break
+ # We didn't actually use these segments so we "put them back" so
+ # they don't go to waste.
+ num_unused_segments = len(current_chunk) - a_end
+ i -= num_unused_segments
+ # Actual next
+ else:
+ is_random_next = False
+ for j in range(a_end, len(current_chunk)):
+ tokens_b.extend(current_chunk[j])
+ truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
+
+ assert len(tokens_a) >= 1
+ assert len(tokens_b) >= 1
+
+ tokens = []
+ segment_ids = []
+ tokens.append("[CLS]")
+ segment_ids.append(0)
+ for token in tokens_a:
+ tokens.append(token)
+ segment_ids.append(0)
+
+ tokens.append("[SEP]")
+ segment_ids.append(0)
+
+ for token in tokens_b:
+ tokens.append(token)
+ segment_ids.append(1)
+ tokens.append("[SEP]")
+ segment_ids.append(1)
+
+ (tokens, masked_lm_positions,
+ masked_lm_labels) = create_masked_lm_predictions(
+ tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
+ do_whole_word_mask)
+ instance = TrainingInstance(
+ tokens=tokens,
+ segment_ids=segment_ids,
+ is_random_next=is_random_next,
+ masked_lm_positions=masked_lm_positions,
+ masked_lm_labels=masked_lm_labels)
+ instances.append(instance)
+ current_chunk = []
+ current_length = 0
+ i += 1
+
+ return instances
+
+
+MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
+ ["index", "label"])
+
+
+def create_masked_lm_predictions(tokens, masked_lm_prob,
+ max_predictions_per_seq, vocab_words, rng,
+ do_whole_word_mask):
+ """Creates the predictions for the masked LM objective."""
+
+ cand_indexes = []
+ for (i, token) in enumerate(tokens):
+ if token == "[CLS]" or token == "[SEP]":
+ continue
+ # Whole Word Masking means that if we mask all of the wordpieces
+ # corresponding to an original word. When a word has been split into
+ # WordPieces, the first token does not have any marker and any subsequence
+ # tokens are prefixed with ##. So whenever we see the ## token, we
+ # append it to the previous set of word indexes.
+ #
+ # Note that Whole Word Masking does *not* change the training code
+ # at all -- we still predict each WordPiece independently, softmaxed
+ # over the entire vocabulary.
+ if (do_whole_word_mask and len(cand_indexes) >= 1 and
+ token.startswith("##")):
+ cand_indexes[-1].append(i)
+ else:
+ cand_indexes.append([i])
+
+ rng.shuffle(cand_indexes)
+
+ output_tokens = list(tokens)
+
+ num_to_predict = min(max_predictions_per_seq,
+ max(1, int(round(len(tokens) * masked_lm_prob))))
+
+ masked_lms = []
+ covered_indexes = set()
+ for index_set in cand_indexes:
+ if len(masked_lms) >= num_to_predict:
+ break
+ # If adding a whole-word mask would exceed the maximum number of
+ # predictions, then just skip this candidate.
+ if len(masked_lms) + len(index_set) > num_to_predict:
+ continue
+ is_any_index_covered = False
+ for index in index_set:
+ if index in covered_indexes:
+ is_any_index_covered = True
+ break
+ if is_any_index_covered:
+ continue
+ for index in index_set:
+ covered_indexes.add(index)
+
+ masked_token = None
+ # 80% of the time, replace with [MASK]
+ if rng.random() < 0.8:
+ masked_token = "[MASK]"
+ else:
+ # 10% of the time, keep original
+ if rng.random() < 0.5:
+ masked_token = tokens[index]
+ # 10% of the time, replace with random word
+ else:
+ masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
+
+ output_tokens[index] = masked_token
+
+ masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
+ assert len(masked_lms) <= num_to_predict
+ masked_lms = sorted(masked_lms, key=lambda x: x.index)
+
+ masked_lm_positions = []
+ masked_lm_labels = []
+ for p in masked_lms:
+ masked_lm_positions.append(p.index)
+ masked_lm_labels.append(p.label)
+
+ return (output_tokens, masked_lm_positions, masked_lm_labels)
+
+
+def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
+ """Truncates a pair of sequences to a maximum sequence length."""
+ while True:
+ total_length = len(tokens_a) + len(tokens_b)
+ if total_length <= max_num_tokens:
+ break
+
+ trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
+ assert len(trunc_tokens) >= 1
+
+ # We want to sometimes truncate from the front and sometimes from the
+ # back to add more randomness and avoid biases.
+ if rng.random() < 0.5:
+ del trunc_tokens[0]
+ else:
+ trunc_tokens.pop()
+
+
+def main(_):
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+
+ input_files = []
+ for input_pattern in FLAGS.input_file.split(","):
+ input_files.extend(tf.io.gfile.glob(input_pattern))
+
+ logging.info("*** Reading from input files ***")
+ for input_file in input_files:
+ logging.info(" %s", input_file)
+
+ rng = random.Random(FLAGS.random_seed)
+ instances = create_training_instances(
+ input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
+ FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
+ rng, FLAGS.do_whole_word_mask)
+
+ output_files = FLAGS.output_file.split(",")
+ logging.info("*** Writing to output files ***")
+ for output_file in output_files:
+ logging.info(" %s", output_file)
+
+ write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
+ FLAGS.max_predictions_per_seq, output_files,
+ FLAGS.gzip_compress)
+
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("input_file")
+ flags.mark_flag_as_required("output_file")
+ flags.mark_flag_as_required("vocab_file")
+ app.run(main)
diff --git a/models/official/nlp/data/pretrain_dataloader.py b/models/official/nlp/data/pretrain_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..18325090caa6d83e68b4077aac4a27ee69bea938
--- /dev/null
+++ b/models/official/nlp/data/pretrain_dataloader.py
@@ -0,0 +1,97 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loads dataset for the BERT pretraining task."""
+from typing import Mapping, Optional
+
+import tensorflow as tf
+
+from official.core import input_reader
+
+
+class BertPretrainDataLoader:
+ """A class to load dataset for bert pretraining task."""
+
+ def __init__(self, params):
+ """Inits `BertPretrainDataLoader` class.
+
+ Args:
+ params: A `BertPretrainDataConfig` object.
+ """
+ self._params = params
+ self._seq_length = params.seq_length
+ self._max_predictions_per_seq = params.max_predictions_per_seq
+ self._use_next_sentence_label = params.use_next_sentence_label
+ self._use_position_id = params.use_position_id
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ name_to_features = {
+ 'input_ids':
+ tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'input_mask':
+ tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'segment_ids':
+ tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'masked_lm_positions':
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
+ 'masked_lm_ids':
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
+ 'masked_lm_weights':
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
+ }
+ if self._use_next_sentence_label:
+ name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
+ tf.int64)
+ if self._use_position_id:
+ name_to_features['position_ids'] = tf.io.FixedLenFeature(
+ [self._seq_length], tf.int64)
+
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _parse(self, record: Mapping[str, tf.Tensor]):
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids'],
+ 'masked_lm_positions': record['masked_lm_positions'],
+ 'masked_lm_ids': record['masked_lm_ids'],
+ 'masked_lm_weights': record['masked_lm_weights'],
+ }
+ if self._use_next_sentence_label:
+ x['next_sentence_labels'] = record['next_sentence_labels']
+ if self._use_position_id:
+ x['position_ids'] = record['position_ids']
+
+ return x
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ params=self._params,
+ decoder_fn=self._decode,
+ parser_fn=self._parse)
+ return reader.read(input_context)
diff --git a/models/official/nlp/data/sentence_prediction_dataloader.py b/models/official/nlp/data/sentence_prediction_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..60dd788403725aeeca2028b237c3330bbf22716c
--- /dev/null
+++ b/models/official/nlp/data/sentence_prediction_dataloader.py
@@ -0,0 +1,64 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loads dataset for the sentence prediction (classification) task."""
+from typing import Mapping, Optional
+import tensorflow as tf
+
+from official.core import input_reader
+
+
+class SentencePredictionDataLoader:
+ """A class to load dataset for sentence prediction (classification) task."""
+
+ def __init__(self, params):
+ self._params = params
+ self._seq_length = params.seq_length
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'label_ids': tf.io.FixedLenFeature([], tf.int64),
+ }
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in example:
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _parse(self, record: Mapping[str, tf.Tensor]):
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids']
+ }
+ y = record['label_ids']
+ return (x, y)
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
+ return reader.read(input_context)
diff --git a/models/official/nlp/data/sentence_retrieval_lib.py b/models/official/nlp/data/sentence_retrieval_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8e83ae579f8221b93e790ea62b91c3d6d2b9e90
--- /dev/null
+++ b/models/official/nlp/data/sentence_retrieval_lib.py
@@ -0,0 +1,168 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BERT library to process data for cross lingual sentence retrieval task."""
+
+import os
+
+from absl import logging
+from official.nlp.bert import tokenization
+from official.nlp.data import classifier_data_lib
+
+
+class BuccProcessor(classifier_data_lib.DataProcessor):
+ """Procssor for Xtreme BUCC data set."""
+ supported_languages = ["de", "fr", "ru", "zh"]
+
+ def __init__(self,
+ process_text_fn=tokenization.convert_to_unicode):
+ super(BuccProcessor, self).__init__(process_text_fn)
+ self.languages = BuccProcessor.supported_languages
+
+ def get_dev_examples(self, data_dir, file_pattern):
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
+ "sample")
+
+ def get_test_examples(self, data_dir, file_pattern):
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
+ "test")
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "BUCC"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "%s-%s" % (set_type, i)
+ int_iden = int(line[0].split("-")[1])
+ text_a = self.process_text_fn(line[1])
+ examples.append(
+ classifier_data_lib.InputExample(
+ guid=guid, text_a=text_a, int_iden=int_iden))
+ return examples
+
+
+class TatoebaProcessor(classifier_data_lib.DataProcessor):
+ """Procssor for Xtreme Tatoeba data set."""
+ supported_languages = [
+ "af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
+ "he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
+ "nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
+ ]
+
+ def __init__(self,
+ process_text_fn=tokenization.convert_to_unicode):
+ super(TatoebaProcessor, self).__init__(process_text_fn)
+ self.languages = TatoebaProcessor.supported_languages
+
+ def get_test_examples(self, data_dir, file_path):
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, file_path)), "test")
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "TATOEBA"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "%s-%s" % (set_type, i)
+ text_a = self.process_text_fn(line[0])
+ examples.append(
+ classifier_data_lib.InputExample(
+ guid=guid, text_a=text_a, int_iden=i))
+ return examples
+
+
+def generate_sentence_retrevial_tf_record(processor,
+ data_dir,
+ tokenizer,
+ eval_data_output_path=None,
+ test_data_output_path=None,
+ max_seq_length=128):
+ """Generates the tf records for retrieval tasks.
+
+ Args:
+ processor: Input processor object to be used for generating data. Subclass
+ of `DataProcessor`.
+ data_dir: Directory that contains train/eval data to process. Data files
+ should be in from.
+ tokenizer: The tokenizer to be applied on the data.
+ eval_data_output_path: Output to which processed tf record for evaluation
+ will be saved.
+ test_data_output_path: Output to which processed tf record for testing
+ will be saved. Must be a pattern template with {} if processor has
+ language specific test data.
+ max_seq_length: Maximum sequence length of the to be generated
+ training/eval data.
+
+ Returns:
+ A dictionary containing input meta data.
+ """
+ assert eval_data_output_path or test_data_output_path
+
+ if processor.get_processor_name() == "BUCC":
+ path_pattern = "{}-en.{{}}.{}"
+
+ if processor.get_processor_name() == "TATOEBA":
+ path_pattern = "{}-en.{}"
+
+ meta_data = {
+ "processor_type": processor.get_processor_name(),
+ "max_seq_length": max_seq_length,
+ "number_eval_data": {},
+ "number_test_data": {},
+ }
+ logging.info("Start to process %s task data", processor.get_processor_name())
+
+ for lang_a in processor.languages:
+ for lang_b in [lang_a, "en"]:
+ if eval_data_output_path:
+ eval_input_data_examples = processor.get_dev_examples(
+ data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
+
+ num_eval_data = len(eval_input_data_examples)
+ logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
+ lang_a, lang_b)
+ output_file = os.path.join(
+ eval_data_output_path,
+ "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
+ classifier_data_lib.file_based_convert_examples_to_features(
+ eval_input_data_examples, None, max_seq_length, tokenizer,
+ output_file, None)
+ meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
+
+ if test_data_output_path:
+ test_input_data_examples = processor.get_test_examples(
+ data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
+
+ num_test_data = len(test_input_data_examples)
+ logging.info("Processing %d test examples of %s-en.%s", num_test_data,
+ lang_a, lang_b)
+ output_file = os.path.join(
+ test_data_output_path,
+ "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
+ classifier_data_lib.file_based_convert_examples_to_features(
+ test_input_data_examples, None, max_seq_length, tokenizer,
+ output_file, None)
+ meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
+
+ return meta_data
diff --git a/models/official/nlp/data/squad_lib.py b/models/official/nlp/data/squad_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbf4c604123c541e7830ffa7176b182a843eef58
--- /dev/null
+++ b/models/official/nlp/data/squad_lib.py
@@ -0,0 +1,898 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library to process data for SQuAD 1.1 and SQuAD 2.0."""
+
+# pylint: disable=g-bad-import-order
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import copy
+import json
+import math
+import os
+import six
+
+from absl import logging
+import tensorflow as tf
+
+from official.nlp.bert import tokenization
+
+
+class SquadExample(object):
+ """A single training/test example for simple sequence classification.
+
+ For examples without an answer, the start and end position are -1.
+
+ Attributes:
+ qas_id: ID of the question-answer pair.
+ question_text: Original text for the question.
+ doc_tokens: The list of tokens in the context obtained by splitting
+ on whitespace only.
+ orig_answer_text: Original text for the answer.
+ start_position: Starting index of the answer in `doc_tokens`.
+ end_position: Ending index of the answer in `doc_tokens`.
+ is_impossible: Whether the question is impossible to answer given the
+ context. Only used in SQuAD 2.0.
+ """
+
+ def __init__(self,
+ qas_id,
+ question_text,
+ doc_tokens,
+ orig_answer_text=None,
+ start_position=None,
+ end_position=None,
+ is_impossible=False):
+ self.qas_id = qas_id
+ self.question_text = question_text
+ self.doc_tokens = doc_tokens
+ self.orig_answer_text = orig_answer_text
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ s = ""
+ s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
+ s += ", question_text: %s" % (
+ tokenization.printable_text(self.question_text))
+ s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
+ if self.start_position:
+ s += ", start_position: %d" % (self.start_position)
+ if self.start_position:
+ s += ", end_position: %d" % (self.end_position)
+ if self.start_position:
+ s += ", is_impossible: %r" % (self.is_impossible)
+ return s
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ example_index,
+ doc_span_index,
+ tokens,
+ token_to_orig_map,
+ token_is_max_context,
+ input_ids,
+ input_mask,
+ segment_ids,
+ start_position=None,
+ end_position=None,
+ is_impossible=None):
+ self.unique_id = unique_id
+ self.example_index = example_index
+ self.doc_span_index = doc_span_index
+ self.tokens = tokens
+ self.token_to_orig_map = token_to_orig_map
+ self.token_is_max_context = token_is_max_context
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.segment_ids = segment_ids
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+
+class FeatureWriter(object):
+ """Writes InputFeature to TF example file."""
+
+ def __init__(self, filename, is_training):
+ self.filename = filename
+ self.is_training = is_training
+ self.num_features = 0
+ tf.io.gfile.makedirs(os.path.dirname(filename))
+ self._writer = tf.io.TFRecordWriter(filename)
+
+ def process_feature(self, feature):
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
+ self.num_features += 1
+
+ def create_int_feature(values):
+ feature = tf.train.Feature(
+ int64_list=tf.train.Int64List(value=list(values)))
+ return feature
+
+ features = collections.OrderedDict()
+ features["unique_ids"] = create_int_feature([feature.unique_id])
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_int_feature(feature.input_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+
+ if self.is_training:
+ features["start_positions"] = create_int_feature([feature.start_position])
+ features["end_positions"] = create_int_feature([feature.end_position])
+ impossible = 0
+ if feature.is_impossible:
+ impossible = 1
+ features["is_impossible"] = create_int_feature([impossible])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ self._writer.write(tf_example.SerializeToString())
+
+ def close(self):
+ self._writer.close()
+
+
+def read_squad_examples(input_file, is_training, version_2_with_negative):
+ """Read a SQuAD json file into a list of SquadExample."""
+ with tf.io.gfile.GFile(input_file, "r") as reader:
+ input_data = json.load(reader)["data"]
+
+ def is_whitespace(c):
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
+ return True
+ return False
+
+ examples = []
+ for entry in input_data:
+ for paragraph in entry["paragraphs"]:
+ paragraph_text = paragraph["context"]
+ doc_tokens = []
+ char_to_word_offset = []
+ prev_is_whitespace = True
+ for c in paragraph_text:
+ if is_whitespace(c):
+ prev_is_whitespace = True
+ else:
+ if prev_is_whitespace:
+ doc_tokens.append(c)
+ else:
+ doc_tokens[-1] += c
+ prev_is_whitespace = False
+ char_to_word_offset.append(len(doc_tokens) - 1)
+
+ for qa in paragraph["qas"]:
+ qas_id = qa["id"]
+ question_text = qa["question"]
+ start_position = None
+ end_position = None
+ orig_answer_text = None
+ is_impossible = False
+ if is_training:
+
+ if version_2_with_negative:
+ is_impossible = qa["is_impossible"]
+ if (len(qa["answers"]) != 1) and (not is_impossible):
+ raise ValueError(
+ "For training, each question should have exactly 1 answer.")
+ if not is_impossible:
+ answer = qa["answers"][0]
+ orig_answer_text = answer["text"]
+ answer_offset = answer["answer_start"]
+ answer_length = len(orig_answer_text)
+ start_position = char_to_word_offset[answer_offset]
+ end_position = char_to_word_offset[answer_offset + answer_length -
+ 1]
+ # Only add answers where the text can be exactly recovered from the
+ # document. If this CAN'T happen it's likely due to weird Unicode
+ # stuff so we will just skip the example.
+ #
+ # Note that this means for training mode, every example is NOT
+ # guaranteed to be preserved.
+ actual_text = " ".join(
+ doc_tokens[start_position:(end_position + 1)])
+ cleaned_answer_text = " ".join(
+ tokenization.whitespace_tokenize(orig_answer_text))
+ if actual_text.find(cleaned_answer_text) == -1:
+ logging.warning("Could not find answer: '%s' vs. '%s'",
+ actual_text, cleaned_answer_text)
+ continue
+ else:
+ start_position = -1
+ end_position = -1
+ orig_answer_text = ""
+
+ example = SquadExample(
+ qas_id=qas_id,
+ question_text=question_text,
+ doc_tokens=doc_tokens,
+ orig_answer_text=orig_answer_text,
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=is_impossible)
+ examples.append(example)
+
+ return examples
+
+
+def convert_examples_to_features(examples,
+ tokenizer,
+ max_seq_length,
+ doc_stride,
+ max_query_length,
+ is_training,
+ output_fn,
+ batch_size=None):
+ """Loads a data file into a list of `InputBatch`s."""
+
+ base_id = 1000000000
+ unique_id = base_id
+ feature = None
+ for (example_index, example) in enumerate(examples):
+ query_tokens = tokenizer.tokenize(example.question_text)
+
+ if len(query_tokens) > max_query_length:
+ query_tokens = query_tokens[0:max_query_length]
+
+ tok_to_orig_index = []
+ orig_to_tok_index = []
+ all_doc_tokens = []
+ for (i, token) in enumerate(example.doc_tokens):
+ orig_to_tok_index.append(len(all_doc_tokens))
+ sub_tokens = tokenizer.tokenize(token)
+ for sub_token in sub_tokens:
+ tok_to_orig_index.append(i)
+ all_doc_tokens.append(sub_token)
+
+ tok_start_position = None
+ tok_end_position = None
+ if is_training and example.is_impossible:
+ tok_start_position = -1
+ tok_end_position = -1
+ if is_training and not example.is_impossible:
+ tok_start_position = orig_to_tok_index[example.start_position]
+ if example.end_position < len(example.doc_tokens) - 1:
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
+ else:
+ tok_end_position = len(all_doc_tokens) - 1
+ (tok_start_position, tok_end_position) = _improve_answer_span(
+ all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
+ example.orig_answer_text)
+
+ # The -3 accounts for [CLS], [SEP] and [SEP]
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
+
+ # We can have documents that are longer than the maximum sequence length.
+ # To deal with this we do a sliding window approach, where we take chunks
+ # of the up to our max length with a stride of `doc_stride`.
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
+ "DocSpan", ["start", "length"])
+ doc_spans = []
+ start_offset = 0
+ while start_offset < len(all_doc_tokens):
+ length = len(all_doc_tokens) - start_offset
+ if length > max_tokens_for_doc:
+ length = max_tokens_for_doc
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
+ if start_offset + length == len(all_doc_tokens):
+ break
+ start_offset += min(length, doc_stride)
+
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
+ tokens = []
+ token_to_orig_map = {}
+ token_is_max_context = {}
+ segment_ids = []
+ tokens.append("[CLS]")
+ segment_ids.append(0)
+ for token in query_tokens:
+ tokens.append(token)
+ segment_ids.append(0)
+ tokens.append("[SEP]")
+ segment_ids.append(0)
+
+ for i in range(doc_span.length):
+ split_token_index = doc_span.start + i
+ token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
+
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
+ split_token_index)
+ token_is_max_context[len(tokens)] = is_max_context
+ tokens.append(all_doc_tokens[split_token_index])
+ segment_ids.append(1)
+ tokens.append("[SEP]")
+ segment_ids.append(1)
+
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(0)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ start_position = None
+ end_position = None
+ if is_training and not example.is_impossible:
+ # For training, if our document chunk does not contain an annotation
+ # we throw it out, since there is nothing to predict.
+ doc_start = doc_span.start
+ doc_end = doc_span.start + doc_span.length - 1
+ out_of_span = False
+ if not (tok_start_position >= doc_start and
+ tok_end_position <= doc_end):
+ out_of_span = True
+ if out_of_span:
+ start_position = 0
+ end_position = 0
+ else:
+ doc_offset = len(query_tokens) + 2
+ start_position = tok_start_position - doc_start + doc_offset
+ end_position = tok_end_position - doc_start + doc_offset
+
+ if is_training and example.is_impossible:
+ start_position = 0
+ end_position = 0
+
+ if example_index < 20:
+ logging.info("*** Example ***")
+ logging.info("unique_id: %s", (unique_id))
+ logging.info("example_index: %s", (example_index))
+ logging.info("doc_span_index: %s", (doc_span_index))
+ logging.info("tokens: %s",
+ " ".join([tokenization.printable_text(x) for x in tokens]))
+ logging.info(
+ "token_to_orig_map: %s", " ".join([
+ "%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)
+ ]))
+ logging.info(
+ "token_is_max_context: %s", " ".join([
+ "%d:%s" % (x, y)
+ for (x, y) in six.iteritems(token_is_max_context)
+ ]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ if is_training and example.is_impossible:
+ logging.info("impossible example")
+ if is_training and not example.is_impossible:
+ answer_text = " ".join(tokens[start_position:(end_position + 1)])
+ logging.info("start_position: %d", (start_position))
+ logging.info("end_position: %d", (end_position))
+ logging.info("answer: %s", tokenization.printable_text(answer_text))
+
+ feature = InputFeatures(
+ unique_id=unique_id,
+ example_index=example_index,
+ doc_span_index=doc_span_index,
+ tokens=tokens,
+ token_to_orig_map=token_to_orig_map,
+ token_is_max_context=token_is_max_context,
+ input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=example.is_impossible)
+
+ # Run callback
+ if is_training:
+ output_fn(feature)
+ else:
+ output_fn(feature, is_padding=False)
+
+ unique_id += 1
+
+ if not is_training and feature:
+ assert batch_size
+ num_padding = 0
+ num_examples = unique_id - base_id
+ if unique_id % batch_size != 0:
+ num_padding = batch_size - (num_examples % batch_size)
+ logging.info("Adding padding examples to make sure no partial batch.")
+ logging.info("Adds %d padding examples for inference.", num_padding)
+ dummy_feature = copy.deepcopy(feature)
+ for _ in range(num_padding):
+ dummy_feature.unique_id = unique_id
+
+ # Run callback
+ output_fn(feature, is_padding=True)
+ unique_id += 1
+ return unique_id - base_id
+
+
+def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
+ orig_answer_text):
+ """Returns tokenized answer spans that better match the annotated answer."""
+
+ # The SQuAD annotations are character based. We first project them to
+ # whitespace-tokenized words. But then after WordPiece tokenization, we can
+ # often find a "better match". For example:
+ #
+ # Question: What year was John Smith born?
+ # Context: The leader was John Smith (1895-1943).
+ # Answer: 1895
+ #
+ # The original whitespace-tokenized answer will be "(1895-1943).". However
+ # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
+ # the exact answer, 1895.
+ #
+ # However, this is not always possible. Consider the following:
+ #
+ # Question: What country is the top exporter of electornics?
+ # Context: The Japanese electronics industry is the lagest in the world.
+ # Answer: Japan
+ #
+ # In this case, the annotator chose "Japan" as a character sub-span of
+ # the word "Japanese". Since our WordPiece tokenizer does not split
+ # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
+ # in SQuAD, but does happen.
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
+
+ for new_start in range(input_start, input_end + 1):
+ for new_end in range(input_end, new_start - 1, -1):
+ text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
+ if text_span == tok_answer_text:
+ return (new_start, new_end)
+
+ return (input_start, input_end)
+
+
+def _check_is_max_context(doc_spans, cur_span_index, position):
+ """Check if this is the 'max context' doc span for the token."""
+
+ # Because of the sliding window approach taken to scoring documents, a single
+ # token can appear in multiple documents. E.g.
+ # Doc: the man went to the store and bought a gallon of milk
+ # Span A: the man went to the
+ # Span B: to the store and bought
+ # Span C: and bought a gallon of
+ # ...
+ #
+ # Now the word 'bought' will have two scores from spans B and C. We only
+ # want to consider the score with "maximum context", which we define as
+ # the *minimum* of its left and right context (the *sum* of left and
+ # right context will always be the same, of course).
+ #
+ # In the example the maximum context for 'bought' would be span C since
+ # it has 1 left context and 3 right context, while span B has 4 left context
+ # and 0 right context.
+ best_score = None
+ best_span_index = None
+ for (span_index, doc_span) in enumerate(doc_spans):
+ end = doc_span.start + doc_span.length - 1
+ if position < doc_span.start:
+ continue
+ if position > end:
+ continue
+ num_left_context = position - doc_span.start
+ num_right_context = end - position
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
+ if best_score is None or score > best_score:
+ best_score = score
+ best_span_index = span_index
+
+ return cur_span_index == best_span_index
+
+
+def write_predictions(all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ output_prediction_file,
+ output_nbest_file,
+ output_null_log_odds_file,
+ version_2_with_negative=False,
+ null_score_diff_threshold=0.0,
+ verbose=False):
+ """Write final predictions to the json file and log-odds of null if needed."""
+ logging.info("Writing predictions to: %s", (output_prediction_file))
+ logging.info("Writing nbest to: %s", (output_nbest_file))
+
+ all_predictions, all_nbest_json, scores_diff_json = (
+ postprocess_output(all_examples=all_examples,
+ all_features=all_features,
+ all_results=all_results,
+ n_best_size=n_best_size,
+ max_answer_length=max_answer_length,
+ do_lower_case=do_lower_case,
+ version_2_with_negative=version_2_with_negative,
+ null_score_diff_threshold=null_score_diff_threshold,
+ verbose=verbose))
+
+ write_to_json_files(all_predictions, output_prediction_file)
+ write_to_json_files(all_nbest_json, output_nbest_file)
+ if version_2_with_negative:
+ write_to_json_files(scores_diff_json, output_null_log_odds_file)
+
+
+def postprocess_output(all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ version_2_with_negative=False,
+ null_score_diff_threshold=0.0,
+ verbose=False):
+ """Postprocess model output, to form predicton results."""
+
+ example_index_to_features = collections.defaultdict(list)
+ for feature in all_features:
+ example_index_to_features[feature.example_index].append(feature)
+ unique_id_to_result = {}
+ for result in all_results:
+ unique_id_to_result[result.unique_id] = result
+
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "PrelimPrediction",
+ ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict()
+
+ for (example_index, example) in enumerate(all_examples):
+ features = example_index_to_features[example_index]
+
+ prelim_predictions = []
+ # keep track of the minimum score of null start+end of position 0
+ score_null = 1000000 # large and positive
+ min_null_feature_index = 0 # the paragraph slice with min mull score
+ null_start_logit = 0 # the start logit at the slice with min null score
+ null_end_logit = 0 # the end logit at the slice with min null score
+ for (feature_index, feature) in enumerate(features):
+ result = unique_id_to_result[feature.unique_id]
+ start_indexes = _get_best_indexes(result.start_logits, n_best_size)
+ end_indexes = _get_best_indexes(result.end_logits, n_best_size)
+ # if we could have irrelevant answers, get the min score of irrelevant
+ if version_2_with_negative:
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
+ if feature_null_score < score_null:
+ score_null = feature_null_score
+ min_null_feature_index = feature_index
+ null_start_logit = result.start_logits[0]
+ null_end_logit = result.end_logits[0]
+ for start_index in start_indexes:
+ for end_index in end_indexes:
+ # We could hypothetically create invalid predictions, e.g., predict
+ # that the start of the span is in the question. We throw out all
+ # invalid predictions.
+ if start_index >= len(feature.tokens):
+ continue
+ if end_index >= len(feature.tokens):
+ continue
+ if start_index not in feature.token_to_orig_map:
+ continue
+ if end_index not in feature.token_to_orig_map:
+ continue
+ if not feature.token_is_max_context.get(start_index, False):
+ continue
+ if end_index < start_index:
+ continue
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ continue
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=feature_index,
+ start_index=start_index,
+ end_index=end_index,
+ start_logit=result.start_logits[start_index],
+ end_logit=result.end_logits[end_index]))
+
+ if version_2_with_negative:
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=min_null_feature_index,
+ start_index=0,
+ end_index=0,
+ start_logit=null_start_logit,
+ end_logit=null_end_logit))
+ prelim_predictions = sorted(
+ prelim_predictions,
+ key=lambda x: (x.start_logit + x.end_logit),
+ reverse=True)
+
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "NbestPrediction", ["text", "start_logit", "end_logit"])
+
+ seen_predictions = {}
+ nbest = []
+ for pred in prelim_predictions:
+ if len(nbest) >= n_best_size:
+ break
+ feature = features[pred.feature_index]
+ if pred.start_index > 0: # this is a non-null prediction
+ tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
+ orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
+ tok_text = " ".join(tok_tokens)
+
+ # De-tokenize WordPieces that have been split off.
+ tok_text = tok_text.replace(" ##", "")
+ tok_text = tok_text.replace("##", "")
+
+ # Clean whitespace
+ tok_text = tok_text.strip()
+ tok_text = " ".join(tok_text.split())
+ orig_text = " ".join(orig_tokens)
+
+ final_text = get_final_text(
+ tok_text, orig_text, do_lower_case, verbose=verbose)
+ if final_text in seen_predictions:
+ continue
+
+ seen_predictions[final_text] = True
+ else:
+ final_text = ""
+ seen_predictions[final_text] = True
+
+ nbest.append(
+ _NbestPrediction(
+ text=final_text,
+ start_logit=pred.start_logit,
+ end_logit=pred.end_logit))
+
+ # if we didn't inlude the empty option in the n-best, inlcude it
+ if version_2_with_negative:
+ if "" not in seen_predictions:
+ nbest.append(
+ _NbestPrediction(
+ text="", start_logit=null_start_logit,
+ end_logit=null_end_logit))
+ # In very rare edge cases we could have no valid predictions. So we
+ # just create a nonce prediction in this case to avoid failure.
+ if not nbest:
+ nbest.append(
+ _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
+
+ assert len(nbest) >= 1
+
+ total_scores = []
+ best_non_null_entry = None
+ for entry in nbest:
+ total_scores.append(entry.start_logit + entry.end_logit)
+ if not best_non_null_entry:
+ if entry.text:
+ best_non_null_entry = entry
+
+ probs = _compute_softmax(total_scores)
+
+ nbest_json = []
+ for (i, entry) in enumerate(nbest):
+ output = collections.OrderedDict()
+ output["text"] = entry.text
+ output["probability"] = probs[i]
+ output["start_logit"] = entry.start_logit
+ output["end_logit"] = entry.end_logit
+ nbest_json.append(output)
+
+ assert len(nbest_json) >= 1
+
+ if not version_2_with_negative:
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
+ else:
+ # pytype: disable=attribute-error
+ # predict "" iff the null score - the score of best non-null > threshold
+ if best_non_null_entry is not None:
+ score_diff = score_null - best_non_null_entry.start_logit - (
+ best_non_null_entry.end_logit)
+ scores_diff_json[example.qas_id] = score_diff
+ if score_diff > null_score_diff_threshold:
+ all_predictions[example.qas_id] = ""
+ else:
+ all_predictions[example.qas_id] = best_non_null_entry.text
+ else:
+ logging.warning("best_non_null_entry is None")
+ scores_diff_json[example.qas_id] = score_null
+ all_predictions[example.qas_id] = ""
+ # pytype: enable=attribute-error
+
+ all_nbest_json[example.qas_id] = nbest_json
+
+ return all_predictions, all_nbest_json, scores_diff_json
+
+
+def write_to_json_files(json_records, json_file):
+ with tf.io.gfile.GFile(json_file, "w") as writer:
+ writer.write(json.dumps(json_records, indent=4) + "\n")
+
+
+def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
+ """Project the tokenized prediction back to the original text."""
+
+ # When we created the data, we kept track of the alignment between original
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
+ # now `orig_text` contains the span of our original text corresponding to the
+ # span that we predicted.
+ #
+ # However, `orig_text` may contain extra characters that we don't want in
+ # our prediction.
+ #
+ # For example, let's say:
+ # pred_text = steve smith
+ # orig_text = Steve Smith's
+ #
+ # We don't want to return `orig_text` because it contains the extra "'s".
+ #
+ # We don't want to return `pred_text` because it's already been normalized
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
+ # our tokenizer does additional normalization like stripping accent
+ # characters).
+ #
+ # What we really want to return is "Steve Smith".
+ #
+ # Therefore, we have to apply a semi-complicated alignment heruistic between
+ # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
+ # can fail in certain cases in which case we just return `orig_text`.
+
+ def _strip_spaces(text):
+ ns_chars = []
+ ns_to_s_map = collections.OrderedDict()
+ for (i, c) in enumerate(text):
+ if c == " ":
+ continue
+ ns_to_s_map[len(ns_chars)] = i
+ ns_chars.append(c)
+ ns_text = "".join(ns_chars)
+ return (ns_text, ns_to_s_map)
+
+ # We first tokenize `orig_text`, strip whitespace from the result
+ # and `pred_text`, and check if they are the same length. If they are
+ # NOT the same length, the heuristic has failed. If they are the same
+ # length, we assume the characters are one-to-one aligned.
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
+
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
+
+ start_position = tok_text.find(pred_text)
+ if start_position == -1:
+ if verbose:
+ logging.info("Unable to find text: '%s' in '%s'", pred_text, orig_text)
+ return orig_text
+ end_position = start_position + len(pred_text) - 1
+
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
+
+ if len(orig_ns_text) != len(tok_ns_text):
+ if verbose:
+ logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
+ orig_ns_text, tok_ns_text)
+ return orig_text
+
+ # We then project the characters in `pred_text` back to `orig_text` using
+ # the character-to-character alignment.
+ tok_s_to_ns_map = {}
+ for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
+ tok_s_to_ns_map[tok_index] = i
+
+ orig_start_position = None
+ if start_position in tok_s_to_ns_map:
+ ns_start_position = tok_s_to_ns_map[start_position]
+ if ns_start_position in orig_ns_to_s_map:
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
+
+ if orig_start_position is None:
+ if verbose:
+ logging.info("Couldn't map start position")
+ return orig_text
+
+ orig_end_position = None
+ if end_position in tok_s_to_ns_map:
+ ns_end_position = tok_s_to_ns_map[end_position]
+ if ns_end_position in orig_ns_to_s_map:
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
+
+ if orig_end_position is None:
+ if verbose:
+ logging.info("Couldn't map end position")
+ return orig_text
+
+ output_text = orig_text[orig_start_position:(orig_end_position + 1)]
+ return output_text
+
+
+def _get_best_indexes(logits, n_best_size):
+ """Get the n-best logits from a list."""
+ index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
+
+ best_indexes = []
+ for i in range(len(index_and_score)): # pylint: disable=consider-using-enumerate
+ if i >= n_best_size:
+ break
+ best_indexes.append(index_and_score[i][0])
+ return best_indexes
+
+
+def _compute_softmax(scores):
+ """Compute softmax probability over raw logits."""
+ if not scores:
+ return []
+
+ max_score = None
+ for score in scores:
+ if max_score is None or score > max_score:
+ max_score = score
+
+ exp_scores = []
+ total_sum = 0.0
+ for score in scores:
+ x = math.exp(score - max_score)
+ exp_scores.append(x)
+ total_sum += x
+
+ probs = []
+ for score in exp_scores:
+ probs.append(score / total_sum)
+ return probs
+
+
+def generate_tf_record_from_json_file(input_file_path,
+ vocab_file_path,
+ output_path,
+ max_seq_length=384,
+ do_lower_case=True,
+ max_query_length=64,
+ doc_stride=128,
+ version_2_with_negative=False):
+ """Generates and saves training data into a tf record file."""
+ train_examples = read_squad_examples(
+ input_file=input_file_path,
+ is_training=True,
+ version_2_with_negative=version_2_with_negative)
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=vocab_file_path, do_lower_case=do_lower_case)
+ train_writer = FeatureWriter(filename=output_path, is_training=True)
+ number_of_examples = convert_examples_to_features(
+ examples=train_examples,
+ tokenizer=tokenizer,
+ max_seq_length=max_seq_length,
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ is_training=True,
+ output_fn=train_writer.process_feature)
+ train_writer.close()
+
+ meta_data = {
+ "task_type": "bert_squad",
+ "train_data_size": number_of_examples,
+ "max_seq_length": max_seq_length,
+ "max_query_length": max_query_length,
+ "doc_stride": doc_stride,
+ "version_2_with_negative": version_2_with_negative,
+ }
+
+ return meta_data
diff --git a/models/official/nlp/data/squad_lib_sp.py b/models/official/nlp/data/squad_lib_sp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c65f713fd09bc4858f77f8ce823b17467606271c
--- /dev/null
+++ b/models/official/nlp/data/squad_lib_sp.py
@@ -0,0 +1,892 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization.
+
+The file is forked from:
+
+https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import copy
+import json
+import math
+import os
+from absl import logging
+import numpy as np
+import tensorflow as tf
+
+from official.nlp.bert import tokenization
+
+
+class SquadExample(object):
+ """A single training/test example for simple sequence classification.
+
+ For examples without an answer, the start and end position are -1.
+ """
+
+ def __init__(self,
+ qas_id,
+ question_text,
+ paragraph_text,
+ orig_answer_text=None,
+ start_position=None,
+ end_position=None,
+ is_impossible=False):
+ self.qas_id = qas_id
+ self.question_text = question_text
+ self.paragraph_text = paragraph_text
+ self.orig_answer_text = orig_answer_text
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ s = ""
+ s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
+ s += ", question_text: %s" % (
+ tokenization.printable_text(self.question_text))
+ s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
+ if self.start_position:
+ s += ", start_position: %d" % (self.start_position)
+ if self.start_position:
+ s += ", end_position: %d" % (self.end_position)
+ if self.start_position:
+ s += ", is_impossible: %r" % (self.is_impossible)
+ return s
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ example_index,
+ doc_span_index,
+ tok_start_to_orig_index,
+ tok_end_to_orig_index,
+ token_is_max_context,
+ tokens,
+ input_ids,
+ input_mask,
+ segment_ids,
+ paragraph_len,
+ start_position=None,
+ end_position=None,
+ is_impossible=None):
+ self.unique_id = unique_id
+ self.example_index = example_index
+ self.doc_span_index = doc_span_index
+ self.tok_start_to_orig_index = tok_start_to_orig_index
+ self.tok_end_to_orig_index = tok_end_to_orig_index
+ self.token_is_max_context = token_is_max_context
+ self.tokens = tokens
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.segment_ids = segment_ids
+ self.paragraph_len = paragraph_len
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+
+def read_squad_examples(input_file, is_training, version_2_with_negative):
+ """Read a SQuAD json file into a list of SquadExample."""
+ del version_2_with_negative
+ with tf.io.gfile.GFile(input_file, "r") as reader:
+ input_data = json.load(reader)["data"]
+
+ examples = []
+ for entry in input_data:
+ for paragraph in entry["paragraphs"]:
+ paragraph_text = paragraph["context"]
+
+ for qa in paragraph["qas"]:
+ qas_id = qa["id"]
+ question_text = qa["question"]
+ start_position = None
+ orig_answer_text = None
+ is_impossible = False
+
+ if is_training:
+ is_impossible = qa.get("is_impossible", False)
+ if (len(qa["answers"]) != 1) and (not is_impossible):
+ raise ValueError(
+ "For training, each question should have exactly 1 answer.")
+ if not is_impossible:
+ answer = qa["answers"][0]
+ orig_answer_text = answer["text"]
+ start_position = answer["answer_start"]
+ else:
+ start_position = -1
+ orig_answer_text = ""
+
+ example = SquadExample(
+ qas_id=qas_id,
+ question_text=question_text,
+ paragraph_text=paragraph_text,
+ orig_answer_text=orig_answer_text,
+ start_position=start_position,
+ is_impossible=is_impossible)
+ examples.append(example)
+
+ return examples
+
+
+def _convert_index(index, pos, m=None, is_start=True):
+ """Converts index."""
+ if index[pos] is not None:
+ return index[pos]
+ n = len(index)
+ rear = pos
+ while rear < n - 1 and index[rear] is None:
+ rear += 1
+ front = pos
+ while front > 0 and index[front] is None:
+ front -= 1
+ assert index[front] is not None or index[rear] is not None
+ if index[front] is None:
+ if index[rear] >= 1:
+ if is_start:
+ return 0
+ else:
+ return index[rear] - 1
+ return index[rear]
+ if index[rear] is None:
+ if m is not None and index[front] < m - 1:
+ if is_start:
+ return index[front] + 1
+ else:
+ return m - 1
+ return index[front]
+ if is_start:
+ if index[rear] > index[front] + 1:
+ return index[front] + 1
+ else:
+ return index[rear]
+ else:
+ if index[rear] > index[front] + 1:
+ return index[rear] - 1
+ else:
+ return index[front]
+
+
+def convert_examples_to_features(examples,
+ tokenizer,
+ max_seq_length,
+ doc_stride,
+ max_query_length,
+ is_training,
+ output_fn,
+ do_lower_case,
+ batch_size=None):
+ """Loads a data file into a list of `InputBatch`s."""
+ cnt_pos, cnt_neg = 0, 0
+ base_id = 1000000000
+ unique_id = base_id
+ max_n, max_m = 1024, 1024
+ f = np.zeros((max_n, max_m), dtype=np.float32)
+
+ for (example_index, example) in enumerate(examples):
+
+ if example_index % 100 == 0:
+ logging.info("Converting %d/%d pos %d neg %d", example_index,
+ len(examples), cnt_pos, cnt_neg)
+
+ query_tokens = tokenization.encode_ids(
+ tokenizer.sp_model,
+ tokenization.preprocess_text(
+ example.question_text, lower=do_lower_case))
+
+ if len(query_tokens) > max_query_length:
+ query_tokens = query_tokens[0:max_query_length]
+
+ paragraph_text = example.paragraph_text
+ para_tokens = tokenization.encode_pieces(
+ tokenizer.sp_model,
+ tokenization.preprocess_text(
+ example.paragraph_text, lower=do_lower_case))
+
+ chartok_to_tok_index = []
+ tok_start_to_chartok_index = []
+ tok_end_to_chartok_index = []
+ char_cnt = 0
+ for i, token in enumerate(para_tokens):
+ new_token = token.replace(tokenization.SPIECE_UNDERLINE, " ")
+ chartok_to_tok_index.extend([i] * len(new_token))
+ tok_start_to_chartok_index.append(char_cnt)
+ char_cnt += len(new_token)
+ tok_end_to_chartok_index.append(char_cnt - 1)
+
+ tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE,
+ " ")
+ n, m = len(paragraph_text), len(tok_cat_text)
+
+ if n > max_n or m > max_m:
+ max_n = max(n, max_n)
+ max_m = max(m, max_m)
+ f = np.zeros((max_n, max_m), dtype=np.float32)
+
+ g = {}
+ # pylint: disable=cell-var-from-loop
+ def _lcs_match(max_dist, n=n, m=m):
+ """Longest-common-substring algorithm."""
+ f.fill(0)
+ g.clear()
+
+ ### longest common sub sequence
+ # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
+ for i in range(n):
+
+ # unlike standard LCS, this is specifically optimized for the setting
+ # because the mismatch between sentence pieces and original text will
+ # be small
+ for j in range(i - max_dist, i + max_dist):
+ if j >= m or j < 0:
+ continue
+
+ if i > 0:
+ g[(i, j)] = 0
+ f[i, j] = f[i - 1, j]
+
+ if j > 0 and f[i, j - 1] > f[i, j]:
+ g[(i, j)] = 1
+ f[i, j] = f[i, j - 1]
+
+ f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
+ if (tokenization.preprocess_text(
+ paragraph_text[i], lower=do_lower_case,
+ remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
+ g[(i, j)] = 2
+ f[i, j] = f_prev + 1
+ # pylint: enable=cell-var-from-loop
+
+ max_dist = abs(n - m) + 5
+ for _ in range(2):
+ _lcs_match(max_dist)
+ if f[n - 1, m - 1] > 0.8 * n:
+ break
+ max_dist *= 2
+
+ orig_to_chartok_index = [None] * n
+ chartok_to_orig_index = [None] * m
+ i, j = n - 1, m - 1
+ while i >= 0 and j >= 0:
+ if (i, j) not in g:
+ break
+ if g[(i, j)] == 2:
+ orig_to_chartok_index[i] = j
+ chartok_to_orig_index[j] = i
+ i, j = i - 1, j - 1
+ elif g[(i, j)] == 1:
+ j = j - 1
+ else:
+ i = i - 1
+
+ if (all(v is None for v in orig_to_chartok_index) or
+ f[n - 1, m - 1] < 0.8 * n):
+ logging.info("MISMATCH DETECTED!")
+ continue
+
+ tok_start_to_orig_index = []
+ tok_end_to_orig_index = []
+ for i in range(len(para_tokens)):
+ start_chartok_pos = tok_start_to_chartok_index[i]
+ end_chartok_pos = tok_end_to_chartok_index[i]
+ start_orig_pos = _convert_index(
+ chartok_to_orig_index, start_chartok_pos, n, is_start=True)
+ end_orig_pos = _convert_index(
+ chartok_to_orig_index, end_chartok_pos, n, is_start=False)
+
+ tok_start_to_orig_index.append(start_orig_pos)
+ tok_end_to_orig_index.append(end_orig_pos)
+
+ if not is_training:
+ tok_start_position = tok_end_position = None
+
+ if is_training and example.is_impossible:
+ tok_start_position = 0
+ tok_end_position = 0
+
+ if is_training and not example.is_impossible:
+ start_position = example.start_position
+ end_position = start_position + len(example.orig_answer_text) - 1
+
+ start_chartok_pos = _convert_index(
+ orig_to_chartok_index, start_position, is_start=True)
+ tok_start_position = chartok_to_tok_index[start_chartok_pos]
+
+ end_chartok_pos = _convert_index(
+ orig_to_chartok_index, end_position, is_start=False)
+ tok_end_position = chartok_to_tok_index[end_chartok_pos]
+ assert tok_start_position <= tok_end_position
+
+ def _piece_to_id(x):
+ return tokenizer.sp_model.PieceToId(x)
+
+ all_doc_tokens = list(map(_piece_to_id, para_tokens))
+
+ # The -3 accounts for [CLS], [SEP] and [SEP]
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
+
+ # We can have documents that are longer than the maximum sequence length.
+ # To deal with this we do a sliding window approach, where we take chunks
+ # of the up to our max length with a stride of `doc_stride`.
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
+ "DocSpan", ["start", "length"])
+ doc_spans = []
+ start_offset = 0
+ while start_offset < len(all_doc_tokens):
+ length = len(all_doc_tokens) - start_offset
+ if length > max_tokens_for_doc:
+ length = max_tokens_for_doc
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
+ if start_offset + length == len(all_doc_tokens):
+ break
+ start_offset += min(length, doc_stride)
+
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
+ tokens = []
+ token_is_max_context = {}
+ segment_ids = []
+
+ cur_tok_start_to_orig_index = []
+ cur_tok_end_to_orig_index = []
+
+ tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
+ segment_ids.append(0)
+ for token in query_tokens:
+ tokens.append(token)
+ segment_ids.append(0)
+ tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
+ segment_ids.append(0)
+
+ for i in range(doc_span.length):
+ split_token_index = doc_span.start + i
+
+ cur_tok_start_to_orig_index.append(
+ tok_start_to_orig_index[split_token_index])
+ cur_tok_end_to_orig_index.append(
+ tok_end_to_orig_index[split_token_index])
+
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
+ split_token_index)
+ token_is_max_context[len(tokens)] = is_max_context
+ tokens.append(all_doc_tokens[split_token_index])
+ segment_ids.append(1)
+ tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
+ segment_ids.append(1)
+
+ paragraph_len = len(tokens)
+ input_ids = tokens
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(0)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ span_is_impossible = example.is_impossible
+ start_position = None
+ end_position = None
+ if is_training and not span_is_impossible:
+ # For training, if our document chunk does not contain an annotation
+ # we throw it out, since there is nothing to predict.
+ doc_start = doc_span.start
+ doc_end = doc_span.start + doc_span.length - 1
+ out_of_span = False
+ if not (tok_start_position >= doc_start and
+ tok_end_position <= doc_end):
+ out_of_span = True
+ if out_of_span:
+ # continue
+ start_position = 0
+ end_position = 0
+ span_is_impossible = True
+ else:
+ doc_offset = len(query_tokens) + 2
+ start_position = tok_start_position - doc_start + doc_offset
+ end_position = tok_end_position - doc_start + doc_offset
+
+ if is_training and span_is_impossible:
+ start_position = 0
+ end_position = 0
+
+ if example_index < 20:
+ logging.info("*** Example ***")
+ logging.info("unique_id: %s", (unique_id))
+ logging.info("example_index: %s", (example_index))
+ logging.info("doc_span_index: %s", (doc_span_index))
+ logging.info("tok_start_to_orig_index: %s",
+ " ".join([str(x) for x in cur_tok_start_to_orig_index]))
+ logging.info("tok_end_to_orig_index: %s",
+ " ".join([str(x) for x in cur_tok_end_to_orig_index]))
+ logging.info(
+ "token_is_max_context: %s", " ".join(
+ ["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()]))
+ logging.info(
+ "input_pieces: %s",
+ " ".join([tokenizer.sp_model.IdToPiece(x) for x in tokens]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+
+ if is_training and span_is_impossible:
+ logging.info("impossible example span")
+
+ if is_training and not span_is_impossible:
+ pieces = [
+ tokenizer.sp_model.IdToPiece(token)
+ for token in tokens[start_position:(end_position + 1)]
+ ]
+ answer_text = tokenizer.sp_model.DecodePieces(pieces)
+ logging.info("start_position: %d", (start_position))
+ logging.info("end_position: %d", (end_position))
+ logging.info("answer: %s", (tokenization.printable_text(answer_text)))
+
+ # With multi processing, the example_index is actually the index
+ # within the current process therefore we use example_index=None
+ # to avoid being used in the future.
+ # The current code does not use example_index of training data.
+ if is_training:
+ feat_example_index = None
+ else:
+ feat_example_index = example_index
+
+ feature = InputFeatures(
+ unique_id=unique_id,
+ example_index=feat_example_index,
+ doc_span_index=doc_span_index,
+ tok_start_to_orig_index=cur_tok_start_to_orig_index,
+ tok_end_to_orig_index=cur_tok_end_to_orig_index,
+ token_is_max_context=token_is_max_context,
+ tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
+ input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ paragraph_len=paragraph_len,
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=span_is_impossible)
+
+ # Run callback
+ if is_training:
+ output_fn(feature)
+ else:
+ output_fn(feature, is_padding=False)
+
+ unique_id += 1
+ if span_is_impossible:
+ cnt_neg += 1
+ else:
+ cnt_pos += 1
+
+ if not is_training and feature:
+ assert batch_size
+ num_padding = 0
+ num_examples = unique_id - base_id
+ if unique_id % batch_size != 0:
+ num_padding = batch_size - (num_examples % batch_size)
+ dummy_feature = copy.deepcopy(feature)
+ for _ in range(num_padding):
+ dummy_feature.unique_id = unique_id
+
+ # Run callback
+ output_fn(feature, is_padding=True)
+ unique_id += 1
+
+ logging.info("Total number of instances: %d = pos %d neg %d",
+ cnt_pos + cnt_neg, cnt_pos, cnt_neg)
+ return unique_id - base_id
+
+
+def _check_is_max_context(doc_spans, cur_span_index, position):
+ """Check if this is the 'max context' doc span for the token."""
+
+ # Because of the sliding window approach taken to scoring documents, a single
+ # token can appear in multiple documents. E.g.
+ # Doc: the man went to the store and bought a gallon of milk
+ # Span A: the man went to the
+ # Span B: to the store and bought
+ # Span C: and bought a gallon of
+ # ...
+ #
+ # Now the word 'bought' will have two scores from spans B and C. We only
+ # want to consider the score with "maximum context", which we define as
+ # the *minimum* of its left and right context (the *sum* of left and
+ # right context will always be the same, of course).
+ #
+ # In the example the maximum context for 'bought' would be span C since
+ # it has 1 left context and 3 right context, while span B has 4 left context
+ # and 0 right context.
+ best_score = None
+ best_span_index = None
+ for (span_index, doc_span) in enumerate(doc_spans):
+ end = doc_span.start + doc_span.length - 1
+ if position < doc_span.start:
+ continue
+ if position > end:
+ continue
+ num_left_context = position - doc_span.start
+ num_right_context = end - position
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
+ if best_score is None or score > best_score:
+ best_score = score
+ best_span_index = span_index
+
+ return cur_span_index == best_span_index
+
+
+def write_predictions(all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ output_prediction_file,
+ output_nbest_file,
+ output_null_log_odds_file,
+ version_2_with_negative=False,
+ null_score_diff_threshold=0.0,
+ verbose=False):
+ """Write final predictions to the json file and log-odds of null if needed."""
+ logging.info("Writing predictions to: %s", (output_prediction_file))
+ logging.info("Writing nbest to: %s", (output_nbest_file))
+
+ all_predictions, all_nbest_json, scores_diff_json = (
+ postprocess_output(all_examples=all_examples,
+ all_features=all_features,
+ all_results=all_results,
+ n_best_size=n_best_size,
+ max_answer_length=max_answer_length,
+ do_lower_case=do_lower_case,
+ version_2_with_negative=version_2_with_negative,
+ null_score_diff_threshold=null_score_diff_threshold,
+ verbose=verbose))
+
+ write_to_json_files(all_predictions, output_prediction_file)
+ write_to_json_files(all_nbest_json, output_nbest_file)
+ if version_2_with_negative:
+ write_to_json_files(scores_diff_json, output_null_log_odds_file)
+
+
+def postprocess_output(all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ version_2_with_negative=False,
+ null_score_diff_threshold=0.0,
+ verbose=False):
+ """Postprocess model output, to form predicton results."""
+
+ del do_lower_case, verbose
+
+ example_index_to_features = collections.defaultdict(list)
+ for feature in all_features:
+ example_index_to_features[feature.example_index].append(feature)
+
+ unique_id_to_result = {}
+ for result in all_results:
+ unique_id_to_result[result.unique_id] = result
+
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "PrelimPrediction",
+ ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict()
+
+ for (example_index, example) in enumerate(all_examples):
+ features = example_index_to_features[example_index]
+
+ prelim_predictions = []
+ # keep track of the minimum score of null start+end of position 0
+ score_null = 1000000 # large and positive
+ min_null_feature_index = 0 # the paragraph slice with min mull score
+ null_start_logit = 0 # the start logit at the slice with min null score
+ null_end_logit = 0 # the end logit at the slice with min null score
+ for (feature_index, feature) in enumerate(features):
+ result = unique_id_to_result[feature.unique_id]
+ start_indexes = _get_best_indexes(result.start_logits, n_best_size)
+ end_indexes = _get_best_indexes(result.end_logits, n_best_size)
+ # if we could have irrelevant answers, get the min score of irrelevant
+ if version_2_with_negative:
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
+ if feature_null_score < score_null:
+ score_null = feature_null_score
+ min_null_feature_index = feature_index
+ null_start_logit = result.start_logits[0]
+ null_end_logit = result.end_logits[0]
+ for start_index in start_indexes:
+ for end_index in end_indexes:
+ doc_offset = feature.tokens.index("[SEP]") + 1
+ # We could hypothetically create invalid predictions, e.g., predict
+ # that the start of the span is in the question. We throw out all
+ # invalid predictions.
+ if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
+ continue
+ if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
+ continue
+ # if start_index not in feature.tok_start_to_orig_index:
+ # continue
+ # if end_index not in feature.tok_end_to_orig_index:
+ # continue
+ if not feature.token_is_max_context.get(start_index, False):
+ continue
+ if end_index < start_index:
+ continue
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ continue
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=feature_index,
+ start_index=start_index - doc_offset,
+ end_index=end_index - doc_offset,
+ start_logit=result.start_logits[start_index],
+ end_logit=result.end_logits[end_index]))
+
+ if version_2_with_negative:
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=min_null_feature_index,
+ start_index=-1,
+ end_index=-1,
+ start_logit=null_start_logit,
+ end_logit=null_end_logit))
+ prelim_predictions = sorted(
+ prelim_predictions,
+ key=lambda x: (x.start_logit + x.end_logit),
+ reverse=True)
+
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "NbestPrediction", ["text", "start_logit", "end_logit"])
+
+ seen_predictions = {}
+ nbest = []
+ for pred in prelim_predictions:
+ if len(nbest) >= n_best_size:
+ break
+ feature = features[pred.feature_index]
+ if pred.start_index >= 0: # this is a non-null prediction
+ tok_start_to_orig_index = feature.tok_start_to_orig_index
+ tok_end_to_orig_index = feature.tok_end_to_orig_index
+ start_orig_pos = tok_start_to_orig_index[pred.start_index]
+ end_orig_pos = tok_end_to_orig_index[pred.end_index]
+
+ paragraph_text = example.paragraph_text
+ final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
+ if final_text in seen_predictions:
+ continue
+
+ seen_predictions[final_text] = True
+ else:
+ final_text = ""
+ seen_predictions[final_text] = True
+
+ nbest.append(
+ _NbestPrediction(
+ text=final_text,
+ start_logit=pred.start_logit,
+ end_logit=pred.end_logit))
+
+ # if we didn't inlude the empty option in the n-best, inlcude it
+ if version_2_with_negative:
+ if "" not in seen_predictions:
+ nbest.append(
+ _NbestPrediction(
+ text="", start_logit=null_start_logit,
+ end_logit=null_end_logit))
+ # In very rare edge cases we could have no valid predictions. So we
+ # just create a nonce prediction in this case to avoid failure.
+ if not nbest:
+ nbest.append(
+ _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
+
+ assert len(nbest) >= 1
+
+ total_scores = []
+ best_non_null_entry = None
+ for entry in nbest:
+ total_scores.append(entry.start_logit + entry.end_logit)
+ if not best_non_null_entry:
+ if entry.text:
+ best_non_null_entry = entry
+
+ probs = _compute_softmax(total_scores)
+
+ nbest_json = []
+ for (i, entry) in enumerate(nbest):
+ output = collections.OrderedDict()
+ output["text"] = entry.text
+ output["probability"] = probs[i]
+ output["start_logit"] = entry.start_logit
+ output["end_logit"] = entry.end_logit
+ nbest_json.append(output)
+
+ assert len(nbest_json) >= 1
+
+ if not version_2_with_negative:
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
+ else:
+ assert best_non_null_entry is not None
+ # predict "" iff the null score - the score of best non-null > threshold
+ score_diff = score_null - best_non_null_entry.start_logit - (
+ best_non_null_entry.end_logit)
+ scores_diff_json[example.qas_id] = score_diff
+ if score_diff > null_score_diff_threshold:
+ all_predictions[example.qas_id] = ""
+ else:
+ all_predictions[example.qas_id] = best_non_null_entry.text
+
+ all_nbest_json[example.qas_id] = nbest_json
+
+ return all_predictions, all_nbest_json, scores_diff_json
+
+
+def write_to_json_files(json_records, json_file):
+ with tf.io.gfile.GFile(json_file, "w") as writer:
+ writer.write(json.dumps(json_records, indent=4) + "\n")
+
+
+def _get_best_indexes(logits, n_best_size):
+ """Get the n-best logits from a list."""
+ index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
+
+ best_indexes = []
+ for i in range(len(index_and_score)):
+ if i >= n_best_size:
+ break
+ best_indexes.append(index_and_score[i][0])
+ return best_indexes
+
+
+def _compute_softmax(scores):
+ """Compute softmax probability over raw logits."""
+ if not scores:
+ return []
+
+ max_score = None
+ for score in scores:
+ if max_score is None or score > max_score:
+ max_score = score
+
+ exp_scores = []
+ total_sum = 0.0
+ for score in scores:
+ x = math.exp(score - max_score)
+ exp_scores.append(x)
+ total_sum += x
+
+ probs = []
+ for score in exp_scores:
+ probs.append(score / total_sum)
+ return probs
+
+
+class FeatureWriter(object):
+ """Writes InputFeature to TF example file."""
+
+ def __init__(self, filename, is_training):
+ self.filename = filename
+ self.is_training = is_training
+ self.num_features = 0
+ tf.io.gfile.makedirs(os.path.dirname(filename))
+ self._writer = tf.io.TFRecordWriter(filename)
+
+ def process_feature(self, feature):
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
+ self.num_features += 1
+
+ def create_int_feature(values):
+ feature = tf.train.Feature(
+ int64_list=tf.train.Int64List(value=list(values)))
+ return feature
+
+ features = collections.OrderedDict()
+ features["unique_ids"] = create_int_feature([feature.unique_id])
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_int_feature(feature.input_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+
+ if self.is_training:
+ features["start_positions"] = create_int_feature([feature.start_position])
+ features["end_positions"] = create_int_feature([feature.end_position])
+ impossible = 0
+ if feature.is_impossible:
+ impossible = 1
+ features["is_impossible"] = create_int_feature([impossible])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ self._writer.write(tf_example.SerializeToString())
+
+ def close(self):
+ self._writer.close()
+
+
+def generate_tf_record_from_json_file(input_file_path,
+ sp_model_file,
+ output_path,
+ max_seq_length=384,
+ do_lower_case=True,
+ max_query_length=64,
+ doc_stride=128,
+ version_2_with_negative=False):
+ """Generates and saves training data into a tf record file."""
+ train_examples = read_squad_examples(
+ input_file=input_file_path,
+ is_training=True,
+ version_2_with_negative=version_2_with_negative)
+ tokenizer = tokenization.FullSentencePieceTokenizer(
+ sp_model_file=sp_model_file)
+ train_writer = FeatureWriter(filename=output_path, is_training=True)
+ number_of_examples = convert_examples_to_features(
+ examples=train_examples,
+ tokenizer=tokenizer,
+ max_seq_length=max_seq_length,
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ is_training=True,
+ output_fn=train_writer.process_feature,
+ do_lower_case=do_lower_case)
+ train_writer.close()
+
+ meta_data = {
+ "task_type": "bert_squad",
+ "train_data_size": number_of_examples,
+ "max_seq_length": max_seq_length,
+ "max_query_length": max_query_length,
+ "doc_stride": doc_stride,
+ "version_2_with_negative": version_2_with_negative,
+ }
+
+ return meta_data
diff --git a/models/official/nlp/data/tagging_data_loader.py b/models/official/nlp/data/tagging_data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..127a5e004023008dbf765aec7bb8bfb7e5f89de1
--- /dev/null
+++ b/models/official/nlp/data/tagging_data_loader.py
@@ -0,0 +1,64 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loads dataset for the tagging (e.g., NER/POS) task."""
+from typing import Mapping, Optional
+import tensorflow as tf
+
+from official.core import input_reader
+
+
+class TaggingDataLoader:
+ """A class to load dataset for tagging (e.g., NER and POS) task."""
+
+ def __init__(self, params):
+ self._params = params
+ self._seq_length = params.seq_length
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ }
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in example:
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _parse(self, record: Mapping[str, tf.Tensor]):
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids']
+ }
+ y = record['label_ids']
+ return (x, y)
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
+ return reader.read(input_context)
diff --git a/models/official/nlp/modeling/README.md b/models/official/nlp/modeling/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0e74b4637c44ed82392203ad4fb148420c05f18b
--- /dev/null
+++ b/models/official/nlp/modeling/README.md
@@ -0,0 +1,43 @@
+# NLP Modeling Library
+
+This libary provides a set of Keras primitives (Layers, Networks, and Models)
+that can be assembled into transformer-based models. They are
+flexible, validated, interoperable, and both TF1 and TF2 compatible.
+
+* [`layers`](layers) are the fundamental building blocks for NLP models.
+They can be used to assemble new layers, networks, or models.
+
+* [`networks`](networks) are combinations of layers (and possibly other networks). They are sub-units of models that would not be trained alone. They
+encapsulate common network structures like a classification head
+or a transformer encoder into an easily handled object with a
+standardized configuration.
+
+* [`models`](models) are combinations of layers and networks that would be trained. Pre-built canned models are provided as both convenience functions and canonical examples.
+
+* [`losses`](losses) contains common loss computation used in NLP tasks.
+
+Besides the pre-defined primitives, it also provides scaffold classes to allow
+easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance.
+
+* [`TransformerScaffold`](layers/transformer_scaffold.py) implements the
+Transformer from ["Attention Is All You Need"]
+(https://arxiv.org/abs/1706.03762), with a customizable attention layer
+option. Users can pass a class to `attention_cls` and associated config to
+`attention_cfg`, in which case the scaffold will instantiate the class with
+the config, or pass a class instance to `attention_cls`.
+
+* [`EncoderScaffold`](networks/encoder_scaffold.py) implements the transformer
+encoder from ["BERT: Pre-training of Deep Bidirectional Transformers for
+Language Understanding"](https://arxiv.org/abs/1810.04805), with customizable
+embedding subnetwork (which will replace the standard embedding logic) and/or a
+custom hidden layer (which will replace the Transformer instantiation in the
+encoder).
+
+BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
+
+
+
+
+
+
+
diff --git a/models/official/nlp/modeling/__init__.py b/models/official/nlp/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/official/nlp/modeling/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/official/nlp/modeling/layers/README.md b/models/official/nlp/modeling/layers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..42f299a3f2308f63f5339bd3f639bef0607f5e97
--- /dev/null
+++ b/models/official/nlp/modeling/layers/README.md
@@ -0,0 +1,64 @@
+# Layers
+
+Layers are the fundamental building blocks for NLP models. They can be used to
+assemble new layers, networks, or models.
+
+* [DenseEinsum](dense_einsum.py) implements a feedforward network using
+ tf.einsum. This layer contains the einsum op, the associated weight, and the
+ logic required to generate the einsum expression for the given
+ initialization parameters.
+
+* [MultiHeadAttention](attention.py) implements an optionally masked attention
+ between query, key, value tensors as described in
+ ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
+ `from_tensor` and `to_tensor` are the same, then this is self-attention.
+
+* [CachedAttention](attention.py) implements an attention layer with cache
+ used for auto-agressive decoding.
+
+* [MultiChannelAttention](multi_channel_attention.py) implements an variant of
+ multi-head attention which can be used to merge multiple streams for
+ cross-attentions.
+
+* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
+ heads attention, as decribed in
+ ["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
+
+* [Transformer](transformer.py) implements an optionally masked transformer as
+ described in
+ ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
+
+* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
+ of self multi-head attention, cross multi-head attention and
+ feedforward network.
+
+* [ReZeroTransformer](rezero_transformer.py) implements Transformer with
+ ReZero described in
+ ["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
+
+* [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding
+ lookups designed for TPU-based models.
+
+* [PositionalEmbedding](position_embedding.py) creates a positional embedding
+ as described in ["BERT: Pre-training of Deep Bidirectional Transformers for
+ Language Understanding"](https://arxiv.org/abs/1810.04805).
+
+* [SelfAttentionMask](self_attention_mask.py) creates a 3D attention mask from
+ a 2D tensor mask.
+
+* [MaskedSoftmax](masked_softmax.py) implements a softmax with an optional
+ masking input. If no mask is provided to this layer, it performs a standard
+ softmax; however, if a mask tensor is applied (which should be 1 in
+ positions where the data should be allowed through, and 0 where the data
+ should be masked), the output will have masked positions set to
+ approximately zero.
+
+* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes
+ the embedding table variable is passed to it.
+
+* [ClassificationHead](cls_head.py) A pooling head over a sequence of
+ embeddings, commonly used by classification tasks.
+
+* [GatedFeedforward](gated_feedforward.py) implements the gated linear layer
+ feedforward as described in
+ ["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202).
diff --git a/models/official/nlp/modeling/layers/__init__.py b/models/official/nlp/modeling/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cd8e7b9e59ceab76e268f83907833eec32c73ce
--- /dev/null
+++ b/models/official/nlp/modeling/layers/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Layers package definition."""
+# pylint: disable=wildcard-import
+from official.nlp.modeling.layers.attention import *
+from official.nlp.modeling.layers.cls_head import *
+from official.nlp.modeling.layers.dense_einsum import DenseEinsum
+from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
+from official.nlp.modeling.layers.masked_lm import MaskedLM
+from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
+from official.nlp.modeling.layers.multi_channel_attention import *
+from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
+from official.nlp.modeling.layers.position_embedding import PositionEmbedding
+from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
+from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
+from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
+from official.nlp.modeling.layers.transformer import *
+from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
diff --git a/models/official/nlp/modeling/layers/attention.py b/models/official/nlp/modeling/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..99692b281794385a97af341d03dea0ee6c46b95b
--- /dev/null
+++ b/models/official/nlp/modeling/layers/attention.py
@@ -0,0 +1,530 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based attention layer."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import collections
+import math
+import string
+
+import numpy as np
+import tensorflow as tf
+
+from official.nlp.modeling.layers import masked_softmax
+
+EinsumDense = tf.keras.layers.experimental.EinsumDense
+_CHR_IDX = string.ascii_lowercase
+
+
+def _build_attention_equation(qkv_rank, attn_axes):
+ """Builds einsum equations for the attention computation.
+
+ Query, key, value inputs after projection are expected to have the shape as:
+ (bs, , , num_heads, channels).
+ bs and are treated as .
+ The attention operations can be generalized:
+ (1) Query-key dot product:
+ (, , num_heads, channels), (,
+ , num_heads, channels) -> (,
+ num_heads, , )
+ (2) Combination:
+ (, num_heads, , ),
+ (, , num_heads, channels) -> (,
+ , num_heads, channels)
+
+ Args:
+ qkv_rank: the rank of query, key, value tensors.
+ attn_axes: a list/tuple of axes, [1, rank), that will do attention.
+
+ Returns:
+ Einsum equations.
+ """
+ target_notation = _CHR_IDX[:qkv_rank]
+ # `batch_dims` includes the head dim.
+ batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,)))
+ letter_offset = qkv_rank
+ source_notation = ""
+ for i in range(qkv_rank):
+ if i in batch_dims or i == qkv_rank - 1:
+ source_notation += target_notation[i]
+ else:
+ source_notation += _CHR_IDX[letter_offset]
+ letter_offset += 1
+
+ product_notation = "".join([target_notation[i] for i in batch_dims] +
+ [target_notation[i] for i in attn_axes] +
+ [source_notation[i] for i in attn_axes])
+ dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
+ product_notation)
+ attn_scores_rank = len(product_notation)
+ combine_equation = "%s,%s->%s" % (product_notation, source_notation,
+ target_notation)
+ return dot_product_equation, combine_equation, attn_scores_rank
+
+
+def _build_proj_equation(free_dims, bound_dims, output_dims):
+ """Builds an einsum equation for projections inside multi-head attention."""
+ input_str = ""
+ kernel_str = ""
+ output_str = ""
+ bias_axes = ""
+ letter_offset = 0
+ for i in range(free_dims):
+ char = _CHR_IDX[i + letter_offset]
+ input_str += char
+ output_str += char
+
+ letter_offset += free_dims
+ for i in range(bound_dims):
+ char = _CHR_IDX[i + letter_offset]
+ input_str += char
+ kernel_str += char
+
+ letter_offset += bound_dims
+ for i in range(output_dims):
+ char = _CHR_IDX[i + letter_offset]
+ kernel_str += char
+ output_str += char
+ bias_axes += char
+ equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
+
+ return equation, bias_axes, len(output_str)
+
+
+def _get_output_shape(output_rank, known_last_dims):
+ return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+class MultiHeadAttention(tf.keras.layers.Layer):
+ """MultiHeadAttention layer.
+
+ This is an implementation of multi-headed attention based on "Attention
+ is all you Need". If `query`, `key,` `value` are the same, then
+ this is self-attention. Each timestep in `query` attends to the
+ corresponding sequence in `key`, and returns a fixed-width vector.
+
+ This layer first projects `query`, `key` and `value`. These are
+ (effectively) a list of tensors of length `num_attention_heads`, where the
+ corresponding shapes are [batch_size, , key_size],
+ [batch_size, , key_size],
+ [batch_size, , value_size].
+
+ Then, the query and key tensors are dot-producted and scaled. These are
+ softmaxed to obtain attention probabilities. The value tensors are then
+ interpolated by these probabilities, then concatenated back to a single
+ tensor.
+
+ Finally, the result tensor with the last dimension as value_size can take an
+ linear projection and return.
+
+ Examples:
+
+ Performs 1D cross-attention over two sequence inputs with an attention mask.
+ Returns the additional attention weights over heads.
+
+ >>> layer = MultiHeadAttention(num_heads=2, key_size=2,
+ ... return_attention_scores=True)
+ >>> target = tf.keras.Input(shape=[8, 16])
+ >>> source = tf.keras.Input(shape=[4, 16])
+ >>> mask_tensor = tf.keras.Input(shape=[8, 4])
+ >>> output_tensor, weights = layer([target, source])
+ >>> print(output_tensor.shape), print(weights.shape)
+ (None, 8, 16) (None, 2, 8, 4)
+
+ Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
+
+ >>> layer = MultiHeadAttention(num_heads=2, key_size=2, attention_axes=(2, 3))
+ >>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
+ >>> output_tensor = layer([input_tensor, input_tensor])
+ >>> print(output_tensor.shape)
+ (None, 5, 3, 4, 16)
+
+ Arguments:
+ num_heads: Number of attention heads.
+ key_size: Size of each attention head for query and key.
+ value_size: Size of each attention head for value.
+ dropout: Dropout probability.
+ use_bias: Boolean, whether the dense layers use bias vectors/matrices.
+ output_shape: The expected shape of an output tensor, besides the batch and
+ sequence dims. If not specified, projects back to the key feature dim.
+ attention_axes: axes over which the attention is applied. `None` means
+ attention over all axes, but batch, heads, and features.
+ return_attention_scores: bool, if `True`, returns the multi-head
+ attention scores as an additional output argument.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ """
+
+ def __init__(self,
+ num_heads,
+ key_size,
+ value_size=None,
+ dropout=0.0,
+ use_bias=True,
+ output_shape=None,
+ attention_axes=None,
+ return_attention_scores=False,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(MultiHeadAttention, self).__init__(**kwargs)
+ self._num_heads = num_heads
+ self._key_size = key_size
+ self._value_size = value_size if value_size else key_size
+ self._dropout = dropout
+ self._use_bias = use_bias
+ self._output_shape = output_shape
+ self._return_attention_scores = return_attention_scores
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
+ self._bias_constraint = tf.keras.constraints.get(bias_constraint)
+ if attention_axes is not None and not isinstance(attention_axes,
+ collections.abc.Sized):
+ self._attention_axes = (attention_axes,)
+ else:
+ self._attention_axes = attention_axes
+
+ def get_config(self):
+ config = {
+ "num_heads":
+ self._num_heads,
+ "key_size":
+ self._key_size,
+ "value_size":
+ self._value_size,
+ "dropout":
+ self._dropout,
+ "use_bias":
+ self._use_bias,
+ "output_shape":
+ self._output_shape,
+ "attention_axes":
+ self._attention_axes,
+ "return_attention_scores":
+ self._return_attention_scores,
+ "kernel_initializer":
+ tf.keras.initializers.serialize(self._kernel_initializer),
+ "bias_initializer":
+ tf.keras.initializers.serialize(self._bias_initializer),
+ "kernel_regularizer":
+ tf.keras.regularizers.serialize(self._kernel_regularizer),
+ "bias_regularizer":
+ tf.keras.regularizers.serialize(self._bias_regularizer),
+ "activity_regularizer":
+ tf.keras.regularizers.serialize(self._activity_regularizer),
+ "kernel_constraint":
+ tf.keras.constraints.serialize(self._kernel_constraint),
+ "bias_constraint":
+ tf.keras.constraints.serialize(self._bias_constraint)
+ }
+ base_config = super(MultiHeadAttention, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def build(self, input_shape):
+ inputs_len = len(input_shape)
+ if inputs_len > 3 or inputs_len < 2:
+ raise ValueError(
+ "Expects inputs list of length 2 or 3, namely [query, value] or "
+ "[query, value, key]. "
+ "Given length: %d" % inputs_len)
+ tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape)
+ query_shape = tensor_shapes[0]
+ value_shape = tensor_shapes[1]
+ key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape
+
+ common_kwargs = dict(
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint)
+
+ free_dims = query_shape.rank - 1
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ free_dims, bound_dims=1, output_dims=2)
+ self._query_dense = EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1,
+ [self._num_heads, self._key_size]),
+ bias_axes=bias_axes if self._use_bias else None,
+ name="query",
+ **common_kwargs)
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ key_shape.rank - 1, bound_dims=1, output_dims=2)
+ self._key_dense = EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1,
+ [self._num_heads, self._key_size]),
+ bias_axes=bias_axes if self._use_bias else None,
+ name="key",
+ **common_kwargs)
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ value_shape.rank - 1, bound_dims=1, output_dims=2)
+ self._value_dense = EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1,
+ [self._num_heads, self._value_size]),
+ bias_axes=bias_axes if self._use_bias else None,
+ name="value",
+ **common_kwargs)
+
+ # Builds the attention computations for multi-head dot product attention.
+ # These computations could be wrapped into the keras attention layer once it
+ # support mult-head einsum computations.
+ self._build_attention(output_rank)
+ if self._output_shape:
+ if not isinstance(self._output_shape, collections.abc.Sized):
+ output_shape = [self._output_shape]
+ else:
+ output_shape = self._output_shape
+ else:
+ output_shape = [query_shape[-1]]
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ free_dims, bound_dims=2, output_dims=len(output_shape))
+ self._output_dense = EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1, output_shape),
+ bias_axes=bias_axes if self._use_bias else None,
+ name="attention_output",
+ **common_kwargs)
+ super(MultiHeadAttention, self).build(input_shape)
+
+ def _build_attention(self, qkv_rank):
+ """Builds multi-head dot-product attention computations.
+
+ This function builds attributes necessary for `_compute_attention` to
+ costomize attention computation to replace the default dot-product
+ attention.
+
+ Args:
+ qkv_rank: the rank of query, key, value tensors.
+ """
+ if self._attention_axes is None:
+ self._attention_axes = tuple(range(1, qkv_rank - 2))
+ else:
+ self._attention_axes = tuple(self._attention_axes)
+ self._dot_product_equation, self._combine_equation, attn_scores_rank = (
+ _build_attention_equation(qkv_rank, attn_axes=self._attention_axes))
+ norm_axes = tuple(
+ range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
+ self._masked_softmax = masked_softmax.MaskedSoftmax(
+ mask_expansion_axes=[1], normalization_axes=norm_axes)
+ self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
+
+ def _compute_attention(self,
+ query_tensor,
+ key_tensor,
+ value_tensor,
+ attention_mask=None):
+ """Applies Dot-product attention with query, key, value tensors.
+
+ This function defines the computation inside `call` with projected
+ multi-head Q, K, V inputs. Users can override this function for customized
+ attention implementation.
+
+ Args:
+ query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
+ key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
+ value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
+ attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
+ attention to certain positions.
+
+ Returns:
+ attention_output: Multi-headed outputs of attention computation.
+ attention_scores: Multi-headed attention weights.
+ """
+ # Take the dot product between "query" and "key" to get the raw
+ # attention scores.
+ attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
+ query_tensor)
+ attention_scores = tf.multiply(attention_scores,
+ 1.0 / math.sqrt(float(self._key_size)))
+
+ # Normalize the attention scores to probabilities.
+ # `attention_scores` = [B, N, T, S]
+ attention_scores = self._masked_softmax(attention_scores, attention_mask)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_scores_dropout = self._dropout_layer(attention_scores)
+
+ # `context_layer` = [B, T, N, H]
+ attention_output = tf.einsum(self._combine_equation,
+ attention_scores_dropout, value_tensor)
+ return attention_output, attention_scores
+
+ def call(self, inputs, attention_mask=None):
+ """Implements the forward pass.
+
+ Size glossary:
+ * Number of heads (H): the number of attention heads.
+ * Value size (V): the size of each value embedding per head.
+ * Key size (K): the size of each key embedding per head. Equally, the size
+ of each query embedding per head. Typically K <= V.
+ * Batch dimensions (B).
+ * Query (target) attention axes shape (T).
+ * Value (source) attention axes shape (S), the rank must match the target.
+
+ Args:
+ inputs: List of the following tensors:
+ * query: Query `Tensor` of shape `[B, T, dim]`.
+ * value: Value `Tensor` of shape `[B, S, dim]`.
+ * key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
+ use `value` for both `key` and `value`, which is the most common case.
+ attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
+ attention to certain positions.
+
+ Returns:
+ attention_output: The result of the computation, of shape [B, T, E],
+ where `T` is for target sequence shapes and `E` is the query input last
+ dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
+ are project to the shape specified by `output_shape`.
+ attention_scores: [Optional] multi-head attention coeffients over
+ attention
+ axes.
+ """
+ inputs_len = len(inputs)
+ if inputs_len > 3 or inputs_len < 2:
+ raise ValueError(
+ "Expects inputs list of length 2 or 3, namely [query, value] or "
+ "[query, value, key]. "
+ "Given length: %d" % inputs_len)
+ query = inputs[0]
+ value = inputs[1]
+ key = inputs[2] if inputs_len == 3 else value
+
+ # N = `num_attention_heads`
+ # H = `size_per_head`
+ # `query_tensor` = [B, T, N ,H]
+ query_tensor = self._query_dense(query)
+
+ # `key_tensor` = [B, S, N, H]
+ key_tensor = self._key_dense(key)
+
+ # `value_tensor` = [B, S, N, H]
+ value_tensor = self._value_dense(value)
+
+ attention_output, attention_scores = self._compute_attention(
+ query_tensor, key_tensor, value_tensor, attention_mask)
+ attention_output = self._output_dense(attention_output)
+
+ if self._return_attention_scores:
+ return attention_output, attention_scores
+ return attention_output
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+class CachedAttention(MultiHeadAttention):
+ """Attention layer with cache used for auto-agressive decoding.
+
+ Arguments are the same as `MultiHeadAttention` layer.
+ """
+
+ def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
+ """Updates cache states and gets full-length key/value tensors."""
+ # Combines cached keys and values with new keys and values.
+ if decode_loop_step is not None:
+ # TPU special case.
+ key_seq_dim = cache["key"].shape.as_list()[1]
+ indices = tf.reshape(
+ tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
+ [1, key_seq_dim, 1, 1])
+ key_tensor = cache["key"] + key_tensor * indices
+ value_seq_dim = cache["value"].shape.as_list()[1]
+ indices = tf.reshape(
+ tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
+ [1, value_seq_dim, 1, 1])
+ value_tensor = cache["value"] + value_tensor * indices
+ else:
+ key_tensor = tf.concat(
+ [tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
+ value_tensor = tf.concat(
+ [tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
+
+ # Update cache
+ cache["key"] = key_tensor
+ cache["value"] = value_tensor
+
+ return key_tensor, value_tensor
+
+ def call(self,
+ inputs,
+ attention_mask=None,
+ cache=None,
+ decode_loop_step=None):
+ from_tensor = inputs[0]
+ to_tensor = inputs[1]
+
+ # Scalar dimensions referenced here:
+ # B = batch size (number of sequences)
+ # F = `from_tensor` sequence length
+ # T = `to_tensor` sequence length
+ # N = `num_attention_heads`
+ # H = `size_per_head`
+ # `query_tensor` = [B, F, N ,H]
+ query_tensor = self._query_dense(from_tensor)
+
+ # `key_tensor` = [B, T, N, H]
+ key_tensor = self._key_dense(to_tensor)
+
+ # `value_tensor` = [B, T, N, H]
+ value_tensor = self._value_dense(to_tensor)
+
+ if cache:
+ key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
+ cache, decode_loop_step)
+
+ # Take the dot product between "query" and "key" to get the raw
+ # attention scores.
+ attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
+ query_tensor)
+ attention_scores = tf.multiply(attention_scores,
+ 1.0 / math.sqrt(float(self._key_size)))
+
+ # Normalize the attention scores to probabilities.
+ # `attention_scores` = [B, N, F, T]
+ attention_scores = self._masked_softmax(attention_scores, attention_mask)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_scores = self._dropout_layer(attention_scores)
+ # `context_layer` = [B, F, N, H]
+ attention_output = tf.einsum(self._combine_equation, attention_scores,
+ value_tensor)
+ attention_output = self._output_dense(attention_output)
+ if self._return_attention_scores:
+ return attention_output, attention_scores, cache
+ return attention_output, cache
diff --git a/models/official/nlp/modeling/layers/attention_test.py b/models/official/nlp/modeling/layers/attention_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceb96f5084d795cdbafa7cdb352fb4692034f803
--- /dev/null
+++ b/models/official/nlp/modeling/layers/attention_test.py
@@ -0,0 +1,255 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the attention layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import attention
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class MultiHeadAttentionTest(keras_parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("key_value_same_proj", None, None, [40, 80]),
+ ("key_value_different_proj", 32, 60, [40, 60]),
+ )
+ def test_non_masked_attention(self, value_size, output_shape, output_dims):
+ """Test that the attention layer can be created without a mask tensor."""
+ test_layer = attention.MultiHeadAttention(
+ num_heads=12,
+ key_size=64,
+ value_size=value_size,
+ output_shape=output_shape)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ value = tf.keras.Input(shape=(20, 80))
+ output = test_layer([query, value])
+ self.assertEqual(output.shape.as_list(), [None] + output_dims)
+
+ def test_non_masked_self_attention(self):
+ """Test with one input (self-attenntion) and no mask tensor."""
+ test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ output = test_layer([query, query])
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
+
+ def test_attention_scores(self):
+ """Test attention outputs with coefficients."""
+ test_layer = attention.MultiHeadAttention(
+ num_heads=12, key_size=64, return_attention_scores=True)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ output, coef = test_layer([query, query])
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
+ self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
+
+ @parameterized.named_parameters(("with_bias", True), ("no_bias", False))
+ def test_masked_attention(self, use_bias):
+ """Test with a mask tensor."""
+ test_layer = attention.MultiHeadAttention(
+ num_heads=2, key_size=2, use_bias=use_bias)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ batch_size = 3
+ query = tf.keras.Input(shape=(4, 8))
+ value = tf.keras.Input(shape=(2, 8))
+ mask_tensor = tf.keras.Input(shape=(4, 2))
+ output = test_layer([query, value], mask_tensor)
+
+ # Create a model containing the test layer.
+ model = tf.keras.Model([query, value, mask_tensor], output)
+
+ # Generate data for the input (non-mask) tensors.
+ from_data = 10 * np.random.random_sample((batch_size, 4, 8))
+ to_data = 10 * np.random.random_sample((batch_size, 2, 8))
+
+ # Invoke the data with a random set of mask data. This should mask at least
+ # one element.
+ mask_data = np.random.randint(2, size=(batch_size, 4, 2))
+ masked_output_data = model.predict([from_data, to_data, mask_data])
+
+ # Invoke the same data, but with a null mask (where no elements are masked).
+ null_mask_data = np.ones((batch_size, 4, 2))
+ unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
+
+ # Because one data is masked and one is not, the outputs should not be the
+ # same.
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
+
+ # Tests the layer with three inputs: Q, K, V.
+ key = tf.keras.Input(shape=(2, 8))
+ output = test_layer([query, value, key], mask_tensor)
+ model = tf.keras.Model([query, value, key, mask_tensor], output)
+
+ masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
+ unmasked_output_data = model.predict(
+ [from_data, to_data, to_data, null_mask_data])
+ # Because one data is masked and one is not, the outputs should not be the
+ # same.
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
+
+ if use_bias:
+ self.assertLen(test_layer._query_dense.trainable_variables, 2)
+ self.assertLen(test_layer._output_dense.trainable_variables, 2)
+ else:
+ self.assertLen(test_layer._query_dense.trainable_variables, 1)
+ self.assertLen(test_layer._output_dense.trainable_variables, 1)
+
+ def test_initializer(self):
+ """Test with a specified initializer."""
+ test_layer = attention.MultiHeadAttention(
+ num_heads=12,
+ key_size=64,
+ kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ output = test_layer([query, query])
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
+
+ @parameterized.named_parameters(
+ ("4d_inputs_one_free_batch", [3, 4], [3, 2], [4, 2], (2,)),
+ ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
+ ("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)))
+ def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
+ """Test with a mask tensor."""
+ test_layer = attention.MultiHeadAttention(
+ num_heads=2, key_size=2, attention_axes=attention_axes)
+ batch_size, hidden_size = 3, 8
+ # Generate data for the input (non-mask) tensors.
+ query_shape = [batch_size] + q_dims + [hidden_size]
+ value_shape = [batch_size] + v_dims + [hidden_size]
+ mask_shape = [batch_size] + mask_dims
+ query = 10 * np.random.random_sample(query_shape)
+ value = 10 * np.random.random_sample(value_shape)
+
+ # Invoke the data with a random set of mask data. This should mask at least
+ # one element.
+ mask_data = np.random.randint(2, size=mask_shape).astype("bool")
+ output = test_layer([query, value], mask_data)
+
+ # Invoke the same data, but with a null mask (where no elements are masked).
+ null_mask_data = np.ones(mask_shape)
+ unmasked_output = test_layer([query, value], null_mask_data)
+ # Because one data is masked and one is not, the outputs should not be the
+ # same.
+ self.assertNotAllClose(output, unmasked_output)
+
+
+class SubclassAttention(attention.MultiHeadAttention):
+
+ def _build_attention(self, qkv_rank):
+ pass
+
+ def _compute_attention(self,
+ query_tensor,
+ key_tensor,
+ value_tensor,
+ attention_mask=None):
+ return value_tensor, None
+
+
+@keras_parameterized.run_all_keras_modes
+class AttentionSubclassTest(keras_parameterized.TestCase):
+
+ def test_initializer(self):
+ """Test with a specified initializer."""
+ test_layer = SubclassAttention(
+ num_heads=12,
+ key_size=64)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ output = test_layer([query, query])
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
+
+
+def _create_cache(batch_size, init_decode_length, num_heads, head_size):
+ return {
+ "key":
+ tf.zeros([batch_size, init_decode_length, num_heads, head_size],
+ dtype=tf.float32),
+ "value":
+ tf.zeros([batch_size, init_decode_length, num_heads, head_size],
+ dtype=tf.float32)
+ }
+
+
+@keras_parameterized.run_all_keras_modes
+class CachedAttentionTest(keras_parameterized.TestCase):
+
+ def test_masked_attention(self):
+ """Test with a mask tensor."""
+ num_heads, head_size = 2, 2
+ # Create a 3-dimensional input (the first dimension is implicit).
+ from_seq_length = 4
+ batch_size = 3
+ # GPU/CPU case.
+ init_decode_length = 0
+ # Directly tests the keras layer.
+ cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
+ layer = attention.CachedAttention(num_heads=num_heads, key_size=head_size)
+
+ # Generate data for the input (non-mask) tensors.
+ from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
+ # Invoke the data with a random set of mask data. This should mask at least
+ # one element.
+ mask_data = np.random.randint(
+ 2, size=(batch_size, from_seq_length, from_seq_length))
+ masked_output_data, cache = layer([from_data, from_data], mask_data, cache)
+ self.assertEqual(masked_output_data.shape, (3, 4, 8))
+ self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
+
+ # Tests inputs without cache.
+ masked_output_data, cache = layer([from_data, from_data, mask_data])
+ self.assertEqual(masked_output_data.shape, (3, 4, 8))
+ self.assertIsNone(cache)
+
+ def test_padded_decode(self):
+ """Test with a mask tensor."""
+ num_heads, head_size = 2, 2
+ from_seq_length = 4
+ # TPU decoding should pre-allocate the entire sequence.
+ batch_size = 3
+ init_decode_length = from_seq_length
+
+ # Directly tests the keras layer.
+ cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
+ layer = attention.CachedAttention(num_heads=num_heads, key_size=head_size)
+
+ # Generate data for the input (non-mask) tensors.
+ from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
+ decode_loop_step = 2
+ mask_data = np.random.randint(
+ 2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
+ # Testing the invocation directly as Keras cannot consume inputs correctly.
+ masked_output_data, cache = layer([from_data, from_data],
+ mask_data,
+ cache,
+ decode_loop_step=decode_loop_step)
+ self.assertEqual(masked_output_data.shape, (3, 4, 8))
+ self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/cls_head.py b/models/official/nlp/modeling/layers/cls_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0240511429e58453fa9c483be120705347a0c754
--- /dev/null
+++ b/models/official/nlp/modeling/layers/cls_head.py
@@ -0,0 +1,90 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A Classification head layer which is common used with sequence encoders."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.modeling import tf_utils
+
+
+class ClassificationHead(tf.keras.layers.Layer):
+ """Pooling head for sentence-level classification tasks."""
+
+ def __init__(self,
+ inner_dim,
+ num_classes,
+ cls_token_idx=0,
+ activation="tanh",
+ dropout_rate=0.0,
+ initializer="glorot_uniform",
+ **kwargs):
+ """Initializes the `ClassificationHead`.
+
+ Args:
+ inner_dim: The dimensionality of inner projection layer.
+ num_classes: Number of output classes.
+ cls_token_idx: The index inside the sequence to pool.
+ activation: Dense layer activation.
+ dropout_rate: Dropout probability.
+ initializer: Initializer for dense layer kernels.
+ **kwargs: Keyword arguments.
+ """
+ super(ClassificationHead, self).__init__(**kwargs)
+ self.dropout_rate = dropout_rate
+ self.inner_dim = inner_dim
+ self.num_classes = num_classes
+ self.activation = tf_utils.get_activation(activation)
+ self.initializer = tf.keras.initializers.get(initializer)
+ self.cls_token_idx = cls_token_idx
+
+ self.dense = tf.keras.layers.Dense(
+ units=inner_dim,
+ activation=self.activation,
+ kernel_initializer=self.initializer,
+ name="pooler_dense")
+ self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
+ self.out_proj = tf.keras.layers.Dense(
+ units=num_classes, kernel_initializer=self.initializer, name="logits")
+
+ def call(self, features):
+ x = features[:, self.cls_token_idx, :] # take token.
+ x = self.dense(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+ def get_config(self):
+ config = {
+ "dropout_rate": self.dropout_rate,
+ "num_classes": self.num_classes,
+ "inner_dim": self.inner_dim,
+ "activation": tf.keras.activations.serialize(self.activation),
+ "initializer": tf.keras.initializers.serialize(self.initializer),
+ }
+ config.update(super(ClassificationHead, self).get_config())
+ return config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
+
+ @property
+ def checkpoint_items(self):
+ return {self.dense.name: self.dense}
diff --git a/models/official/nlp/modeling/layers/cls_head_test.py b/models/official/nlp/modeling/layers/cls_head_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea671f94f5806800f1f5ce07df9fffeff7a3ab68
--- /dev/null
+++ b/models/official/nlp/modeling/layers/cls_head_test.py
@@ -0,0 +1,42 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cls_head."""
+
+import tensorflow as tf
+
+from official.nlp.modeling.layers import cls_head
+
+
+class ClassificationHead(tf.test.TestCase):
+
+ def test_layer_invocation(self):
+ test_layer = cls_head.ClassificationHead(inner_dim=5, num_classes=2)
+ features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
+ output = test_layer(features)
+ self.assertAllClose(output, [[0., 0.], [0., 0.]])
+ self.assertSameElements(test_layer.checkpoint_items.keys(),
+ ["pooler_dense"])
+
+ def test_layer_serialization(self):
+ layer = cls_head.ClassificationHead(10, 2)
+ new_layer = cls_head.ClassificationHead.from_config(layer.get_config())
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(layer.get_config(), new_layer.get_config())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/dense_einsum.py b/models/official/nlp/modeling/layers/dense_einsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba2383e6d9e47f1e1d39898c16bf99748e4d38e3
--- /dev/null
+++ b/models/official/nlp/modeling/layers/dense_einsum.py
@@ -0,0 +1,180 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based einsum layer."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+class DenseEinsum(tf.keras.layers.Layer):
+ """A densely connected layer that uses tf.einsum as the backing computation.
+
+ This layer can perform einsum calculations of arbitrary dimensionality.
+
+ Arguments:
+ output_shape: Positive integer or tuple, dimensionality of the output space.
+ num_summed_dimensions: The number of dimensions to sum over. Standard 2D
+ matmul should use 1, 3D matmul should use 2, and so forth.
+ activation: Activation function to use. If you don't specify anything, no
+ activation is applied
+ (ie. "linear" activation: `a(x) = x`).
+ use_bias: Boolean, whether the layer uses a bias vector.
+ kernel_initializer: Initializer for the `kernel` weights matrix.
+ bias_initializer: Initializer for the bias vector.
+ kernel_regularizer: Regularizer function applied to the `kernel` weights
+ matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
+ activity_regularizer: Regularizer function applied to the output of the
+ layer (its "activation")..
+ kernel_constraint: Constraint function applied to the `kernel` weights
+ matrix.
+ bias_constraint: Constraint function applied to the bias vector.
+ Input shape:
+ N-D tensor with shape: `(batch_size, ..., input_dim)`. The most common
+ situation would be a 2D input with shape `(batch_size, input_dim)`.
+ Output shape:
+ N-D tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D
+ input with shape `(batch_size, input_dim)`, the output would have shape
+ `(batch_size, units)`.
+ """
+
+ def __init__(self,
+ output_shape,
+ num_summed_dimensions=1,
+ activation=None,
+ use_bias=True,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(DenseEinsum, self).__init__(**kwargs)
+ self._output_shape = output_shape if isinstance(
+ output_shape, (list, tuple)) else (output_shape,)
+ self._activation = tf.keras.activations.get(activation)
+ self._use_bias = use_bias
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
+ self._bias_constraint = tf.keras.constraints.get(bias_constraint)
+ self._num_summed_dimensions = num_summed_dimensions
+ self._einsum_string = None
+
+ def _build_einsum_string(self, free_input_dims, bound_dims, output_dims):
+ input_str = ""
+ kernel_str = ""
+ output_str = ""
+ letter_offset = 0
+ for i in range(free_input_dims):
+ char = _CHR_IDX[i + letter_offset]
+ input_str += char
+ output_str += char
+
+ letter_offset += free_input_dims
+ for i in range(bound_dims):
+ char = _CHR_IDX[i + letter_offset]
+ input_str += char
+ kernel_str += char
+
+ letter_offset += bound_dims
+ for i in range(output_dims):
+ char = _CHR_IDX[i + letter_offset]
+ kernel_str += char
+ output_str += char
+
+ return input_str + "," + kernel_str + "->" + output_str
+
+ def build(self, input_shape):
+ input_shape = tf.TensorShape(input_shape)
+ input_rank = input_shape.rank
+ free_input_dims = input_rank - self._num_summed_dimensions
+ output_dims = len(self._output_shape)
+
+ self._einsum_string = self._build_einsum_string(free_input_dims,
+ self._num_summed_dimensions,
+ output_dims)
+
+ # This is only saved for testing purposes.
+ self._kernel_shape = (
+ input_shape[free_input_dims:].concatenate(self._output_shape))
+
+ self._kernel = self.add_weight(
+ "kernel",
+ shape=self._kernel_shape,
+ initializer=self._kernel_initializer,
+ regularizer=self._kernel_regularizer,
+ constraint=self._kernel_constraint,
+ dtype=self.dtype,
+ trainable=True)
+ if self._use_bias:
+ self._bias = self.add_weight(
+ "bias",
+ shape=self._output_shape,
+ initializer=self._bias_initializer,
+ regularizer=self._bias_regularizer,
+ constraint=self._bias_constraint,
+ dtype=self.dtype,
+ trainable=True)
+ else:
+ self._bias = None
+ super(DenseEinsum, self).build(input_shape)
+
+ def get_config(self):
+ config = {
+ "output_shape":
+ self._output_shape,
+ "num_summed_dimensions":
+ self._num_summed_dimensions,
+ "activation":
+ tf.keras.activations.serialize(self._activation),
+ "use_bias":
+ self._use_bias,
+ "kernel_initializer":
+ tf.keras.initializers.serialize(self._kernel_initializer),
+ "bias_initializer":
+ tf.keras.initializers.serialize(self._bias_initializer),
+ "kernel_regularizer":
+ tf.keras.regularizers.serialize(self._kernel_regularizer),
+ "bias_regularizer":
+ tf.keras.regularizers.serialize(self._bias_regularizer),
+ "activity_regularizer":
+ tf.keras.regularizers.serialize(self._activity_regularizer),
+ "kernel_constraint":
+ tf.keras.constraints.serialize(self._kernel_constraint),
+ "bias_constraint":
+ tf.keras.constraints.serialize(self._bias_constraint)
+ }
+ base_config = super(DenseEinsum, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def call(self, inputs):
+ ret = tf.einsum(self._einsum_string, inputs, self._kernel)
+ if self._use_bias:
+ ret += self._bias
+ if self._activation is not None:
+ ret = self._activation(ret)
+ return ret
diff --git a/models/official/nlp/modeling/layers/dense_einsum_test.py b/models/official/nlp/modeling/layers/dense_einsum_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..57a60fe52fa835c09df228274d42ed7eb8f39595
--- /dev/null
+++ b/models/official/nlp/modeling/layers/dense_einsum_test.py
@@ -0,0 +1,123 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras-based einsum layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import dense_einsum
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class DenseEinsumLayer(keras_parameterized.TestCase):
+
+ def test_3D_einsum_with_two_bound_dimensions(self):
+ test_layer = dense_einsum.DenseEinsum(
+ output_shape=(64,), num_summed_dimensions=2)
+ # Create a 4-dimensional input (the first dimension is implicit).
+ input_tensor = tf.keras.Input(shape=(None, 40, 80))
+ _ = test_layer(input_tensor)
+ self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
+ self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
+
+ def test_3D_einsum_with_one_bound_dimensions(self):
+ test_layer = dense_einsum.DenseEinsum(
+ output_shape=(64, 32), num_summed_dimensions=1)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ input_tensor = tf.keras.Input(shape=(None, 80))
+ _ = test_layer(input_tensor)
+ self.assertEqual(test_layer._einsum_string, "abc,cde->abde")
+ self.assertEqual(test_layer._kernel_shape, (80, 64, 32))
+
+ def test_2D_einsum_with_one_bound_dimensions(self):
+ test_layer = dense_einsum.DenseEinsum(
+ output_shape=(64,), num_summed_dimensions=1)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ input_tensor = tf.keras.Input(shape=(None, 80))
+ _ = test_layer(input_tensor)
+ self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
+ self.assertEqual(test_layer._kernel_shape, (80, 64))
+
+ def test_bias_term_can_be_disabled(self):
+ # A layer created using the bias should have two weights.
+ test_layer = dense_einsum.DenseEinsum(
+ output_shape=64, num_summed_dimensions=1, use_bias=True)
+ input_tensor = tf.keras.Input(shape=(None, 80))
+ _ = test_layer(input_tensor)
+ self.assertEqual(2, len(test_layer.get_weights()))
+
+ # A layer created without the bias should have only one weight.
+ test_layer = dense_einsum.DenseEinsum(
+ output_shape=64, num_summed_dimensions=1, use_bias=False)
+ input_tensor = tf.keras.Input(shape=(None, 80))
+ _ = test_layer(input_tensor)
+ self.assertEqual(1, len(test_layer.get_weights()))
+
+ def test_activation(self):
+ # Create a model that does not use an activation.
+ no_activation_layer = dense_einsum.DenseEinsum(
+ output_shape=64, num_summed_dimensions=1, activation=None)
+ input_tensor = tf.keras.Input(shape=(None, 80))
+ output_tensor = no_activation_layer(input_tensor)
+ no_activation_model = tf.keras.Model(input_tensor, output_tensor)
+
+ # Create a model that uses a softmax activation.
+ activation_layer = dense_einsum.DenseEinsum(
+ output_shape=64, num_summed_dimensions=1, activation="softmax")
+ input_tensor = tf.keras.Input(shape=(None, 80))
+ output_tensor = activation_layer(input_tensor)
+ activation_model = tf.keras.Model(input_tensor, output_tensor)
+
+ # Make sure the models' weights are identical.
+ activation_model.set_weights(no_activation_model.get_weights())
+
+ # Predict using each model on the same input data. The output should be
+ # different, since one is using a softmax - even though the models' weights
+ # are the same.
+ input_values = 10 * np.random.random_sample((10, 4, 80))
+ non_activated_data = no_activation_model.predict(input_values)
+ activated_data = activation_model.predict(input_values)
+ self.assertNotAllClose(activated_data, non_activated_data)
+
+ def test_non_iterable_output_shape(self):
+ test_layer = dense_einsum.DenseEinsum(
+ output_shape=64, num_summed_dimensions=1)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ input_tensor = tf.keras.Input(shape=(None, 80))
+ _ = test_layer(input_tensor)
+ self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
+ self.assertEqual(test_layer._kernel_shape, (80, 64))
+
+ def test_with_explicit_initializer(self):
+ test_layer = dense_einsum.DenseEinsum(
+ output_shape=(64,),
+ num_summed_dimensions=2,
+ kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
+ # Create a 4-dimensional input (the first dimension is implicit).
+ input_tensor = tf.keras.Input(shape=(None, 40, 80))
+ _ = test_layer(input_tensor)
+ self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
+ self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/gated_feedforward.py b/models/official/nlp/modeling/layers/gated_feedforward.py
new file mode 100644
index 0000000000000000000000000000000000000000..11c912885a7b8eb68e6d764653275fb2b5d2de92
--- /dev/null
+++ b/models/official/nlp/modeling/layers/gated_feedforward.py
@@ -0,0 +1,210 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based gated feedforward layer."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import gin
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+@gin.configurable
+class GatedFeedforward(tf.keras.layers.Layer):
+ """Gated linear feedforward layer.
+
+ This layer follows the paper "GLU Variants Improve Transformer"
+ (https://arxiv.org/abs/2002.05202). In additional, it allows to stack
+ multiple feedforward blocks and specify the position of dropout layer.
+
+ Arguments:
+ intermediate_size: Size of the intermediate layer.
+ intermediate_activation: Activation for the intermediate layer.
+ dropout: Dropout probability for the output dropout.
+ use_gate: Whether to use gated linear units. If True, assuming `GELU` as
+ the activation and omitting bias, will apply
+ `GEGLU(x, W, V, W_2) = (GEGLU(xW) * xV)W2`; if False, will follow
+ "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) paper
+ and apply `FFN(x, W, W_2) = GELU(xW_1)W_2.`
+ num_blocks: The number of feedforward blocks to stack. Each block contains
+ a (gated) linear layer and a fully connected layer followed by dropout,
+ layer norm and residual.
+ dropout_position: Where to apply the dropout, the value can be either
+ `before_residual` or `after_residual`. If `before_residual`, will apply
+ `layer_output = layer_norm(dropout(layer_output) + layer_input)`;
+ if `after residual`, will apply
+ `layer_output = dropout(layer_norm(layer_output + layer_input))`.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ """
+
+ def __init__(self,
+ intermediate_size,
+ intermediate_activation,
+ dropout,
+ use_gate=True,
+ num_blocks=1,
+ dropout_position="before_residual",
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(GatedFeedforward, self).__init__(**kwargs)
+ self._intermediate_size = intermediate_size
+ self._intermediate_activation = intermediate_activation
+ self._dropout = dropout
+ self._use_gate = use_gate
+ self._num_blocks = num_blocks
+ self._dropout_position = dropout_position
+ if self._dropout_position not in ("before_residual", "after_residual"):
+ raise ValueError(
+ "The dropout_position should be either `before_residual` or"
+ "`after_residual`, got: %s" % self._dropout_position)
+
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
+ self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
+ self._bias_constraint = tf.keras.constraints.get(bias_constraint)
+
+ def build(self, input_shape):
+ hidden_size = input_shape.as_list()[-1]
+
+ common_kwargs = dict(
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint)
+ self._intermediate_dense = []
+ self._intermediate_activation_layers = []
+ self._gate_dense = []
+ self._output_dense = []
+ self._output_dropout = []
+ self._output_layer_norm = []
+ activation_policy = tf.keras.mixed_precision.experimental.global_policy()
+ if activation_policy.name == "mixed_bfloat16":
+ # bfloat16 causes BERT with the LAMB optimizer to not converge
+ # as well, so we use float32.
+ # TODO(b/154538392): Investigate this.
+ activation_policy = tf.float32
+ for i in range(self._num_blocks):
+ self._intermediate_dense.append(
+ tf.keras.layers.experimental.EinsumDense(
+ "abc,cd->abd",
+ output_shape=(None, self._intermediate_size),
+ bias_axes="d",
+ name="intermediate_%d" % i,
+ **common_kwargs))
+ self._intermediate_activation_layers.append(tf.keras.layers.Activation(
+ self._intermediate_activation, dtype=activation_policy))
+ if self._use_gate:
+ self._gate_dense.append(
+ tf.keras.layers.experimental.EinsumDense(
+ "abc,cd->abd",
+ output_shape=(None, self._intermediate_size),
+ bias_axes="d",
+ name="gate_%d" % i,
+ **common_kwargs))
+ self._output_dense.append(
+ tf.keras.layers.experimental.EinsumDense(
+ "abc,cd->abd",
+ output_shape=(None, hidden_size),
+ bias_axes="d",
+ name="output_%d" % i,
+ **common_kwargs))
+ self._output_dropout.append(
+ tf.keras.layers.Dropout(rate=self._dropout))
+ # Use float32 in layernorm for numeric stability.
+ self._output_layer_norm.append(
+ tf.keras.layers.LayerNormalization(
+ name="output_layer_norm_%d" % i,
+ axis=-1,
+ epsilon=1e-12,
+ dtype=tf.float32))
+
+ def get_config(self):
+ config = {
+ "intermediate_size":
+ self._intermediate_size,
+ "intermediate_activation":
+ self._intermediate_activation,
+ "dropout":
+ self._dropout,
+ "use_gate":
+ self._use_gate,
+ "num_blocks":
+ self._num_blocks,
+ "dropout_position":
+ self._dropout_position,
+ "kernel_initializer":
+ tf.keras.initializers.serialize(self._kernel_initializer),
+ "bias_initializer":
+ tf.keras.initializers.serialize(self._bias_initializer),
+ "kernel_regularizer":
+ tf.keras.regularizers.serialize(self._kernel_regularizer),
+ "bias_regularizer":
+ tf.keras.regularizers.serialize(self._bias_regularizer),
+ "activity_regularizer":
+ tf.keras.regularizers.serialize(self._activity_regularizer),
+ "kernel_constraint":
+ tf.keras.constraints.serialize(self._kernel_constraint),
+ "bias_constraint":
+ tf.keras.constraints.serialize(self._bias_constraint)
+ }
+ base_config = super(GatedFeedforward, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def call(self, inputs):
+ layer_output = inputs
+ for i in range(self._num_blocks):
+ layer_input = layer_output
+ intermediate_output = self._intermediate_dense[i](layer_input)
+ intermediate_output = self._intermediate_activation_layers[i](
+ intermediate_output)
+ if self._use_gate:
+ gated_linear = self._gate_dense[i](layer_input)
+ intermediate_output = intermediate_output * gated_linear
+
+ layer_output = self._output_dense[i](intermediate_output)
+ if self._dropout_position == "before_residual":
+ layer_output = self._output_dropout[i](layer_output)
+
+ # During mixed precision training, `layer_input` may be from layer norm.
+ # If so, it is always fp32. Cast layer_output to fp32 for the subsequent
+ # add.
+ if layer_input.dtype == tf.float32:
+ layer_output = tf.cast(layer_output, tf.float32)
+ layer_output = self._output_layer_norm[i](layer_output + layer_input)
+ if self._dropout_position == "after_residual":
+ layer_output = self._output_dropout[i](layer_output)
+
+ return layer_output
diff --git a/models/official/nlp/modeling/layers/gated_feedforward_test.py b/models/official/nlp/modeling/layers/gated_feedforward_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8daeb5d32fde9be2765fe3819b13ee9a13546f55
--- /dev/null
+++ b/models/official/nlp/modeling/layers/gated_feedforward_test.py
@@ -0,0 +1,127 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras-based gated feedforward layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import gated_feedforward
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class GatedFeedforwardTest(keras_parameterized.TestCase):
+
+ def tearDown(self):
+ super(GatedFeedforwardTest, self).tearDown()
+ tf.keras.mixed_precision.experimental.set_policy("float32")
+
+ @parameterized.parameters(
+ (True, 1, "after_residual", "float32"),
+ (True, 1, "after_residual", "mixed_float16"),
+ (False, 4, "before_residual", "float32"),
+ (False, 4, "before_residual", "mixed_float16"),
+ (True, 4, "after_residual", "float32"),
+ (True, 4, "after_residual", "mixed_float16"),
+ (False, 1, "before_residual", "float32"),
+ (False, 1, "before_residual", "mixed_float16"),
+ )
+ def test_layer_creation(self, use_gate, num_blocks, dropout_position, dtype):
+ tf.keras.mixed_precision.experimental.set_policy(dtype)
+ kwargs = dict(
+ intermediate_size=128,
+ intermediate_activation="relu",
+ dropout=0.1,
+ use_gate=use_gate,
+ num_blocks=num_blocks,
+ dropout_position=dropout_position,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros")
+ test_layer = gated_feedforward.GatedFeedforward(**kwargs)
+
+ sequence_length = 64
+ width = 128
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(data_tensor)
+ # The default output of a transformer layer should be the same as the input.
+ self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
+
+ @parameterized.parameters(
+ (True, 1, "after_residual", "float32"),
+ (True, 1, "after_residual", "mixed_float16"),
+ (False, 4, "before_residual", "float32"),
+ (False, 4, "before_residual", "mixed_float16"),
+ (True, 4, "after_residual", "float32"),
+ (True, 4, "after_residual", "mixed_float16"),
+ (False, 1, "before_residual", "float32"),
+ (False, 1, "before_residual", "mixed_float16"),
+ )
+ def test_layer_invocation(self, use_gate, num_blocks, dropout_position,
+ dtype):
+ tf.keras.mixed_precision.experimental.set_policy(dtype)
+ kwargs = dict(
+ intermediate_size=16,
+ intermediate_activation="relu",
+ dropout=0.1,
+ use_gate=use_gate,
+ num_blocks=num_blocks,
+ dropout_position=dropout_position,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros")
+ test_layer = gated_feedforward.GatedFeedforward(**kwargs)
+
+ sequence_length = 16
+ width = 32
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(data_tensor)
+
+ # Create a model from the test layer.
+ model = tf.keras.Model(data_tensor, output_tensor)
+
+ # Invoke the model on test data.
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ output_data = model.predict(input_data)
+ self.assertEqual(output_data.shape, (batch_size, sequence_length, width))
+
+ def test_serialize_deserialize(self):
+ kwargs = dict(
+ intermediate_size=16,
+ intermediate_activation="relu",
+ dropout=0.1,
+ use_gate=False,
+ num_blocks=4,
+ dropout_position="after_residual",
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros")
+ test_layer = gated_feedforward.GatedFeedforward(**kwargs)
+ new_layer = gated_feedforward.GatedFeedforward.from_config(
+ test_layer.get_config())
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/masked_lm.py b/models/official/nlp/modeling/layers/masked_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b81556f4c7d82e79c9d9cda4894a26fde6a93f7
--- /dev/null
+++ b/models/official/nlp/modeling/layers/masked_lm.py
@@ -0,0 +1,124 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Masked language model network."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.modeling import tf_utils
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class MaskedLM(tf.keras.layers.Layer):
+ """Masked language model network head for BERT modeling.
+
+ This network implements a masked language model based on the provided network.
+ It assumes that the network being passed has a "get_embedding_table()" method.
+
+ Arguments:
+ embedding_table: The embedding table of the targets.
+ activation: The activation, if any, for the dense layer.
+ initializer: The intializer for the dense layer. Defaults to a Glorot
+ uniform initializer.
+ output: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ """
+
+ def __init__(self,
+ embedding_table,
+ activation=None,
+ initializer='glorot_uniform',
+ output='logits',
+ name='cls/predictions',
+ **kwargs):
+ super(MaskedLM, self).__init__(name=name, **kwargs)
+ self.embedding_table = embedding_table
+ self.activation = activation
+ self.initializer = tf.keras.initializers.get(initializer)
+
+ if output not in ('predictions', 'logits'):
+ raise ValueError(
+ ('Unknown `output` value "%s". `output` can be either "logits" or '
+ '"predictions"') % output)
+ self._output_type = output
+
+ def build(self, input_shape):
+ self._vocab_size, hidden_size = self.embedding_table.shape
+ self.dense = tf.keras.layers.Dense(
+ hidden_size,
+ activation=self.activation,
+ kernel_initializer=self.initializer,
+ name='transform/dense')
+ self.layer_norm = tf.keras.layers.LayerNormalization(
+ axis=-1, epsilon=1e-12, name='transform/LayerNorm')
+ self.bias = self.add_weight(
+ 'output_bias/bias',
+ shape=(self._vocab_size,),
+ initializer='zeros',
+ trainable=True)
+
+ super(MaskedLM, self).build(input_shape)
+
+ def call(self, sequence_data, masked_positions):
+ masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
+ lm_data = self.dense(masked_lm_input)
+ lm_data = self.layer_norm(lm_data)
+ lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
+ logits = tf.nn.bias_add(lm_data, self.bias)
+
+ masked_positions_shape = tf_utils.get_shape_list(
+ masked_positions, name='masked_positions_tensor')
+ logits = tf.reshape(logits,
+ [-1, masked_positions_shape[1], self._vocab_size])
+ if self._output_type == 'logits':
+ return logits
+ return tf.nn.log_softmax(logits)
+
+ def get_config(self):
+ raise NotImplementedError('MaskedLM cannot be directly serialized because '
+ 'it has variable sharing logic.')
+
+ def _gather_indexes(self, sequence_tensor, positions):
+ """Gathers the vectors at the specific positions.
+
+ Args:
+ sequence_tensor: Sequence output of `BertModel` layer of shape
+ (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
+ hidden units of `BertModel` layer.
+ positions: Positions ids of tokens in sequence to mask for pretraining
+ of with dimension (batch_size, num_predictions) where
+ `num_predictions` is maximum number of tokens to mask out and predict
+ per each sequence.
+
+ Returns:
+ Masked out sequence tensor of shape (batch_size * num_predictions,
+ num_hidden).
+ """
+ sequence_shape = tf_utils.get_shape_list(
+ sequence_tensor, name='sequence_output_tensor')
+ batch_size, seq_length, width = sequence_shape
+
+ flat_offsets = tf.reshape(
+ tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
+ flat_positions = tf.reshape(positions + flat_offsets, [-1])
+ flat_sequence_tensor = tf.reshape(sequence_tensor,
+ [batch_size * seq_length, width])
+ output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
+
+ return output_tensor
diff --git a/models/official/nlp/modeling/layers/masked_lm_test.py b/models/official/nlp/modeling/layers/masked_lm_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..12e28ec95ff49c95c2729efeae04382bad5c611f
--- /dev/null
+++ b/models/official/nlp/modeling/layers/masked_lm_test.py
@@ -0,0 +1,162 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for masked language model network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+
+from official.nlp.modeling.layers import masked_lm
+from official.nlp.modeling.networks import transformer_encoder
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class MaskedLMTest(keras_parameterized.TestCase):
+
+ def create_layer(self,
+ vocab_size,
+ sequence_length,
+ hidden_size,
+ output='predictions',
+ xformer_stack=None):
+ # First, create a transformer stack that we can use to get the LM's
+ # vocabulary weight.
+ if xformer_stack is None:
+ xformer_stack = transformer_encoder.TransformerEncoder(
+ vocab_size=vocab_size,
+ num_layers=1,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ num_attention_heads=4,
+ )
+
+ # Create a maskedLM from the transformer stack.
+ test_layer = masked_lm.MaskedLM(
+ embedding_table=xformer_stack.get_embedding_table(),
+ output=output)
+ return test_layer
+
+ def test_layer_creation(self):
+ vocab_size = 100
+ sequence_length = 32
+ hidden_size = 64
+ num_predictions = 21
+ test_layer = self.create_layer(
+ vocab_size=vocab_size,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size)
+
+ # Make sure that the output tensor of the masked LM is the right shape.
+ lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
+ masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
+ output = test_layer(lm_input_tensor, masked_positions=masked_positions)
+
+ expected_output_shape = [None, num_predictions, vocab_size]
+ self.assertEqual(expected_output_shape, output.shape.as_list())
+
+ def test_layer_invocation_with_external_logits(self):
+ vocab_size = 100
+ sequence_length = 32
+ hidden_size = 64
+ num_predictions = 21
+ xformer_stack = transformer_encoder.TransformerEncoder(
+ vocab_size=vocab_size,
+ num_layers=1,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ num_attention_heads=4,
+ )
+ test_layer = self.create_layer(
+ vocab_size=vocab_size,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ xformer_stack=xformer_stack,
+ output='predictions')
+ logit_layer = self.create_layer(
+ vocab_size=vocab_size,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ xformer_stack=xformer_stack,
+ output='logits')
+
+ # Create a model from the masked LM layer.
+ lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
+ masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
+ output = test_layer(lm_input_tensor, masked_positions)
+ logit_output = logit_layer(lm_input_tensor, masked_positions)
+ logit_output = tf.keras.layers.Activation(tf.nn.log_softmax)(logit_output)
+ logit_layer.set_weights(test_layer.get_weights())
+ model = tf.keras.Model([lm_input_tensor, masked_positions], output)
+ logits_model = tf.keras.Model(([lm_input_tensor, masked_positions]),
+ logit_output)
+
+ # Invoke the masked LM on some fake data to make sure there are no runtime
+ # errors in the code.
+ batch_size = 3
+ lm_input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, hidden_size))
+ masked_position_data = np.random.randint(
+ sequence_length, size=(batch_size, num_predictions))
+ # ref_outputs = model.predict([lm_input_data, masked_position_data])
+ # outputs = logits_model.predict([lm_input_data, masked_position_data])
+ ref_outputs = model([lm_input_data, masked_position_data])
+ outputs = logits_model([lm_input_data, masked_position_data])
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, num_predictions, vocab_size)
+ self.assertEqual(expected_output_shape, ref_outputs.shape)
+ self.assertEqual(expected_output_shape, outputs.shape)
+ self.assertAllClose(ref_outputs, outputs)
+
+ def test_layer_invocation(self):
+ vocab_size = 100
+ sequence_length = 32
+ hidden_size = 64
+ num_predictions = 21
+ test_layer = self.create_layer(
+ vocab_size=vocab_size,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size)
+
+ # Create a model from the masked LM layer.
+ lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
+ masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
+ output = test_layer(lm_input_tensor, masked_positions)
+ model = tf.keras.Model([lm_input_tensor, masked_positions], output)
+
+ # Invoke the masked LM on some fake data to make sure there are no runtime
+ # errors in the code.
+ batch_size = 3
+ lm_input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, hidden_size))
+ masked_position_data = np.random.randint(
+ 2, size=(batch_size, num_predictions))
+ _ = model.predict([lm_input_data, masked_position_data])
+
+ def test_unknown_output_type_fails(self):
+ with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
+ _ = self.create_layer(
+ vocab_size=8, sequence_length=8, hidden_size=8, output='bad')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/masked_softmax.py b/models/official/nlp/modeling/layers/masked_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..42a9e97a329e6c2892bb584f38375888a7fbdd2f
--- /dev/null
+++ b/models/official/nlp/modeling/layers/masked_softmax.py
@@ -0,0 +1,72 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based softmax layer with optional masking."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class MaskedSoftmax(tf.keras.layers.Layer):
+ """Performs a softmax with optional masking on a tensor.
+
+ Arguments:
+ mask_expansion_axes: Any axes that should be padded on the mask tensor.
+ normalization_axes: On which axes the softmax should perform.
+ """
+
+ def __init__(self,
+ mask_expansion_axes=None,
+ normalization_axes=None,
+ **kwargs):
+ self._mask_expansion_axes = mask_expansion_axes
+ if normalization_axes is None:
+ self._normalization_axes = (-1,)
+ else:
+ self._normalization_axes = normalization_axes
+ super(MaskedSoftmax, self).__init__(**kwargs)
+
+ def call(self, scores, mask=None):
+
+ if mask is not None:
+ for _ in range(len(scores.shape) - len(mask.shape)):
+ mask = tf.expand_dims(mask, axis=self._mask_expansion_axes)
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ adder = (1.0 - tf.cast(mask, scores.dtype)) * -10000.0
+
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ scores += adder
+
+ if len(self._normalization_axes) == 1:
+ return tf.nn.softmax(scores, axis=self._normalization_axes[0])
+ else:
+ return tf.math.exp(scores - tf.math.reduce_logsumexp(
+ scores, axis=self._normalization_axes, keepdims=True))
+
+ def get_config(self):
+ config = {
+ 'mask_expansion_axes': self._mask_expansion_axes,
+ 'normalization_axes': self._normalization_axes
+ }
+ base_config = super(MaskedSoftmax, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
diff --git a/models/official/nlp/modeling/layers/masked_softmax_test.py b/models/official/nlp/modeling/layers/masked_softmax_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..befe0f786a7b4d84a5dc975d1780acdd2c964a2c
--- /dev/null
+++ b/models/official/nlp/modeling/layers/masked_softmax_test.py
@@ -0,0 +1,119 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras-based masked softmax layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import masked_softmax
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
+
+ def test_non_masked_softmax(self):
+ test_layer = masked_softmax.MaskedSoftmax()
+ input_tensor = tf.keras.Input(shape=(4, 8))
+ output = test_layer(input_tensor)
+ model = tf.keras.Model(input_tensor, output)
+
+ input_data = 10 * np.random.random_sample((3, 4, 8))
+ output_data = model.predict(input_data)
+ expected_data = tf.nn.softmax(input_data)
+ self.assertAllClose(expected_data, output_data)
+
+ def test_masked_softmax(self):
+ test_layer = masked_softmax.MaskedSoftmax()
+ input_tensor = tf.keras.Input(shape=(4, 8))
+ mask_tensor = tf.keras.Input(shape=(4, 8))
+ output = test_layer(input_tensor, mask_tensor)
+ model = tf.keras.Model([input_tensor, mask_tensor], output)
+
+ input_data = 10 * np.random.random_sample((3, 4, 8))
+ mask_data = np.random.randint(2, size=(3, 4, 8))
+
+ output_data = model.predict([input_data, mask_data])
+ expected_zeros = np.greater(mask_data, 0)
+ is_zeros = np.greater(output_data, 0)
+ self.assertAllEqual(expected_zeros, is_zeros)
+
+ def test_masked_softmax_with_none_mask(self):
+ test_layer = masked_softmax.MaskedSoftmax()
+ input_tensor = tf.keras.Input(shape=(4, 8))
+ output = test_layer(input_tensor, None)
+ model = tf.keras.Model(input_tensor, output)
+
+ input_data = 10 * np.random.random_sample((3, 4, 8))
+ output_data = model.predict(input_data)
+ expected_data = tf.nn.softmax(input_data)
+ self.assertAllClose(expected_data, output_data)
+
+ def test_softmax_with_axes_expansion(self):
+ test_layer = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
+ input_tensor = tf.keras.Input(shape=(4, 8))
+ mask_tensor = tf.keras.Input(shape=(8))
+ output = test_layer(input_tensor, mask_tensor)
+ model = tf.keras.Model([input_tensor, mask_tensor], output)
+
+ input_data = 10 * np.random.random_sample((3, 4, 8))
+ mask_data = np.random.randint(2, size=(3, 8))
+
+ output_data = model.predict([input_data, mask_data])
+ expanded_mask = np.expand_dims(mask_data, axis=1) * np.ones_like(input_data)
+ expected_zeros = np.greater(expanded_mask, 0)
+ is_zeros = np.greater(output_data, 0)
+ self.assertAllEqual(expected_zeros, is_zeros)
+
+ def test_masked_softmax_high_dims(self):
+ test_layer = masked_softmax.MaskedSoftmax(
+ mask_expansion_axes=[1], normalization_axes=[6, 7])
+ input_shape = [2, 3, 4, 5, 6, 7, 8]
+ mask_shape = [5, 6, 7, 8]
+ input_tensor = tf.keras.Input(shape=input_shape)
+ mask_tensor = tf.keras.Input(shape=mask_shape)
+ output = test_layer(input_tensor, mask_tensor)
+ model = tf.keras.Model([input_tensor, mask_tensor], output)
+
+ input_data = 10 * np.random.random_sample([3] + input_shape)
+ mask_data = np.random.randint(2, size=[3] + mask_shape)
+
+ output_data = model.predict([input_data, mask_data])
+ expanded_mask = np.expand_dims(mask_data, axis=1)
+ expanded_mask = np.expand_dims(expanded_mask, axis=1)
+ expanded_mask = np.expand_dims(
+ expanded_mask, axis=1) * np.ones_like(input_data)
+ expected_zeros = np.greater(expanded_mask, 0)
+ is_zeros = np.greater(output_data, 0)
+ self.assertAllEqual(expected_zeros, is_zeros)
+
+ def test_serialize_deserialize(self):
+ test_layer = masked_softmax.MaskedSoftmax(
+ mask_expansion_axes=[1], normalization_axes=[6, 7])
+ new_layer = masked_softmax.MaskedSoftmax.from_config(
+ test_layer.get_config())
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/multi_channel_attention.py b/models/official/nlp/modeling/layers/multi_channel_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..499d977c753518f0892267ac98abc6bf7618c2cd
--- /dev/null
+++ b/models/official/nlp/modeling/layers/multi_channel_attention.py
@@ -0,0 +1,165 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Multi-channel Attention."""
+# pylint: disable=g-classes-have-attributes
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import math
+
+import tensorflow as tf
+from official.modeling import tf_utils
+from official.nlp.modeling.layers import attention
+from official.nlp.modeling.layers import dense_einsum
+from official.nlp.modeling.layers import masked_softmax
+
+
+class VotingAttention(tf.keras.layers.Layer):
+ """Voting Attention layer.
+
+ Arguments:
+ num_heads: the number of attention heads.
+ head_size: per-head hidden size.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ """
+
+ def __init__(self,
+ num_heads,
+ head_size,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(VotingAttention, self).__init__(**kwargs)
+ self._num_heads = num_heads
+ self._head_size = head_size
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
+ self._bias_constraint = tf.keras.constraints.get(bias_constraint)
+
+ def build(self, unused_input_shapes):
+ self._query_dense = dense_einsum.DenseEinsum(
+ output_shape=(self._num_heads, self._head_size),
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ dtype=self.dtype,
+ name="encdocatt_query")
+ self._key_dense = dense_einsum.DenseEinsum(
+ output_shape=(self._num_heads, self._head_size),
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ dtype=self.dtype,
+ name="encdocatt_key")
+ super(VotingAttention, self).build(unused_input_shapes)
+
+ def call(self, encoder_outputs, doc_attention_mask):
+ num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1]
+ cls_embeddings = encoder_outputs[:, :, 0, :]
+ key = self._key_dense(cls_embeddings)
+ query = self._query_dense(cls_embeddings)
+ doc_attention_mask = tf.cast(doc_attention_mask, tf.float32)
+
+ key = tf.einsum("BANH,BA->BANH", key, doc_attention_mask)
+ query = tf.einsum("BANH,BA->BANH", query, doc_attention_mask)
+ attention_matrix = tf.einsum("BXNH,BYNH->BNXY", query, key)
+ mask = tf.ones([num_docs, num_docs])
+ mask = tf.linalg.set_diag(mask, tf.zeros(num_docs))
+ attention_matrix = tf.einsum("BNXY,XY->BNXY", attention_matrix, mask)
+ doc_attention_probs = tf.einsum("BNAY->BNA", attention_matrix)
+ doc_attention_probs = tf.einsum("BNA->BA", doc_attention_probs)
+ infadder = (1.0 - doc_attention_mask) * -100000.0
+ return tf.nn.softmax(doc_attention_probs + infadder)
+
+
+class MultiChannelAttention(attention.MultiHeadAttention):
+ """Multi-channel Attention layer.
+
+ Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple
+ cross-attention target sequences.
+ """
+
+ def build(self, input_shape):
+ super(MultiChannelAttention, self).build(input_shape)
+ self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
+
+ def call(self, inputs, attention_mask=None):
+ from_tensor = inputs[0]
+ to_tensor = inputs[1]
+ doc_attention_probs = inputs[2]
+
+ # Scalar dimensions referenced here:
+ # B = batch size (number of stories)
+ # A = num_docs (number of docs)
+ # F = `from_tensor` sequence length
+ # T = `to_tensor` sequence length
+ # N = `num_attention_heads`
+ # H = `size_per_head`
+ # `query_tensor` = [B, F, N ,H]
+ query_tensor = self._query_dense(from_tensor)
+
+ # `key_tensor` = [B, A, T, N, H]
+ key_tensor = self._key_dense(to_tensor)
+
+ # `value_tensor` = [B, A, T, N, H]
+ value_tensor = self._value_dense(to_tensor)
+
+ # Take the dot product between "query" and "key" to get the raw
+ # attention scores.
+ attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor)
+ attention_scores = tf.multiply(attention_scores,
+ 1.0 / math.sqrt(float(self._key_size)))
+
+ # Normalize the attention scores to probabilities.
+ # `attention_probs` = [B, A, N, F, T]
+ attention_probs = self._masked_softmax(attention_scores, attention_mask)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self._dropout_layer(attention_probs)
+
+ # `context_layer` = [B, F, N, H]
+ context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
+ value_tensor)
+ attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs,
+ context_layer)
+ attention_output = self._output_dense(attention_output)
+ return attention_output
diff --git a/models/official/nlp/modeling/layers/multi_channel_attention_test.py b/models/official/nlp/modeling/layers/multi_channel_attention_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6e0e7fec48635d09e6e30c3ad247044ae9785f
--- /dev/null
+++ b/models/official/nlp/modeling/layers/multi_channel_attention_test.py
@@ -0,0 +1,56 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for nlp.nhnet.multi_channel_attention."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from official.nlp.modeling.layers import multi_channel_attention
+
+
+class MultiChannelAttentionTest(tf.test.TestCase):
+
+ def test_doc_attention(self):
+ num_heads = 2
+ doc_attention = multi_channel_attention.VotingAttention(
+ num_heads, head_size=8)
+ num_docs = 3
+ inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32)
+ doc_mask = np.zeros((2, num_docs), dtype=np.float32)
+ outputs = doc_attention(inputs, doc_mask)
+ self.assertEqual(outputs.shape, (2, num_docs))
+
+ def test_multi_channel_attention(self):
+ num_heads = 2
+ num_docs = 5
+ attention_layer = multi_channel_attention.MultiChannelAttention(
+ num_heads, key_size=2)
+
+ from_data = 10 * np.random.random_sample((3, 4, 8))
+ to_data = 10 * np.random.random_sample((3, num_docs, 2, 8))
+ mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
+ doc_probs = np.random.randint(
+ 2, size=(3, num_heads, 4, num_docs)).astype(float)
+ outputs = attention_layer([from_data, to_data, doc_probs], mask_data)
+ self.assertEqual(outputs.shape, (3, 4, 8))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/on_device_embedding.py b/models/official/nlp/modeling/layers/on_device_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..739cdb7e4dde157ef52d7a98769a4c40819634a7
--- /dev/null
+++ b/models/official/nlp/modeling/layers/on_device_embedding.py
@@ -0,0 +1,88 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based one-hot embedding layer."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+class OnDeviceEmbedding(tf.keras.layers.Layer):
+ """Performs an embedding lookup suitable for accelerator devices.
+
+ This layer uses either tf.gather or tf.one_hot to translate integer indices to
+ float embeddings.
+
+ Arguments:
+ vocab_size: Number of elements in the vocabulary.
+ embedding_width: Output size of the embedding layer.
+ initializer: The initializer to use for the embedding weights. Defaults to
+ "glorot_uniform".
+ use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
+ lookup. Defaults to False (that is, using tf.gather). Setting this option
+ to True may improve performance, especially on small vocabulary sizes, but
+ will generally require more memory.
+ """
+
+ def __init__(self,
+ vocab_size,
+ embedding_width,
+ initializer="glorot_uniform",
+ use_one_hot=False,
+ **kwargs):
+
+ super(OnDeviceEmbedding, self).__init__(**kwargs)
+ self._vocab_size = vocab_size
+ self._embedding_width = embedding_width
+ self._initializer = initializer
+ self._use_one_hot = use_one_hot
+
+ def get_config(self):
+ config = {
+ "vocab_size": self._vocab_size,
+ "embedding_width": self._embedding_width,
+ "initializer": self._initializer,
+ "use_one_hot": self._use_one_hot,
+ }
+ base_config = super(OnDeviceEmbedding, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def build(self, input_shape):
+ self.embeddings = self.add_weight(
+ "embeddings",
+ shape=[self._vocab_size, self._embedding_width],
+ initializer=self._initializer,
+ dtype=tf.float32)
+
+ super(OnDeviceEmbedding, self).build(input_shape)
+
+ def call(self, inputs):
+ flat_inputs = tf.reshape(inputs, [-1])
+ if self._use_one_hot:
+ one_hot_data = tf.one_hot(
+ flat_inputs, depth=self._vocab_size, dtype=self.embeddings.dtype)
+ embeddings = tf.matmul(one_hot_data, self.embeddings)
+ else:
+ embeddings = tf.gather(self.embeddings, flat_inputs)
+ embeddings = tf.reshape(
+ embeddings,
+ # Work around b/142213824: prefer concat to shape over a Python list.
+ tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
+ embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
+ return embeddings
diff --git a/models/official/nlp/modeling/layers/on_device_embedding_test.py b/models/official/nlp/modeling/layers/on_device_embedding_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b9b98f181470ea233d8297550a2dd92786baae
--- /dev/null
+++ b/models/official/nlp/modeling/layers/on_device_embedding_test.py
@@ -0,0 +1,198 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras-based one-hot embedding layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import on_device_embedding
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
+
+ def test_layer_creation(self):
+ vocab_size = 31
+ embedding_width = 27
+ test_layer = on_device_embedding.OnDeviceEmbedding(
+ vocab_size=vocab_size, embedding_width=embedding_width)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ sequence_length = 23
+ input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
+ output_tensor = test_layer(input_tensor)
+
+ # The output should be the same as the input, save that it has an extra
+ # embedding_width dimension on the end.
+ expected_output_shape = [None, sequence_length, embedding_width]
+ self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
+ self.assertEqual(output_tensor.dtype, tf.float32)
+
+ def test_layer_creation_with_mixed_precision(self):
+ vocab_size = 31
+ embedding_width = 27
+ policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
+ test_layer = on_device_embedding.OnDeviceEmbedding(
+ vocab_size=vocab_size, embedding_width=embedding_width, dtype=policy)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ sequence_length = 23
+ input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
+ output_tensor = test_layer(input_tensor)
+
+ # The output should be the same as the input, save that it has an extra
+ # embedding_width dimension on the end.
+ expected_output_shape = [None, sequence_length, embedding_width]
+ self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
+ self.assertEqual(output_tensor.dtype, tf.float16)
+
+ def test_layer_invocation(self):
+ vocab_size = 31
+ embedding_width = 27
+ test_layer = on_device_embedding.OnDeviceEmbedding(
+ vocab_size=vocab_size, embedding_width=embedding_width)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ sequence_length = 23
+ input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
+ output_tensor = test_layer(input_tensor)
+
+ # Create a model from the test layer.
+ model = tf.keras.Model(input_tensor, output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 3
+ input_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ output = model.predict(input_data)
+ self.assertEqual(tf.float32, output.dtype)
+
+ def test_layer_invocation_with_mixed_precision(self):
+ vocab_size = 31
+ embedding_width = 27
+ policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
+ test_layer = on_device_embedding.OnDeviceEmbedding(
+ vocab_size=vocab_size, embedding_width=embedding_width,
+ dtype=policy)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ sequence_length = 23
+ input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
+ output_tensor = test_layer(input_tensor)
+
+ # Create a model from the test layer.
+ model = tf.keras.Model(input_tensor, output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 3
+ input_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ output = model.predict(input_data)
+ self.assertEqual(tf.float16, output.dtype)
+
+ def test_one_hot_layer_creation(self):
+ vocab_size = 31
+ embedding_width = 27
+ test_layer = on_device_embedding.OnDeviceEmbedding(
+ vocab_size=vocab_size,
+ embedding_width=embedding_width,
+ use_one_hot=True)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ sequence_length = 23
+ input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
+ output_tensor = test_layer(input_tensor)
+
+ # The output should be the same as the input, save that it has an extra
+ # embedding_width dimension on the end.
+ expected_output_shape = [None, sequence_length, embedding_width]
+ self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
+ self.assertEqual(output_tensor.dtype, tf.float32)
+
+ def test_one_hot_layer_creation_with_mixed_precision(self):
+ vocab_size = 31
+ embedding_width = 27
+ policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
+ test_layer = on_device_embedding.OnDeviceEmbedding(
+ vocab_size=vocab_size,
+ embedding_width=embedding_width,
+ dtype=policy,
+ use_one_hot=True)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ sequence_length = 23
+ input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
+ output_tensor = test_layer(input_tensor)
+
+ # The output should be the same as the input, save that it has an extra
+ # embedding_width dimension on the end.
+ expected_output_shape = [None, sequence_length, embedding_width]
+ self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
+ self.assertEqual(output_tensor.dtype, tf.float16)
+
+ def test_one_hot_layer_invocation(self):
+ vocab_size = 31
+ embedding_width = 27
+ test_layer = on_device_embedding.OnDeviceEmbedding(
+ vocab_size=vocab_size,
+ embedding_width=embedding_width,
+ use_one_hot=True)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ sequence_length = 23
+ input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
+ output_tensor = test_layer(input_tensor)
+
+ # Create a model from the test layer.
+ model = tf.keras.Model(input_tensor, output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 3
+ input_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ output = model.predict(input_data)
+ self.assertEqual(tf.float32, output.dtype)
+
+ def test_one_hot_layer_invocation_with_mixed_precision(self):
+ vocab_size = 31
+ embedding_width = 27
+ policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
+ test_layer = on_device_embedding.OnDeviceEmbedding(
+ vocab_size=vocab_size,
+ embedding_width=embedding_width,
+ dtype=policy,
+ use_one_hot=True)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ sequence_length = 23
+ input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
+ output_tensor = test_layer(input_tensor)
+
+ # Create a model from the test layer.
+ model = tf.keras.Model(input_tensor, output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 3
+ input_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ output = model.predict(input_data)
+ self.assertEqual(tf.float16, output.dtype)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/position_embedding.py b/models/official/nlp/modeling/layers/position_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..169e54de112d9a3ce65e9fa68f066a107d35c7a4
--- /dev/null
+++ b/models/official/nlp/modeling/layers/position_embedding.py
@@ -0,0 +1,205 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based positional embedding layer."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import math
+
+import tensorflow as tf
+
+from official.modeling import tf_utils
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+class PositionEmbedding(tf.keras.layers.Layer):
+ """Creates a positional embedding.
+
+ This layer creates a positional embedding as described in "BERT: Pre-training
+ of Deep Bidirectional Transformers for Language Understanding"
+ (https://arxiv.org/abs/1810.04805).
+
+ This layer can be set up to either create a statically shaped slice or a
+ dynamically shaped slice. If `use_dynamic_slicing` is True, the input tensor
+ can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the
+ input size must be fixed.
+
+ Arguments:
+ use_dynamic_slicing: Whether to use the dynamic slicing path.
+ max_sequence_length: The maximum size of the dynamic sequence. Only
+ applicable if `use_dynamic_slicing` is True.
+ initializer: The initializer to use for the embedding weights. Defaults to
+ "glorot_uniform".
+ """
+
+ def __init__(self,
+ initializer="glorot_uniform",
+ use_dynamic_slicing=False,
+ max_sequence_length=None,
+ **kwargs):
+ # We need to have a default dtype of float32, since the inputs (which Keras
+ # usually uses to infer the dtype) will always be int32.
+ if "dtype" not in kwargs:
+ kwargs["dtype"] = "float32"
+
+ super(PositionEmbedding, self).__init__(**kwargs)
+ if use_dynamic_slicing and max_sequence_length is None:
+ raise ValueError(
+ "If `use_dynamic_slicing` is True, `max_sequence_length` must be set."
+ )
+ self._max_sequence_length = max_sequence_length
+ self._initializer = tf.keras.initializers.get(initializer)
+ self._use_dynamic_slicing = use_dynamic_slicing
+
+ def get_config(self):
+ config = {
+ "max_sequence_length": self._max_sequence_length,
+ "initializer": tf.keras.initializers.serialize(self._initializer),
+ "use_dynamic_slicing": self._use_dynamic_slicing,
+ }
+ base_config = super(PositionEmbedding, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def build(self, input_shape):
+ """Implements build() for the layer."""
+ dimension_list = input_shape.as_list()
+
+ if len(dimension_list) != 3:
+ raise ValueError("PositionEmbedding expects a 3-dimensional input tensor "
+ "of shape [batch, sequence, width]")
+ seq_length = dimension_list[1]
+ width = dimension_list[2]
+
+ # If we are not using dynamic slicing, we must assume that the sequence
+ # length is fixed and max_sequence_length should not be specified.
+ if not self._use_dynamic_slicing:
+ if seq_length is None:
+ raise ValueError(
+ "PositionEmbedding must have `use_dynamic_slicing` set "
+ "to True (and max_sequence_length set) when the "
+ "sequence (1st) dimension of the input is None.")
+ if self._max_sequence_length is not None:
+ raise ValueError(
+ "When `use_dynamic_slicing` is False, max_sequence_length should "
+ "not be specified and we ought to use seq_length to get the "
+ "variable shape.")
+
+ if self._max_sequence_length is not None:
+ weight_sequence_length = self._max_sequence_length
+ else:
+ weight_sequence_length = seq_length
+
+ self._position_embeddings = self.add_weight(
+ "embeddings",
+ shape=[weight_sequence_length, width],
+ initializer=self._initializer)
+
+ super(PositionEmbedding, self).build(input_shape)
+
+ def call(self, inputs):
+ """Implements call() for the layer."""
+ input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
+ if self._use_dynamic_slicing:
+ position_embeddings = self._position_embeddings[:input_shape[1], :]
+ else:
+ position_embeddings = self._position_embeddings
+
+ return tf.broadcast_to(position_embeddings, input_shape)
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+class RelativePositionEmbedding(tf.keras.layers.Layer):
+ """Creates a positional embedding.
+
+ This layer calculates the position encoding as a mix of sine and cosine
+ functions with geometrically increasing wavelengths. Defined and formulized in
+ "Attention is All You Need", section 3.5.
+ (https://arxiv.org/abs/1706.03762).
+
+ Arguments:
+ hidden_size: Size of the hidden layer.
+ min_timescale: Minimum scale that will be applied at each position
+ max_timescale: Maximum scale that will be applied at each position.
+ """
+
+ def __init__(self,
+ hidden_size,
+ min_timescale=1.0,
+ max_timescale=1.0e4,
+ **kwargs):
+ # We need to have a default dtype of float32, since the inputs (which Keras
+ # usually uses to infer the dtype) will always be int32.
+ # We compute the positional encoding in float32 even if the model uses
+ # float16, as many of the ops used, like log and exp, are numerically
+ # unstable in float16.
+ if "dtype" not in kwargs:
+ kwargs["dtype"] = "float32"
+
+ super(RelativePositionEmbedding, self).__init__(**kwargs)
+ self._hidden_size = hidden_size
+ self._min_timescale = min_timescale
+ self._max_timescale = max_timescale
+
+ def get_config(self):
+ config = {
+ "hidden_size": self._hidden_size,
+ "min_timescale": self._min_timescale,
+ "max_timescale": self._max_timescale,
+ "length": self._length,
+ }
+ base_config = super(RelativePositionEmbedding, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def call(self, inputs, length=None):
+ """Implements call() for the layer.
+
+ Args:
+ inputs: An tensor whose second dimension will be used as `length`. If
+ `None`, the other `length` argument must be specified.
+ length: An optional integer specifying the number of positions. If both
+ `inputs` and `length` are spcified, `length` must be equal to the
+ second dimension of `inputs`.
+
+ Returns:
+ A tensor in shape of [length, hidden_size].
+ """
+ if inputs is None and length is None:
+ raise ValueError(
+ "If inputs is None, `length` must be set in "
+ "RelativePositionEmbedding().")
+ if inputs is not None:
+ input_shape = tf_utils.get_shape_list(inputs)
+ if length is not None and length != input_shape[1]:
+ raise ValueError(
+ "If inputs is not None, `length` must equal to input_shape[1]."
+ )
+ length = input_shape[1]
+ position = tf.cast(tf.range(length), tf.float32)
+ num_timescales = self._hidden_size // 2
+ min_timescale, max_timescale = self._min_timescale, self._max_timescale
+ log_timescale_increment = (
+ math.log(float(max_timescale) / float(min_timescale)) /
+ (tf.cast(num_timescales, tf.float32) - 1))
+ inv_timescales = min_timescale * tf.exp(
+ tf.cast(tf.range(num_timescales), tf.float32) *
+ -log_timescale_increment)
+ scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales,
+ 0)
+ position_embeddings = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)],
+ axis=1)
+ return position_embeddings
diff --git a/models/official/nlp/modeling/layers/position_embedding_test.py b/models/official/nlp/modeling/layers/position_embedding_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..89a29af7a4e8dc7c0369131f635d1c6abe74fdbc
--- /dev/null
+++ b/models/official/nlp/modeling/layers/position_embedding_test.py
@@ -0,0 +1,131 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras-based positional embedding layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import position_embedding
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
+
+ def test_static_layer_output_shape(self):
+ test_layer = position_embedding.PositionEmbedding()
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_length = 21
+ width = 30
+ input_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(input_tensor)
+
+ # When using static positional embedding shapes, the output is expected
+ # to be the same as the input shape in all dimensions save batch.
+ expected_output_shape = [None, sequence_length, width]
+ self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
+ # The default output dtype for this layer should be tf.float32.
+ self.assertEqual(tf.float32, output_tensor.dtype)
+
+ def test_float16_dtype(self):
+ test_layer = position_embedding.PositionEmbedding(dtype="float16")
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_length = 21
+ width = 30
+ input_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(input_tensor)
+
+ # When using static positional embedding shapes, the output is expected
+ # to be the same as the input shape in all dimensions save batch.
+ expected_output_shape = [None, sequence_length, width]
+ self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
+ # The default output dtype for this layer should be tf.float32.
+ self.assertEqual(tf.float16, output_tensor.dtype)
+
+ def test_dynamic_layer_output_shape(self):
+ max_sequence_length = 40
+ test_layer = position_embedding.PositionEmbedding(
+ use_dynamic_slicing=True, max_sequence_length=max_sequence_length)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ width = 30
+ input_tensor = tf.keras.Input(shape=(None, width))
+ output_tensor = test_layer(input_tensor)
+
+ # When using dynamic positional embedding shapes, the output is expected
+ # to be the same as the input shape in all dimensions - but may be None if
+ # the input shape is None there.
+ expected_output_shape = [None, None, width]
+ self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
+
+ def test_dynamic_layer_slicing(self):
+ max_sequence_length = 40
+ test_layer = position_embedding.PositionEmbedding(
+ use_dynamic_slicing=True, max_sequence_length=max_sequence_length)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ width = 30
+ input_tensor = tf.keras.Input(shape=(None, width))
+ output_tensor = test_layer(input_tensor)
+
+ model = tf.keras.Model(input_tensor, output_tensor)
+
+ # Create input data that is shorter than max_sequence_length, which should
+ # trigger a down-slice.
+ input_length = 17
+ # Note: This test explicitly uses a batch size of 1. This is to get around
+ # Keras' restriction on Model invocations: inputs are expected to have the
+ # same batch cardinality as outputs. In practice, this layer should be used
+ # inside a model, where it can be projected when added to another tensor.
+ input_data = np.ones((1, input_length, width))
+ output_data = model.predict(input_data)
+
+ self.assertAllEqual([1, input_length, width], output_data.shape)
+
+ def test_relative_tensor_input(self):
+ hidden_size = 8
+ test_layer = position_embedding.RelativePositionEmbedding(
+ hidden_size=hidden_size)
+
+ # create a 3-dimensional input for test_layer to infer length as 1.
+ input_tensor = tf.constant([[[0] * hidden_size]])
+ output_tensor = test_layer(input_tensor)
+
+ # expected output is the theoretical result of the input based on
+ # sine cosine relative position embedding formula.
+ expected_output_tensor = tf.constant([[0, 0, 0, 0, 1, 1, 1, 1]])
+ self.assertAllEqual(output_tensor, expected_output_tensor)
+
+ def test_relative_length_input(self):
+ hidden_size = 8
+
+ # When we do not have tensor as input, we explicitly specify length
+ # value when initializing test_layer.
+ test_layer = position_embedding.RelativePositionEmbedding(
+ hidden_size=hidden_size)
+ input_tensor = None
+ output_tensor = test_layer(input_tensor, length=1)
+
+ # expected output is the theoretical result of the input based on
+ # sine cosine relative position embedding formula.
+ expected_output_tensor = tf.constant([[0, 0, 0, 0, 1, 1, 1, 1]])
+ self.assertAllEqual(output_tensor, expected_output_tensor)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/rezero_transformer.py b/models/official/nlp/modeling/layers/rezero_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..42bc1af0031db97a12a730b8e1abe98f3c9318e0
--- /dev/null
+++ b/models/official/nlp/modeling/layers/rezero_transformer.py
@@ -0,0 +1,247 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based rezero-transformer block layer (Transformer with ReZero)."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import gin
+import tensorflow as tf
+
+from official.nlp.modeling.layers import attention
+from official.nlp.modeling.layers import dense_einsum
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+@gin.configurable
+class ReZeroTransformer(tf.keras.layers.Layer):
+ """Transformer layer with ReZero.
+
+ This layer implements the Transformer from "Attention Is All You Need".
+ (https://arxiv.org/abs/1706.03762).
+ The residual connection implements the ReZero method.
+ (https://arxiv.org/abs/2003.04887)
+
+ Arguments:
+ num_attention_heads: Number of attention heads.
+ intermediate_size: Size of the intermediate layer.
+ intermediate_activation: Activation for the intermediate layer.
+ dropout_rate: Dropout probability for the post-attention and output dropout.
+ attention_dropout_rate: Dropout probability for within the attention layer.
+ output_range: the sequence output range, [0, output_range) by slicing the
+ target sequence. `None` means the target sequence is not sliced.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ use_layer_norm: If add layer_norm on top of the ReZero.
+ """
+
+ def __init__(self,
+ num_attention_heads,
+ intermediate_size,
+ intermediate_activation,
+ dropout_rate=0.0,
+ attention_dropout_rate=0.0,
+ output_range=None,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ use_layer_norm=False,
+ **kwargs):
+ super(ReZeroTransformer, self).__init__(**kwargs)
+
+ self._num_heads = num_attention_heads
+ self._intermediate_size = intermediate_size
+ self._intermediate_activation = intermediate_activation
+ self._attention_dropout_rate = attention_dropout_rate
+ self._dropout_rate = dropout_rate
+ self._output_range = output_range
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
+ self._bias_constraint = tf.keras.constraints.get(bias_constraint)
+ self._use_layer_norm = use_layer_norm
+
+ def build(self, input_shape):
+ input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
+ input_tensor_shape = tf.TensorShape(input_tensor)
+ if len(input_tensor_shape) != 3:
+ raise ValueError("TransformerLayer expects a three-dimensional input of "
+ "shape [batch, sequence, width].")
+ batch_size, sequence_length, hidden_size = input_tensor_shape
+
+ if len(input_shape) == 2:
+ mask_tensor_shape = tf.TensorShape(input_shape[1])
+ expected_mask_tensor_shape = tf.TensorShape(
+ [batch_size, sequence_length, sequence_length])
+ if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
+ raise ValueError("When passing a mask tensor to TransformerLayer, the "
+ "mask tensor must be of shape [batch, "
+ "sequence_length, sequence_length] (here %s). Got a "
+ "mask tensor of shape %s." %
+ (expected_mask_tensor_shape, mask_tensor_shape))
+ if hidden_size % self._num_heads != 0:
+ raise ValueError(
+ "The input size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (hidden_size, self._num_heads))
+ self._attention_head_size = int(hidden_size // self._num_heads)
+
+ self._attention_layer = attention.MultiHeadAttention(
+ num_heads=self._num_heads,
+ key_size=self._attention_head_size,
+ dropout=self._attention_dropout_rate,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="self_attention")
+ self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
+ if self._use_layer_norm:
+ # Use float32 in layernorm for numeric stability.
+ # It is probably safe in mixed_float16, but we haven't validated this yet.
+ self._attention_layer_norm = (
+ tf.keras.layers.LayerNormalization(
+ name="self_attention_layer_norm",
+ axis=-1,
+ epsilon=1e-12,
+ dtype=tf.float32))
+ self._intermediate_dense = dense_einsum.DenseEinsum(
+ output_shape=self._intermediate_size,
+ activation=None,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="intermediate")
+ policy = tf.keras.mixed_precision.experimental.global_policy()
+ if policy.name == "mixed_bfloat16":
+ # bfloat16 causes BERT with the LAMB optimizer to not converge
+ # as well, so we use float32.
+ # TODO(b/154538392): Investigate this.
+ policy = tf.float32
+ self._intermediate_activation_layer = tf.keras.layers.Activation(
+ self._intermediate_activation, dtype=policy)
+ self._output_dense = dense_einsum.DenseEinsum(
+ output_shape=hidden_size,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="output")
+ self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
+ if self._use_layer_norm:
+ # Use float32 in layernorm for numeric stability.
+ self._output_layer_norm = tf.keras.layers.LayerNormalization(
+ name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
+
+ self._rezero_a = self.add_weight(
+ name="rezero_alpha",
+ initializer=tf.keras.initializers.Zeros(),
+ trainable=True, dtype=tf.float32)
+
+ super(ReZeroTransformer, self).build(input_shape)
+
+ def get_config(self):
+ config = {
+ "num_attention_heads":
+ self._num_heads,
+ "intermediate_size":
+ self._intermediate_size,
+ "intermediate_activation":
+ self._intermediate_activation,
+ "dropout_rate":
+ self._dropout_rate,
+ "attention_dropout_rate":
+ self._attention_dropout_rate,
+ "output_range":
+ self._output_range,
+ "use_layer_norm":
+ self._use_layer_norm,
+ "kernel_initializer":
+ tf.keras.initializers.serialize(self._kernel_initializer),
+ "bias_initializer":
+ tf.keras.initializers.serialize(self._bias_initializer),
+ "kernel_regularizer":
+ tf.keras.regularizers.serialize(self._kernel_regularizer),
+ "bias_regularizer":
+ tf.keras.regularizers.serialize(self._bias_regularizer),
+ "activity_regularizer":
+ tf.keras.regularizers.serialize(self._activity_regularizer),
+ "kernel_constraint":
+ tf.keras.constraints.serialize(self._kernel_constraint),
+ "bias_constraint":
+ tf.keras.constraints.serialize(self._bias_constraint),
+ }
+ base_config = super(ReZeroTransformer, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def reset_rezero(self):
+ self._rezero_a.assign(0.)
+
+ def call(self, inputs):
+ if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
+ input_tensor, attention_mask = inputs
+ else:
+ input_tensor, attention_mask = (inputs, None)
+
+ if self._output_range:
+ target_tensor = input_tensor[:, 0:self._output_range, :]
+ attention_mask = attention_mask[:, 0:self._output_range, :]
+ else:
+ target_tensor = input_tensor
+ attention_inputs = [target_tensor, input_tensor]
+
+ attention_output = self._attention_layer(attention_inputs, attention_mask)
+ attention_output = self._attention_dropout(attention_output)
+ attention_output = target_tensor + self._rezero_a * attention_output
+ if self._use_layer_norm:
+ attention_output = self._attention_layer_norm(attention_output)
+ else:
+ attention_output = tf.cast(attention_output, tf.float32)
+
+ intermediate_output = self._intermediate_dense(attention_output)
+ intermediate_output = self._intermediate_activation_layer(
+ intermediate_output)
+ layer_output = self._output_dense(intermediate_output)
+ layer_output = self._output_dropout(layer_output)
+ # During mixed precision training, attention_output is from layer norm and
+ # is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
+ layer_output = attention_output + tf.cast(self._rezero_a * layer_output,
+ tf.float32)
+ if self._use_layer_norm:
+ layer_output = self._output_layer_norm(layer_output)
+
+ return layer_output
diff --git a/models/official/nlp/modeling/layers/rezero_transformer_test.py b/models/official/nlp/modeling/layers/rezero_transformer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ef0aa218c70c919f62492b00ef5f53348dd5938
--- /dev/null
+++ b/models/official/nlp/modeling/layers/rezero_transformer_test.py
@@ -0,0 +1,133 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras-based rezero-transformer block layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import rezero_transformer
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
+
+ def tearDown(self):
+ super(TransformerWithReZeroLayerTest, self).tearDown()
+ tf.keras.mixed_precision.experimental.set_policy('float32')
+
+ def test_layer_invocation_with_float16_dtype(self):
+ tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
+ test_layer = rezero_transformer.ReZeroTransformer(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+
+ # Create a model from the test layer.
+ model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = (10 * np.random.random_sample(
+ (batch_size, sequence_length, width)))
+ # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
+ # which here is (batch, sequence_length, sequence_length)
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ _ = model.predict([input_data, mask_data])
+
+ def test_rezero_without_layer_norm(self):
+ test_layer = rezero_transformer.ReZeroTransformer(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu',
+ use_layer_norm=False)
+
+ input_length, width = 16, 30
+ input_tensor = tf.keras.Input(shape=(input_length, width))
+ output_tensor = test_layer(input_tensor)
+ model = tf.keras.Model(input_tensor, output_tensor)
+
+ input_data = np.random.rand(2, input_length, width)
+ test_layer._rezero_a.assign(1.0)
+ test_layer.reset_rezero()
+ output_data = model.predict(input_data)
+
+ self.assertAllClose(input_data, output_data)
+
+ def test_rezero_with_layer_norm(self):
+ test_layer = rezero_transformer.ReZeroTransformer(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu',
+ use_layer_norm=True)
+
+ input_length, width = 16, 30
+ input_tensor = tf.keras.Input(shape=(input_length, width))
+ output_tensor = test_layer(input_tensor)
+ model = tf.keras.Model(input_tensor, output_tensor)
+
+ input_data = np.random.rand(2, input_length, width) + 2.0
+ output_data = model.predict(input_data)
+ input_data_normed = (
+ input_data - np.mean(input_data, axis=-1, keepdims=True)) / (
+ np.std(input_data, axis=-1, keepdims=True))
+
+ self.assertAllClose(input_data_normed, output_data)
+
+ def test_layer_output_range(self):
+ test_layer = rezero_transformer.ReZeroTransformer(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ output_tensor = test_layer([input_data, mask_data])
+
+ # The layer only attends to the first token and outputs the first token
+ # embeeding.
+ new_layer = rezero_transformer.ReZeroTransformer(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu',
+ output_range=1)
+ _ = new_layer([input_data, mask_data])
+ new_layer.set_weights(test_layer.get_weights())
+ new_output_tensor = new_layer([input_data, mask_data])
+ self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/self_attention_mask.py b/models/official/nlp/modeling/layers/self_attention_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..933b4960dc0a86d4a0a767f8017853c2f2290d16
--- /dev/null
+++ b/models/official/nlp/modeling/layers/self_attention_mask.py
@@ -0,0 +1,63 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras layer that creates a self-attention mask."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+from official.modeling import tf_utils
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class SelfAttentionMask(tf.keras.layers.Layer):
+ """Create 3D attention mask from a 2D tensor mask.
+
+ inputs[0]: from_tensor: 2D or 3D Tensor of shape
+ [batch_size, from_seq_length, ...].
+ inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
+
+ Returns:
+ float Tensor of shape [batch_size, from_seq_length, to_seq_length].
+ """
+
+ def call(self, inputs):
+ from_tensor = inputs[0]
+ to_mask = inputs[1]
+ from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
+ batch_size = from_shape[0]
+ from_seq_length = from_shape[1]
+
+ to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
+ to_seq_length = to_shape[1]
+
+ to_mask = tf.cast(
+ tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
+ dtype=from_tensor.dtype)
+
+ # We don't assume that `from_tensor` is a mask (although it could be). We
+ # don't actually care if we attend *from* padding tokens (only *to* padding)
+ # tokens so we create a tensor of all ones.
+ #
+ # `broadcast_ones` = [batch_size, from_seq_length, 1]
+ broadcast_ones = tf.ones(
+ shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)
+
+ # Here we broadcast along two dimensions to create the mask.
+ mask = broadcast_ones * to_mask
+
+ return mask
diff --git a/models/official/nlp/modeling/layers/talking_heads_attention.py b/models/official/nlp/modeling/layers/talking_heads_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c65ba1e66165617aaf5652c2f77015e9a3eb7ccb
--- /dev/null
+++ b/models/official/nlp/modeling/layers/talking_heads_attention.py
@@ -0,0 +1,153 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Talking Head Attention layer."""
+# pylint: disable=g-classes-have-attributes
+import math
+import string
+
+import gin
+import tensorflow as tf
+
+from official.nlp.modeling.layers import attention
+
+_CHR_IDX = string.ascii_lowercase
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+@gin.configurable
+class TalkingHeadsAttention(attention.MultiHeadAttention):
+ """Implements Talking-Heads Attention.
+
+ This is an implementation of Talking-Heads Attention based on the paper
+ Talking-Heads Attention (https://arxiv.org/abs/2003.02436): it enhanced
+ multi-head attention by including linearprojections across the attention-heads
+ dimension, immediately before and after the softmax operation.
+
+ See the base class `MultiHeadAttention` for more details.
+
+ Arguments:
+ num_heads: Number of attention heads.
+ key_size: Size of each attention head for query and key.
+ value_size: Size of each attention head for value.
+ dropout: Dropout probability.
+ use_bias: Boolean, whether the dense layers use bias vectors/matrices.
+ output_shape: The expected shape of an output tensor, besides the batch and
+ sequence dims. If not specified, projects back to the key feature dim.
+ attention_axes: axes over which the attention is applied. `None` means
+ attention over all axes, but batch, heads, and features.
+ return_attention_scores: bool, if `True`, returns the multi-head attention
+ scores as an additional output argument.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ """
+
+ def _build_attention(self, qkv_rank):
+ """Builds multi-head dot-product attention computations.
+
+ This function overrides base class to create additional linear projection
+ that will be applied on attention scores before and after softmax.
+
+ Args:
+ qkv_rank: the rank of query, key, value tensors after projection.
+ """
+ super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
+
+ # Build an equation:
+ # (, num_heads_a, ...),(num_heads_a, num_heads_b) ->
+ # (, num_heads_b, ...)
+ # qkv_ranks has `batch_dims`, `attention_dims`, `num_heads` and `channels`.
+ num_batch_dims = qkv_rank - len(self._attention_axes) - 2
+
+ # The shape of attn_scores is:
+ # (, num_heads, , )
+ attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
+ scores_notation = _CHR_IDX[:attn_scores_rank]
+ projection_notation = scores_notation[num_batch_dims] + (
+ _CHR_IDX[attn_scores_rank])
+ projected_scores_notation = scores_notation[:num_batch_dims] + (
+ _CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
+ self._talking_heads_equation = "%s,%s->%s" % (
+ scores_notation, projection_notation, projected_scores_notation)
+
+ self._pre_softmax_weight = self.add_weight(
+ "pre_softmax_weight",
+ shape=(self._num_heads, self._num_heads),
+ initializer=self._kernel_initializer,
+ regularizer=self._kernel_regularizer,
+ constraint=self._kernel_constraint,
+ dtype=self.dtype,
+ trainable=True)
+ self._post_softmax_weight = self.add_weight(
+ "post_softmax_weight",
+ shape=(self._num_heads, self._num_heads),
+ initializer=self._kernel_initializer,
+ regularizer=self._kernel_regularizer,
+ constraint=self._kernel_constraint,
+ dtype=self.dtype,
+ trainable=True)
+
+ def _compute_attention(self,
+ query_tensor,
+ key_tensor,
+ value_tensor,
+ attention_mask=None):
+ """Applies Dot-product attention with query, key, value tensors.
+
+ This function overrides base class to apply additional linear projection
+ on attention scores before and after softmax.
+
+ Args:
+ query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
+ key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
+ value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
+ attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
+ attention to certain positions.
+
+ Returns:
+ attention_output: Multi-headed outputs of attention computation.
+ attention_scores: Multi-headed attention weights.
+ """
+ # Take the dot product between "query" and "key" to get the raw
+ # attention scores.
+ attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
+ query_tensor)
+ attention_scores = tf.multiply(attention_scores,
+ 1.0 / math.sqrt(float(self._key_size)))
+
+ # Apply linear projection before softmax
+ attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
+ self._pre_softmax_weight)
+
+ # Normalize the attention scores to probabilities.
+ # `attention_scores` = [B, N, T, S]
+ attention_scores = self._masked_softmax(attention_scores, attention_mask)
+
+ # Apply linear projection after softmax
+ attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
+ self._post_softmax_weight)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_scores_dropout = self._dropout_layer(attention_scores)
+
+ # `context_layer` = [B, T, N, H]
+ attention_output = tf.einsum(self._combine_equation,
+ attention_scores_dropout, value_tensor)
+ return attention_output, attention_scores
diff --git a/models/official/nlp/modeling/layers/talking_heads_attention_test.py b/models/official/nlp/modeling/layers/talking_heads_attention_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed24eda26c6f532b5e5011f5bfc8109eeca68a03
--- /dev/null
+++ b/models/official/nlp/modeling/layers/talking_heads_attention_test.py
@@ -0,0 +1,163 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the attention layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import talking_heads_attention
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+# This test is revised base on attention.MultiHeadAttentionTest.
+@keras_parameterized.run_all_keras_modes
+class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("key_value_same_proj", None, None, [40, 80]),
+ ("key_value_different_proj", 32, 60, [40, 60]),
+ )
+ def test_non_masked_attention(self, value_size, output_shape, output_dims):
+ """Test that the attention layer can be created without a mask tensor."""
+ test_layer = talking_heads_attention.TalkingHeadsAttention(
+ num_heads=12,
+ key_size=64,
+ value_size=value_size,
+ output_shape=output_shape)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ value = tf.keras.Input(shape=(20, 80))
+ output = test_layer([query, value])
+ self.assertEqual(output.shape.as_list(), [None] + output_dims)
+
+ def test_non_masked_self_attention(self):
+ """Test with one input (self-attenntion) and no mask tensor."""
+ test_layer = talking_heads_attention.TalkingHeadsAttention(
+ num_heads=12, key_size=64)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ output = test_layer([query, query])
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
+
+ def test_attention_scores(self):
+ """Test attention outputs with coefficients."""
+ test_layer = talking_heads_attention.TalkingHeadsAttention(
+ num_heads=12, key_size=64, return_attention_scores=True)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ output, coef = test_layer([query, query])
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
+ self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
+
+ @parameterized.named_parameters(("with_bias", True), ("no_bias", False))
+ def test_masked_attention(self, use_bias):
+ """Test with a mask tensor."""
+ test_layer = talking_heads_attention.TalkingHeadsAttention(
+ num_heads=12, key_size=2, use_bias=use_bias)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ batch_size = 3
+ query = tf.keras.Input(shape=(4, 8))
+ value = tf.keras.Input(shape=(2, 8))
+ mask_tensor = tf.keras.Input(shape=(4, 2))
+ output = test_layer([query, value], mask_tensor)
+
+ # Create a model containing the test layer.
+ model = tf.keras.Model([query, value, mask_tensor], output)
+
+ # Generate data for the input (non-mask) tensors.
+ from_data = 10 * np.random.random_sample((batch_size, 4, 8))
+ to_data = 10 * np.random.random_sample((batch_size, 2, 8))
+
+ # Invoke the data with a random set of mask data. This should mask at least
+ # one element.
+ mask_data = np.random.randint(2, size=(batch_size, 4, 2))
+ masked_output_data = model.predict([from_data, to_data, mask_data])
+
+ # Invoke the same data, but with a null mask (where no elements are masked).
+ null_mask_data = np.ones((batch_size, 4, 2))
+ unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
+
+ # Because one data is masked and one is not, the outputs should not be the
+ # same.
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
+
+ # Tests the layer with three inputs: Q, K, V.
+ key = tf.keras.Input(shape=(2, 8))
+ output = test_layer([query, value, key], mask_tensor)
+ model = tf.keras.Model([query, value, key, mask_tensor], output)
+
+ masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
+ unmasked_output_data = model.predict(
+ [from_data, to_data, to_data, null_mask_data])
+ # Because one data is masked and one is not, the outputs should not be the
+ # same.
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
+
+ if use_bias:
+ self.assertLen(test_layer._query_dense.trainable_variables, 2)
+ self.assertLen(test_layer._output_dense.trainable_variables, 2)
+ else:
+ self.assertLen(test_layer._query_dense.trainable_variables, 1)
+ self.assertLen(test_layer._output_dense.trainable_variables, 1)
+
+ def test_initializer(self):
+ """Test with a specified initializer."""
+ test_layer = talking_heads_attention.TalkingHeadsAttention(
+ num_heads=12,
+ key_size=64,
+ kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
+ # Create a 3-dimensional input (the first dimension is implicit).
+ query = tf.keras.Input(shape=(40, 80))
+ output = test_layer([query, query])
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
+
+ @parameterized.named_parameters(
+ ("4d_inputs_one_free_batch", [3, 4], [3, 2], [4, 2], (2,)),
+ ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
+ ("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)))
+ def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
+ """Test with a mask tensor."""
+ test_layer = talking_heads_attention.TalkingHeadsAttention(
+ num_heads=12, key_size=2, attention_axes=attention_axes)
+ batch_size, hidden_size = 3, 8
+ # Generate data for the input (non-mask) tensors.
+ query_shape = [batch_size] + q_dims + [hidden_size]
+ value_shape = [batch_size] + v_dims + [hidden_size]
+ mask_shape = [batch_size] + mask_dims
+ query = 10 * np.random.random_sample(query_shape)
+ value = 10 * np.random.random_sample(value_shape)
+
+ # Invoke the data with a random set of mask data. This should mask at least
+ # one element.
+ mask_data = np.random.randint(2, size=mask_shape).astype("bool")
+ output = test_layer([query, value], mask_data)
+
+ # Invoke the same data, but with a null mask (where no elements are masked).
+ null_mask_data = np.ones(mask_shape)
+ unmasked_output = test_layer([query, value], null_mask_data)
+ # Because one data is masked and one is not, the outputs should not be the
+ # same.
+ self.assertNotAllClose(output, unmasked_output)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/transformer.py b/models/official/nlp/modeling/layers/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..92f509cf26b802dcd769b97e4c11987e713d8d16
--- /dev/null
+++ b/models/official/nlp/modeling/layers/transformer.py
@@ -0,0 +1,437 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based transformer block layer."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import gin
+import tensorflow as tf
+
+from official.nlp.modeling.layers import attention
+from official.nlp.modeling.layers import dense_einsum
+from official.nlp.modeling.layers import multi_channel_attention
+from official.nlp.modeling.layers.util import tf_function_if_eager
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+class Transformer(tf.keras.layers.Layer):
+ """Transformer layer.
+
+ This layer implements the Transformer from "Attention Is All You Need".
+ (https://arxiv.org/abs/1706.03762).
+
+ Arguments:
+ num_attention_heads: Number of attention heads.
+ intermediate_size: Size of the intermediate layer.
+ intermediate_activation: Activation for the intermediate layer.
+ dropout_rate: Dropout probability for the post-attention and output dropout.
+ attention_dropout_rate: Dropout probability for within the attention layer.
+ output_range: the sequence output range, [0, output_range) by slicing the
+ target sequence. `None` means the target sequence is not sliced.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ """
+
+ def __init__(self,
+ num_attention_heads,
+ intermediate_size,
+ intermediate_activation,
+ dropout_rate=0.0,
+ attention_dropout_rate=0.0,
+ output_range=None,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(Transformer, self).__init__(**kwargs)
+
+ self._num_heads = num_attention_heads
+ self._intermediate_size = intermediate_size
+ self._intermediate_activation = intermediate_activation
+ self._attention_dropout_rate = attention_dropout_rate
+ self._dropout_rate = dropout_rate
+ self._output_range = output_range
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
+ self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
+ self._bias_constraint = tf.keras.constraints.get(bias_constraint)
+
+ def build(self, input_shape):
+ input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
+ input_tensor_shape = tf.TensorShape(input_tensor)
+ if len(input_tensor_shape) != 3:
+ raise ValueError("TransformerLayer expects a three-dimensional input of "
+ "shape [batch, sequence, width].")
+ batch_size, sequence_length, hidden_size = input_tensor_shape
+
+ if len(input_shape) == 2:
+ mask_tensor_shape = tf.TensorShape(input_shape[1])
+ expected_mask_tensor_shape = tf.TensorShape(
+ [batch_size, sequence_length, sequence_length])
+ if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
+ raise ValueError("When passing a mask tensor to TransformerLayer, the "
+ "mask tensor must be of shape [batch, "
+ "sequence_length, sequence_length] (here %s). Got a "
+ "mask tensor of shape %s." %
+ (expected_mask_tensor_shape, mask_tensor_shape))
+ if hidden_size % self._num_heads != 0:
+ raise ValueError(
+ "The input size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (hidden_size, self._num_heads))
+ self._attention_head_size = int(hidden_size // self._num_heads)
+
+ self._attention_layer = attention.MultiHeadAttention(
+ num_heads=self._num_heads,
+ key_size=self._attention_head_size,
+ dropout=self._attention_dropout_rate,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="self_attention")
+ # pylint: disable=protected-access
+ self._attention_layer.build([input_tensor_shape] * 3)
+ self._attention_output_dense = self._attention_layer._output_dense
+ # pylint: enable=protected-access
+ self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
+ # Use float32 in layernorm for numeric stability.
+ # It is probably safe in mixed_float16, but we haven't validated this yet.
+ self._attention_layer_norm = (
+ tf.keras.layers.LayerNormalization(
+ name="self_attention_layer_norm",
+ axis=-1,
+ epsilon=1e-12,
+ dtype=tf.float32))
+ self._intermediate_dense = dense_einsum.DenseEinsum(
+ output_shape=self._intermediate_size,
+ activation=None,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="intermediate")
+ policy = tf.keras.mixed_precision.experimental.global_policy()
+ if policy.name == "mixed_bfloat16":
+ # bfloat16 causes BERT with the LAMB optimizer to not converge
+ # as well, so we use float32.
+ # TODO(b/154538392): Investigate this.
+ policy = tf.float32
+ self._intermediate_activation_layer = tf.keras.layers.Activation(
+ self._intermediate_activation, dtype=policy)
+ self._output_dense = dense_einsum.DenseEinsum(
+ output_shape=hidden_size,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="output")
+ self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
+ # Use float32 in layernorm for numeric stability.
+ self._output_layer_norm = tf.keras.layers.LayerNormalization(
+ name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
+
+ super(Transformer, self).build(input_shape)
+
+ def get_config(self):
+ config = {
+ "num_attention_heads":
+ self._num_heads,
+ "intermediate_size":
+ self._intermediate_size,
+ "intermediate_activation":
+ self._intermediate_activation,
+ "dropout_rate":
+ self._dropout_rate,
+ "attention_dropout_rate":
+ self._attention_dropout_rate,
+ "output_range":
+ self._output_range,
+ "kernel_initializer":
+ tf.keras.initializers.serialize(self._kernel_initializer),
+ "bias_initializer":
+ tf.keras.initializers.serialize(self._bias_initializer),
+ "kernel_regularizer":
+ tf.keras.regularizers.serialize(self._kernel_regularizer),
+ "bias_regularizer":
+ tf.keras.regularizers.serialize(self._bias_regularizer),
+ "activity_regularizer":
+ tf.keras.regularizers.serialize(self._activity_regularizer),
+ "kernel_constraint":
+ tf.keras.constraints.serialize(self._kernel_constraint),
+ "bias_constraint":
+ tf.keras.constraints.serialize(self._bias_constraint)
+ }
+ base_config = super(Transformer, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def call(self, inputs):
+ if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
+ input_tensor, attention_mask = inputs
+ else:
+ input_tensor, attention_mask = (inputs, None)
+
+ if self._output_range:
+ target_tensor = input_tensor[:, 0:self._output_range, :]
+ attention_mask = attention_mask[:, 0:self._output_range, :]
+ else:
+ target_tensor = input_tensor
+ attention_inputs = [target_tensor, input_tensor]
+
+ attention_output = self._attention_layer(attention_inputs, attention_mask)
+ attention_output = self._attention_dropout(attention_output)
+ attention_output = self._attention_layer_norm(target_tensor +
+ attention_output)
+ intermediate_output = self._intermediate_dense(attention_output)
+ intermediate_output = self._intermediate_activation_layer(
+ intermediate_output)
+ layer_output = self._output_dense(intermediate_output)
+ layer_output = self._output_dropout(layer_output)
+ # During mixed precision training, attention_output is from layer norm and
+ # is always fp32 for now. Cast layer_output to fp32 for the subsequent
+ # add.
+ layer_output = tf.cast(layer_output, tf.float32)
+ layer_output = self._output_layer_norm(layer_output + attention_output)
+
+ return layer_output
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+@gin.configurable
+class CompiledTransformer(Transformer):
+
+ @tf_function_if_eager(experimental_compile=True)
+ def call(self, inputs):
+ return super(CompiledTransformer, self).call(inputs)
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+class TransformerDecoderLayer(tf.keras.layers.Layer):
+ """Single transformer layer for decoder.
+
+ It has three sub-layers:
+ (1) a multi-head self-attention mechanism.
+ (2) a encoder-decoder attention.
+ (3) a positionwise fully connected feed-forward network.
+
+ Arguments:
+ num_attention_heads: Number of attention heads.
+ intermediate_size: Size of the intermediate layer.
+ intermediate_activation: Activation for the intermediate layer.
+ dropout_rate: Dropout probability for the post-attention and output dropout.
+ attention_dropout_rate: Dropout probability for within the attention layer.
+ multi_channel_cross_attention: Whether to use `MultiChannelAttention` for
+ cross-attention between target sequences and source sequences.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ """
+
+ def __init__(self,
+ num_attention_heads,
+ intermediate_size,
+ intermediate_activation,
+ dropout_rate=0.0,
+ attention_dropout_rate=0.0,
+ multi_channel_cross_attention=False,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(TransformerDecoderLayer, self).__init__(**kwargs)
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.intermediate_activation = tf.keras.activations.get(
+ intermediate_activation)
+ self.dropout_rate = dropout_rate
+ self.attention_dropout_rate = attention_dropout_rate
+ self.multi_channel_cross_attention = multi_channel_cross_attention
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
+ self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
+ self._bias_constraint = tf.keras.constraints.get(bias_constraint)
+ if self.multi_channel_cross_attention:
+ self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
+ else:
+ self._cross_attention_cls = attention.MultiHeadAttention
+
+ def build(self, input_shape):
+ target_tensor_shape = tf.TensorShape(input_shape[0])
+ if len(target_tensor_shape) != 3:
+ raise ValueError("TransformerLayer expects a three-dimensional input of "
+ "shape [batch, sequence, width].")
+ hidden_size = target_tensor_shape[2]
+ if hidden_size % self.num_attention_heads != 0:
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (hidden_size, self.num_attention_heads))
+ self.attention_head_size = int(hidden_size / self.num_attention_heads)
+ # Self attention.
+ self.self_attention = attention.CachedAttention(
+ num_heads=self.num_attention_heads,
+ key_size=self.attention_head_size,
+ dropout=self.attention_dropout_rate,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="self_attention")
+ self.self_attention_output_dense = dense_einsum.DenseEinsum(
+ output_shape=hidden_size,
+ num_summed_dimensions=2,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="self_attention_output")
+ self.self_attention_dropout = tf.keras.layers.Dropout(
+ rate=self.dropout_rate)
+ self.self_attention_layer_norm = (
+ tf.keras.layers.LayerNormalization(
+ name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
+ # Encoder-decoder attention.
+ self.encdec_attention = self._cross_attention_cls(
+ num_heads=self.num_attention_heads,
+ key_size=self.attention_head_size,
+ dropout=self.attention_dropout_rate,
+ output_shape=hidden_size,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="attention/encdec")
+
+ self.encdec_attention_dropout = tf.keras.layers.Dropout(
+ rate=self.dropout_rate)
+ self.encdec_attention_layer_norm = (
+ tf.keras.layers.LayerNormalization(
+ name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
+
+ # Feed-forward projection.
+ self.intermediate_dense = dense_einsum.DenseEinsum(
+ output_shape=self.intermediate_size,
+ activation=None,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="intermediate")
+ self.intermediate_activation_layer = tf.keras.layers.Activation(
+ self.intermediate_activation)
+ self.output_dense = dense_einsum.DenseEinsum(
+ output_shape=hidden_size,
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint,
+ name="output")
+ self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
+ self.output_layer_norm = tf.keras.layers.LayerNormalization(
+ name="output_layer_norm", axis=-1, epsilon=1e-12)
+ super(TransformerDecoderLayer, self).build(input_shape)
+
+ def common_layers_with_encoder(self):
+ """Gets layer objects that can make a Transformer encoder block."""
+ return [
+ self.self_attention, self.self_attention_layer_norm,
+ self.intermediate_dense, self.output_dense, self.output_layer_norm
+ ]
+
+ def call(self, inputs, cache=None, decode_loop_step=None):
+ if self.multi_channel_cross_attention:
+ if len(inputs) != 5:
+ raise ValueError(
+ "TransformerDecoderLayer must have 5 inputs, when it uses "
+ "multi_channel_cross_attention. But it got: %d" % len(inputs))
+ elif len(inputs) != 4:
+ raise ValueError(
+ "TransformerDecoderLayer must have 4 inputs, but it got: %d" %
+ len(inputs))
+ input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
+ self_attention_inputs = [input_tensor, input_tensor]
+ self_attention_output, cache = self.self_attention(
+ self_attention_inputs,
+ attention_mask=self_attention_mask,
+ cache=cache,
+ decode_loop_step=decode_loop_step)
+ self_attention_output = self.self_attention_dropout(self_attention_output)
+ self_attention_output = self.self_attention_layer_norm(
+ input_tensor + self_attention_output)
+
+ cross_attn_inputs = [self_attention_output, memory]
+ if self.multi_channel_cross_attention:
+ # Accesses the 5-th input tensor for the doc-attention probabilities.
+ cross_attn_inputs.append(inputs[-1])
+ attention_output = self.encdec_attention(cross_attn_inputs, attention_mask)
+ attention_output = self.encdec_attention_dropout(attention_output)
+ attention_output = self.encdec_attention_layer_norm(self_attention_output +
+ attention_output)
+
+ intermediate_output = self.intermediate_dense(attention_output)
+ intermediate_output = self.intermediate_activation_layer(
+ intermediate_output)
+ layer_output = self.output_dense(intermediate_output)
+ layer_output = self.output_dropout(layer_output)
+ layer_output = self.output_layer_norm(layer_output + attention_output)
+ return layer_output, cache
diff --git a/models/official/nlp/modeling/layers/transformer_scaffold.py b/models/official/nlp/modeling/layers/transformer_scaffold.py
new file mode 100644
index 0000000000000000000000000000000000000000..d988febfa68a3e45d3919892ba677c85350f71d6
--- /dev/null
+++ b/models/official/nlp/modeling/layers/transformer_scaffold.py
@@ -0,0 +1,285 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based transformer scaffold layer."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import gin
+import tensorflow as tf
+
+from official.nlp.modeling.layers import attention
+
+
+@tf.keras.utils.register_keras_serializable(package="Text")
+@gin.configurable
+class TransformerScaffold(tf.keras.layers.Layer):
+ """Transformer scaffold layer.
+
+ This layer implements the Transformer from "Attention Is All You Need".
+ (https://arxiv.org/abs/1706.03762), with a customizable attention layer and
+ feedforward layer option. Users can pass a class to
+ `attention_cls`/`feedforward_cls` and associated config to
+ `attention_cfg`/`feedforward_cfg`, in which case the scaffold will
+ instantiate the class with the config, or pass a class instance to
+ `attention_cls`/`feedforward_cls`.
+
+ Arguments:
+ num_attention_heads: Number of attention heads.
+ intermediate_size: Size of the intermediate layer.
+ intermediate_activation: Activation for the intermediate layer.
+ attention_cls: A class to instantiate attention layer, or a layer instance.
+ attention_cfg: The config with which to instantiate `attention_cls`. Ignored
+ if attention_cls is a layer instance or None. If `attention_cls` is a
+ class, but `attention_cfg` is None, following kwargs will be used to
+ instantiate the attention instance:
+ {
+ "num_heads": num_attention_heads,
+ "key_size": int(hidden_size // num_attention_heads),
+ "dropout": attention_dropout_rate,
+ "name": "self_attention"
+ }, where `hidden_size` is the input tensor's last dimension.
+ feedforward_cls: A class to instantiate feedforward layer, or a layer
+ instance. If None, will use the standard feedforward layer as described
+ in "Attention Is All You Need" paper. If not None, the instantiated
+ feedforward layer is expected to take the output of attention as input
+ and its output is this transformer layer's output.
+ feedforward_cfg: The config with which to instantiate `feedforward_cls`.
+ Ignored if feedforward_cls is a layer instance or is None.
+ If `feedforward_cls` is a class, but `feedforward_cfg` is None, following
+ kwargs will be used to instantiate the feedforward instance:
+ {
+ "intermediate_size": intermediate_size,
+ "intermediate_activation": intermediate_activation,
+ "dropout": dropout_rate,
+ "name": "feedforward"
+ }.
+ dropout_rate: Dropout probability for the post-attention and output dropout.
+ attention_dropout_rate: Dropout probability for within the attention layer.
+ kernel_initializer: Initializer for dense layer kernels.
+ bias_initializer: Initializer for dense layer biases.
+ kernel_regularizer: Regularizer for dense layer kernels.
+ bias_regularizer: Regularizer for dense layer biases.
+ activity_regularizer: Regularizer for dense layer activity.
+ kernel_constraint: Constraint for dense layer kernels.
+ bias_constraint: Constraint for dense layer kernels.
+ """
+
+ def __init__(self,
+ num_attention_heads,
+ intermediate_size,
+ intermediate_activation,
+ attention_cls=attention.MultiHeadAttention,
+ attention_cfg=None,
+ feedforward_cls=None,
+ feedforward_cfg=None,
+ dropout_rate=0.0,
+ attention_dropout_rate=0.0,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(TransformerScaffold, self).__init__(**kwargs)
+
+ self._attention_cfg = attention_cfg
+ self._attention_cls = attention_cls
+ self._feedforward_cls = feedforward_cls
+ self._feedforward_cfg = feedforward_cfg
+ self._num_heads = num_attention_heads
+ self._intermediate_size = intermediate_size
+ self._intermediate_activation = intermediate_activation
+ self._attention_dropout_rate = attention_dropout_rate
+ self._dropout_rate = dropout_rate
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
+ self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
+ self._bias_constraint = tf.keras.constraints.get(bias_constraint)
+
+ def build(self, input_shape):
+ input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
+ input_tensor_shape = tf.TensorShape(input_tensor)
+ if len(input_tensor_shape) != 3:
+ raise ValueError(
+ "TransformerScaffold expects a three-dimensional input of "
+ "shape [batch, sequence, width].")
+ batch_size, sequence_length, hidden_size = input_tensor_shape
+
+ if len(input_shape) == 2:
+ mask_tensor_shape = tf.TensorShape(input_shape[1])
+ expected_mask_tensor_shape = tf.TensorShape(
+ [batch_size, sequence_length, sequence_length])
+ if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
+ raise ValueError("When passing a mask tensor to TransformerLayer, the "
+ "mask tensor must be of shape [batch, "
+ "sequence_length, sequence_length] (here %s). Got a "
+ "mask tensor of shape %s." %
+ (expected_mask_tensor_shape, mask_tensor_shape))
+ if hidden_size % self._num_heads != 0:
+ raise ValueError(
+ "The input size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (hidden_size, self._num_heads))
+ self._attention_head_size = int(hidden_size // self._num_heads)
+
+ common_kwargs = dict(
+ kernel_initializer=self._kernel_initializer,
+ bias_initializer=self._bias_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activity_regularizer=self._activity_regularizer,
+ kernel_constraint=self._kernel_constraint,
+ bias_constraint=self._bias_constraint)
+
+ def get_layer_instance(instance_or_cls, config, default_config):
+ if isinstance(instance_or_cls, tf.keras.layers.Layer):
+ return instance_or_cls
+ else:
+ if config is None:
+ return instance_or_cls(**default_config)
+ else:
+ return instance_or_cls(**config)
+
+ default_attention_cfg = {
+ "num_heads": self._num_heads,
+ "key_size": self._attention_head_size,
+ "dropout": self._attention_dropout_rate,
+ "name": "self_attention"
+ }
+ default_attention_cfg.update(common_kwargs)
+ self._attention_layer = get_layer_instance(
+ self._attention_cls,
+ config=self._attention_cfg,
+ default_config=default_attention_cfg)
+
+ if self._feedforward_cls is not None:
+ default_feedforward_cfg = {
+ "intermediate_size": self._intermediate_size,
+ "intermediate_activation": self._intermediate_activation,
+ "dropout": self._dropout_rate,
+ "name": "feedforward",
+ }
+ default_feedforward_cfg.update(common_kwargs)
+ self._feedforward_block = get_layer_instance(
+ self._feedforward_cls,
+ config=self._feedforward_cfg,
+ default_config=default_feedforward_cfg)
+ else:
+ self._feedforward_block = None
+
+ self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
+ # Use float32 in layernorm for numeric stability.
+ # It is probably safe in mixed_float16, but we haven't validated this yet.
+ self._attention_layer_norm = (
+ tf.keras.layers.LayerNormalization(
+ name="self_attention_layer_norm", axis=-1, epsilon=1e-12,
+ dtype=tf.float32))
+
+ if self._feedforward_block is None:
+ self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
+ "abc,cd->abd",
+ output_shape=(None, self._intermediate_size),
+ bias_axes="d",
+ name="intermediate",
+ **common_kwargs)
+ policy = tf.keras.mixed_precision.experimental.global_policy()
+ if policy.name == "mixed_bfloat16":
+ # bfloat16 causes BERT with the LAMB optimizer to not converge
+ # as well, so we use float32.
+ # TODO(b/154538392): Investigate this.
+ policy = tf.float32
+ self._intermediate_activation_layer = tf.keras.layers.Activation(
+ self._intermediate_activation, dtype=policy)
+ self._output_dense = tf.keras.layers.experimental.EinsumDense(
+ "abc,cd->abd",
+ output_shape=(None, hidden_size),
+ bias_axes="d",
+ name="output",
+ **common_kwargs)
+
+ self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
+ # Use float32 in layernorm for numeric stability.
+ self._output_layer_norm = tf.keras.layers.LayerNormalization(
+ name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
+
+ super(TransformerScaffold, self).build(input_shape)
+
+ def get_config(self):
+ config = {
+ "attention_cls":
+ self._attention_layer,
+ "feedforward_cls":
+ self._feedforward_block,
+ "num_attention_heads":
+ self._num_heads,
+ "intermediate_size":
+ self._intermediate_size,
+ "intermediate_activation":
+ self._intermediate_activation,
+ "dropout_rate":
+ self._dropout_rate,
+ "attention_dropout_rate":
+ self._attention_dropout_rate,
+ "kernel_initializer":
+ tf.keras.initializers.serialize(self._kernel_initializer),
+ "bias_initializer":
+ tf.keras.initializers.serialize(self._bias_initializer),
+ "kernel_regularizer":
+ tf.keras.regularizers.serialize(self._kernel_regularizer),
+ "bias_regularizer":
+ tf.keras.regularizers.serialize(self._bias_regularizer),
+ "activity_regularizer":
+ tf.keras.regularizers.serialize(self._activity_regularizer),
+ "kernel_constraint":
+ tf.keras.constraints.serialize(self._kernel_constraint),
+ "bias_constraint":
+ tf.keras.constraints.serialize(self._bias_constraint)
+ }
+ base_config = super(TransformerScaffold, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def call(self, inputs):
+ if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
+ input_tensor, attention_mask = inputs
+ else:
+ input_tensor, attention_mask = (inputs, None)
+
+ attention_inputs = [input_tensor, input_tensor]
+
+ attention_output = self._attention_layer(attention_inputs, attention_mask)
+ attention_output = self._attention_dropout(attention_output)
+ attention_output = self._attention_layer_norm(input_tensor +
+ attention_output)
+ if self._feedforward_block is None:
+ intermediate_output = self._intermediate_dense(attention_output)
+ intermediate_output = self._intermediate_activation_layer(
+ intermediate_output)
+ layer_output = self._output_dense(intermediate_output)
+ layer_output = self._output_dropout(layer_output)
+ # During mixed precision training, attention_output is from layer norm
+ # and is always fp32 for now. Cast layer_output to fp32 for the subsequent
+ # add.
+ layer_output = tf.cast(layer_output, tf.float32)
+ layer_output = self._output_layer_norm(layer_output + attention_output)
+ else:
+ layer_output = self._feedforward_block(attention_output)
+
+ return layer_output
diff --git a/models/official/nlp/modeling/layers/transformer_scaffold_test.py b/models/official/nlp/modeling/layers/transformer_scaffold_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad919889569501c1c29a3c0f88f3e1d1621aec3a
--- /dev/null
+++ b/models/official/nlp/modeling/layers/transformer_scaffold_test.py
@@ -0,0 +1,544 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras-based transformer block layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import attention
+from official.nlp.modeling.layers import transformer_scaffold
+
+
+# Test class that wraps a standard attention layer. If this layer is called
+# at any point, the list passed to the config object will be filled with a
+# boolean 'True'. We register this class as a Keras serializable so we can
+# test serialization below.
+@tf.keras.utils.register_keras_serializable(package='TestOnlyAttention')
+class ValidatedAttentionLayer(attention.MultiHeadAttention):
+
+ def __init__(self, call_list, **kwargs):
+ super(ValidatedAttentionLayer, self).__init__(**kwargs)
+ self.list = call_list
+
+ def call(self, inputs, attention_mask=None):
+ self.list.append(True)
+ return super(ValidatedAttentionLayer, self).call(
+ inputs, attention_mask=attention_mask)
+
+ def get_config(self):
+ config = super(ValidatedAttentionLayer, self).get_config()
+ config['call_list'] = []
+ return config
+
+
+# Test class implements a simple feedforward layer. If this layer is called
+# at any point, the list passed to the config object will be filled with a
+# boolean 'True'. We register this class as a Keras serializable so we can
+# test serialization below.
+@tf.keras.utils.register_keras_serializable(package='TestOnlyFeedforward')
+class ValidatedFeedforwardLayer(tf.keras.layers.Layer):
+
+ def __init__(self, call_list, activation, **kwargs):
+ super(ValidatedFeedforwardLayer, self).__init__(**kwargs)
+ self.list = call_list
+ self.activation = activation
+
+ def build(self, input_shape):
+ hidden_size = input_shape.as_list()[-1]
+ self._feedforward_dense = tf.keras.layers.experimental.EinsumDense(
+ '...x,xy->...y',
+ output_shape=hidden_size,
+ bias_axes='y',
+ activation=self.activation,
+ name='feedforward')
+
+ def call(self, inputs):
+ self.list.append(True)
+ return self._feedforward_dense(inputs)
+
+ def get_config(self):
+ config = super(ValidatedFeedforwardLayer, self).get_config()
+ config['call_list'] = []
+ config['activation'] = self.activation
+ return config
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class TransformerLayerTest(keras_parameterized.TestCase):
+
+ def tearDown(self):
+ super(TransformerLayerTest, self).tearDown()
+ tf.keras.mixed_precision.experimental.set_policy('float32')
+
+ def test_layer_creation(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(data_tensor)
+ # The default output of a transformer layer should be the same as the input.
+ self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
+
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+
+ def test_layer_creation_with_feedforward_cls(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ feedforward_call_list = []
+ feedforward_layer_cfg = {
+ 'activation': 'relu',
+ 'call_list': feedforward_call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ feedforward_cls=ValidatedFeedforwardLayer,
+ feedforward_cfg=feedforward_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=None,
+ intermediate_activation=None)
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(data_tensor)
+ # The default output of a transformer layer should be the same as the input.
+ self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
+
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+ self.assertNotEmpty(feedforward_call_list)
+ self.assertTrue(feedforward_call_list[0],
+ "The passed layer class wasn't instantiated.")
+
+ def test_layer_creation_with_mask(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+ # The default output of a transformer layer should be the same as the input.
+ self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+
+ def test_layer_creation_with_incorrect_mask_fails(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
+ with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
+ _ = test_layer([data_tensor, mask_tensor])
+
+ def test_layer_invocation(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(data_tensor)
+
+ # Create a model from the test layer.
+ model = tf.keras.Model(data_tensor, output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ _ = model.predict(input_data)
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+
+ def test_layer_invocation_with_feedforward_cls(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ feedforward_call_list = []
+ feedforward_layer_cfg = {
+ 'activation': 'relu',
+ 'call_list': feedforward_call_list,
+ }
+ feedforward_layer = ValidatedFeedforwardLayer(**feedforward_layer_cfg)
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ feedforward_cls=feedforward_layer,
+ num_attention_heads=10,
+ intermediate_size=None,
+ intermediate_activation=None)
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+
+ # Create a model from the test layer.
+ model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
+ # which here is (batch, sequence_length, sequence_length)
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ _ = model.predict([input_data, mask_data])
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+ self.assertNotEmpty(feedforward_call_list)
+ self.assertTrue(feedforward_call_list[0],
+ "The passed layer class wasn't instantiated.")
+
+ def test_layer_invocation_with_mask(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+
+ # Create a model from the test layer.
+ model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
+ # which here is (batch, sequence_length, sequence_length)
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ _ = model.predict([input_data, mask_data])
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+
+ def test_layer_invocation_with_float16_dtype(self):
+ tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+
+ # Create a model from the test layer.
+ model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = (10 * np.random.random_sample(
+ (batch_size, sequence_length, width)))
+ # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
+ # which here is (batch, sequence_length, sequence_length)
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ _ = model.predict([input_data, mask_data])
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+
+ def test_transform_with_initializer(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu',
+ kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output = test_layer(data_tensor)
+ # The default output of a transformer layer should be the same as the input.
+ self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0])
+
+ def test_layer_restoration_from_config(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ 'name': 'test_layer',
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+
+ # Create a model from the test layer.
+ model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
+ # which here is (batch, sequence_length, sequence_length)
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ pre_serialization_output = model.predict([input_data, mask_data])
+
+ # Serialize the model config. Pass the serialized data through json to
+ # ensure that we can serialize this layer to disk.
+ serialized_data = json.dumps(model.get_config())
+ post_string_serialized_data = json.loads(serialized_data)
+
+ # Create a new model from the old config, and copy the weights. These models
+ # should have identical outputs.
+ new_model = tf.keras.Model.from_config(post_string_serialized_data)
+ new_model.set_weights(model.get_weights())
+ output = new_model.predict([input_data, mask_data])
+
+ self.assertAllClose(pre_serialization_output, output)
+
+ # If the layer was configured correctly, it should have a list attribute
+ # (since it should have the custom class and config passed to it).
+ new_model.summary()
+ new_call_list = new_model.get_layer(
+ name='transformer_scaffold')._attention_layer.list
+ self.assertNotEmpty(new_call_list)
+ self.assertTrue(new_call_list[0],
+ "The passed layer class wasn't instantiated.")
+
+ def test_layer_with_feedforward_cls_restoration_from_config(self):
+ sequence_length = 21
+ width = 80
+
+ call_list = []
+ attention_layer_cfg = {
+ 'num_heads': 10,
+ 'key_size': 8,
+ 'call_list': call_list,
+ 'name': 'test_layer',
+ }
+ feedforward_call_list = []
+ feedforward_layer_cfg = {
+ 'activation': 'relu',
+ 'call_list': feedforward_call_list,
+ }
+ test_layer = transformer_scaffold.TransformerScaffold(
+ attention_cls=ValidatedAttentionLayer,
+ attention_cfg=attention_layer_cfg,
+ feedforward_cls=ValidatedFeedforwardLayer,
+ feedforward_cfg=feedforward_layer_cfg,
+ num_attention_heads=10,
+ intermediate_size=None,
+ intermediate_activation=None)
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+
+ # Create a model from the test layer.
+ model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
+ # which here is (batch, sequence_length, sequence_length)
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ pre_serialization_output = model.predict([input_data, mask_data])
+
+ # Serialize the model config. Pass the serialized data through json to
+ # ensure that we can serialize this layer to disk.
+ serialized_data = json.dumps(model.get_config())
+ post_string_serialized_data = json.loads(serialized_data)
+
+ # Create a new model from the old config, and copy the weights. These models
+ # should have identical outputs.
+ new_model = tf.keras.Model.from_config(post_string_serialized_data)
+ new_model.set_weights(model.get_weights())
+ output = new_model.predict([input_data, mask_data])
+
+ self.assertAllClose(pre_serialization_output, output)
+
+ # If the layer was configured correctly, it should have a list attribute
+ # (since it should have the custom class and config passed to it).
+ new_model.summary()
+ new_call_list = new_model.get_layer(
+ name='transformer_scaffold')._attention_layer.list
+ self.assertNotEmpty(new_call_list)
+ self.assertTrue(new_call_list[0],
+ "The passed layer class wasn't instantiated.")
+ new_feedforward_call_list = new_model.get_layer(
+ name='transformer_scaffold')._feedforward_block.list
+ self.assertNotEmpty(new_feedforward_call_list)
+ self.assertTrue(new_feedforward_call_list[0],
+ "The passed layer class wasn't instantiated.")
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/transformer_test.py b/models/official/nlp/modeling/layers/transformer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..841feb9948cb69abe1b1b73364b6f09fa2bde836
--- /dev/null
+++ b/models/official/nlp/modeling/layers/transformer_test.py
@@ -0,0 +1,253 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras-based transformer block layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.layers import transformer
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+@parameterized.named_parameters(('base', transformer.Transformer),
+ ('xla', transformer.CompiledTransformer))
+class TransformerLayerTest(keras_parameterized.TestCase):
+
+ def tearDown(self):
+ super(TransformerLayerTest, self).tearDown()
+ tf.keras.mixed_precision.experimental.set_policy('float32')
+
+ def test_layer_creation(self, transformer_cls):
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(data_tensor)
+ # The default output of a transformer layer should be the same as the input.
+ self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
+
+ def test_layer_creation_with_mask(self, transformer_cls):
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+ # The default output of a transformer layer should be the same as the input.
+ self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
+
+ def test_layer_creation_with_incorrect_mask_fails(self, transformer_cls):
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
+ with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
+ _ = test_layer([data_tensor, mask_tensor])
+
+ def test_layer_invocation(self, transformer_cls):
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output_tensor = test_layer(data_tensor)
+
+ # Create a model from the test layer.
+ model = tf.keras.Model(data_tensor, output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ _ = model.predict(input_data)
+
+ def test_layer_invocation_with_mask(self, transformer_cls):
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+
+ # Create a model from the test layer.
+ model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
+ # which here is (batch, sequence_length, sequence_length)
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ _ = model.predict([input_data, mask_data])
+
+ def test_layer_output_range(self, transformer_cls):
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+
+ batch_size = 6
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, width))
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ output_tensor = test_layer([input_data, mask_data])
+
+ # The layer only attends to the first token and outputs the first token
+ # embeeding.
+ new_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu',
+ output_range=1)
+ _ = new_layer([input_data, mask_data])
+ new_layer.set_weights(test_layer.get_weights())
+ new_output_tensor = new_layer([input_data, mask_data])
+ self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
+
+ def test_layer_invocation_with_float16_dtype(self, transformer_cls):
+ tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu')
+ sequence_length = 21
+ width = 80
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ # Create a 2-dimensional input (the first dimension is implicit).
+ mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
+ output_tensor = test_layer([data_tensor, mask_tensor])
+
+ # Create a model from the test layer.
+ model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
+
+ # Invoke the model on test data. We can't validate the output data itself
+ # (the NN is too complex) but this will rule out structural runtime errors.
+ batch_size = 6
+ input_data = (10 * np.random.random_sample(
+ (batch_size, sequence_length, width)))
+ # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
+ # which here is (batch, sequence_length, sequence_length)
+ mask_data = np.random.randint(
+ 2, size=(batch_size, sequence_length, sequence_length))
+ _ = model.predict([input_data, mask_data])
+
+ def test_transform_with_initializer(self, transformer_cls):
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu',
+ kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
+ sequence_length = 21
+ width = 80
+ # Create a 3-dimensional input (the first dimension is implicit).
+ data_tensor = tf.keras.Input(shape=(sequence_length, width))
+ output = test_layer(data_tensor)
+ # The default output of a transformer layer should be the same as the input.
+ self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
+
+ def test_dynamic_layer_sequence(self, transformer_cls):
+ test_layer = transformer_cls(
+ num_attention_heads=10,
+ intermediate_size=2048,
+ intermediate_activation='relu',
+ kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
+ # Create a 3-dimensional input (the first dimension is implicit).
+ width = 30
+ input_tensor = tf.keras.Input(shape=(None, width))
+ output_tensor = test_layer(input_tensor)
+ model = tf.keras.Model(input_tensor, output_tensor)
+
+ input_length = 17
+ input_data = np.ones((1, input_length, width))
+ output_data = model.predict(input_data)
+
+ self.assertAllEqual([1, input_length, width], output_data.shape)
+
+
+def _create_cache(batch_size, init_decode_length, num_heads, head_size):
+ return {
+ 'key':
+ tf.zeros([batch_size, init_decode_length, num_heads, head_size],
+ dtype=tf.float32),
+ 'value':
+ tf.zeros([batch_size, init_decode_length, num_heads, head_size],
+ dtype=tf.float32)
+ }
+
+
+@keras_parameterized.run_all_keras_modes
+class TransformerDecoderLayerTest(keras_parameterized.TestCase):
+
+ def test_decoder_block_with_cache(self):
+ num_attention_heads = 2
+ hidden_size = 16
+ decoder_block = transformer.TransformerDecoderLayer(
+ num_attention_heads=num_attention_heads,
+ intermediate_size=32,
+ intermediate_activation='relu',
+ dropout_rate=0.1,
+ attention_dropout_rate=0.1)
+ # Forward path.
+ dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
+ dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
+ inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
+ cache = _create_cache(2, 0, num_attention_heads,
+ hidden_size // num_attention_heads)
+ output, cache = decoder_block(inputs, cache)
+ self.assertEqual(output.shape, (2, 4, hidden_size))
+ self.assertEqual(cache['value'].shape, (2, 4, 2, 8))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/layers/util.py b/models/official/nlp/modeling/layers/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..354f216ea4ea743fb48be256126df100abe5cfa9
--- /dev/null
+++ b/models/official/nlp/modeling/layers/util.py
@@ -0,0 +1,51 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-based transformer block layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf
+
+
+class TfFunctionIfEagerDecorator(object):
+ """Helper decorator function to optionally apply the @tf.function annotation."""
+
+ def __init__(self, **kwargs):
+ self.func_kwargs = kwargs
+
+ def __call__(self, func):
+
+ @functools.wraps(func)
+ def wrapped_func(*args):
+ # TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
+ if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions"
+ ) or tf.compat.v1.executing_eagerly_outside_functions():
+ return tf.function(func=func, **self.func_kwargs)(*args)
+ return func(*args)
+
+ # Cache the created function in self._call_impl.
+ if not hasattr(self, "_call_impl"):
+ self._call_impl = wrapped_func
+ return self._call_impl
+
+
+def tf_function_if_eager(**kwargs):
+ """Applies the @tf.function decorator only if running in eager mode."""
+ return TfFunctionIfEagerDecorator(**kwargs)
diff --git a/models/official/nlp/modeling/losses/README.md b/models/official/nlp/modeling/losses/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..522150cfa1518797b488146fae506bfcaf063b8e
--- /dev/null
+++ b/models/official/nlp/modeling/losses/README.md
@@ -0,0 +1,9 @@
+# Losses
+
+Losses contains common loss computation used in NLP tasks.
+
+* `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse
+categorical crossentropy loss.
+
+* `weighted_sparse_categorical_crossentropy_per_example_loss` computes
+per-example sparse categorical crossentropy loss.
diff --git a/models/official/nlp/modeling/losses/__init__.py b/models/official/nlp/modeling/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..919bad30809b1a4967ecb7edcb206e92637477db
--- /dev/null
+++ b/models/official/nlp/modeling/losses/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Activations package definition. Subject to change."""
+from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
+from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import per_example_loss as weighted_sparse_categorical_crossentropy_per_example_loss
diff --git a/models/official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy.py b/models/official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88d8e3665b70be63aaa4aa2f90bb78e4bd9af3f
--- /dev/null
+++ b/models/official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy.py
@@ -0,0 +1,106 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sparse categorical cross-entropy losses."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def _adjust_labels(labels, predictions):
+ """Adjust the 'labels' tensor by squeezing it if needed."""
+ labels = tf.cast(labels, tf.int32)
+ if len(predictions.shape) == len(labels.shape):
+ labels = tf.squeeze(labels, [-1])
+ return labels, predictions
+
+
+def _validate_rank(labels, predictions, weights):
+ if weights is not None and len(weights.shape) != len(labels.shape):
+ raise RuntimeError(
+ ("Weight and label tensors were not of the same rank. weights.shape "
+ "was %s, and labels.shape was %s.") %
+ (predictions.shape, labels.shape))
+ if (len(predictions.shape) - 1) != len(labels.shape):
+ raise RuntimeError(
+ ("Weighted sparse categorical crossentropy expects `labels` to have a "
+ "rank of one less than `predictions`. labels.shape was %s, and "
+ "predictions.shape was %s.") % (labels.shape, predictions.shape))
+
+
+def per_example_loss(labels, predictions, weights=None):
+ """Calculate a per-example sparse categorical crossentropy loss.
+
+ This loss function assumes that the predictions are post-softmax.
+ Args:
+ labels: The labels to evaluate against. Should be a set of integer indices
+ ranging from 0 to (vocab_size-1).
+ predictions: The network predictions. Should have softmax already applied.
+ weights: An optional weight array of the same shape as the 'labels' array.
+ If None, all examples will be used.
+
+ Returns:
+ A tensor of shape predictions.shape[:-1] containing the per-example
+ loss.
+ """
+ # When using these functions with the Keras core API, we will need to squeeze
+ # the labels tensor - Keras adds a spurious inner dimension.
+ labels, predictions = _adjust_labels(labels, predictions)
+ _validate_rank(labels, predictions, weights)
+
+ labels_one_hot = tf.one_hot(labels, predictions.shape[-1])
+ labels_one_hot = tf.cast(labels_one_hot, predictions.dtype)
+ per_example_loss_data = -tf.reduce_sum(
+ predictions * labels_one_hot, axis=[-1])
+ if weights is not None:
+ weights = tf.cast(weights, per_example_loss_data.dtype)
+ per_example_loss_data = weights * per_example_loss_data
+ return per_example_loss_data
+
+
+def loss(labels, predictions, weights=None):
+ """Calculate a per-batch sparse categorical crossentropy loss.
+
+ This loss function assumes that the predictions are post-softmax.
+ Args:
+ labels: The labels to evaluate against. Should be a set of integer indices
+ ranging from 0 to (vocab_size-1).
+ predictions: The network predictions. Should have softmax already applied.
+ weights: An optional weight array of the same shape as the 'labels' array.
+ If None, all examples will be used.
+
+ Returns:
+ A loss scalar.
+
+ Raises:
+ RuntimeError if the passed tensors do not have the same rank.
+ """
+ # When using these functions with the Keras core API, we will need to squeeze
+ # the labels tensor - Keras adds a spurious inner dimension.
+ labels, predictions = _adjust_labels(labels, predictions)
+ _validate_rank(labels, predictions, weights)
+
+ per_example_loss_data = per_example_loss(labels, predictions, weights)
+
+ if weights is None:
+ return tf.reduce_mean(per_example_loss_data)
+ else:
+ numerator = tf.reduce_sum(per_example_loss_data)
+ weights = tf.cast(weights, predictions.dtype)
+ denominator = tf.reduce_sum(weights) + 1e-5
+ return numerator / denominator
diff --git a/models/official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py b/models/official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fec2a318b06f0af44b73d200b22d8a22ba88ddf
--- /dev/null
+++ b/models/official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py
@@ -0,0 +1,380 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for masked LM loss."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling import layers
+from official.nlp.modeling import networks
+from official.nlp.modeling.losses import weighted_sparse_categorical_crossentropy
+
+
+@keras_parameterized.run_all_keras_modes
+class ClassificationLossTest(keras_parameterized.TestCase):
+
+ def create_lm_model(self,
+ vocab_size,
+ sequence_length,
+ hidden_size,
+ num_predictions,
+ output="predictions"):
+ # First, create a transformer stack that we can use to get the LM's
+ # vocabulary weight.
+ xformer_stack = networks.TransformerEncoder(
+ vocab_size=vocab_size,
+ num_layers=1,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ num_attention_heads=4,
+ )
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ _ = xformer_stack([word_ids, mask, type_ids])
+
+ # Create a maskedLM from the transformer stack.
+ test_layer = layers.MaskedLM(
+ embedding_table=xformer_stack.get_embedding_table(),
+ output=output)
+
+ # Create a model from the masked LM layer.
+ lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
+ masked_lm_positions = tf.keras.Input(
+ shape=(num_predictions,), dtype=tf.int32)
+ output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
+ return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
+
+ def create_classification_model(self, input_width, num_classes):
+ test_object = networks.Classification(
+ input_width=input_width, num_classes=num_classes)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ pooled_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
+ output = test_object(pooled_data)
+ return tf.keras.Model(pooled_data, output)
+
+ def test_per_example_loss_3d_input(self):
+ """Test per-example loss with a 3-dimensional input, from a masked LM."""
+ vocab_size = 100
+ sequence_length = 32
+ hidden_size = 64
+ num_predictions = 21
+ model = self.create_lm_model(
+ vocab_size=vocab_size,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ num_predictions=num_predictions)
+
+ # Get the output of the masked LM.
+ batch_size = 3
+ lm_input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, hidden_size))
+ masked_position_data = np.random.randint(
+ 2, size=(batch_size, num_predictions))
+ output_data = model.predict([lm_input_data, masked_position_data])
+
+ # Calculate per-example loss.
+ labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
+ per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels)
+
+ # Per-example loss data should have one value per prediction, and those
+ # values shouldn't be zero in this case (as we're using random data).
+ expected_shape = [batch_size, num_predictions]
+ self.assertEqual(expected_shape, per_example_loss_data.shape.as_list())
+ self.assertNotAllClose(
+ tf.zeros_like(per_example_loss_data), per_example_loss_data)
+
+ def test_per_example_loss_2d_input(self):
+ """Test per-example loss with a 2-d input, from a classifier."""
+ input_width = 512
+ num_classes = 10
+ model = self.create_classification_model(input_width, num_classes)
+
+ # Invoke the network as part of a Model.
+ batch_size = 3
+ input_data = 10 * np.random.random_sample((batch_size, input_width))
+ output_data = model.predict(input_data)
+
+ # Calculate per example loss.
+ labels = np.random.randint(num_classes, size=(batch_size))
+ per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels)
+
+ # Per-example loss data should have one value per batch item, and those
+ # values shouldn't be zero in this case (as we're using random data).
+ self.assertEqual([batch_size], per_example_loss_data.shape.as_list())
+ self.assertNotAllClose(
+ tf.zeros_like(per_example_loss_data), per_example_loss_data)
+
+ def test_per_example_loss_weights_3d_input(self):
+ """Test weighted per-example loss with a 3-d input, from a masked LM."""
+ vocab_size = 100
+ sequence_length = 32
+ hidden_size = 64
+ num_predictions = 21
+ model = self.create_lm_model(
+ vocab_size=vocab_size,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ num_predictions=num_predictions)
+
+ # Get the output of the masked LM.
+ batch_size = 3
+ lm_input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, hidden_size))
+ masked_position_data = np.random.randint(
+ 2, size=(batch_size, num_predictions))
+ output_data = model.predict([lm_input_data, masked_position_data])
+
+ # Calculate per-example loss with weights.
+ labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
+ weights = np.random.randint(2, size=(batch_size, num_predictions))
+
+ per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels, weights=weights)
+
+ # Weighted per-example loss data should be equivalent to multiplying the
+ # loss tensor by the weights tensor.
+ expected_weighted_loss = per_example_loss_data * weights
+ self.assertAllClose(expected_weighted_loss, per_example_loss_data)
+
+ def test_per_example_loss_weights_2d_input(self):
+ """Test weighted per-example loss with a 2-d input, from a classifier."""
+ input_width = 512
+ num_classes = 10
+ model = self.create_classification_model(input_width, num_classes)
+
+ # Invoke the network as part of a Model.
+ batch_size = 3
+ input_data = 10 * np.random.random_sample((batch_size, input_width))
+ output_data = model.predict(input_data)
+
+ # Calculate per-example loss with weights.
+ labels = np.random.randint(num_classes, size=(batch_size))
+ weights = np.random.randint(2, size=(batch_size))
+
+ per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels, weights=weights)
+
+ # Weighted per-example loss data should be equivalent to multiplying the
+ # loss tensor by the weights tensor.
+ expected_weighted_loss = per_example_loss_data * weights
+ self.assertAllClose(expected_weighted_loss, per_example_loss_data)
+
+ def test_loss_3d_input(self):
+ """Test overall loss with a 3-dimensional input, from a masked LM."""
+ vocab_size = 100
+ sequence_length = 32
+ hidden_size = 64
+ num_predictions = 21
+ model = self.create_lm_model(
+ vocab_size=vocab_size,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ num_predictions=num_predictions)
+
+ # Get the output of the masked LM.
+ batch_size = 3
+ lm_input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, hidden_size))
+ masked_position_data = np.random.randint(
+ 2, size=(batch_size, num_predictions))
+ output_data = model.predict([lm_input_data, masked_position_data])
+
+ # Calculate loss.
+ labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
+ weights = np.random.randint(2, size=(batch_size, num_predictions))
+ per_example_loss_data = weighted_sparse_categorical_crossentropy.loss(
+ predictions=output_data, labels=labels, weights=weights)
+
+ # Total loss data should have one value, and that value shouldn't be zero
+ # in this case (as we're using random data).
+ expected_shape = [] # Scalar
+ self.assertEqual(expected_shape, per_example_loss_data.shape.as_list())
+ self.assertNotAllClose(
+ tf.zeros_like(per_example_loss_data), per_example_loss_data)
+
+ def test_loss_2d_input(self):
+ """Test overall loss with a 2-d input, from a classifier."""
+ input_width = 512
+ num_classes = 10
+ model = self.create_classification_model(input_width, num_classes)
+
+ # Invoke the network as part of a Model.
+ batch_size = 3
+ input_data = 10 * np.random.random_sample((batch_size, input_width))
+ output_data = model.predict(input_data)
+
+ # Calculate per example loss.
+ labels = np.random.randint(num_classes, size=(batch_size))
+ loss_data = weighted_sparse_categorical_crossentropy.loss(
+ predictions=output_data, labels=labels)
+
+ # Loss data should have one value only, and that value shouldn't be zero in
+ # this case (as we're using random data).
+ self.assertNotAllClose(0, loss_data)
+
+ def test_loss_weights_3d_input(self):
+ """Test masked loss with a 3-dimensional input, from a masked LM."""
+ vocab_size = 100
+ sequence_length = 32
+ hidden_size = 64
+ num_predictions = 21
+ model = self.create_lm_model(
+ vocab_size=vocab_size,
+ sequence_length=sequence_length,
+ hidden_size=hidden_size,
+ num_predictions=num_predictions)
+
+ # Get the output of the masked LM.
+ batch_size = 3
+ lm_input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, hidden_size))
+ masked_position_data = np.random.randint(
+ 2, size=(batch_size, num_predictions))
+ output_data = model.predict([lm_input_data, masked_position_data])
+
+ # Calculate a fully masked weight tensor. This should give a loss of zero.
+ labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
+ null_weights = np.zeros((batch_size, num_predictions))
+ weighted_loss_data = weighted_sparse_categorical_crossentropy.loss(
+ predictions=output_data, labels=labels, weights=null_weights)
+
+ # Because the tensor is fully masked, the loss should be 0.
+ self.assertAllClose(0, weighted_loss_data)
+
+ def test_loss_weights_2d_input(self):
+ """Test masked loss with a 2-d input, from a classifier."""
+ input_width = 512
+ num_classes = 10
+ model = self.create_classification_model(input_width, num_classes)
+
+ # Invoke the network as part of a Model.
+ batch_size = 3
+ input_data = 10 * np.random.random_sample((batch_size, input_width))
+ output_data = model.predict(input_data)
+
+ # Calculate a fully masked weight tensor. This should give a loss of zero.
+ labels = np.random.randint(num_classes, size=(batch_size))
+ null_weights = np.zeros((batch_size))
+ weighted_loss_data = weighted_sparse_categorical_crossentropy.loss(
+ predictions=output_data, labels=labels, weights=null_weights)
+
+ # Because the tensor is fully masked, the loss should be 0.
+ self.assertAllClose(0, weighted_loss_data)
+
+ def test_mismatched_predictions_and_labels_ranks_squeezes(self):
+ """Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
+ batch_size = 3
+ output_data = np.random.random_sample((batch_size, 10))
+ labels = np.random.randint(10, size=(batch_size, 1))
+
+ # All that this test tests is that the squeeze is successful.
+ _ = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels)
+
+ def test_mismatched_weights_and_labels_ranks_fail(self):
+ """Test that the loss asserts when rank(predictions) != rank(labels)."""
+ batch_size = 3
+ output_data = np.random.random_sample((batch_size, 10, 15))
+ labels = np.random.randint(10, size=(batch_size, 10))
+ weights = np.random.randint(2, size=(batch_size))
+
+ with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
+ _ = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels, weights=weights)
+ with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
+ _ = weighted_sparse_categorical_crossentropy.loss(
+ predictions=output_data, labels=labels, weights=weights)
+
+ def test_tf_tensor_inputs(self):
+ """Test that tf.Tensors can be used as inputs to the loss function."""
+ batch_size = 3
+ output_data = tf.convert_to_tensor(
+ np.random.random_sample((batch_size, 10, 15)))
+ labels = tf.convert_to_tensor(np.random.randint(10, size=(batch_size, 10)))
+ weights = tf.convert_to_tensor(np.random.randint(2, size=(batch_size, 10)))
+
+ # We're not trying to validate numerical correctness, just ensure that
+ # we can in fact pass tensors to these functions without causing runtime
+ # errors from the shape checking code.
+ _ = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels, weights=weights)
+ _ = weighted_sparse_categorical_crossentropy.loss(
+ predictions=output_data, labels=labels, weights=weights)
+
+ def test_legacy_lm_loss_compatibility(self):
+ """Test to validate computational correctness during refactors."""
+ # This is the empirical output of a masked LM with the following parameters:
+ # batch_size = 3
+ # vocab_size = 5
+ # sequence_length = 4
+ # num_predictions = 2
+ output_data = np.array(
+ [[[-2.5286622, -1.0963473, -1.4925185, -2.4451098, -1.2923571],
+ [-2.7117882, -1.1205841, -4.02187, -0.9966936, -1.5119683]],
+ [[-2.5379114, -0.82479054, -2.287932, -1.3747153, -2.053741],
+ [-2.5379114, -0.82479054, -2.287932, -1.3747153, -2.053741]],
+ [[-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509],
+ [-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]])
+ labels = np.array([[4, 0], [2, 2], [2, 1]])
+
+ # Validate that per_example loss calculations are the same.
+ per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels)
+ expected_per_example_loss_data = [[1.2923571, 2.7117882],
+ [2.287932, 2.287932],
+ [3.0924666, 1.8219438]]
+ self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
+
+ # Validate that overall loss calculations are the same.
+ weights = np.array([[1, 0], [0, 0], [0, 0]])
+ loss_data = weighted_sparse_categorical_crossentropy.loss(
+ predictions=output_data, labels=labels, weights=weights)
+ expected_loss_data = 1.2923441
+ self.assertAllClose(expected_loss_data, loss_data)
+
+ def test_legacy_classification_loss_compatibility(self):
+ """Test to validate computational correctness during refactors."""
+ # This is the empirical output of a classifier with the following params:
+ # batch_size = 2
+ # num_classes = 3
+ output_data = np.array([[-1.6094601e-03, -1.0966038e+01, -6.4434357e+00],
+ [-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]])
+ labels = np.array([2, 1])
+
+ # Validate that per_example loss calculations are the same.
+ per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
+ predictions=output_data, labels=labels)
+ expected_per_example_loss_data = [6.4434357, 6.4009643]
+ self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
+
+ # Validate that overall loss calculations are the same.
+ weights = None
+ loss_data = weighted_sparse_categorical_crossentropy.loss(
+ predictions=output_data, labels=labels, weights=weights)
+ expected_loss_data = 6.4222
+ self.assertAllClose(expected_loss_data, loss_data)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/models/README.md b/models/official/nlp/modeling/models/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c2e572b6fe07631c17f37b29723fc7a0ac94a81e
--- /dev/null
+++ b/models/official/nlp/modeling/models/README.md
@@ -0,0 +1,22 @@
+# Models
+
+Models are combinations of layers and networks that would be trained.
+
+Several pre-built canned models are provided to train encoder networks. These
+models are intended as both convenience functions and canonical examples.
+
+* [`BertClassifier`](bert_classifier.py) implements a simple classification
+model containing a single classification head using the Classification network.
+It can be used as a regression model as well.
+
+* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
+classification model containing a single classification head using the
+TokenClassification network.
+
+* [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span
+start-end predictor (that is, a model that predicts two values: a start token
+index and an end token index), suitable for SQuAD-style tasks.
+
+* [`BertPretrainer`](bert_pretrainer.py) implements a masked LM and a
+classification head using the Masked LM and Classification networks,
+respectively.
diff --git a/models/official/nlp/modeling/models/__init__.py b/models/official/nlp/modeling/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7912e3cf8a70c19a35ef51a123b5ef3d1335617f
--- /dev/null
+++ b/models/official/nlp/modeling/models/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Models package definition."""
+from official.nlp.modeling.models.bert_classifier import BertClassifier
+from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
+from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
+from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
diff --git a/models/official/nlp/modeling/models/bert_classifier.py b/models/official/nlp/modeling/models/bert_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..8db6faeba0dbebfe4f7b63cb4c3c4c33607c56cc
--- /dev/null
+++ b/models/official/nlp/modeling/models/bert_classifier.py
@@ -0,0 +1,91 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trainer network for BERT-style models."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.nlp.modeling import networks
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class BertClassifier(tf.keras.Model):
+ """Classifier model based on a BERT-style transformer-based encoder.
+
+ This is an implementation of the network structure surrounding a transformer
+ encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
+ for Language Understanding" (https://arxiv.org/abs/1810.04805).
+
+ The BertClassifier allows a user to pass in a transformer stack, and
+ instantiates a classification network based on the passed `num_classes`
+ argument. If `num_classes` is set to 1, a regression network is instantiated.
+
+ Arguments:
+ network: A transformer network. This network should output a sequence output
+ and a classification output. Furthermore, it should expose its embedding
+ table via a "get_embedding_table" method.
+ num_classes: Number of classes to predict from the classification network.
+ initializer: The initializer (if any) to use in the classification networks.
+ Defaults to a Glorot uniform initializer.
+ output: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ """
+
+ def __init__(self,
+ network,
+ num_classes,
+ initializer='glorot_uniform',
+ output='logits',
+ dropout_rate=0.1,
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._config = {
+ 'network': network,
+ 'num_classes': num_classes,
+ 'initializer': initializer,
+ 'output': output,
+ }
+
+ # We want to use the inputs of the passed network as the inputs to this
+ # Model. To do this, we need to keep a handle to the network inputs for use
+ # when we construct the Model object at the end of init.
+ inputs = network.inputs
+
+ # Because we have a copy of inputs to create this Model object, we can
+ # invoke the Network object with its own input tensors to start the Model.
+ _, cls_output = network(inputs)
+ cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
+
+ self.classifier = networks.Classification(
+ input_width=cls_output.shape[-1],
+ num_classes=num_classes,
+ initializer=initializer,
+ output=output,
+ name='classification')
+ predictions = self.classifier(cls_output)
+
+ super(BertClassifier, self).__init__(
+ inputs=inputs, outputs=predictions, **kwargs)
+
+ def get_config(self):
+ return self._config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/models/bert_classifier_test.py b/models/official/nlp/modeling/models/bert_classifier_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dade8508592d5e3344b79e50dce74ccc27526c7
--- /dev/null
+++ b/models/official/nlp/modeling/models/bert_classifier_test.py
@@ -0,0 +1,107 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BERT trainer network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling import networks
+from official.nlp.modeling.models import bert_classifier
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class BertClassifierTest(keras_parameterized.TestCase):
+
+ @parameterized.parameters(1, 3)
+ def test_bert_trainer(self, num_classes):
+ """Validate that the Keras object can be created."""
+ # Build a transformer network to use within the BERT trainer.
+ vocab_size = 100
+ sequence_length = 512
+ test_network = networks.TransformerEncoder(
+ vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
+
+ # Create a BERT trainer with the created network.
+ bert_trainer_model = bert_classifier.BertClassifier(
+ test_network,
+ num_classes=num_classes)
+
+ # Create a set of 2-dimensional inputs (the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+
+ # Invoke the trainer model on the inputs. This causes the layer to be built.
+ cls_outs = bert_trainer_model([word_ids, mask, type_ids])
+
+ # Validate that the outputs are of the expected shape.
+ expected_classification_shape = [None, num_classes]
+ self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
+
+ @parameterized.parameters(1, 2)
+ def test_bert_trainer_tensor_call(self, num_classes):
+ """Validate that the Keras object can be invoked."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=2)
+
+ # Create a BERT trainer with the created network.
+ bert_trainer_model = bert_classifier.BertClassifier(
+ test_network, num_classes=num_classes)
+
+ # Create a set of 2-dimensional data tensors to feed into the model.
+ word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
+ mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
+ type_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
+
+ # Invoke the trainer model on the tensors. In Eager mode, this does the
+ # actual calculation. (We can't validate the outputs, since the network is
+ # too complex: this simply ensures we're not hitting runtime errors.)
+ _ = bert_trainer_model([word_ids, mask, type_ids])
+
+ def test_serialize_deserialize(self):
+ """Validate that the BERT trainer can be serialized and deserialized."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=5)
+
+ # Create a BERT trainer with the created network. (Note that all the args
+ # are different, so we can catch any serialization mismatches.)
+ bert_trainer_model = bert_classifier.BertClassifier(
+ test_network, num_classes=4, initializer='zeros', output='predictions')
+
+ # Create another BERT trainer via serialization and deserialization.
+ config = bert_trainer_model.get_config()
+ new_bert_trainer_model = bert_classifier.BertClassifier.from_config(config)
+
+ # Validate that the config can be forced to JSON.
+ _ = new_bert_trainer_model.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(bert_trainer_model.get_config(),
+ new_bert_trainer_model.get_config())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/models/bert_pretrainer.py b/models/official/nlp/modeling/models/bert_pretrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce33747f03af723927fba138ddec55160262449
--- /dev/null
+++ b/models/official/nlp/modeling/models/bert_pretrainer.py
@@ -0,0 +1,231 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trainer network for BERT-style models."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import copy
+from typing import List, Optional
+
+import gin
+import tensorflow as tf
+
+from official.nlp.modeling import layers
+from official.nlp.modeling import networks
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class BertPretrainer(tf.keras.Model):
+ """BERT network training model.
+
+ This is an implementation of the network structure surrounding a transformer
+ encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
+ for Language Understanding" (https://arxiv.org/abs/1810.04805).
+
+ The BertPretrainer allows a user to pass in a transformer stack, and
+ instantiates the masked language model and classification networks that are
+ used to create the training objectives.
+
+ Arguments:
+ network: A transformer network. This network should output a sequence output
+ and a classification output.
+ num_classes: Number of classes to predict from the classification network.
+ num_token_predictions: Number of tokens to predict from the masked LM.
+ embedding_table: Embedding table of a network. If None, the
+ "network.get_embedding_table()" is used.
+ activation: The activation (if any) to use in the masked LM network. If
+ None, no activation will be used.
+ initializer: The initializer (if any) to use in the masked LM and
+ classification networks. Defaults to a Glorot uniform initializer.
+ output: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ """
+
+ def __init__(self,
+ network,
+ num_classes,
+ num_token_predictions,
+ embedding_table=None,
+ activation=None,
+ initializer='glorot_uniform',
+ output='logits',
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._config = {
+ 'network': network,
+ 'num_classes': num_classes,
+ 'num_token_predictions': num_token_predictions,
+ 'activation': activation,
+ 'initializer': initializer,
+ 'output': output,
+ }
+ self.encoder = network
+ # We want to use the inputs of the passed network as the inputs to this
+ # Model. To do this, we need to keep a copy of the network inputs for use
+ # when we construct the Model object at the end of init. (We keep a copy
+ # because we'll be adding another tensor to the copy later.)
+ network_inputs = self.encoder.inputs
+ inputs = copy.copy(network_inputs)
+
+ # Because we have a copy of inputs to create this Model object, we can
+ # invoke the Network object with its own input tensors to start the Model.
+ # Note that, because of how deferred construction happens, we can't use
+ # the copy of the list here - by the time the network is invoked, the list
+ # object contains the additional input added below.
+ sequence_output, cls_output = self.encoder(network_inputs)
+
+ # The encoder network may get outputs from all layers.
+ if isinstance(sequence_output, list):
+ sequence_output = sequence_output[-1]
+ if isinstance(cls_output, list):
+ cls_output = cls_output[-1]
+ sequence_output_length = sequence_output.shape.as_list()[1]
+ if sequence_output_length < num_token_predictions:
+ raise ValueError(
+ "The passed network's output length is %s, which is less than the "
+ 'requested num_token_predictions %s.' %
+ (sequence_output_length, num_token_predictions))
+
+ masked_lm_positions = tf.keras.layers.Input(
+ shape=(num_token_predictions,),
+ name='masked_lm_positions',
+ dtype=tf.int32)
+ inputs.append(masked_lm_positions)
+
+ if embedding_table is None:
+ embedding_table = self.encoder.get_embedding_table()
+ self.masked_lm = layers.MaskedLM(
+ embedding_table=embedding_table,
+ activation=activation,
+ initializer=initializer,
+ output=output,
+ name='cls/predictions')
+ lm_outputs = self.masked_lm(
+ sequence_output, masked_positions=masked_lm_positions)
+
+ self.classification = networks.Classification(
+ input_width=cls_output.shape[-1],
+ num_classes=num_classes,
+ initializer=initializer,
+ output=output,
+ name='classification')
+ sentence_outputs = self.classification(cls_output)
+
+ super(BertPretrainer, self).__init__(
+ inputs=inputs,
+ outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
+ **kwargs)
+
+ def get_config(self):
+ return self._config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
+
+
+# TODO(hongkuny): Migrate to BertPretrainerV2 for all usages.
+@tf.keras.utils.register_keras_serializable(package='Text')
+@gin.configurable
+class BertPretrainerV2(tf.keras.Model):
+ """BERT pretraining model V2.
+
+ (Experimental).
+ Adds the masked language model head and optional classification heads upon the
+ transformer encoder. When num_masked_tokens == 0, there won't be MaskedLM
+ head.
+
+ Arguments:
+ num_masked_tokens: Number of tokens to predict from the masked LM.
+ encoder_network: A transformer network. This network should output a
+ sequence output and a classification output.
+ mlm_activation: The activation (if any) to use in the masked LM network. If
+ None, no activation will be used.
+ mlm_initializer: The initializer (if any) to use in the masked LM. Default
+ to a Glorot uniform initializer.
+ classification_heads: A list of optional head layers to transform on encoder
+ sequence outputs.
+ name: The name of the model.
+ Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
+ dictionary.
+ Outputs: A dictionary of `lm_output` and classification head outputs keyed by
+ head names.
+ """
+
+ def __init__(
+ self,
+ num_masked_tokens: int,
+ encoder_network: tf.keras.Model,
+ mlm_activation=None,
+ mlm_initializer='glorot_uniform',
+ classification_heads: Optional[List[tf.keras.layers.Layer]] = None,
+ name: str = 'bert',
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._config = {
+ 'encoder_network': encoder_network,
+ 'num_masked_tokens': num_masked_tokens,
+ 'mlm_initializer': mlm_initializer,
+ 'classification_heads': classification_heads,
+ 'name': name,
+ }
+
+ self.encoder_network = encoder_network
+ inputs = copy.copy(self.encoder_network.inputs)
+ sequence_output, _ = self.encoder_network(inputs)
+
+ self.classification_heads = classification_heads or []
+ if len(set([cls.name for cls in self.classification_heads])) != len(
+ self.classification_heads):
+ raise ValueError('Classification heads should have unique names.')
+
+ outputs = dict()
+ if num_masked_tokens > 0:
+ self.masked_lm = layers.MaskedLM(
+ embedding_table=self.encoder_network.get_embedding_table(),
+ activation=mlm_activation,
+ initializer=mlm_initializer,
+ name='cls/predictions')
+ masked_lm_positions = tf.keras.layers.Input(
+ shape=(num_masked_tokens,),
+ name='masked_lm_positions',
+ dtype=tf.int32)
+ inputs.append(masked_lm_positions)
+ outputs['lm_output'] = self.masked_lm(
+ sequence_output, masked_positions=masked_lm_positions)
+ for cls_head in self.classification_heads:
+ outputs[cls_head.name] = cls_head(sequence_output)
+
+ super(BertPretrainerV2, self).__init__(
+ inputs=inputs, outputs=outputs, name=name, **kwargs)
+
+ @property
+ def checkpoint_items(self):
+ """Returns a dictionary of items to be additionally checkpointed."""
+ items = dict(encoder=self.encoder_network)
+ for head in self.classification_heads:
+ for key, item in head.checkpoint_items.items():
+ items['.'.join([head.name, key])] = item
+ return items
+
+ def get_config(self):
+ return self._config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/models/bert_pretrainer_test.py b/models/official/nlp/modeling/models/bert_pretrainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb9ace5ccf132ec0423276b28fa1e1e473a97290
--- /dev/null
+++ b/models/official/nlp/modeling/models/bert_pretrainer_test.py
@@ -0,0 +1,164 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BERT trainer network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling import networks
+from official.nlp.modeling.models import bert_pretrainer
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class BertPretrainerTest(keras_parameterized.TestCase):
+
+ def test_bert_pretrainer(self):
+ """Validate that the Keras object can be created."""
+ # Build a transformer network to use within the BERT trainer.
+ vocab_size = 100
+ sequence_length = 512
+ test_network = networks.TransformerEncoder(
+ vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
+
+ # Create a BERT trainer with the created network.
+ num_classes = 3
+ num_token_predictions = 2
+ bert_trainer_model = bert_pretrainer.BertPretrainer(
+ test_network,
+ num_classes=num_classes,
+ num_token_predictions=num_token_predictions)
+
+ # Create a set of 2-dimensional inputs (the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ masked_lm_positions = tf.keras.Input(
+ shape=(num_token_predictions,), dtype=tf.int32)
+
+ # Invoke the trainer model on the inputs. This causes the layer to be built.
+ outputs = bert_trainer_model(
+ [word_ids, mask, type_ids, masked_lm_positions])
+
+ # Validate that the outputs are of the expected shape.
+ expected_lm_shape = [None, num_token_predictions, vocab_size]
+ expected_classification_shape = [None, num_classes]
+ self.assertAllEqual(expected_lm_shape, outputs['masked_lm'].shape.as_list())
+ self.assertAllEqual(expected_classification_shape,
+ outputs['classification'].shape.as_list())
+
+ def test_bert_trainer_tensor_call(self):
+ """Validate that the Keras object can be invoked."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=2)
+
+ # Create a BERT trainer with the created network.
+ bert_trainer_model = bert_pretrainer.BertPretrainer(
+ test_network, num_classes=2, num_token_predictions=2)
+
+ # Create a set of 2-dimensional data tensors to feed into the model.
+ word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
+ mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
+ type_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
+ lm_mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
+
+ # Invoke the trainer model on the tensors. In Eager mode, this does the
+ # actual calculation. (We can't validate the outputs, since the network is
+ # too complex: this simply ensures we're not hitting runtime errors.)
+ _ = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
+
+ def test_serialize_deserialize(self):
+ """Validate that the BERT trainer can be serialized and deserialized."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=5)
+
+ # Create a BERT trainer with the created network. (Note that all the args
+ # are different, so we can catch any serialization mismatches.)
+ bert_trainer_model = bert_pretrainer.BertPretrainer(
+ test_network, num_classes=4, num_token_predictions=3)
+
+ # Create another BERT trainer via serialization and deserialization.
+ config = bert_trainer_model.get_config()
+ new_bert_trainer_model = bert_pretrainer.BertPretrainer.from_config(config)
+
+ # Validate that the config can be forced to JSON.
+ _ = new_bert_trainer_model.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(bert_trainer_model.get_config(),
+ new_bert_trainer_model.get_config())
+
+ def test_bert_pretrainerv2(self):
+ """Validate that the Keras object can be created."""
+ # Build a transformer network to use within the BERT trainer.
+ vocab_size = 100
+ sequence_length = 512
+ test_network = networks.TransformerEncoder(
+ vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
+
+ # Create a BERT trainer with the created network.
+ num_token_predictions = 2
+ bert_trainer_model = bert_pretrainer.BertPretrainerV2(
+ encoder_network=test_network, num_masked_tokens=num_token_predictions)
+
+ # Create a set of 2-dimensional inputs (the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
+
+ # Invoke the trainer model on the inputs. This causes the layer to be built.
+ outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
+
+ # Validate that the outputs are of the expected shape.
+ expected_lm_shape = [None, num_token_predictions, vocab_size]
+ self.assertAllEqual(expected_lm_shape, outputs['lm_output'].shape.as_list())
+
+ def test_v2_serialize_deserialize(self):
+ """Validate that the BERT trainer can be serialized and deserialized."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=5)
+
+ # Create a BERT trainer with the created network. (Note that all the args
+ # are different, so we can catch any serialization mismatches.)
+ bert_trainer_model = bert_pretrainer.BertPretrainerV2(
+ encoder_network=test_network, num_masked_tokens=2)
+
+ # Create another BERT trainer via serialization and deserialization.
+ config = bert_trainer_model.get_config()
+ new_bert_trainer_model = bert_pretrainer.BertPretrainerV2.from_config(
+ config)
+
+ # Validate that the config can be forced to JSON.
+ _ = new_bert_trainer_model.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(bert_trainer_model.get_config(),
+ new_bert_trainer_model.get_config())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/models/bert_span_labeler.py b/models/official/nlp/modeling/models/bert_span_labeler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dd9ab13f518373b6bf82800256d75df9d553750
--- /dev/null
+++ b/models/official/nlp/modeling/models/bert_span_labeler.py
@@ -0,0 +1,103 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trainer network for BERT-style models."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.nlp.modeling import networks
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class BertSpanLabeler(tf.keras.Model):
+ """Span labeler model based on a BERT-style transformer-based encoder.
+
+ This is an implementation of the network structure surrounding a transformer
+ encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
+ for Language Understanding" (https://arxiv.org/abs/1810.04805).
+
+ The BertSpanLabeler allows a user to pass in a transformer stack, and
+ instantiates a span labeling network based on a single dense layer.
+
+ Arguments:
+ network: A transformer network. This network should output a sequence output
+ and a classification output. Furthermore, it should expose its embedding
+ table via a "get_embedding_table" method.
+ initializer: The initializer (if any) to use in the span labeling network.
+ Defaults to a Glorot uniform initializer.
+ output: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ """
+
+ def __init__(self,
+ network,
+ initializer='glorot_uniform',
+ output='logits',
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._network = network
+ self._config = {
+ 'network': network,
+ 'initializer': initializer,
+ 'output': output,
+ }
+
+ # We want to use the inputs of the passed network as the inputs to this
+ # Model. To do this, we need to keep a handle to the network inputs for use
+ # when we construct the Model object at the end of init.
+ inputs = network.inputs
+
+ # Because we have a copy of inputs to create this Model object, we can
+ # invoke the Network object with its own input tensors to start the Model.
+ sequence_output, _ = network(inputs)
+
+ # This is an instance variable for ease of access to the underlying task
+ # network.
+ self.span_labeling = networks.SpanLabeling(
+ input_width=sequence_output.shape[-1],
+ initializer=initializer,
+ output=output,
+ name='span_labeling')
+ start_logits, end_logits = self.span_labeling(sequence_output)
+
+ # Use identity layers wrapped in lambdas to explicitly name the output
+ # tensors. This allows us to use string-keyed dicts in Keras fit/predict/
+ # evaluate calls.
+ start_logits = tf.keras.layers.Lambda(
+ tf.identity, name='start_positions')(
+ start_logits)
+ end_logits = tf.keras.layers.Lambda(
+ tf.identity, name='end_positions')(
+ end_logits)
+
+ logits = [start_logits, end_logits]
+
+ super(BertSpanLabeler, self).__init__(
+ inputs=inputs, outputs=logits, **kwargs)
+
+ @property
+ def checkpoint_items(self):
+ return dict(encoder=self._network)
+
+ def get_config(self):
+ return self._config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/models/bert_span_labeler_test.py b/models/official/nlp/modeling/models/bert_span_labeler_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d05e91b52c9ba69a65df7dee4783ffc4113b8a3c
--- /dev/null
+++ b/models/official/nlp/modeling/models/bert_span_labeler_test.py
@@ -0,0 +1,124 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BERT trainer network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling import networks
+from official.nlp.modeling.models import bert_span_labeler
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class BertSpanLabelerTest(keras_parameterized.TestCase):
+
+ def test_bert_trainer(self):
+ """Validate that the Keras object can be created."""
+ # Build a transformer network to use within the BERT trainer.
+ vocab_size = 100
+ sequence_length = 512
+ test_network = networks.TransformerEncoder(
+ vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
+
+ # Create a BERT trainer with the created network.
+ bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
+
+ # Create a set of 2-dimensional inputs (the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+
+ # Invoke the trainer model on the inputs. This causes the layer to be built.
+ cls_outs = bert_trainer_model([word_ids, mask, type_ids])
+
+ # Validate that there are 2 outputs are of the expected shape.
+ self.assertEqual(2, len(cls_outs))
+ expected_shape = [None, sequence_length]
+ for out in cls_outs:
+ self.assertAllEqual(expected_shape, out.shape.as_list())
+
+ def test_bert_trainer_named_compilation(self):
+ """Validate compilation using explicit output names."""
+ # Build a transformer network to use within the BERT trainer.
+ vocab_size = 100
+ sequence_length = 512
+ test_network = networks.TransformerEncoder(
+ vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
+
+ # Create a BERT trainer with the created network.
+ bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
+
+ # Attempt to compile the model using a string-keyed dict of output names to
+ # loss functions. This will validate that the outputs are named as we
+ # expect.
+ bert_trainer_model.compile(
+ optimizer='sgd',
+ loss={
+ 'start_positions': 'mse',
+ 'end_positions': 'mse'
+ })
+
+ def test_bert_trainer_tensor_call(self):
+ """Validate that the Keras object can be invoked."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=2)
+
+ # Create a BERT trainer with the created network.
+ bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
+
+ # Create a set of 2-dimensional data tensors to feed into the model.
+ word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
+ mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
+ type_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
+
+ # Invoke the trainer model on the tensors. In Eager mode, this does the
+ # actual calculation. (We can't validate the outputs, since the network is
+ # too complex: this simply ensures we're not hitting runtime errors.)
+ _ = bert_trainer_model([word_ids, mask, type_ids])
+
+ def test_serialize_deserialize(self):
+ """Validate that the BERT trainer can be serialized and deserialized."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=5)
+
+ # Create a BERT trainer with the created network. (Note that all the args
+ # are different, so we can catch any serialization mismatches.)
+ bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
+
+ # Create another BERT trainer via serialization and deserialization.
+ config = bert_trainer_model.get_config()
+ new_bert_trainer_model = bert_span_labeler.BertSpanLabeler.from_config(
+ config)
+
+ # Validate that the config can be forced to JSON.
+ _ = new_bert_trainer_model.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(bert_trainer_model.get_config(),
+ new_bert_trainer_model.get_config())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/models/bert_token_classifier.py b/models/official/nlp/modeling/models/bert_token_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..4967d71776d685c8631d19d3c07a9fc1e8a25bf6
--- /dev/null
+++ b/models/official/nlp/modeling/models/bert_token_classifier.py
@@ -0,0 +1,97 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trainer network for BERT-style models."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.nlp.modeling import networks
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class BertTokenClassifier(tf.keras.Model):
+ """Token classifier model based on a BERT-style transformer-based encoder.
+
+ This is an implementation of the network structure surrounding a transformer
+ encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
+ for Language Understanding" (https://arxiv.org/abs/1810.04805).
+
+ The BertTokenClassifier allows a user to pass in a transformer stack, and
+ instantiates a token classification network based on the passed `num_classes`
+ argument.
+
+ Arguments:
+ network: A transformer network. This network should output a sequence output
+ and a classification output. Furthermore, it should expose its embedding
+ table via a "get_embedding_table" method.
+ num_classes: Number of classes to predict from the classification network.
+ initializer: The initializer (if any) to use in the classification networks.
+ Defaults to a Glorot uniform initializer.
+ output: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ """
+
+ def __init__(self,
+ network,
+ num_classes,
+ initializer='glorot_uniform',
+ output='logits',
+ dropout_rate=0.1,
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._network = network
+ self._config = {
+ 'network': network,
+ 'num_classes': num_classes,
+ 'initializer': initializer,
+ 'output': output,
+ }
+
+ # We want to use the inputs of the passed network as the inputs to this
+ # Model. To do this, we need to keep a handle to the network inputs for use
+ # when we construct the Model object at the end of init.
+ inputs = network.inputs
+
+ # Because we have a copy of inputs to create this Model object, we can
+ # invoke the Network object with its own input tensors to start the Model.
+ sequence_output, _ = network(inputs)
+ sequence_output = tf.keras.layers.Dropout(
+ rate=dropout_rate)(sequence_output)
+
+ self.classifier = networks.TokenClassification(
+ input_width=sequence_output.shape[-1],
+ num_classes=num_classes,
+ initializer=initializer,
+ output=output,
+ name='classification')
+ predictions = self.classifier(sequence_output)
+
+ super(BertTokenClassifier, self).__init__(
+ inputs=inputs, outputs=predictions, **kwargs)
+
+ @property
+ def checkpoint_items(self):
+ return dict(encoder=self._network)
+
+ def get_config(self):
+ return self._config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/models/bert_token_classifier_test.py b/models/official/nlp/modeling/models/bert_token_classifier_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..970b531cf5673e4040ceb417ffb67a8ef6aea70a
--- /dev/null
+++ b/models/official/nlp/modeling/models/bert_token_classifier_test.py
@@ -0,0 +1,107 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BERT trainer network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling import networks
+from official.nlp.modeling.models import bert_token_classifier
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class BertTokenClassifierTest(keras_parameterized.TestCase):
+
+ def test_bert_trainer(self):
+ """Validate that the Keras object can be created."""
+ # Build a transformer network to use within the BERT trainer.
+ vocab_size = 100
+ sequence_length = 512
+ test_network = networks.TransformerEncoder(
+ vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
+
+ # Create a BERT trainer with the created network.
+ num_classes = 3
+ bert_trainer_model = bert_token_classifier.BertTokenClassifier(
+ test_network,
+ num_classes=num_classes)
+
+ # Create a set of 2-dimensional inputs (the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+
+ # Invoke the trainer model on the inputs. This causes the layer to be built.
+ sequence_outs = bert_trainer_model([word_ids, mask, type_ids])
+
+ # Validate that the outputs are of the expected shape.
+ expected_classification_shape = [None, sequence_length, num_classes]
+ self.assertAllEqual(expected_classification_shape,
+ sequence_outs.shape.as_list())
+
+ def test_bert_trainer_tensor_call(self):
+ """Validate that the Keras object can be invoked."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=2)
+
+ # Create a BERT trainer with the created network.
+ bert_trainer_model = bert_token_classifier.BertTokenClassifier(
+ test_network, num_classes=2)
+
+ # Create a set of 2-dimensional data tensors to feed into the model.
+ word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
+ mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
+ type_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
+
+ # Invoke the trainer model on the tensors. In Eager mode, this does the
+ # actual calculation. (We can't validate the outputs, since the network is
+ # too complex: this simply ensures we're not hitting runtime errors.)
+ _ = bert_trainer_model([word_ids, mask, type_ids])
+
+ def test_serialize_deserialize(self):
+ """Validate that the BERT trainer can be serialized and deserialized."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=2, sequence_length=5)
+
+ # Create a BERT trainer with the created network. (Note that all the args
+ # are different, so we can catch any serialization mismatches.)
+ bert_trainer_model = bert_token_classifier.BertTokenClassifier(
+ test_network, num_classes=4, initializer='zeros', output='predictions')
+
+ # Create another BERT trainer via serialization and deserialization.
+ config = bert_trainer_model.get_config()
+ new_bert_trainer_model = (
+ bert_token_classifier.BertTokenClassifier.from_config(config))
+
+ # Validate that the config can be forced to JSON.
+ _ = new_bert_trainer_model.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(bert_trainer_model.get_config(),
+ new_bert_trainer_model.get_config())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/models/electra_pretrainer.py b/models/official/nlp/modeling/models/electra_pretrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..21fe3a0d9719739fa1adce3d628a3df6b261c177
--- /dev/null
+++ b/models/official/nlp/modeling/models/electra_pretrainer.py
@@ -0,0 +1,307 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trainer network for ELECTRA models."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import copy
+import tensorflow as tf
+
+from official.modeling import tf_utils
+from official.nlp.modeling import layers
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class ElectraPretrainer(tf.keras.Model):
+ """ELECTRA network training model.
+
+ This is an implementation of the network structure described in "ELECTRA:
+ Pre-training Text Encoders as Discriminators Rather Than Generators" (
+ https://arxiv.org/abs/2003.10555).
+
+ The ElectraPretrainer allows a user to pass in two transformer models, one for
+ generator, the other for discriminator, and instantiates the masked language
+ model (at generator side) and classification networks (at discriminator side)
+ that are used to create the training objectives.
+
+ Arguments:
+ generator_network: A transformer network for generator, this network should
+ output a sequence output and an optional classification output.
+ discriminator_network: A transformer network for discriminator, this network
+ should output a sequence output
+ vocab_size: Size of generator output vocabulary
+ num_classes: Number of classes to predict from the classification network
+ for the generator network (not used now)
+ sequence_length: Input sequence length
+ last_hidden_dim: Last hidden dim of generator transformer output
+ num_token_predictions: Number of tokens to predict from the masked LM.
+ mlm_activation: The activation (if any) to use in the masked LM and
+ classification networks. If None, no activation will be used.
+ mlm_initializer: The initializer (if any) to use in the masked LM and
+ classification networks. Defaults to a Glorot uniform initializer.
+ output_type: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ disallow_correct: Whether to disallow the generator to generate the exact
+ same token in the original sentence
+ """
+
+ def __init__(self,
+ generator_network,
+ discriminator_network,
+ vocab_size,
+ num_classes,
+ sequence_length,
+ last_hidden_dim,
+ num_token_predictions,
+ mlm_activation=None,
+ mlm_initializer='glorot_uniform',
+ output_type='logits',
+ disallow_correct=False,
+ **kwargs):
+ super(ElectraPretrainer, self).__init__()
+ self._config = {
+ 'generator_network': generator_network,
+ 'discriminator_network': discriminator_network,
+ 'vocab_size': vocab_size,
+ 'num_classes': num_classes,
+ 'sequence_length': sequence_length,
+ 'last_hidden_dim': last_hidden_dim,
+ 'num_token_predictions': num_token_predictions,
+ 'mlm_activation': mlm_activation,
+ 'mlm_initializer': mlm_initializer,
+ 'output_type': output_type,
+ 'disallow_correct': disallow_correct,
+ }
+ for k, v in kwargs.items():
+ self._config[k] = v
+
+ self.generator_network = generator_network
+ self.discriminator_network = discriminator_network
+ self.vocab_size = vocab_size
+ self.num_classes = num_classes
+ self.sequence_length = sequence_length
+ self.last_hidden_dim = last_hidden_dim
+ self.num_token_predictions = num_token_predictions
+ self.mlm_activation = mlm_activation
+ self.mlm_initializer = mlm_initializer
+ self.output_type = output_type
+ self.disallow_correct = disallow_correct
+ self.masked_lm = layers.MaskedLM(
+ embedding_table=generator_network.get_embedding_table(),
+ activation=mlm_activation,
+ initializer=mlm_initializer,
+ output=output_type,
+ name='generator_masked_lm')
+ self.classification = layers.ClassificationHead(
+ inner_dim=last_hidden_dim,
+ num_classes=num_classes,
+ initializer=mlm_initializer,
+ name='generator_classification_head')
+ self.discriminator_head = tf.keras.layers.Dense(
+ units=1, kernel_initializer=mlm_initializer)
+
+ def call(self, inputs):
+ input_word_ids = inputs['input_word_ids']
+ input_mask = inputs['input_mask']
+ input_type_ids = inputs['input_type_ids']
+ masked_lm_positions = inputs['masked_lm_positions']
+
+ ### Generator ###
+ sequence_output, cls_output = self.generator_network(
+ [input_word_ids, input_mask, input_type_ids])
+
+ # The generator encoder network may get outputs from all layers.
+ if isinstance(sequence_output, list):
+ sequence_output = sequence_output[-1]
+ if isinstance(cls_output, list):
+ cls_output = cls_output[-1]
+
+ lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
+ sentence_outputs = self.classification(sequence_output)
+
+ ### Sampling from generator ###
+ fake_data = self._get_fake_data(inputs, lm_outputs, duplicate=True)
+
+ ### Discriminator ###
+ disc_input = fake_data['inputs']
+ disc_label = fake_data['is_fake_tokens']
+ disc_sequence_output, _ = self.discriminator_network([
+ disc_input['input_word_ids'], disc_input['input_mask'],
+ disc_input['input_type_ids']
+ ])
+
+ # The discriminator encoder network may get outputs from all layers.
+ if isinstance(disc_sequence_output, list):
+ disc_sequence_output = disc_sequence_output[-1]
+
+ disc_logits = self.discriminator_head(disc_sequence_output)
+ disc_logits = tf.squeeze(disc_logits, axis=-1)
+
+ return lm_outputs, sentence_outputs, disc_logits, disc_label
+
+ def _get_fake_data(self, inputs, mlm_logits, duplicate=True):
+ """Generate corrupted data for discriminator.
+
+ Args:
+ inputs: A dict of all inputs, same as the input of call() function
+ mlm_logits: The generator's output logits
+ duplicate: Whether to copy the original inputs dict during modifications
+
+ Returns:
+ A dict of generated fake data
+ """
+ inputs = unmask(inputs, duplicate)
+
+ if self.disallow_correct:
+ disallow = tf.one_hot(
+ inputs['masked_lm_ids'], depth=self.vocab_size, dtype=tf.float32)
+ else:
+ disallow = None
+
+ sampled_tokens = tf.stop_gradient(
+ sample_from_softmax(mlm_logits, disallow=disallow))
+ sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
+ updated_input_ids, masked = scatter_update(inputs['input_word_ids'],
+ sampled_tokids,
+ inputs['masked_lm_positions'])
+ labels = masked * (1 - tf.cast(
+ tf.equal(updated_input_ids, inputs['input_word_ids']), tf.int32))
+
+ updated_inputs = get_updated_inputs(
+ inputs, duplicate, input_word_ids=updated_input_ids)
+
+ return {
+ 'inputs': updated_inputs,
+ 'is_fake_tokens': labels,
+ 'sampled_tokens': sampled_tokens
+ }
+
+ def get_config(self):
+ return self._config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
+
+
+def scatter_update(sequence, updates, positions):
+ """Scatter-update a sequence.
+
+ Args:
+ sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
+ updates: A tensor of size batch_size*seq_len(*depth)
+ positions: A [batch_size, n_positions] tensor
+
+ Returns:
+ updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth]
+ tensor of "sequence" with elements at "positions" replaced by the values
+ at "updates". Updates to index 0 are ignored. If there are duplicated
+ positions the update is only applied once.
+ updates_mask: A [batch_size, seq_len] mask tensor of which inputs were
+ updated.
+ """
+ shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3])
+ depth_dimension = (len(shape) == 3)
+ if depth_dimension:
+ batch_size, seq_len, depth = shape
+ else:
+ batch_size, seq_len = shape
+ depth = 1
+ sequence = tf.expand_dims(sequence, -1)
+ n_positions = tf_utils.get_shape_list(positions)[1]
+
+ shift = tf.expand_dims(seq_len * tf.range(batch_size), -1)
+ flat_positions = tf.reshape(positions + shift, [-1, 1])
+ flat_updates = tf.reshape(updates, [-1, depth])
+ updates = tf.scatter_nd(flat_positions, flat_updates,
+ [batch_size * seq_len, depth])
+ updates = tf.reshape(updates, [batch_size, seq_len, depth])
+
+ flat_updates_mask = tf.ones([batch_size * n_positions], tf.int32)
+ updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask,
+ [batch_size * seq_len])
+ updates_mask = tf.reshape(updates_mask, [batch_size, seq_len])
+ not_first_token = tf.concat([
+ tf.zeros((batch_size, 1), tf.int32),
+ tf.ones((batch_size, seq_len - 1), tf.int32)
+ ], -1)
+ updates_mask *= not_first_token
+ updates_mask_3d = tf.expand_dims(updates_mask, -1)
+
+ # account for duplicate positions
+ if sequence.dtype == tf.float32:
+ updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
+ updates /= tf.maximum(1.0, updates_mask_3d)
+ else:
+ assert sequence.dtype == tf.int32
+ updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d))
+ updates_mask = tf.minimum(updates_mask, 1)
+ updates_mask_3d = tf.minimum(updates_mask_3d, 1)
+
+ updated_sequence = (((1 - updates_mask_3d) * sequence) +
+ (updates_mask_3d * updates))
+ if not depth_dimension:
+ updated_sequence = tf.squeeze(updated_sequence, -1)
+
+ return updated_sequence, updates_mask
+
+
+def sample_from_softmax(logits, disallow=None):
+ """Implement softmax sampling using gumbel softmax trick.
+
+ Args:
+ logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
+ the generator output logits for each masked position.
+ disallow: If `None`, we directly sample tokens from the logits. Otherwise,
+ this is a tensor of size [batch_size, num_token_predictions, vocab_size]
+ indicating the true word id in each masked position.
+
+ Returns:
+ sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot
+ tensor indicating the sampled word id in each masked position.
+ """
+ if disallow is not None:
+ logits -= 1000.0 * disallow
+ uniform_noise = tf.random.uniform(
+ tf_utils.get_shape_list(logits), minval=0, maxval=1)
+ gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
+
+ # Here we essentially follow the original paper and use temperature 1.0 for
+ # generator output logits.
+ sampled_tokens = tf.one_hot(
+ tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1, output_type=tf.int32),
+ logits.shape[-1])
+ return sampled_tokens
+
+
+def unmask(inputs, duplicate):
+ unmasked_input_word_ids, _ = scatter_update(inputs['input_word_ids'],
+ inputs['masked_lm_ids'],
+ inputs['masked_lm_positions'])
+ return get_updated_inputs(
+ inputs, duplicate, input_word_ids=unmasked_input_word_ids)
+
+
+def get_updated_inputs(inputs, duplicate, **kwargs):
+ if duplicate:
+ new_inputs = copy.copy(inputs)
+ else:
+ new_inputs = inputs
+ for k, v in kwargs.items():
+ new_inputs[k] = v
+ return new_inputs
diff --git a/models/official/nlp/modeling/models/electra_pretrainer_test.py b/models/official/nlp/modeling/models/electra_pretrainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5644ab1a0e5812bac4a3202e8d53cffb260550f
--- /dev/null
+++ b/models/official/nlp/modeling/models/electra_pretrainer_test.py
@@ -0,0 +1,156 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ELECTRA pre trainer network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling import networks
+from official.nlp.modeling.models import electra_pretrainer
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class ElectraPretrainerTest(keras_parameterized.TestCase):
+
+ def test_electra_pretrainer(self):
+ """Validate that the Keras object can be created."""
+ # Build a transformer network to use within the ELECTRA trainer.
+ vocab_size = 100
+ sequence_length = 512
+ test_generator_network = networks.TransformerEncoder(
+ vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
+ test_discriminator_network = networks.TransformerEncoder(
+ vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
+
+ # Create a ELECTRA trainer with the created network.
+ num_classes = 3
+ num_token_predictions = 2
+ eletrca_trainer_model = electra_pretrainer.ElectraPretrainer(
+ generator_network=test_generator_network,
+ discriminator_network=test_discriminator_network,
+ vocab_size=vocab_size,
+ num_classes=num_classes,
+ sequence_length=sequence_length,
+ last_hidden_dim=768,
+ num_token_predictions=num_token_predictions,
+ disallow_correct=True)
+
+ # Create a set of 2-dimensional inputs (the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ lm_positions = tf.keras.Input(
+ shape=(num_token_predictions,), dtype=tf.int32)
+ lm_ids = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
+ inputs = {
+ 'input_word_ids': word_ids,
+ 'input_mask': mask,
+ 'input_type_ids': type_ids,
+ 'masked_lm_positions': lm_positions,
+ 'masked_lm_ids': lm_ids
+ }
+
+ # Invoke the trainer model on the inputs. This causes the layer to be built.
+ lm_outs, cls_outs, disc_logits, disc_label = eletrca_trainer_model(inputs)
+
+ # Validate that the outputs are of the expected shape.
+ expected_lm_shape = [None, num_token_predictions, vocab_size]
+ expected_classification_shape = [None, num_classes]
+ expected_disc_logits_shape = [None, sequence_length]
+ expected_disc_label_shape = [None, sequence_length]
+ self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list())
+ self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
+ self.assertAllEqual(expected_disc_logits_shape, disc_logits.shape.as_list())
+ self.assertAllEqual(expected_disc_label_shape, disc_label.shape.as_list())
+
+ def test_electra_trainer_tensor_call(self):
+ """Validate that the Keras object can be invoked."""
+ # Build a transformer network to use within the ELECTRA trainer. (Here, we
+ # use a short sequence_length for convenience.)
+ test_generator_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=4, sequence_length=3)
+ test_discriminator_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=4, sequence_length=3)
+
+ # Create a ELECTRA trainer with the created network.
+ eletrca_trainer_model = electra_pretrainer.ElectraPretrainer(
+ generator_network=test_generator_network,
+ discriminator_network=test_discriminator_network,
+ vocab_size=100,
+ num_classes=2,
+ sequence_length=3,
+ last_hidden_dim=768,
+ num_token_predictions=2)
+
+ # Create a set of 2-dimensional data tensors to feed into the model.
+ word_ids = tf.constant([[1, 1, 1], [2, 2, 2]], dtype=tf.int32)
+ mask = tf.constant([[1, 1, 1], [1, 0, 0]], dtype=tf.int32)
+ type_ids = tf.constant([[1, 1, 1], [2, 2, 2]], dtype=tf.int32)
+ lm_positions = tf.constant([[0, 1], [0, 2]], dtype=tf.int32)
+ lm_ids = tf.constant([[10, 20], [20, 30]], dtype=tf.int32)
+ inputs = {
+ 'input_word_ids': word_ids,
+ 'input_mask': mask,
+ 'input_type_ids': type_ids,
+ 'masked_lm_positions': lm_positions,
+ 'masked_lm_ids': lm_ids
+ }
+
+ # Invoke the trainer model on the tensors. In Eager mode, this does the
+ # actual calculation. (We can't validate the outputs, since the network is
+ # too complex: this simply ensures we're not hitting runtime errors.)
+ _, _, _, _ = eletrca_trainer_model(inputs)
+
+ def test_serialize_deserialize(self):
+ """Validate that the ELECTRA trainer can be serialized and deserialized."""
+ # Build a transformer network to use within the BERT trainer. (Here, we use
+ # a short sequence_length for convenience.)
+ test_generator_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=4, sequence_length=3)
+ test_discriminator_network = networks.TransformerEncoder(
+ vocab_size=100, num_layers=4, sequence_length=3)
+
+ # Create a ELECTRA trainer with the created network. (Note that all the args
+ # are different, so we can catch any serialization mismatches.)
+ electra_trainer_model = electra_pretrainer.ElectraPretrainer(
+ generator_network=test_generator_network,
+ discriminator_network=test_discriminator_network,
+ vocab_size=100,
+ num_classes=2,
+ sequence_length=3,
+ last_hidden_dim=768,
+ num_token_predictions=2)
+
+ # Create another BERT trainer via serialization and deserialization.
+ config = electra_trainer_model.get_config()
+ new_electra_trainer_model = electra_pretrainer.ElectraPretrainer.from_config(
+ config)
+
+ # Validate that the config can be forced to JSON.
+ _ = new_electra_trainer_model.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(electra_trainer_model.get_config(),
+ new_electra_trainer_model.get_config())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/networks/README.md b/models/official/nlp/modeling/networks/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..42347373edc1e2999019c7259dda78bc58138ef2
--- /dev/null
+++ b/models/official/nlp/modeling/networks/README.md
@@ -0,0 +1,27 @@
+# Networks
+
+Networks are combinations of layers (and possibly other networks). They are sub-units of models that would not be trained alone. It
+encapsulates common network structures like a classification head
+or a transformer encoder into an easily handled object with a
+standardized configuration.
+
+* [`TransformerEncoder`](transformer_encoder.py) implements a bi-directional
+Transformer-based encoder as described in ["BERT: Pre-training of Deep
+Bidirectional Transformers for Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the embedding lookups,
+transformer layers and pooling layer.
+
+* [`AlbertTransformerEncoder`](albert_transformer_encoder.py) implements a
+Transformer-encoder described in the paper ["ALBERT: A Lite BERT for
+Self-supervised Learning of Language Representations]
+(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters
+into two smaller matrices and shares parameters across layers.
+
+* [`Classification`](classification.py) contains a single hidden layer, and is
+intended for use as a classification or regression (if number of classes is set
+to 1) head.
+
+* [`TokenClassification`](token_classification.py) contains a single hidden
+layer, and is intended for use as a token classification head.
+
+* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task.
+
diff --git a/models/official/nlp/modeling/networks/__init__.py b/models/official/nlp/modeling/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8443e9f9303326a82212ef3da4e3057218522bb
--- /dev/null
+++ b/models/official/nlp/modeling/networks/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Networks package definition."""
+from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder
+from official.nlp.modeling.networks.classification import Classification
+from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
+from official.nlp.modeling.networks.span_labeling import SpanLabeling
+from official.nlp.modeling.networks.token_classification import TokenClassification
+from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder
diff --git a/models/official/nlp/modeling/networks/albert_transformer_encoder.py b/models/official/nlp/modeling/networks/albert_transformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..398fb00c18c7341765beec50e9b0e6ecaee46e5c
--- /dev/null
+++ b/models/official/nlp/modeling/networks/albert_transformer_encoder.py
@@ -0,0 +1,192 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.modeling import activations
+from official.nlp.modeling import layers
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class AlbertTransformerEncoder(tf.keras.Model):
+ """ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.
+
+ This network implements the encoder described in the paper "ALBERT: A Lite
+ BERT for Self-supervised Learning of Language Representations"
+ (https://arxiv.org/abs/1909.11942).
+
+ Compared with BERT (https://arxiv.org/abs/1810.04805), ALBERT refactorizes
+ embedding parameters into two smaller matrices and shares parameters
+ across layers.
+
+ The default values for this object are taken from the ALBERT-Base
+ implementation described in the paper.
+
+ Arguments:
+ vocab_size: The size of the token vocabulary.
+ embedding_width: The width of the word embeddings. If the embedding width is
+ not equal to hidden size, embedding parameters will be factorized into two
+ matrices in the shape of ['vocab_size', 'embedding_width'] and
+ ['embedding_width', 'hidden_size'] ('embedding_width' is usually much
+ smaller than 'hidden_size').
+ hidden_size: The size of the transformer hidden layers.
+ num_layers: The number of transformer layers.
+ num_attention_heads: The number of attention heads for each transformer. The
+ hidden size must be divisible by the number of attention heads.
+ sequence_length: The sequence length that this encoder expects. If None, the
+ sequence length is dynamic; if an integer, the encoder will require
+ sequences padded to this length.
+ max_sequence_length: The maximum sequence length that this encoder can
+ consume. If None, max_sequence_length uses the value from sequence length.
+ This determines the variable shape for positional embeddings.
+ type_vocab_size: The number of types that the 'type_ids' input can take.
+ intermediate_size: The intermediate size for the transformer layers.
+ activation: The activation to use for the transformer layers.
+ dropout_rate: The dropout rate to use for the transformer layers.
+ attention_dropout_rate: The dropout rate to use for the attention layers
+ within the transformer layers.
+ initializer: The initialzer to use for all weights in this encoder.
+ """
+
+ def __init__(self,
+ vocab_size,
+ embedding_width=128,
+ hidden_size=768,
+ num_layers=12,
+ num_attention_heads=12,
+ sequence_length=512,
+ max_sequence_length=None,
+ type_vocab_size=16,
+ intermediate_size=3072,
+ activation=activations.gelu,
+ dropout_rate=0.1,
+ attention_dropout_rate=0.1,
+ initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ **kwargs):
+ activation = tf.keras.activations.get(activation)
+ initializer = tf.keras.initializers.get(initializer)
+
+ if not max_sequence_length:
+ max_sequence_length = sequence_length
+ self._self_setattr_tracking = False
+ self._config_dict = {
+ 'vocab_size': vocab_size,
+ 'embedding_width': embedding_width,
+ 'hidden_size': hidden_size,
+ 'num_layers': num_layers,
+ 'num_attention_heads': num_attention_heads,
+ 'sequence_length': sequence_length,
+ 'max_sequence_length': max_sequence_length,
+ 'type_vocab_size': type_vocab_size,
+ 'intermediate_size': intermediate_size,
+ 'activation': tf.keras.activations.serialize(activation),
+ 'dropout_rate': dropout_rate,
+ 'attention_dropout_rate': attention_dropout_rate,
+ 'initializer': tf.keras.initializers.serialize(initializer),
+ }
+
+ word_ids = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name='input_word_ids')
+ mask = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name='input_mask')
+ type_ids = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name='input_type_ids')
+
+ if embedding_width is None:
+ embedding_width = hidden_size
+ self._embedding_layer = layers.OnDeviceEmbedding(
+ vocab_size=vocab_size,
+ embedding_width=embedding_width,
+ initializer=initializer,
+ name='word_embeddings')
+ word_embeddings = self._embedding_layer(word_ids)
+
+ # Always uses dynamic slicing for simplicity.
+ self._position_embedding_layer = layers.PositionEmbedding(
+ initializer=initializer,
+ use_dynamic_slicing=True,
+ max_sequence_length=max_sequence_length,
+ name='position_embedding')
+ position_embeddings = self._position_embedding_layer(word_embeddings)
+
+ type_embeddings = (
+ layers.OnDeviceEmbedding(
+ vocab_size=type_vocab_size,
+ embedding_width=embedding_width,
+ initializer=initializer,
+ use_one_hot=True,
+ name='type_embeddings')(type_ids))
+
+ embeddings = tf.keras.layers.Add()(
+ [word_embeddings, position_embeddings, type_embeddings])
+ embeddings = (
+ tf.keras.layers.LayerNormalization(
+ name='embeddings/layer_norm',
+ axis=-1,
+ epsilon=1e-12,
+ dtype=tf.float32)(embeddings))
+ embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
+ # We project the 'embedding' output to 'hidden_size' if it is not already
+ # 'hidden_size'.
+ if embedding_width != hidden_size:
+ embeddings = tf.keras.layers.experimental.EinsumDense(
+ '...x,xy->...y',
+ output_shape=hidden_size,
+ bias_axes='y',
+ kernel_initializer=initializer,
+ name='embedding_projection')(
+ embeddings)
+
+ data = embeddings
+ attention_mask = layers.SelfAttentionMask()([data, mask])
+ shared_layer = layers.Transformer(
+ num_attention_heads=num_attention_heads,
+ intermediate_size=intermediate_size,
+ intermediate_activation=activation,
+ dropout_rate=dropout_rate,
+ attention_dropout_rate=attention_dropout_rate,
+ kernel_initializer=initializer,
+ name='transformer')
+ for _ in range(num_layers):
+ data = shared_layer([data, attention_mask])
+
+ first_token_tensor = (
+ tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data)
+ )
+ cls_output = tf.keras.layers.Dense(
+ units=hidden_size,
+ activation='tanh',
+ kernel_initializer=initializer,
+ name='pooler_transform')(
+ first_token_tensor)
+
+ super(AlbertTransformerEncoder, self).__init__(
+ inputs=[word_ids, mask, type_ids], outputs=[data, cls_output], **kwargs)
+
+ def get_embedding_table(self):
+ return self._embedding_layer.embeddings
+
+ def get_config(self):
+ return self._config_dict
+
+ @classmethod
+ def from_config(cls, config):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/networks/albert_transformer_encoder_test.py b/models/official/nlp/modeling/networks/albert_transformer_encoder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..44368e494ae04dd9b92c63987e6881aabd8ff4c2
--- /dev/null
+++ b/models/official/nlp/modeling/networks/albert_transformer_encoder_test.py
@@ -0,0 +1,174 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ALBERT transformer-based text encoder network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.networks import albert_transformer_encoder
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
+
+ def tearDown(self):
+ super(AlbertTransformerEncoderTest, self).tearDown()
+ tf.keras.mixed_precision.experimental.set_policy("float32")
+
+ @parameterized.named_parameters(
+ dict(testcase_name="default", expected_dtype=tf.float32),
+ dict(
+ testcase_name="with_float16_dtype",
+ expected_dtype=tf.float16),
+ )
+ def test_network_creation(self, expected_dtype):
+ hidden_size = 32
+ sequence_length = 21
+
+ kwargs = dict(
+ vocab_size=100,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ num_attention_heads=2,
+ num_layers=3)
+ if expected_dtype == tf.float16:
+ tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
+
+ # Create a small TransformerEncoder for testing.
+ test_network = albert_transformer_encoder.AlbertTransformerEncoder(**kwargs)
+
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask, type_ids])
+
+ expected_data_shape = [None, sequence_length, hidden_size]
+ expected_pooled_shape = [None, hidden_size]
+ self.assertAllEqual(expected_data_shape, data.shape.as_list())
+ self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
+
+ # If float_dtype is set to float16, the data output is float32 (from a layer
+ # norm) and pool output should be float16.
+ self.assertEqual(tf.float32, data.dtype)
+ self.assertEqual(expected_dtype, pooled.dtype)
+
+ # ALBERT has additonal 'embedding_hidden_mapping_in' weights and
+ # it shares transformer weights.
+ self.assertNotEmpty(
+ [x for x in test_network.weights if "embedding_projection/" in x.name])
+ self.assertNotEmpty(
+ [x for x in test_network.weights if "transformer/" in x.name])
+ self.assertEmpty(
+ [x for x in test_network.weights if "transformer/layer" in x.name])
+
+ def test_network_invocation(self):
+ hidden_size = 32
+ sequence_length = 21
+ vocab_size = 57
+ num_types = 7
+ # Create a small TransformerEncoder for testing.
+ test_network = albert_transformer_encoder.AlbertTransformerEncoder(
+ vocab_size=vocab_size,
+ embedding_width=8,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ num_attention_heads=2,
+ num_layers=3,
+ type_vocab_size=num_types)
+ self.assertTrue(
+ test_network._position_embedding_layer._use_dynamic_slicing)
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask, type_ids])
+
+ # Create a model based off of this network:
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+
+ # Invoke the model. We can't validate the output data here (the model is too
+ # complex) but this will catch structural runtime errors.
+ batch_size = 3
+ word_id_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ mask_data = np.random.randint(2, size=(batch_size, sequence_length))
+ type_id_data = np.random.randint(
+ num_types, size=(batch_size, sequence_length))
+ _ = model.predict([word_id_data, mask_data, type_id_data])
+
+ # Creates a TransformerEncoder with max_sequence_length != sequence_length
+ max_sequence_length = 128
+ test_network = albert_transformer_encoder.AlbertTransformerEncoder(
+ vocab_size=vocab_size,
+ embedding_width=8,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ max_sequence_length=max_sequence_length,
+ num_attention_heads=2,
+ num_layers=3,
+ type_vocab_size=num_types)
+ self.assertTrue(test_network._position_embedding_layer._use_dynamic_slicing)
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+ _ = model.predict([word_id_data, mask_data, type_id_data])
+
+ def test_serialize_deserialize(self):
+ tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
+ # Create a network object that sets all of its config options.
+ kwargs = dict(
+ vocab_size=100,
+ embedding_width=8,
+ hidden_size=32,
+ num_layers=3,
+ num_attention_heads=2,
+ sequence_length=21,
+ max_sequence_length=21,
+ type_vocab_size=12,
+ intermediate_size=1223,
+ activation="relu",
+ dropout_rate=0.05,
+ attention_dropout_rate=0.22,
+ initializer="glorot_uniform")
+ network = albert_transformer_encoder.AlbertTransformerEncoder(**kwargs)
+
+ expected_config = dict(kwargs)
+ expected_config["activation"] = tf.keras.activations.serialize(
+ tf.keras.activations.get(expected_config["activation"]))
+ expected_config["initializer"] = tf.keras.initializers.serialize(
+ tf.keras.initializers.get(expected_config["initializer"]))
+ self.assertEqual(network.get_config(), expected_config)
+
+ # Create another network object from the first object's config.
+ new_network = (
+ albert_transformer_encoder.AlbertTransformerEncoder.from_config(
+ network.get_config()))
+
+ # Validate that the config can be forced to JSON.
+ _ = new_network.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(network.get_config(), new_network.get_config())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/networks/classification.py b/models/official/nlp/modeling/networks/classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc326136cd18593bc5e06dd2f68a1e0da17a1409
--- /dev/null
+++ b/models/official/nlp/modeling/networks/classification.py
@@ -0,0 +1,91 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classification and regression network."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class Classification(tf.keras.Model):
+ """Classification network head for BERT modeling.
+
+ This network implements a simple classifier head based on a dense layer. If
+ num_classes is one, it can be considered as a regression problem.
+
+ Arguments:
+ input_width: The innermost dimension of the input tensor to this network.
+ num_classes: The number of classes that this network should classify to. If
+ equal to 1, a regression problem is assumed.
+ activation: The activation, if any, for the dense layer in this network.
+ initializer: The intializer for the dense layer in this network. Defaults to
+ a Glorot uniform initializer.
+ output: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ """
+
+ def __init__(self,
+ input_width,
+ num_classes,
+ initializer='glorot_uniform',
+ output='logits',
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._config_dict = {
+ 'input_width': input_width,
+ 'num_classes': num_classes,
+ 'initializer': initializer,
+ 'output': output,
+ }
+
+ cls_output = tf.keras.layers.Input(
+ shape=(input_width,), name='cls_output', dtype=tf.float32)
+
+ self.logits = tf.keras.layers.Dense(
+ num_classes,
+ activation=None,
+ kernel_initializer=initializer,
+ name='predictions/transform/logits')(
+ cls_output)
+
+ policy = tf.keras.mixed_precision.experimental.global_policy()
+ if policy.name == 'mixed_bfloat16':
+ # b/158514794: bf16 is not stable with post-softmax cross-entropy.
+ policy = tf.float32
+ predictions = tf.keras.layers.Activation(tf.nn.log_softmax,
+ dtype=policy)(self.logits)
+
+ if output == 'logits':
+ output_tensors = self.logits
+ elif output == 'predictions':
+ output_tensors = predictions
+ else:
+ raise ValueError(
+ ('Unknown `output` value "%s". `output` can be either "logits" or '
+ '"predictions"') % output)
+
+ super(Classification, self).__init__(
+ inputs=[cls_output], outputs=output_tensors, **kwargs)
+
+ def get_config(self):
+ return self._config_dict
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/networks/classification_test.py b/models/official/nlp/modeling/networks/classification_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..457c135be4bce0c11faef36f099515ba4b0e8c53
--- /dev/null
+++ b/models/official/nlp/modeling/networks/classification_test.py
@@ -0,0 +1,181 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for classification network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.networks import classification
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class ClassificationTest(keras_parameterized.TestCase):
+
+ @parameterized.parameters(1, 10)
+ def test_network_creation(self, num_classes):
+ """Validate that the Keras object can be created."""
+ input_width = 512
+ test_object = classification.Classification(
+ input_width=input_width, num_classes=num_classes)
+ # Create a 2-dimensional input (the first dimension is implicit).
+ cls_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
+ output = test_object(cls_data)
+
+ # Validate that the outputs are of the expected shape.
+ expected_output_shape = [None, num_classes]
+ self.assertEqual(expected_output_shape, output.shape.as_list())
+
+ @parameterized.parameters(1, 10)
+ def test_network_invocation(self, num_classes):
+ """Validate that the Keras object can be invoked."""
+ input_width = 512
+ test_object = classification.Classification(
+ input_width=input_width, num_classes=num_classes, output='predictions')
+ # Create a 2-dimensional input (the first dimension is implicit).
+ cls_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
+ output = test_object(cls_data)
+
+ # Invoke the network as part of a Model.
+ model = tf.keras.Model(cls_data, output)
+ input_data = 10 * np.random.random_sample((3, input_width))
+ _ = model.predict(input_data)
+
+ def test_network_invocation_with_internal_logits(self):
+ """Validate that the logit outputs are correct."""
+ input_width = 512
+ num_classes = 10
+ test_object = classification.Classification(
+ input_width=input_width, num_classes=num_classes, output='predictions')
+
+ # Create a 2-dimensional input (the first dimension is implicit).
+ cls_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
+ output = test_object(cls_data)
+ model = tf.keras.Model(cls_data, output)
+ logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
+
+ batch_size = 3
+ input_data = 10 * np.random.random_sample((batch_size, input_width))
+ outputs = model.predict(input_data)
+ logits = logits_model.predict(input_data)
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, num_classes)
+ self.assertEqual(expected_output_shape, outputs.shape)
+ self.assertEqual(expected_output_shape, logits.shape)
+
+ # Ensure that the logits, when softmaxed, create the outputs.
+ input_tensor = tf.keras.Input(expected_output_shape[1:])
+ output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
+ softmax_model = tf.keras.Model(input_tensor, output_tensor)
+
+ calculated_softmax = softmax_model.predict(logits)
+ self.assertAllClose(outputs, calculated_softmax)
+
+ @parameterized.parameters(1, 10)
+ def test_network_invocation_with_internal_and_external_logits(self,
+ num_classes):
+ """Validate that the logit outputs are correct."""
+ input_width = 512
+ test_object = classification.Classification(
+ input_width=input_width, num_classes=num_classes, output='logits')
+
+ # Create a 2-dimensional input (the first dimension is implicit).
+ cls_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
+ output = test_object(cls_data)
+ model = tf.keras.Model(cls_data, output)
+ logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
+
+ batch_size = 3
+ input_data = 10 * np.random.random_sample((batch_size, input_width))
+ outputs = model.predict(input_data)
+ logits = logits_model.predict(input_data)
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, num_classes)
+ self.assertEqual(expected_output_shape, outputs.shape)
+ self.assertEqual(expected_output_shape, logits.shape)
+
+ self.assertAllClose(outputs, logits)
+
+ def test_network_invocation_with_logit_output(self):
+ """Validate that the logit outputs are correct."""
+ input_width = 512
+ num_classes = 10
+ test_object = classification.Classification(
+ input_width=input_width, num_classes=num_classes, output='predictions')
+ logit_object = classification.Classification(
+ input_width=input_width, num_classes=num_classes, output='logits')
+ logit_object.set_weights(test_object.get_weights())
+
+ # Create a 2-dimensional input (the first dimension is implicit).
+ cls_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
+ output = test_object(cls_data)
+ logit_output = logit_object(cls_data)
+
+ model = tf.keras.Model(cls_data, output)
+ logits_model = tf.keras.Model(cls_data, logit_output)
+
+ batch_size = 3
+ input_data = 10 * np.random.random_sample((batch_size, input_width))
+ outputs = model.predict(input_data)
+ logits = logits_model.predict(input_data)
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, num_classes)
+ self.assertEqual(expected_output_shape, outputs.shape)
+ self.assertEqual(expected_output_shape, logits.shape)
+
+ # Ensure that the logits, when softmaxed, create the outputs.
+ input_tensor = tf.keras.Input(expected_output_shape[1:])
+ output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
+ softmax_model = tf.keras.Model(input_tensor, output_tensor)
+
+ calculated_softmax = softmax_model.predict(logits)
+ self.assertAllClose(outputs, calculated_softmax)
+
+ def test_serialize_deserialize(self):
+ # Create a network object that sets all of its config options.
+ network = classification.Classification(
+ input_width=128,
+ num_classes=10,
+ initializer='zeros',
+ output='predictions')
+
+ # Create another network object from the first object's config.
+ new_network = classification.Classification.from_config(
+ network.get_config())
+
+ # Validate that the config can be forced to JSON.
+ _ = new_network.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(network.get_config(), new_network.get_config())
+
+ def test_unknown_output_type_fails(self):
+ with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
+ _ = classification.Classification(
+ input_width=128, num_classes=10, output='bad')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/networks/encoder_scaffold.py b/models/official/nlp/modeling/networks/encoder_scaffold.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec9b2d102db9c3a49de509e9d9011bcf6a758e7f
--- /dev/null
+++ b/models/official/nlp/modeling/networks/encoder_scaffold.py
@@ -0,0 +1,273 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Transformer-based text encoder network."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import inspect
+
+import gin
+import tensorflow as tf
+
+from official.nlp.modeling import layers
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+@gin.configurable
+class EncoderScaffold(tf.keras.Model):
+ """Bi-directional Transformer-based encoder network scaffold.
+
+ This network allows users to flexibly implement an encoder similar to the one
+ described in "BERT: Pre-training of Deep Bidirectional Transformers for
+ Language Understanding" (https://arxiv.org/abs/1810.04805).
+
+ In this network, users can choose to provide a custom embedding subnetwork
+ (which will replace the standard embedding logic) and/or a custom hidden layer
+ class (which will replace the Transformer instantiation in the encoder). For
+ each of these custom injection points, users can pass either a class or a
+ class instance. If a class is passed, that class will be instantiated using
+ the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance
+ is passed, that instance will be invoked. (In the case of hidden_cls, the
+ instance will be invoked 'num_hidden_instances' times.
+
+ If the hidden_cls is not overridden, a default transformer layer will be
+ instantiated.
+
+ Arguments:
+ pooled_output_dim: The dimension of pooled output.
+ pooler_layer_initializer: The initializer for the classification
+ layer.
+ embedding_cls: The class or instance to use to embed the input data. This
+ class or instance defines the inputs to this encoder and outputs
+ (1) embeddings tensor with shape [batch_size, seq_length, hidden_size] and
+ (2) attention masking with tensor [batch_size, seq_length, seq_length].
+ If embedding_cls is not set, a default embedding network
+ (from the original BERT paper) will be created.
+ embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
+ be instantiated. If embedding_cls is not set, a config dict must be
+ passed to 'embedding_cfg' with the following values:
+ "vocab_size": The size of the token vocabulary.
+ "type_vocab_size": The size of the type vocabulary.
+ "hidden_size": The hidden size for this encoder.
+ "max_seq_length": The maximum sequence length for this encoder.
+ "seq_length": The sequence length for this encoder.
+ "initializer": The initializer for the embedding portion of this encoder.
+ "dropout_rate": The dropout rate to apply before the encoding layers.
+ embedding_data: A reference to the embedding weights that will be used to
+ train the masked language model, if necessary. This is optional, and only
+ needed if (1) you are overriding embedding_cls and (2) are doing standard
+ pretraining.
+ num_hidden_instances: The number of times to instantiate and/or invoke the
+ hidden_cls.
+ hidden_cls: The class or instance to encode the input data. If hidden_cls is
+ not set, a KerasBERT transformer layer will be used as the encoder class.
+ hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
+ instantiated. If hidden_cls is not set, a config dict must be passed to
+ 'hidden_cfg' with the following values:
+ "num_attention_heads": The number of attention heads. The hidden size
+ must be divisible by num_attention_heads.
+ "intermediate_size": The intermediate size of the transformer.
+ "intermediate_activation": The activation to apply in the transfomer.
+ "dropout_rate": The overall dropout rate for the transformer layers.
+ "attention_dropout_rate": The dropout rate for the attention layers.
+ "kernel_initializer": The initializer for the transformer layers.
+ return_all_layer_outputs: Whether to output sequence embedding outputs of
+ all encoder transformer layers.
+ """
+
+ def __init__(
+ self,
+ pooled_output_dim,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ embedding_cls=None,
+ embedding_cfg=None,
+ embedding_data=None,
+ num_hidden_instances=1,
+ hidden_cls=layers.Transformer,
+ hidden_cfg=None,
+ return_all_layer_outputs=False,
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._hidden_cls = hidden_cls
+ self._hidden_cfg = hidden_cfg
+ self._num_hidden_instances = num_hidden_instances
+ self._pooled_output_dim = pooled_output_dim
+ self._pooler_layer_initializer = pooler_layer_initializer
+ self._embedding_cls = embedding_cls
+ self._embedding_cfg = embedding_cfg
+ self._embedding_data = embedding_data
+ self._return_all_layer_outputs = return_all_layer_outputs
+ self._kwargs = kwargs
+
+ if embedding_cls:
+ if inspect.isclass(embedding_cls):
+ self._embedding_network = embedding_cls(
+ **embedding_cfg) if embedding_cfg else embedding_cls()
+ else:
+ self._embedding_network = embedding_cls
+ inputs = self._embedding_network.inputs
+ embeddings, attention_mask = self._embedding_network(inputs)
+ else:
+ self._embedding_network = None
+ word_ids = tf.keras.layers.Input(
+ shape=(embedding_cfg['seq_length'],),
+ dtype=tf.int32,
+ name='input_word_ids')
+ mask = tf.keras.layers.Input(
+ shape=(embedding_cfg['seq_length'],),
+ dtype=tf.int32,
+ name='input_mask')
+ type_ids = tf.keras.layers.Input(
+ shape=(embedding_cfg['seq_length'],),
+ dtype=tf.int32,
+ name='input_type_ids')
+ inputs = [word_ids, mask, type_ids]
+
+ self._embedding_layer = layers.OnDeviceEmbedding(
+ vocab_size=embedding_cfg['vocab_size'],
+ embedding_width=embedding_cfg['hidden_size'],
+ initializer=embedding_cfg['initializer'],
+ name='word_embeddings')
+
+ word_embeddings = self._embedding_layer(word_ids)
+
+ # Always uses dynamic slicing for simplicity.
+ self._position_embedding_layer = layers.PositionEmbedding(
+ initializer=embedding_cfg['initializer'],
+ use_dynamic_slicing=True,
+ max_sequence_length=embedding_cfg['max_seq_length'],
+ name='position_embedding')
+ position_embeddings = self._position_embedding_layer(word_embeddings)
+
+ type_embeddings = (
+ layers.OnDeviceEmbedding(
+ vocab_size=embedding_cfg['type_vocab_size'],
+ embedding_width=embedding_cfg['hidden_size'],
+ initializer=embedding_cfg['initializer'],
+ use_one_hot=True,
+ name='type_embeddings')(type_ids))
+
+ embeddings = tf.keras.layers.Add()(
+ [word_embeddings, position_embeddings, type_embeddings])
+ embeddings = (
+ tf.keras.layers.LayerNormalization(
+ name='embeddings/layer_norm',
+ axis=-1,
+ epsilon=1e-12,
+ dtype=tf.float32)(embeddings))
+ embeddings = (
+ tf.keras.layers.Dropout(
+ rate=embedding_cfg['dropout_rate'])(embeddings))
+
+ attention_mask = layers.SelfAttentionMask()([embeddings, mask])
+
+ data = embeddings
+
+ layer_output_data = []
+ self._hidden_layers = []
+ for _ in range(num_hidden_instances):
+ if inspect.isclass(hidden_cls):
+ layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
+ else:
+ layer = hidden_cls
+ data = layer([data, attention_mask])
+ layer_output_data.append(data)
+ self._hidden_layers.append(layer)
+
+ first_token_tensor = (
+ tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
+ layer_output_data[-1]))
+ self._pooler_layer = tf.keras.layers.Dense(
+ units=pooled_output_dim,
+ activation='tanh',
+ kernel_initializer=pooler_layer_initializer,
+ name='cls_transform')
+ cls_output = self._pooler_layer(first_token_tensor)
+
+ if return_all_layer_outputs:
+ outputs = [layer_output_data, cls_output]
+ else:
+ outputs = [layer_output_data[-1], cls_output]
+
+ super(EncoderScaffold, self).__init__(
+ inputs=inputs, outputs=outputs, **kwargs)
+
+ def get_config(self):
+ config_dict = {
+ 'num_hidden_instances':
+ self._num_hidden_instances,
+ 'pooled_output_dim':
+ self._pooled_output_dim,
+ 'pooler_layer_initializer':
+ self._pooler_layer_initializer,
+ 'embedding_cls':
+ self._embedding_network,
+ 'embedding_cfg':
+ self._embedding_cfg,
+ 'hidden_cfg':
+ self._hidden_cfg,
+ 'return_all_layer_outputs':
+ self._return_all_layer_outputs,
+ }
+ if inspect.isclass(self._hidden_cls):
+ config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
+ self._hidden_cls)
+ else:
+ config_dict['hidden_cls'] = self._hidden_cls
+
+ config_dict.update(self._kwargs)
+ return config_dict
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ if 'hidden_cls_string' in config:
+ config['hidden_cls'] = tf.keras.utils.get_registered_object(
+ config['hidden_cls_string'], custom_objects=custom_objects)
+ del config['hidden_cls_string']
+ return cls(**config)
+
+ def get_embedding_table(self):
+ if self._embedding_network is None:
+ # In this case, we don't have a custom embedding network and can return
+ # the standard embedding data.
+ return self._embedding_layer.embeddings
+
+ if self._embedding_data is None:
+ raise RuntimeError(('The EncoderScaffold %s does not have a reference '
+ 'to the embedding data. This is required when you '
+ 'pass a custom embedding network to the scaffold. '
+ 'It is also possible that you are trying to get '
+ 'embedding data from an embedding scaffold with a '
+ 'custom embedding network where the scaffold has '
+ 'been serialized and deserialized. Unfortunately, '
+ 'accessing custom embedding references after '
+ 'serialization is not yet supported.') % self.name)
+ else:
+ return self._embedding_data
+
+ @property
+ def hidden_layers(self):
+ """List of hidden layers in the encoder."""
+ return self._hidden_layers
+
+ @property
+ def pooler_layer(self):
+ """The pooler dense layer after the transformer layers."""
+ return self._pooler_layer
diff --git a/models/official/nlp/modeling/networks/encoder_scaffold_test.py b/models/official/nlp/modeling/networks/encoder_scaffold_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..664bccd08e11720918e0060458dc934350d2d594
--- /dev/null
+++ b/models/official/nlp/modeling/networks/encoder_scaffold_test.py
@@ -0,0 +1,646 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for transformer-based text encoder network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.modeling import activations
+from official.nlp.modeling import layers
+from official.nlp.modeling.networks import encoder_scaffold
+
+
+# Test class that wraps a standard transformer layer. If this layer is called
+# at any point, the list passed to the config object will be filled with a
+# boolean 'True'. We register this class as a Keras serializable so we can
+# test serialization below.
+@tf.keras.utils.register_keras_serializable(package="TestOnly")
+class ValidatedTransformerLayer(layers.Transformer):
+
+ def __init__(self, call_list, **kwargs):
+ super(ValidatedTransformerLayer, self).__init__(**kwargs)
+ self.list = call_list
+
+ def call(self, inputs):
+ self.list.append(True)
+ return super(ValidatedTransformerLayer, self).call(inputs)
+
+ def get_config(self):
+ config = super(ValidatedTransformerLayer, self).get_config()
+ config["call_list"] = []
+ return config
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
+
+ def tearDown(self):
+ super(EncoderScaffoldLayerClassTest, self).tearDown()
+ tf.keras.mixed_precision.experimental.set_policy("float32")
+
+ @parameterized.named_parameters(
+ dict(testcase_name="only_final_output", return_all_layer_outputs=False),
+ dict(testcase_name="all_layer_outputs", return_all_layer_outputs=True))
+ def test_network_creation(self, return_all_layer_outputs):
+ hidden_size = 32
+ sequence_length = 21
+ num_hidden_instances = 3
+ embedding_cfg = {
+ "vocab_size": 100,
+ "type_vocab_size": 16,
+ "hidden_size": hidden_size,
+ "seq_length": sequence_length,
+ "max_seq_length": sequence_length,
+ "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "dropout_rate": 0.1,
+ }
+
+ call_list = []
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "call_list":
+ call_list
+ }
+ # Create a small EncoderScaffold for testing.
+ test_network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=num_hidden_instances,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cls=ValidatedTransformerLayer,
+ hidden_cfg=hidden_cfg,
+ embedding_cfg=embedding_cfg,
+ return_all_layer_outputs=return_all_layer_outputs)
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ output_data, pooled = test_network([word_ids, mask, type_ids])
+
+ if return_all_layer_outputs:
+ self.assertIsInstance(output_data, list)
+ self.assertLen(output_data, num_hidden_instances)
+ data = output_data[-1]
+ else:
+ data = output_data
+ self.assertIsInstance(test_network.hidden_layers, list)
+ self.assertLen(test_network.hidden_layers, num_hidden_instances)
+ self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
+
+ expected_data_shape = [None, sequence_length, hidden_size]
+ expected_pooled_shape = [None, hidden_size]
+ self.assertAllEqual(expected_data_shape, data.shape.as_list())
+ self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
+
+ # The default output dtype is float32.
+ self.assertAllEqual(tf.float32, data.dtype)
+ self.assertAllEqual(tf.float32, pooled.dtype)
+
+ # If call_list[0] exists and is True, the passed layer class was
+ # instantiated from the given config properly.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+
+ def test_network_creation_with_float16_dtype(self):
+ tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
+ hidden_size = 32
+ sequence_length = 21
+ embedding_cfg = {
+ "vocab_size": 100,
+ "type_vocab_size": 16,
+ "hidden_size": hidden_size,
+ "seq_length": sequence_length,
+ "max_seq_length": sequence_length,
+ "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "dropout_rate": 0.1,
+ }
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ }
+ # Create a small EncoderScaffold for testing.
+ test_network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=3,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cfg=hidden_cfg,
+ embedding_cfg=embedding_cfg)
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask, type_ids])
+
+ expected_data_shape = [None, sequence_length, hidden_size]
+ expected_pooled_shape = [None, hidden_size]
+ self.assertAllEqual(expected_data_shape, data.shape.as_list())
+ self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
+
+ # If float_dtype is set to float16, the data output is float32 (from a layer
+ # norm) and pool output should be float16.
+ self.assertAllEqual(tf.float32, data.dtype)
+ self.assertAllEqual(tf.float16, pooled.dtype)
+
+ def test_network_invocation(self):
+ hidden_size = 32
+ sequence_length = 21
+ vocab_size = 57
+ num_types = 7
+ embedding_cfg = {
+ "vocab_size": vocab_size,
+ "type_vocab_size": num_types,
+ "hidden_size": hidden_size,
+ "seq_length": sequence_length,
+ "max_seq_length": sequence_length,
+ "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "dropout_rate": 0.1,
+ }
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ }
+ # Create a small EncoderScaffold for testing.
+ test_network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=3,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cfg=hidden_cfg,
+ embedding_cfg=embedding_cfg)
+
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask, type_ids])
+
+ # Create a model based off of this network:
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+
+ # Invoke the model. We can't validate the output data here (the model is too
+ # complex) but this will catch structural runtime errors.
+ batch_size = 3
+ word_id_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ mask_data = np.random.randint(2, size=(batch_size, sequence_length))
+ type_id_data = np.random.randint(
+ num_types, size=(batch_size, sequence_length))
+ _ = model.predict([word_id_data, mask_data, type_id_data])
+
+ # Creates a EncoderScaffold with max_sequence_length != sequence_length
+ num_types = 7
+ embedding_cfg = {
+ "vocab_size": vocab_size,
+ "type_vocab_size": num_types,
+ "hidden_size": hidden_size,
+ "seq_length": sequence_length,
+ "max_seq_length": sequence_length * 2,
+ "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "dropout_rate": 0.1,
+ }
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ }
+ # Create a small EncoderScaffold for testing.
+ test_network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=3,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cfg=hidden_cfg,
+ embedding_cfg=embedding_cfg)
+
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+ _ = model.predict([word_id_data, mask_data, type_id_data])
+
+ def test_serialize_deserialize(self):
+ # Create a network object that sets all of its config options.
+ hidden_size = 32
+ sequence_length = 21
+ embedding_cfg = {
+ "vocab_size": 100,
+ "type_vocab_size": 16,
+ "hidden_size": hidden_size,
+ "seq_length": sequence_length,
+ "max_seq_length": sequence_length,
+ "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "dropout_rate": 0.1,
+ }
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ }
+ # Create a small EncoderScaffold for testing.
+ network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=3,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cfg=hidden_cfg,
+ embedding_cfg=embedding_cfg)
+
+ # Create another network object from the first object's config.
+ new_network = encoder_scaffold.EncoderScaffold.from_config(
+ network.get_config())
+
+ # Validate that the config can be forced to JSON.
+ _ = new_network.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(network.get_config(), new_network.get_config())
+
+
+@keras_parameterized.run_all_keras_modes
+class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
+
+ def test_network_invocation(self):
+ hidden_size = 32
+ sequence_length = 21
+ vocab_size = 57
+
+ # Build an embedding network to swap in for the default network. This one
+ # will have 2 inputs (mask and word_ids) instead of 3, and won't use
+ # positional embeddings.
+
+ word_ids = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name="input_word_ids")
+ mask = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name="input_mask")
+ embedding_layer = layers.OnDeviceEmbedding(
+ vocab_size=vocab_size,
+ embedding_width=hidden_size,
+ initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ name="word_embeddings")
+ word_embeddings = embedding_layer(word_ids)
+ attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
+ network = tf.keras.Model([word_ids, mask],
+ [word_embeddings, attention_mask])
+
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ }
+
+ # Create a small EncoderScaffold for testing.
+ test_network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=3,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cfg=hidden_cfg,
+ embedding_cls=network,
+ embedding_data=embedding_layer.embeddings)
+
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask])
+
+ # Create a model based off of this network:
+ model = tf.keras.Model([word_ids, mask], [data, pooled])
+
+ # Invoke the model. We can't validate the output data here (the model is too
+ # complex) but this will catch structural runtime errors.
+ batch_size = 3
+ word_id_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ mask_data = np.random.randint(2, size=(batch_size, sequence_length))
+ _ = model.predict([word_id_data, mask_data])
+
+ # Test that we can get the embedding data that we passed to the object. This
+ # is necessary to support standard language model training.
+ self.assertIs(embedding_layer.embeddings,
+ test_network.get_embedding_table())
+
+ def test_serialize_deserialize(self):
+ hidden_size = 32
+ sequence_length = 21
+ vocab_size = 57
+
+ # Build an embedding network to swap in for the default network. This one
+ # will have 2 inputs (mask and word_ids) instead of 3, and won't use
+ # positional embeddings.
+
+ word_ids = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name="input_word_ids")
+ mask = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name="input_mask")
+ embedding_layer = layers.OnDeviceEmbedding(
+ vocab_size=vocab_size,
+ embedding_width=hidden_size,
+ initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ name="word_embeddings")
+ word_embeddings = embedding_layer(word_ids)
+ attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
+ network = tf.keras.Model([word_ids, mask],
+ [word_embeddings, attention_mask])
+
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ }
+
+ # Create a small EncoderScaffold for testing.
+ test_network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=3,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cfg=hidden_cfg,
+ embedding_cls=network,
+ embedding_data=embedding_layer.embeddings)
+
+ # Create another network object from the first object's config.
+ new_network = encoder_scaffold.EncoderScaffold.from_config(
+ test_network.get_config())
+
+ # Validate that the config can be forced to JSON.
+ _ = new_network.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(test_network.get_config(), new_network.get_config())
+
+ # Create a model based off of the old and new networks:
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+
+ data, pooled = new_network([word_ids, mask])
+ new_model = tf.keras.Model([word_ids, mask], [data, pooled])
+
+ data, pooled = test_network([word_ids, mask])
+ model = tf.keras.Model([word_ids, mask], [data, pooled])
+
+ # Copy the weights between models.
+ new_model.set_weights(model.get_weights())
+
+ # Invoke the models.
+ batch_size = 3
+ word_id_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ mask_data = np.random.randint(2, size=(batch_size, sequence_length))
+ data, cls = model.predict([word_id_data, mask_data])
+ new_data, new_cls = new_model.predict([word_id_data, mask_data])
+
+ # The output should be equal.
+ self.assertAllEqual(data, new_data)
+ self.assertAllEqual(cls, new_cls)
+
+ # We should not be able to get a reference to the embedding data.
+ with self.assertRaisesRegex(RuntimeError, ".*does not have a reference.*"):
+ new_network.get_embedding_table()
+
+
+@keras_parameterized.run_all_keras_modes
+class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
+
+ def test_network_invocation(self):
+ hidden_size = 32
+ sequence_length = 21
+ vocab_size = 57
+ num_types = 7
+
+ embedding_cfg = {
+ "vocab_size": vocab_size,
+ "type_vocab_size": num_types,
+ "hidden_size": hidden_size,
+ "seq_length": sequence_length,
+ "max_seq_length": sequence_length,
+ "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "dropout_rate": 0.1,
+ }
+
+ call_list = []
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "call_list":
+ call_list
+ }
+ # Create a small EncoderScaffold for testing. This time, we pass an already-
+ # instantiated layer object.
+
+ xformer = ValidatedTransformerLayer(**hidden_cfg)
+
+ test_network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=3,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cls=xformer,
+ embedding_cfg=embedding_cfg)
+
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask, type_ids])
+
+ # Create a model based off of this network:
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+
+ # Invoke the model. We can't validate the output data here (the model is too
+ # complex) but this will catch structural runtime errors.
+ batch_size = 3
+ word_id_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ mask_data = np.random.randint(2, size=(batch_size, sequence_length))
+ type_id_data = np.random.randint(
+ num_types, size=(batch_size, sequence_length))
+ _ = model.predict([word_id_data, mask_data, type_id_data])
+
+ # If call_list[0] exists and is True, the passed layer class was
+ # called as part of the graph creation.
+ self.assertNotEmpty(call_list)
+ self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
+
+ def test_serialize_deserialize(self):
+ hidden_size = 32
+ sequence_length = 21
+ vocab_size = 57
+ num_types = 7
+
+ embedding_cfg = {
+ "vocab_size": vocab_size,
+ "type_vocab_size": num_types,
+ "hidden_size": hidden_size,
+ "seq_length": sequence_length,
+ "max_seq_length": sequence_length,
+ "initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "dropout_rate": 0.1,
+ }
+
+ call_list = []
+ hidden_cfg = {
+ "num_attention_heads":
+ 2,
+ "intermediate_size":
+ 3072,
+ "intermediate_activation":
+ activations.gelu,
+ "dropout_rate":
+ 0.1,
+ "attention_dropout_rate":
+ 0.1,
+ "kernel_initializer":
+ tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ "call_list":
+ call_list
+ }
+ # Create a small EncoderScaffold for testing. This time, we pass an already-
+ # instantiated layer object.
+
+ xformer = ValidatedTransformerLayer(**hidden_cfg)
+
+ test_network = encoder_scaffold.EncoderScaffold(
+ num_hidden_instances=3,
+ pooled_output_dim=hidden_size,
+ pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=0.02),
+ hidden_cls=xformer,
+ embedding_cfg=embedding_cfg)
+
+ # Create another network object from the first object's config.
+ new_network = encoder_scaffold.EncoderScaffold.from_config(
+ test_network.get_config())
+
+ # Validate that the config can be forced to JSON.
+ _ = new_network.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(test_network.get_config(), new_network.get_config())
+
+ # Create a model based off of the old and new networks:
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+
+ data, pooled = new_network([word_ids, mask, type_ids])
+ new_model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+
+ data, pooled = test_network([word_ids, mask, type_ids])
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+
+ # Copy the weights between models.
+ new_model.set_weights(model.get_weights())
+
+ # Invoke the models.
+ batch_size = 3
+ word_id_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ mask_data = np.random.randint(2, size=(batch_size, sequence_length))
+ type_id_data = np.random.randint(
+ num_types, size=(batch_size, sequence_length))
+ data, cls = model.predict([word_id_data, mask_data, type_id_data])
+ new_data, new_cls = new_model.predict(
+ [word_id_data, mask_data, type_id_data])
+
+ # The output should be equal.
+ self.assertAllEqual(data, new_data)
+ self.assertAllEqual(cls, new_cls)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/modeling/networks/span_labeling.py b/models/official/nlp/modeling/networks/span_labeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d704c33b6d62ae059d01b81bca146ca1c5adca4
--- /dev/null
+++ b/models/official/nlp/modeling/networks/span_labeling.py
@@ -0,0 +1,92 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Span labeling network."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class SpanLabeling(tf.keras.Model):
+ """Span labeling network head for BERT modeling.
+
+ This network implements a simple single-span labeler based on a dense layer.
+
+ Arguments:
+ input_width: The innermost dimension of the input tensor to this network.
+ activation: The activation, if any, for the dense layer in this network.
+ initializer: The intializer for the dense layer in this network. Defaults to
+ a Glorot uniform initializer.
+ output: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ """
+
+ def __init__(self,
+ input_width,
+ activation=None,
+ initializer='glorot_uniform',
+ output='logits',
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._config = {
+ 'input_width': input_width,
+ 'activation': activation,
+ 'initializer': initializer,
+ 'output': output,
+ }
+
+ sequence_data = tf.keras.layers.Input(
+ shape=(None, input_width), name='sequence_data', dtype=tf.float32)
+
+ intermediate_logits = tf.keras.layers.Dense(
+ 2, # This layer predicts start location and end location.
+ activation=activation,
+ kernel_initializer=initializer,
+ name='predictions/transform/logits')(
+ sequence_data)
+ self.start_logits, self.end_logits = (
+ tf.keras.layers.Lambda(self._split_output_tensor)(intermediate_logits))
+
+ start_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(
+ self.start_logits)
+ end_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(
+ self.end_logits)
+
+ if output == 'logits':
+ output_tensors = [self.start_logits, self.end_logits]
+ elif output == 'predictions':
+ output_tensors = [start_predictions, end_predictions]
+ else:
+ raise ValueError(
+ ('Unknown `output` value "%s". `output` can be either "logits" or '
+ '"predictions"') % output)
+
+ super(SpanLabeling, self).__init__(
+ inputs=[sequence_data], outputs=output_tensors, **kwargs)
+
+ def _split_output_tensor(self, tensor):
+ transposed_tensor = tf.transpose(tensor, [2, 0, 1])
+ return tf.unstack(transposed_tensor)
+
+ def get_config(self):
+ return self._config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/networks/span_labeling_test.py b/models/official/nlp/modeling/networks/span_labeling_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8533a77b7830c1abe921fa93cd4e0cd7e8229475
--- /dev/null
+++ b/models/official/nlp/modeling/networks/span_labeling_test.py
@@ -0,0 +1,174 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for span_labeling network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.networks import span_labeling
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class SpanLabelingTest(keras_parameterized.TestCase):
+
+ def test_network_creation(self):
+ """Validate that the Keras object can be created."""
+ sequence_length = 15
+ input_width = 512
+ test_network = span_labeling.SpanLabeling(
+ input_width=input_width, output='predictions')
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(
+ shape=(sequence_length, input_width), dtype=tf.float32)
+ start_outputs, end_outputs = test_network(sequence_data)
+
+ # Validate that the outputs are of the expected shape.
+ expected_output_shape = [None, sequence_length]
+ self.assertEqual(expected_output_shape, start_outputs.shape.as_list())
+ self.assertEqual(expected_output_shape, end_outputs.shape.as_list())
+
+ def test_network_invocation(self):
+ """Validate that the Keras object can be invoked."""
+ sequence_length = 15
+ input_width = 512
+ test_network = span_labeling.SpanLabeling(input_width=input_width)
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(
+ shape=(sequence_length, input_width), dtype=tf.float32)
+ outputs = test_network(sequence_data)
+ model = tf.keras.Model(sequence_data, outputs)
+
+ # Invoke the network as part of a Model.
+ batch_size = 3
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, input_width))
+ start_outputs, end_outputs = model.predict(input_data)
+
+ # Validate that the outputs are of the expected shape.
+ expected_output_shape = (batch_size, sequence_length)
+ self.assertEqual(expected_output_shape, start_outputs.shape)
+ self.assertEqual(expected_output_shape, end_outputs.shape)
+
+ def test_network_invocation_with_internal_logit_output(self):
+ """Validate that the logit outputs are correct."""
+ sequence_length = 15
+ input_width = 512
+ test_network = span_labeling.SpanLabeling(
+ input_width=input_width, output='predictions')
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(
+ shape=(sequence_length, input_width), dtype=tf.float32)
+ output = test_network(sequence_data)
+ model = tf.keras.Model(sequence_data, output)
+ logit_model = tf.keras.Model(
+ test_network.inputs,
+ [test_network.start_logits, test_network.end_logits])
+
+ batch_size = 3
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, input_width))
+ start_outputs, end_outputs = model.predict(input_data)
+ start_logits, end_logits = logit_model.predict(input_data)
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, sequence_length)
+ self.assertEqual(expected_output_shape, start_outputs.shape)
+ self.assertEqual(expected_output_shape, end_outputs.shape)
+ self.assertEqual(expected_output_shape, start_logits.shape)
+ self.assertEqual(expected_output_shape, end_logits.shape)
+
+ # Ensure that the logits, when softmaxed, create the outputs.
+ input_tensor = tf.keras.Input(expected_output_shape[1:])
+ output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
+ softmax_model = tf.keras.Model(input_tensor, output_tensor)
+
+ start_softmax = softmax_model.predict(start_logits)
+ self.assertAllClose(start_outputs, start_softmax)
+ end_softmax = softmax_model.predict(end_logits)
+ self.assertAllClose(end_outputs, end_softmax)
+
+ def test_network_invocation_with_external_logit_output(self):
+ """Validate that the logit outputs are correct."""
+ sequence_length = 15
+ input_width = 512
+ test_network = span_labeling.SpanLabeling(
+ input_width=input_width, output='predictions')
+ logit_network = span_labeling.SpanLabeling(
+ input_width=input_width, output='logits')
+ logit_network.set_weights(test_network.get_weights())
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(
+ shape=(sequence_length, input_width), dtype=tf.float32)
+ output = test_network(sequence_data)
+ logit_output = logit_network(sequence_data)
+ model = tf.keras.Model(sequence_data, output)
+ logit_model = tf.keras.Model(sequence_data, logit_output)
+
+ batch_size = 3
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, input_width))
+ start_outputs, end_outputs = model.predict(input_data)
+ start_logits, end_logits = logit_model.predict(input_data)
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, sequence_length)
+ self.assertEqual(expected_output_shape, start_outputs.shape)
+ self.assertEqual(expected_output_shape, end_outputs.shape)
+ self.assertEqual(expected_output_shape, start_logits.shape)
+ self.assertEqual(expected_output_shape, end_logits.shape)
+
+ # Ensure that the logits, when softmaxed, create the outputs.
+ input_tensor = tf.keras.Input(expected_output_shape[1:])
+ output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
+ softmax_model = tf.keras.Model(input_tensor, output_tensor)
+
+ start_softmax = softmax_model.predict(start_logits)
+ self.assertAllClose(start_outputs, start_softmax)
+ end_softmax = softmax_model.predict(end_logits)
+ self.assertAllClose(end_outputs, end_softmax)
+
+ def test_serialize_deserialize(self):
+ # Create a network object that sets all of its config options.
+ network = span_labeling.SpanLabeling(
+ input_width=128,
+ activation='relu',
+ initializer='zeros',
+ output='predictions')
+
+ # Create another network object from the first object's config.
+ new_network = span_labeling.SpanLabeling.from_config(network.get_config())
+
+ # Validate that the config can be forced to JSON.
+ _ = new_network.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(network.get_config(), new_network.get_config())
+
+ def test_unknown_output_type_fails(self):
+ with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
+ _ = span_labeling.SpanLabeling(input_width=10, output='bad')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/networks/token_classification.py b/models/official/nlp/modeling/networks/token_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff6163481e6f267a5aefac352ff38447a275a13a
--- /dev/null
+++ b/models/official/nlp/modeling/networks/token_classification.py
@@ -0,0 +1,83 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classification network."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class TokenClassification(tf.keras.Model):
+ """TokenClassification network head for BERT modeling.
+
+ This network implements a simple token classifier head based on a dense layer.
+
+ Arguments:
+ input_width: The innermost dimension of the input tensor to this network.
+ num_classes: The number of classes that this network should classify to.
+ activation: The activation, if any, for the dense layer in this network.
+ initializer: The intializer for the dense layer in this network. Defaults to
+ a Glorot uniform initializer.
+ output: The output style for this network. Can be either 'logits' or
+ 'predictions'.
+ """
+
+ def __init__(self,
+ input_width,
+ num_classes,
+ initializer='glorot_uniform',
+ output='logits',
+ **kwargs):
+ self._self_setattr_tracking = False
+ self._config_dict = {
+ 'input_width': input_width,
+ 'num_classes': num_classes,
+ 'initializer': initializer,
+ 'output': output,
+ }
+
+ sequence_data = tf.keras.layers.Input(
+ shape=(None, input_width), name='sequence_data', dtype=tf.float32)
+
+ self.logits = tf.keras.layers.Dense(
+ num_classes,
+ activation=None,
+ kernel_initializer=initializer,
+ name='predictions/transform/logits')(
+ sequence_data)
+ predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
+
+ if output == 'logits':
+ output_tensors = self.logits
+ elif output == 'predictions':
+ output_tensors = predictions
+ else:
+ raise ValueError(
+ ('Unknown `output` value "%s". `output` can be either "logits" or '
+ '"predictions"') % output)
+
+ super(TokenClassification, self).__init__(
+ inputs=[sequence_data], outputs=output_tensors, **kwargs)
+
+ def get_config(self):
+ return self._config_dict
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/networks/token_classification_test.py b/models/official/nlp/modeling/networks/token_classification_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb695c7845b125a5f34d82ff38218ca2dccdfe54
--- /dev/null
+++ b/models/official/nlp/modeling/networks/token_classification_test.py
@@ -0,0 +1,192 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for token classification network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.networks import token_classification
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class TokenClassificationTest(keras_parameterized.TestCase):
+
+ def test_network_creation(self):
+ """Validate that the Keras object can be created."""
+ sequence_length = 5
+ input_width = 512
+ num_classes = 10
+ test_object = token_classification.TokenClassification(
+ input_width=input_width, num_classes=num_classes)
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
+ dtype=tf.float32)
+ output = test_object(sequence_data)
+
+ # Validate that the outputs are of the expected shape.
+ expected_output_shape = [None, sequence_length, num_classes]
+ self.assertEqual(expected_output_shape, output.shape.as_list())
+
+ def test_network_invocation(self):
+ """Validate that the Keras object can be invoked."""
+ sequence_length = 5
+ input_width = 512
+ num_classes = 10
+ test_object = token_classification.TokenClassification(
+ input_width=input_width, num_classes=num_classes, output='predictions')
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
+ dtype=tf.float32)
+ output = test_object(sequence_data)
+
+ # Invoke the network as part of a Model.
+ model = tf.keras.Model(sequence_data, output)
+ input_data = 10 * np.random.random_sample((3, sequence_length, input_width))
+ _ = model.predict(input_data)
+
+ def test_network_invocation_with_internal_logits(self):
+ """Validate that the logit outputs are correct."""
+ sequence_length = 5
+ input_width = 512
+ num_classes = 10
+ test_object = token_classification.TokenClassification(
+ input_width=input_width, num_classes=num_classes, output='predictions')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
+ dtype=tf.float32)
+ output = test_object(sequence_data)
+ model = tf.keras.Model(sequence_data, output)
+ logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
+
+ batch_size = 3
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, input_width))
+ outputs = model.predict(input_data)
+ logits = logits_model.predict(input_data)
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, sequence_length, num_classes)
+ self.assertEqual(expected_output_shape, outputs.shape)
+ self.assertEqual(expected_output_shape, logits.shape)
+
+ # Ensure that the logits, when softmaxed, create the outputs.
+ input_tensor = tf.keras.Input(expected_output_shape[1:])
+ output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
+ softmax_model = tf.keras.Model(input_tensor, output_tensor)
+
+ calculated_softmax = softmax_model.predict(logits)
+ self.assertAllClose(outputs, calculated_softmax)
+
+ def test_network_invocation_with_internal_and_external_logits(self):
+ """Validate that the logit outputs are correct."""
+ sequence_length = 5
+ input_width = 512
+ num_classes = 10
+ test_object = token_classification.TokenClassification(
+ input_width=input_width, num_classes=num_classes, output='logits')
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
+ dtype=tf.float32)
+ output = test_object(sequence_data)
+ model = tf.keras.Model(sequence_data, output)
+ logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
+
+ batch_size = 3
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, input_width))
+ outputs = model.predict(input_data)
+ logits = logits_model.predict(input_data)
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, sequence_length, num_classes)
+ self.assertEqual(expected_output_shape, outputs.shape)
+ self.assertEqual(expected_output_shape, logits.shape)
+
+ self.assertAllClose(outputs, logits)
+
+ def test_network_invocation_with_logit_output(self):
+ """Validate that the logit outputs are correct."""
+ sequence_length = 5
+ input_width = 512
+ num_classes = 10
+ test_object = token_classification.TokenClassification(
+ input_width=input_width, num_classes=num_classes, output='predictions')
+ logit_object = token_classification.TokenClassification(
+ input_width=input_width, num_classes=num_classes, output='logits')
+ logit_object.set_weights(test_object.get_weights())
+
+ # Create a 3-dimensional input (the first dimension is implicit).
+ sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
+ dtype=tf.float32)
+ output = test_object(sequence_data)
+ logit_output = logit_object(sequence_data)
+
+ model = tf.keras.Model(sequence_data, output)
+ logits_model = tf.keras.Model(sequence_data, logit_output)
+
+ batch_size = 3
+ input_data = 10 * np.random.random_sample(
+ (batch_size, sequence_length, input_width))
+ outputs = model.predict(input_data)
+ logits = logits_model.predict(input_data)
+
+ # Ensure that the tensor shapes are correct.
+ expected_output_shape = (batch_size, sequence_length, num_classes)
+ self.assertEqual(expected_output_shape, outputs.shape)
+ self.assertEqual(expected_output_shape, logits.shape)
+
+ # Ensure that the logits, when softmaxed, create the outputs.
+ input_tensor = tf.keras.Input(expected_output_shape[1:])
+ output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
+ softmax_model = tf.keras.Model(input_tensor, output_tensor)
+
+ calculated_softmax = softmax_model.predict(logits)
+ self.assertAllClose(outputs, calculated_softmax)
+
+ def test_serialize_deserialize(self):
+ # Create a network object that sets all of its config options.
+ network = token_classification.TokenClassification(
+ input_width=128,
+ num_classes=10,
+ initializer='zeros',
+ output='predictions')
+
+ # Create another network object from the first object's config.
+ new_network = token_classification.TokenClassification.from_config(
+ network.get_config())
+
+ # Validate that the config can be forced to JSON.
+ _ = new_network.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(network.get_config(), new_network.get_config())
+
+ def test_unknown_output_type_fails(self):
+ with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
+ _ = token_classification.TokenClassification(
+ input_width=128, num_classes=10, output='bad')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/modeling/networks/transformer_encoder.py b/models/official/nlp/modeling/networks/transformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c6054ddcc242d5184c6e0e4dcd5102e6955b915
--- /dev/null
+++ b/models/official/nlp/modeling/networks/transformer_encoder.py
@@ -0,0 +1,238 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Transformer-based text encoder network."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.modeling import activations
+from official.nlp.modeling import layers
+
+
+@tf.keras.utils.register_keras_serializable(package='Text')
+class TransformerEncoder(tf.keras.Model):
+ """Bi-directional Transformer-based encoder network.
+
+ This network implements a bi-directional Transformer-based encoder as
+ described in "BERT: Pre-training of Deep Bidirectional Transformers for
+ Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
+ embedding lookups and transformer layers, but not the masked language model
+ or classification task networks.
+
+ The default values for this object are taken from the BERT-Base implementation
+ in "BERT: Pre-training of Deep Bidirectional Transformers for Language
+ Understanding".
+
+ Arguments:
+ vocab_size: The size of the token vocabulary.
+ hidden_size: The size of the transformer hidden layers.
+ num_layers: The number of transformer layers.
+ num_attention_heads: The number of attention heads for each transformer. The
+ hidden size must be divisible by the number of attention heads.
+ sequence_length: The sequence length that this encoder expects. If None, the
+ sequence length is dynamic; if an integer, the encoder will require
+ sequences padded to this length.
+ max_sequence_length: The maximum sequence length that this encoder can
+ consume. If None, max_sequence_length uses the value from sequence length.
+ This determines the variable shape for positional embeddings.
+ type_vocab_size: The number of types that the 'type_ids' input can take.
+ intermediate_size: The intermediate size for the transformer layers.
+ activation: The activation to use for the transformer layers.
+ dropout_rate: The dropout rate to use for the transformer layers.
+ attention_dropout_rate: The dropout rate to use for the attention layers
+ within the transformer layers.
+ initializer: The initialzer to use for all weights in this encoder.
+ return_all_encoder_outputs: Whether to output sequence embedding outputs of
+ all encoder transformer layers.
+ output_range: The sequence output range, [0, output_range), by slicing the
+ target sequence of the last transformer layer. `None` means the entire
+ target sequence will attend to the source sequence, which yeilds the full
+ output.
+ embedding_width: The width of the word embeddings. If the embedding width
+ is not equal to hidden size, embedding parameters will be factorized into
+ two matrices in the shape of ['vocab_size', 'embedding_width'] and
+ ['embedding_width', 'hidden_size'] ('embedding_width' is usually much
+ smaller than 'hidden_size').
+ embedding_layer: The word embedding layer. `None` means we will create a new
+ embedding layer. Otherwise, we will reuse the given embedding layer. This
+ parameter is originally added for ELECTRA model which needs to tie the
+ generator embeddings with the discriminator embeddings.
+ """
+
+ def __init__(self,
+ vocab_size,
+ hidden_size=768,
+ num_layers=12,
+ num_attention_heads=12,
+ sequence_length=512,
+ max_sequence_length=None,
+ type_vocab_size=16,
+ intermediate_size=3072,
+ activation=activations.gelu,
+ dropout_rate=0.1,
+ attention_dropout_rate=0.1,
+ initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
+ return_all_encoder_outputs=False,
+ output_range=None,
+ embedding_width=None,
+ embedding_layer=None,
+ **kwargs):
+ activation = tf.keras.activations.get(activation)
+ initializer = tf.keras.initializers.get(initializer)
+
+ if not max_sequence_length:
+ max_sequence_length = sequence_length
+ self._self_setattr_tracking = False
+ self._config_dict = {
+ 'vocab_size': vocab_size,
+ 'hidden_size': hidden_size,
+ 'num_layers': num_layers,
+ 'num_attention_heads': num_attention_heads,
+ 'sequence_length': sequence_length,
+ 'max_sequence_length': max_sequence_length,
+ 'type_vocab_size': type_vocab_size,
+ 'intermediate_size': intermediate_size,
+ 'activation': tf.keras.activations.serialize(activation),
+ 'dropout_rate': dropout_rate,
+ 'attention_dropout_rate': attention_dropout_rate,
+ 'initializer': tf.keras.initializers.serialize(initializer),
+ 'return_all_encoder_outputs': return_all_encoder_outputs,
+ 'output_range': output_range,
+ 'embedding_width': embedding_width,
+ }
+
+ word_ids = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name='input_word_ids')
+ mask = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name='input_mask')
+ type_ids = tf.keras.layers.Input(
+ shape=(sequence_length,), dtype=tf.int32, name='input_type_ids')
+
+ if embedding_width is None:
+ embedding_width = hidden_size
+ if embedding_layer is None:
+ self._embedding_layer = layers.OnDeviceEmbedding(
+ vocab_size=vocab_size,
+ embedding_width=embedding_width,
+ initializer=initializer,
+ name='word_embeddings')
+ else:
+ self._embedding_layer = embedding_layer
+ word_embeddings = self._embedding_layer(word_ids)
+
+ # Always uses dynamic slicing for simplicity.
+ self._position_embedding_layer = layers.PositionEmbedding(
+ initializer=initializer,
+ use_dynamic_slicing=True,
+ max_sequence_length=max_sequence_length,
+ name='position_embedding')
+ position_embeddings = self._position_embedding_layer(word_embeddings)
+ self._type_embedding_layer = layers.OnDeviceEmbedding(
+ vocab_size=type_vocab_size,
+ embedding_width=embedding_width,
+ initializer=initializer,
+ use_one_hot=True,
+ name='type_embeddings')
+ type_embeddings = self._type_embedding_layer(type_ids)
+
+ embeddings = tf.keras.layers.Add()(
+ [word_embeddings, position_embeddings, type_embeddings])
+
+ embeddings = (
+ tf.keras.layers.LayerNormalization(
+ name='embeddings/layer_norm',
+ axis=-1,
+ epsilon=1e-12,
+ dtype=tf.float32)(embeddings))
+ embeddings = (
+ tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
+
+ # We project the 'embedding' output to 'hidden_size' if it is not already
+ # 'hidden_size'.
+ if embedding_width != hidden_size:
+ self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
+ '...x,xy->...y',
+ output_shape=hidden_size,
+ bias_axes='y',
+ kernel_initializer=initializer,
+ name='embedding_projection')
+ embeddings = self._embedding_projection(embeddings)
+
+ self._transformer_layers = []
+ data = embeddings
+ attention_mask = layers.SelfAttentionMask()([data, mask])
+ encoder_outputs = []
+ for i in range(num_layers):
+ if i == num_layers - 1 and output_range is not None:
+ transformer_output_range = output_range
+ else:
+ transformer_output_range = None
+ layer = layers.Transformer(
+ num_attention_heads=num_attention_heads,
+ intermediate_size=intermediate_size,
+ intermediate_activation=activation,
+ dropout_rate=dropout_rate,
+ attention_dropout_rate=attention_dropout_rate,
+ output_range=transformer_output_range,
+ kernel_initializer=initializer,
+ name='transformer/layer_%d' % i)
+ self._transformer_layers.append(layer)
+ data = layer([data, attention_mask])
+ encoder_outputs.append(data)
+
+ first_token_tensor = (
+ tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
+ encoder_outputs[-1]))
+ self._pooler_layer = tf.keras.layers.Dense(
+ units=hidden_size,
+ activation='tanh',
+ kernel_initializer=initializer,
+ name='pooler_transform')
+ cls_output = self._pooler_layer(first_token_tensor)
+
+ if return_all_encoder_outputs:
+ outputs = [encoder_outputs, cls_output]
+ else:
+ outputs = [encoder_outputs[-1], cls_output]
+
+ super(TransformerEncoder, self).__init__(
+ inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
+
+ def get_embedding_table(self):
+ return self._embedding_layer.embeddings
+
+ def get_embedding_layer(self):
+ return self._embedding_layer
+
+ def get_config(self):
+ return self._config_dict
+
+ @property
+ def transformer_layers(self):
+ """List of Transformer layers in the encoder."""
+ return self._transformer_layers
+
+ @property
+ def pooler_layer(self):
+ """The pooler dense layer after the transformer layers."""
+ return self._pooler_layer
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/models/official/nlp/modeling/networks/transformer_encoder_test.py b/models/official/nlp/modeling/networks/transformer_encoder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9fbc3aaa25e39908618626538902643edaabe72
--- /dev/null
+++ b/models/official/nlp/modeling/networks/transformer_encoder_test.py
@@ -0,0 +1,231 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for transformer-based text encoder network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
+from official.nlp.modeling.networks import transformer_encoder
+
+
+# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
+# guarantees forward compatibility of this code for the V2 switchover.
+@keras_parameterized.run_all_keras_modes
+class TransformerEncoderTest(keras_parameterized.TestCase):
+
+ def tearDown(self):
+ super(TransformerEncoderTest, self).tearDown()
+ tf.keras.mixed_precision.experimental.set_policy("float32")
+
+ def test_network_creation(self):
+ hidden_size = 32
+ sequence_length = 21
+ # Create a small TransformerEncoder for testing.
+ test_network = transformer_encoder.TransformerEncoder(
+ vocab_size=100,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ num_attention_heads=2,
+ num_layers=3)
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask, type_ids])
+
+ self.assertIsInstance(test_network.transformer_layers, list)
+ self.assertLen(test_network.transformer_layers, 3)
+ self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
+
+ expected_data_shape = [None, sequence_length, hidden_size]
+ expected_pooled_shape = [None, hidden_size]
+ self.assertAllEqual(expected_data_shape, data.shape.as_list())
+ self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
+
+ # The default output dtype is float32.
+ self.assertAllEqual(tf.float32, data.dtype)
+ self.assertAllEqual(tf.float32, pooled.dtype)
+
+ def test_all_encoder_outputs_network_creation(self):
+ hidden_size = 32
+ sequence_length = 21
+ # Create a small TransformerEncoder for testing.
+ test_network = transformer_encoder.TransformerEncoder(
+ vocab_size=100,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ num_attention_heads=2,
+ num_layers=3,
+ return_all_encoder_outputs=True)
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ all_encoder_outputs, pooled = test_network([word_ids, mask, type_ids])
+
+ expected_data_shape = [None, sequence_length, hidden_size]
+ expected_pooled_shape = [None, hidden_size]
+ self.assertLen(all_encoder_outputs, 3)
+ for data in all_encoder_outputs:
+ self.assertAllEqual(expected_data_shape, data.shape.as_list())
+ self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
+
+ # The default output dtype is float32.
+ self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
+ self.assertAllEqual(tf.float32, pooled.dtype)
+
+ def test_network_creation_with_float16_dtype(self):
+ hidden_size = 32
+ sequence_length = 21
+ tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
+ # Create a small TransformerEncoder for testing.
+ test_network = transformer_encoder.TransformerEncoder(
+ vocab_size=100,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ num_attention_heads=2,
+ num_layers=3)
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask, type_ids])
+
+ expected_data_shape = [None, sequence_length, hidden_size]
+ expected_pooled_shape = [None, hidden_size]
+ self.assertAllEqual(expected_data_shape, data.shape.as_list())
+ self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
+
+ # If float_dtype is set to float16, the data output is float32 (from a layer
+ # norm) and pool output should be float16.
+ self.assertAllEqual(tf.float32, data.dtype)
+ self.assertAllEqual(tf.float16, pooled.dtype)
+
+ @parameterized.named_parameters(
+ ("all_sequence", None, 21),
+ ("output_range", 1, 1),
+ )
+ def test_network_invocation(self, output_range, out_seq_len):
+ hidden_size = 32
+ sequence_length = 21
+ vocab_size = 57
+ num_types = 7
+ # Create a small TransformerEncoder for testing.
+ test_network = transformer_encoder.TransformerEncoder(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ num_attention_heads=2,
+ num_layers=3,
+ type_vocab_size=num_types,
+ output_range=output_range)
+ self.assertTrue(
+ test_network._position_embedding_layer._use_dynamic_slicing)
+ # Create the inputs (note that the first dimension is implicit).
+ word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
+ data, pooled = test_network([word_ids, mask, type_ids])
+
+ # Create a model based off of this network:
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+
+ # Invoke the model. We can't validate the output data here (the model is too
+ # complex) but this will catch structural runtime errors.
+ batch_size = 3
+ word_id_data = np.random.randint(
+ vocab_size, size=(batch_size, sequence_length))
+ mask_data = np.random.randint(2, size=(batch_size, sequence_length))
+ type_id_data = np.random.randint(
+ num_types, size=(batch_size, sequence_length))
+ _ = model.predict([word_id_data, mask_data, type_id_data])
+
+ # Creates a TransformerEncoder with max_sequence_length != sequence_length
+ max_sequence_length = 128
+ test_network = transformer_encoder.TransformerEncoder(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ max_sequence_length=max_sequence_length,
+ num_attention_heads=2,
+ num_layers=3,
+ type_vocab_size=num_types)
+ self.assertTrue(test_network._position_embedding_layer._use_dynamic_slicing)
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+ outputs = model.predict([word_id_data, mask_data, type_id_data])
+ self.assertEqual(outputs[0].shape[1], out_seq_len)
+
+ # Creates a TransformerEncoder with embedding_width != hidden_size
+ test_network = transformer_encoder.TransformerEncoder(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ sequence_length=sequence_length,
+ max_sequence_length=max_sequence_length,
+ num_attention_heads=2,
+ num_layers=3,
+ type_vocab_size=num_types,
+ embedding_width=16)
+ model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
+ outputs = model.predict([word_id_data, mask_data, type_id_data])
+ self.assertEqual(outputs[0].shape[-1], hidden_size)
+ self.assertTrue(hasattr(test_network, "_embedding_projection"))
+
+ def test_serialize_deserialize(self):
+ tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
+ # Create a network object that sets all of its config options.
+ kwargs = dict(
+ vocab_size=100,
+ hidden_size=32,
+ num_layers=3,
+ num_attention_heads=2,
+ sequence_length=21,
+ max_sequence_length=21,
+ type_vocab_size=12,
+ intermediate_size=1223,
+ activation="relu",
+ dropout_rate=0.05,
+ attention_dropout_rate=0.22,
+ initializer="glorot_uniform",
+ return_all_encoder_outputs=False,
+ output_range=-1,
+ embedding_width=16)
+ network = transformer_encoder.TransformerEncoder(**kwargs)
+
+ expected_config = dict(kwargs)
+ expected_config["activation"] = tf.keras.activations.serialize(
+ tf.keras.activations.get(expected_config["activation"]))
+ expected_config["initializer"] = tf.keras.initializers.serialize(
+ tf.keras.initializers.get(expected_config["initializer"]))
+ self.assertEqual(network.get_config(), expected_config)
+
+ # Create another network object from the first object's config.
+ new_network = transformer_encoder.TransformerEncoder.from_config(
+ network.get_config())
+
+ # Validate that the config can be forced to JSON.
+ _ = new_network.to_json()
+
+ # If the serialization was successful, the new config should match the old.
+ self.assertAllEqual(network.get_config(), new_network.get_config())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/nhnet/README.md b/models/official/nlp/nhnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..14c55636ab52b4582cb6b12e88a282c7adbb059e
--- /dev/null
+++ b/models/official/nlp/nhnet/README.md
@@ -0,0 +1,168 @@
+# Multi-doc News Headline Generation Model: NHNet
+
+This repository contains TensorFlow 2.x implementation for NHNet [[1]](#1) as
+well as instructions for producing the data we described in the paper.
+
+## Introduction
+
+NHNet is a multi-doc news headline generation model. It extends a standard
+Transformer-based encoder-decoder model to multi-doc setting and relies on an
+article-level attention layer to capture information common to most (if not all)
+input news articles in a news cluster or story, and provide robustness against
+potential outliers in the input due to clustering quality.
+
+Our academic paper [[1]](#1) which describes NHNet in detail can be found here:
+https://arxiv.org/abs/2001.09386.
+
+## Dataset
+
+**Raw Data:** One can [download](https://github.com/google-research-datasets/NewSHead)
+our multi-doc headline dataset which
+contains 369,940 news stories and 932,571 unique URLs. We split these stories
+into train (359,940 stories), validation (5,000 stories) and test set (5,000
+stories) by timestamp.
+
+More information, please checkout:
+https://github.com/google-research-datasets/NewSHead
+
+### Crawling
+
+Unfortunately, we will not be able to release the pre-processed dataset that is
+exactly used in the paper. Users need to crawl the URLs and the recommended
+pre-processing is using an open-sourced library to download and parse the news
+content including title and leading paragraphs. For ease of this process, we
+provide a config of [news-please](https://github.com/fhamborg/news-please) that
+will crawl and extract news articles on a local machine.
+
+First, install the `news-please` CLI (requires python 3.x)
+```shell
+$ pip3 install news-please
+```
+
+Next, run the crawler with our provided [config and URL list](https://github.com/google-research-datasets/NewSHead/releases)
+
+```shell
+# Sets to path of the downloaded data folder.
+$ DATA_FOLDER=/path/to/downloaded_dataset
+
+# Uses CLI interface to crawl. We assume news_please subfolder contains the
+# decompressed config.cfg and sitelist.hjson.
+$ news-please -c $DATA_FOLDER/news_please
+```
+By default, it will store crawled
+articles under `/tmp/nhnet/`. To terminate the process press `CTRL+C`.
+
+The crawling may take some days (48 hours in our test) and it depends on the
+network environment and #threads set in the config. As the crawling tool won't
+stop automatically, it is not straightforward to check the progress. We suggest
+to terminate the job if there are no new articles crawled in a short time period
+(e.g., 10 minutes) by running
+```shell
+$ find /tmp/nhnet -type f | wc -l
+```
+Please note that it is expected that some URLs are no longer available on the
+web as time goes by.
+
+### Data Processing
+
+Given the crawled articles under `/tmp/nhnet/`, we would like to transform these
+textual articles into a set of `TFRecord` files containing serialized
+tensorflow.Example protocol buffers, with feature keys following the BERT
+[[2]](#2) tradition but is extended for multiple text segments. We will later
+use these processed TFRecords for training and evaluation.
+
+To do this, please first download a [BERT pretrained checkpoint](https://github.com/tensorflow/models/tree/master/official/nlp/bert#access-to-pretrained-checkpoints)
+(`BERT-Base,Uncased` preferred for efficiency) and decompress the `tar.gz` file.
+We need the vocabulary file and later use the checkpoint for NHNet
+initialization.
+
+Next, we can run the following data preprocess script which may take a few hours
+ to read files and tokenize article content.
+
+
+```shell
+# Recall that we use DATA_FOLDER=/path/to/downloaded_dataset.
+$ python3 raw_data_preprocess.py \
+ -crawled_articles=/tmp/nhnet \
+ -vocab=/path/to/bert_checkpoint/vocab.txt \
+ -do_lower_case=True \
+ -len_title=15 \
+ -len_passage=200 \
+ -max_num_articles=5 \
+ -data_folder=$DATA_FOLDER
+```
+
+This python script will export processed train/valid/eval files under
+`$DATA_FOLDER/processed/`.
+
+## Training
+
+Please first install TensorFlow 2 and Tensorflow Model Garden following the
+[requirments section](https://github.com/tensorflow/models/tree/master/official#requirements).
+
+### CPU/GPU
+```shell
+$ python3 trainer.py \
+ --mode=train_and_eval \
+ --vocab=/path/to/bert_checkpoint/vocab.txt \
+ --init_checkpoint=/path/to/bert_checkpoint/bert_model.ckpt \
+ --params_override='init_from_bert2bert=false' \
+ --train_file_pattern=$DATA_FOLDER/processed/train.tfrecord* \
+ --model_dir=/path/to/output/model \
+ --len_title=15 \
+ --len_passage=200 \
+ --max_num_articles=5 \
+ --model_type=nhnet \
+ --train_batch_size=16 \
+ --train_steps=10000 \
+ --steps_per_loop=1 \
+ --checkpoint_interval=100
+```
+
+### TPU
+```shell
+$ python3 trainer.py \
+ --mode=train_and_eval \
+ --vocab=/path/to/bert_checkpoint/vocab.txt \
+ --init_checkpoint=/path/to/bert_checkpoint/bert_model.ckpt \
+ --params_override='init_from_bert2bert=false' \
+ --train_file_pattern=$DATA_FOLDER/processed/train.tfrecord* \
+ --model_dir=/path/to/output/model \
+ --len_title=15 \
+ --len_passage=200 \
+ --max_num_articles=5 \
+ --model_type=nhnet \
+ --train_batch_size=1024 \
+ --train_steps=10000 \
+ --steps_per_loop=1000 \
+ --checkpoint_interval=1000 \
+ --distribution_strategy=tpu \
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+In the paper, we train more than 10k steps with batch size set as 1024 with
+TPU-v3-64.
+
+Note that, `trainer.py` also supports `train` mode and continuous `eval` mode.
+For large scale TPU training, we recommend the have a process running the
+`train` mode and another process running the continuous `eval` mode which can
+runs on GPUs.
+This is the setting we commonly used for large-scale experiments, because `eval`
+will be non-blocking to the expensive training load.
+
+### Metrics
+**Note: the metrics reported by `evaluation.py` are approximated on
+word-piece level rather than the real string tokens. Some metrics like BLEU
+scores can be off.**
+
+We will release a colab to evaluate results on string-level soon.
+
+## References
+
+[1] Xiaotao Gu, Yuning Mao, Jiawei Han, Jialu Liu, You Wu, Cong
+Yu, Daniel Finnie, Hongkun Yu, Jiaqi Zhai and Nicholas Zukoski "Generating
+Representative Headlines for News Stories": https://arxiv.org/abs/2001.09386.
+World Wide Web Conf. (WWW’2020).
+
+[2] Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina
+Toutanova "BERT: Pre-training of Deep Bidirectional Transformers for Language
+Understanding": https://arxiv.org/abs/1810.04805.
diff --git a/models/official/nlp/nhnet/__init__.py b/models/official/nlp/nhnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/nlp/nhnet/configs.py b/models/official/nlp/nhnet/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..41cfa6117cb49e00224becb87b129401562a9807
--- /dev/null
+++ b/models/official/nlp/nhnet/configs.py
@@ -0,0 +1,107 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common NHNet/Bert2Bert configuration."""
+
+from typing import List, Text
+
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class BERT2BERTConfig(base_config.Config):
+ """High-level configurations for BERT2BERT model.
+
+ These include parameters that are not directly related to the experiment,
+ e.g. encoder, decoder, prediction, training, etc.
+ """
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_hidden_layers: int = 12
+ num_attention_heads: int = 12
+ intermediate_size: int = 3072
+ hidden_act: str = "gelu"
+ hidden_dropout_prob: float = 0.1
+ attention_probs_dropout_prob: float = 0.1
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+ decoder_intermediate_size: int = 3072
+ num_decoder_attn_heads: int = 12
+ num_decoder_layers: int = 12
+
+ label_smoothing: float = 0.1
+ learning_rate: float = 0.05
+ learning_rate_warmup_steps: int = 20000
+ optimizer: str = "Adam"
+ adam_beta1: float = 0.9
+ adam_beta2: float = 0.997
+ adam_epsilon: float = 1e-09
+
+ # predict params
+ beam_size: int = 5
+ alpha: float = 0.6
+ initializer_gain: float = 1.0
+ use_cache: bool = True
+
+ # input params
+ input_sharding: bool = False
+ input_data_not_padded: bool = False
+ pad_token_id: int = 0
+ end_token_id: int = 102
+ start_token_id: int = 101
+
+
+@dataclasses.dataclass
+class NHNetConfig(BERT2BERTConfig):
+ """High-level configurations for NHNet model.
+
+ These include parameters that are not directly related to the experiment,
+ e.g. encoder, decoder, prediction, training, etc.
+ """
+ multi_channel_cross_attention: bool = True
+ passage_list: List[Text] = dataclasses.field(
+ default_factory=lambda: [chr(ord("b") + i) for i in range(5)])
+
+ # Initialization method.
+ # If init_from_bert2bert is false, we assume the checkpoint is from BERT
+ # pretraining and only encoder and self-attention variables are initialized.
+ init_from_bert2bert: bool = True
+
+
+UNITTEST_CONFIG = {
+ "attention_probs_dropout_prob": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "hidden_size": 16,
+ "initializer_range": 0.02,
+ "intermediate_size": 32,
+ "max_position_embeddings": 128,
+ "num_attention_heads": 2,
+ "num_hidden_layers": 1,
+ "type_vocab_size": 2,
+ "vocab_size": 30522,
+ "initializer_gain": 1.0,
+ "decoder_intermediate_size": 32,
+ "num_decoder_attn_heads": 2,
+ "num_decoder_layers": 1,
+ "use_cache": True,
+ "input_data_not_padded": False,
+ "pad_token_id": 0,
+ "end_token_id": 102,
+ "start_token_id": 101,
+}
diff --git a/models/official/nlp/nhnet/configs_test.py b/models/official/nlp/nhnet/configs_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b855ec6a955cd7f2a50fb173f7f5efb68b84263
--- /dev/null
+++ b/models/official/nlp/nhnet/configs_test.py
@@ -0,0 +1,121 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for configs."""
+
+import tensorflow as tf
+from official.nlp.nhnet import configs
+
+BERT2BERT_CONFIG = {
+ "vocab_size": 30522,
+ "hidden_size": 768,
+ "num_hidden_layers": 12,
+ "num_attention_heads": 12,
+ "intermediate_size": 3072,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "attention_probs_dropout_prob": 0.1,
+ "max_position_embeddings": 512,
+ "type_vocab_size": 2,
+ "initializer_range": 0.02,
+
+ # model params
+ "decoder_intermediate_size": 3072,
+ "num_decoder_attn_heads": 12,
+ "num_decoder_layers": 12,
+
+ # training params
+ "label_smoothing": 0.1,
+ "learning_rate": 0.05,
+ "learning_rate_warmup_steps": 20000,
+ "optimizer": "Adam",
+ "adam_beta1": 0.9,
+ "adam_beta2": 0.997,
+ "adam_epsilon": 1e-09,
+
+ # predict params
+ "beam_size": 5,
+ "alpha": 0.6,
+ "initializer_gain": 1.0,
+ "use_cache": True,
+
+ # input params
+ "input_sharding": False,
+ "input_data_not_padded": False,
+ "pad_token_id": 0,
+ "end_token_id": 102,
+ "start_token_id": 101,
+}
+
+NHNET_CONFIG = {
+ "vocab_size": 30522,
+ "hidden_size": 768,
+ "num_hidden_layers": 12,
+ "num_attention_heads": 12,
+ "intermediate_size": 3072,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "attention_probs_dropout_prob": 0.1,
+ "max_position_embeddings": 512,
+ "type_vocab_size": 2,
+ "initializer_range": 0.02,
+
+ # model params
+ "decoder_intermediate_size": 3072,
+ "num_decoder_attn_heads": 12,
+ "num_decoder_layers": 12,
+ "multi_channel_cross_attention": True,
+
+ # training params
+ "label_smoothing": 0.1,
+ "learning_rate": 0.05,
+ "learning_rate_warmup_steps": 20000,
+ "optimizer": "Adam",
+ "adam_beta1": 0.9,
+ "adam_beta2": 0.997,
+ "adam_epsilon": 1e-09,
+
+ # predict params
+ "beam_size": 5,
+ "alpha": 0.6,
+ "initializer_gain": 1.0,
+ "use_cache": True,
+
+ # input params
+ "passage_list": ["b", "c", "d", "e", "f"],
+ "input_sharding": False,
+ "input_data_not_padded": False,
+ "pad_token_id": 0,
+ "end_token_id": 102,
+ "start_token_id": 101,
+
+ "init_from_bert2bert": True,
+}
+
+
+class ConfigsTest(tf.test.TestCase):
+
+ def test_configs(self):
+ cfg = configs.BERT2BERTConfig()
+ cfg.validate()
+ self.assertEqual(cfg.as_dict(), BERT2BERT_CONFIG)
+
+ def test_nhnet_config(self):
+ cfg = configs.NHNetConfig()
+ cfg.validate()
+ self.assertEqual(cfg.as_dict(), NHNET_CONFIG)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/nhnet/decoder.py b/models/official/nlp/nhnet/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b38fa2a6b6a251af48848e5d0a8d684be8f4c098
--- /dev/null
+++ b/models/official/nlp/nhnet/decoder.py
@@ -0,0 +1,375 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Transformer decoder that mimics a BERT encoder, to load BERT checkpoints."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+from official.modeling import tf_utils
+from official.nlp.modeling import layers
+from official.nlp.modeling.layers import transformer
+from official.nlp.transformer import model_utils as transformer_utils
+
+
+class TransformerDecoder(tf.keras.layers.Layer):
+ """Transformer decoder stack."""
+
+ def __init__(self,
+ num_hidden_layers=12,
+ hidden_size=768,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ intermediate_activation="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ attend_to_last_layer=True,
+ multi_channel_cross_attention=False,
+ **kwargs):
+ super(TransformerDecoder, self).__init__(**kwargs)
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.intermediate_activation = tf_utils.get_activation(
+ intermediate_activation)
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.attend_to_last_layer = attend_to_last_layer
+ self.multi_channel_cross_attention = multi_channel_cross_attention
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.layers = []
+ for i in range(self.num_hidden_layers):
+ self.layers.append(
+ transformer.TransformerDecoderLayer(
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ intermediate_activation=self.intermediate_activation,
+ dropout_rate=self.hidden_dropout_prob,
+ attention_dropout_rate=self.attention_probs_dropout_prob,
+ kernel_initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=self.initializer_range),
+ multi_channel_cross_attention=self.multi_channel_cross_attention,
+ name=("layer_%d" % i)))
+ super(TransformerDecoder, self).build(unused_input_shapes)
+
+ def call(self, inputs, cache=None, decode_loop_step=None):
+ """Return the output of the decoder layer stacks.
+
+ Args:
+ inputs: A dictionary of inputs. `decoder_inputs` is a tf.int32 tensor for
+ input ids. `encoder_outputs` is a list of tensors with shape
+ [batch_size, input_length, hidden_size]. `self_attention_mask` is the
+ bias for decoder self-attention layer. [1, 1, target_length,
+ target_length]. `attention_mask` is the bias for encoder-decoder
+ attention layer, [batch_size, 1, 1, input_length].
+ cache: A dictionary of cache tensors, including key & value attentions.
+ decode_loop_step: an integer to indicate the step inside a decoding loop.
+
+ Returns:
+ Output of decoder layer stack.
+ float32 tensor with shape [batch_size, target_length, hidden_size]
+ """
+ decoder_inputs = inputs["decoder_inputs"]
+ encoder_outputs = inputs["encoder_outputs"]
+ self_attention_mask = inputs["self_attention_mask"]
+ attention_mask = inputs["attention_mask"]
+ decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
+ batch_size = decoder_shape[0]
+ decoder_length = decoder_shape[1]
+
+ def _to_bert_self_attention_mask(matrix):
+ """[1, 1, target_len, target_len] -> [bs, target_len, target_len]."""
+ matrix = tf.squeeze(matrix, axis=[1])
+ matrix = tf.tile(matrix, [batch_size, 1, 1])
+ return matrix
+
+ def _to_bert_encdec_attention_mask(matrix):
+ """[bs, 1, 1, input_len] -> [bs, target_len, input_len]."""
+ if self.multi_channel_cross_attention:
+ matrix = tf.expand_dims(matrix, axis=2)
+ matrix = tf.tile(matrix, [1, 1, decoder_length, 1])
+ else:
+ matrix = tf.squeeze(matrix, axis=[1])
+ matrix = tf.tile(matrix, [1, decoder_length, 1])
+ return matrix
+
+ attention_mask = _to_bert_encdec_attention_mask(attention_mask)
+ self_attention_mask = _to_bert_self_attention_mask(self_attention_mask)
+
+ output_tensor = decoder_inputs
+ for layer_idx in range(self.num_hidden_layers):
+ if self.attend_to_last_layer:
+ memory = encoder_outputs[-1]
+ else:
+ memory = encoder_outputs[layer_idx]
+ if self.multi_channel_cross_attention:
+ transformer_inputs = [
+ output_tensor, memory, attention_mask, self_attention_mask,
+ inputs["doc_attention_probs"]
+ ]
+ else:
+ transformer_inputs = [
+ output_tensor, memory, attention_mask, self_attention_mask
+ ]
+ # Gets the cache for decoding.
+ if cache is None:
+ output_tensor, _ = self.layers[layer_idx](transformer_inputs)
+ else:
+ cache_layer_idx = str(layer_idx)
+ output_tensor, cache[cache_layer_idx] = self.layers[layer_idx](
+ transformer_inputs,
+ cache=cache[cache_layer_idx],
+ decode_loop_step=decode_loop_step)
+ return output_tensor, cache
+
+
+def get_attention_bias(input_tensor,
+ bias_type,
+ padding_value=0,
+ max_length=None):
+ """A helper function to get various attention bias tensors."""
+ if bias_type not in ("single_cross", "multi_cross", "decoder_self"):
+ raise ValueError("Invalid attention bias type: %s" % bias_type)
+ if bias_type == "single_cross":
+ length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1]
+ bias = transformer_utils.get_padding_bias(
+ input_tensor, padding_value=padding_value)
+ elif bias_type == "multi_cross":
+ length = tf_utils.get_shape_list(input_tensor, expected_rank=3)[2]
+ padding = transformer_utils.get_padding(
+ input_tensor, padding_value=padding_value)
+ bias = padding * -1e9
+ else:
+ if max_length is not None:
+ length = max_length
+ else:
+ length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1]
+ bias = transformer_utils.get_decoder_self_attention_bias(length)
+
+ return tf.where(bias < 0, tf.zeros_like(bias), tf.ones_like(bias))
+
+
+class AttentionBias(tf.keras.layers.Layer):
+
+ def __init__(self, bias_type, **kwargs):
+ super(AttentionBias, self).__init__(**kwargs)
+ self.bias_type = bias_type
+
+ def call(self, inputs):
+ return get_attention_bias(inputs, self.bias_type)
+
+
+class EmbeddingPostprocessor(tf.keras.layers.Layer):
+ """Performs various post-processing on a word embedding tensor."""
+
+ def __init__(self,
+ use_type_embeddings=False,
+ token_type_vocab_size=None,
+ use_position_embeddings=True,
+ max_position_embeddings=512,
+ dropout_prob=0.0,
+ initializer_range=0.02,
+ initializer=None,
+ **kwargs):
+ super(EmbeddingPostprocessor, self).__init__(**kwargs)
+ self.use_type_embeddings = use_type_embeddings
+ self.token_type_vocab_size = token_type_vocab_size
+ self.use_position_embeddings = use_position_embeddings
+ self.max_position_embeddings = max_position_embeddings
+ self.dropout_prob = dropout_prob
+ self.initializer_range = initializer_range
+
+ if not initializer:
+ self.initializer = tf.keras.initializers.TruncatedNormal(
+ stddev=initializer_range)
+ else:
+ self.initializer = initializer
+
+ if self.use_type_embeddings and not self.token_type_vocab_size:
+ raise ValueError("If `use_type_embeddings` is True, then "
+ "`token_type_vocab_size` must be specified.")
+
+ def build(self, input_shapes):
+ """Implements build() for the layer."""
+ (word_embeddings_shape, _) = input_shapes
+ width = word_embeddings_shape.as_list()[-1]
+ self.type_embeddings = None
+ if self.use_type_embeddings:
+ self.type_embeddings = self.add_weight(
+ "type_embeddings",
+ shape=[self.token_type_vocab_size, width],
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=self.initializer_range),
+ dtype=self.dtype)
+
+ self.position_embeddings = None
+ if self.use_position_embeddings:
+ self.position_embeddings = self.add_weight(
+ "position_embeddings",
+ shape=[self.max_position_embeddings, width],
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=self.initializer_range),
+ dtype=self.dtype)
+
+ self.output_layer_norm = tf.keras.layers.LayerNormalization(
+ name="layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
+ self.output_dropout = tf.keras.layers.Dropout(
+ rate=self.dropout_prob, dtype=tf.float32)
+ super(EmbeddingPostprocessor, self).build(input_shapes)
+
+ def __call__(self, word_embeddings, token_type_ids=None, **kwargs):
+ inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids])
+ return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs)
+
+ def call(self, inputs):
+ """Implements call() for the layer."""
+ unpacked_inputs = tf_utils.unpack_inputs(inputs)
+ word_embeddings = unpacked_inputs[0]
+ token_type_ids = unpacked_inputs[1]
+ input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
+ batch_size = input_shape[0]
+ seq_length = input_shape[1]
+ width = input_shape[2]
+
+ output = word_embeddings
+ if self.use_type_embeddings:
+ flat_token_type_ids = tf.reshape(token_type_ids, [-1])
+ token_type_embeddings = tf.gather(self.type_embeddings,
+ flat_token_type_ids)
+ token_type_embeddings = tf.reshape(token_type_embeddings,
+ [batch_size, seq_length, width])
+ output += token_type_embeddings
+
+ if self.use_position_embeddings:
+ position_embeddings = tf.expand_dims(
+ tf.slice(self.position_embeddings, [0, 0], [seq_length, width]),
+ axis=0)
+
+ output += position_embeddings
+
+ output = self.output_layer_norm(output)
+ output = self.output_dropout(output)
+
+ return output
+
+
+class Decoder(tf.keras.layers.Layer):
+ """The decoder network which can reuse encoder embeddings for target."""
+
+ def __init__(self, config, embedding_lookup=None, **kwargs):
+ super(Decoder, self).__init__(**kwargs)
+ self.config = config
+ # Shares vocabulary embedding.
+ self.embedding_lookup = None
+ if embedding_lookup:
+ self.embedding_lookup = embedding_lookup
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ if self.embedding_lookup is None:
+ self.embedding_lookup = layers.OnDeviceEmbedding(
+ vocab_size=self.config.vocab_size,
+ embedding_width=self.config.hidden_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=self.config.initializer_range),
+ name="target_embeddings")
+ self.embedding_postprocessor = EmbeddingPostprocessor(
+ use_type_embeddings=False,
+ use_position_embeddings=True,
+ max_position_embeddings=self.config.max_position_embeddings,
+ dropout_prob=self.config.hidden_dropout_prob,
+ initializer=tf.keras.initializers.VarianceScaling(
+ scale=self.config.initializer_gain,
+ mode="fan_avg",
+ distribution="uniform"),
+ name="embedding_postprocessor")
+ # Decoder can use a different intermediate size.
+ self.multi_channel_cross_attention = self.config.get(
+ "multi_channel_cross_attention", False)
+ self.decoder = TransformerDecoder(
+ num_hidden_layers=self.config.num_decoder_layers,
+ hidden_size=self.config.hidden_size,
+ num_attention_heads=self.config.num_decoder_attn_heads,
+ intermediate_size=self.config.decoder_intermediate_size,
+ intermediate_activation=self.config.hidden_act,
+ hidden_dropout_prob=self.config.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
+ initializer_range=self.config.initializer_range,
+ multi_channel_cross_attention=self.multi_channel_cross_attention,
+ name="decoder")
+ super(Decoder, self).build(unused_input_shapes)
+
+ def _decoding_step_time_signal(self, target_embeds, decode_loop_step):
+ """Applies time signal (positional embeddings) for decoded embeddings."""
+ # TODO(hongkuny): migrate to keras bert and design a module to handle this.
+ output = target_embeds
+ if self.embedding_postprocessor.use_position_embeddings:
+ position_embeddings = tf.gather(
+ self.embedding_postprocessor.position_embeddings, [decode_loop_step])
+ # Broadcasts to all sequences inside a batch.
+ output += position_embeddings
+
+ output = self.embedding_postprocessor.output_layer_norm(output)
+ output = self.embedding_postprocessor.output_dropout(output)
+ return output
+
+ def call(self,
+ inputs,
+ cache=None,
+ decode_loop_step=None,
+ padded_decode=False):
+ """Implements call() for the layer.
+
+ Args:
+ inputs: a list of input tensors.
+ cache: A dictionary of cache tensors, including key & value attentions.
+ Due to the limit of keras, we uses the side effect to update cache and
+ states of tensors will be mutated.
+ decode_loop_step: an integer to indicate the step inside a decoding loop.
+ padded_decode: a boolean indicates if the pass is for padded decoding.
+
+ Returns:
+ Decoder output tensors.
+ """
+ attention_bias = inputs["attention_bias"]
+ target_ids = inputs["target_ids"]
+ all_encoder_outputs = inputs["all_encoder_outputs"]
+ self_attention_bias = inputs["self_attention_bias"]
+ if not isinstance(all_encoder_outputs, list):
+ all_encoder_outputs = [all_encoder_outputs]
+
+ target_embeds = self.embedding_lookup(target_ids)
+ if decode_loop_step is None:
+ target_embeds = self.embedding_postprocessor(target_embeds)
+ else:
+ target_embeds = self._decoding_step_time_signal(target_embeds,
+ decode_loop_step)
+ decoder_inputs = dict(
+ decoder_inputs=target_embeds,
+ encoder_outputs=all_encoder_outputs,
+ self_attention_mask=self_attention_bias,
+ attention_mask=attention_bias)
+ if self.multi_channel_cross_attention:
+ decoder_inputs["doc_attention_probs"] = inputs["doc_attention_probs"]
+ decode_outputs, cache = self.decoder(
+ decoder_inputs, cache, decode_loop_step if padded_decode else None)
+ return decode_outputs
diff --git a/models/official/nlp/nhnet/decoder_test.py b/models/official/nlp/nhnet/decoder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5effbdb090e9c08939bfc203091e960741700c6
--- /dev/null
+++ b/models/official/nlp/nhnet/decoder_test.py
@@ -0,0 +1,151 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for nlp.nhnet.decoder."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from official.nlp.modeling import layers
+from official.nlp.nhnet import configs
+from official.nlp.nhnet import decoder
+from official.nlp.nhnet import utils
+
+
+class DecoderTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(DecoderTest, self).setUp()
+ self._config = utils.get_test_params()
+
+ def test_transformer_decoder(self):
+ decoder_block = decoder.TransformerDecoder(
+ num_hidden_layers=self._config.num_hidden_layers,
+ hidden_size=self._config.hidden_size,
+ num_attention_heads=self._config.num_attention_heads,
+ intermediate_size=self._config.intermediate_size,
+ intermediate_activation=self._config.hidden_act,
+ hidden_dropout_prob=self._config.hidden_dropout_prob,
+ attention_probs_dropout_prob=self._config.attention_probs_dropout_prob,
+ initializer_range=self._config.initializer_range)
+ decoder_block.build(None)
+ self.assertEqual(len(decoder_block.layers), self._config.num_hidden_layers)
+
+ def test_bert_decoder(self):
+ seq_length = 10
+ encoder_input_ids = tf.keras.layers.Input(
+ shape=(seq_length,), name="encoder_input_ids", dtype=tf.int32)
+ target_ids = tf.keras.layers.Input(
+ shape=(seq_length,), name="target_ids", dtype=tf.int32)
+ encoder_outputs = tf.keras.layers.Input(
+ shape=(seq_length, self._config.hidden_size),
+ name="all_encoder_outputs",
+ dtype=tf.float32)
+ embedding_lookup = layers.OnDeviceEmbedding(
+ vocab_size=self._config.vocab_size,
+ embedding_width=self._config.hidden_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=self._config.initializer_range),
+ name="word_embeddings")
+ cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")(
+ encoder_input_ids)
+ self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
+ target_ids)
+ inputs = dict(
+ attention_bias=cross_attention_bias,
+ self_attention_bias=self_attention_bias,
+ target_ids=target_ids,
+ all_encoder_outputs=encoder_outputs)
+ decoder_layer = decoder.Decoder(self._config, embedding_lookup)
+ outputs = decoder_layer(inputs)
+ model_inputs = dict(
+ encoder_input_ids=encoder_input_ids,
+ target_ids=target_ids,
+ all_encoder_outputs=encoder_outputs)
+ model = tf.keras.Model(inputs=model_inputs, outputs=outputs, name="test")
+ self.assertLen(decoder_layer.trainable_weights, 30)
+ # Forward path.
+ fake_inputs = {
+ "encoder_input_ids": np.zeros((2, 10), dtype=np.int32),
+ "target_ids": np.zeros((2, 10), dtype=np.int32),
+ "all_encoder_outputs": np.zeros((2, 10, 16), dtype=np.float32),
+ }
+ output_tensor = model(fake_inputs)
+ self.assertEqual(output_tensor.shape, (2, 10, 16))
+
+ def test_multi_doc_decoder(self):
+ self._config = utils.get_test_params(cls=configs.NHNetConfig)
+ seq_length = 10
+ num_docs = 5
+ encoder_input_ids = tf.keras.layers.Input(
+ shape=(num_docs, seq_length), name="encoder_input_ids", dtype=tf.int32)
+ target_ids = tf.keras.layers.Input(
+ shape=(seq_length,), name="target_ids", dtype=tf.int32)
+ encoder_outputs = tf.keras.layers.Input(
+ shape=(num_docs, seq_length, self._config.hidden_size),
+ name="all_encoder_outputs",
+ dtype=tf.float32)
+ embedding_lookup = layers.OnDeviceEmbedding(
+ vocab_size=self._config.vocab_size,
+ embedding_width=self._config.hidden_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=self._config.initializer_range),
+ name="word_embeddings")
+ doc_attention_probs = tf.keras.layers.Input(
+ shape=(self._config.num_decoder_attn_heads, seq_length, num_docs),
+ name="doc_attention_probs",
+ dtype=tf.float32)
+ cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")(
+ encoder_input_ids)
+ self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
+ target_ids)
+
+ inputs = dict(
+ attention_bias=cross_attention_bias,
+ self_attention_bias=self_attention_bias,
+ target_ids=target_ids,
+ all_encoder_outputs=encoder_outputs,
+ doc_attention_probs=doc_attention_probs)
+
+ decoder_layer = decoder.Decoder(self._config, embedding_lookup)
+ outputs = decoder_layer(inputs)
+ model_inputs = dict(
+ encoder_input_ids=encoder_input_ids,
+ target_ids=target_ids,
+ all_encoder_outputs=encoder_outputs,
+ doc_attention_probs=doc_attention_probs)
+ model = tf.keras.Model(inputs=model_inputs, outputs=outputs, name="test")
+ self.assertLen(decoder_layer.trainable_weights, 30)
+ # Forward path.
+ fake_inputs = {
+ "encoder_input_ids":
+ np.zeros((2, num_docs, seq_length), dtype=np.int32),
+ "target_ids":
+ np.zeros((2, seq_length), dtype=np.int32),
+ "all_encoder_outputs":
+ np.zeros((2, num_docs, seq_length, 16), dtype=np.float32),
+ "doc_attention_probs":
+ np.zeros(
+ (2, self._config.num_decoder_attn_heads, seq_length, num_docs),
+ dtype=np.float32)
+ }
+ output_tensor = model(fake_inputs)
+ self.assertEqual(output_tensor.shape, (2, seq_length, 16))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/nhnet/evaluation.py b/models/official/nlp/nhnet/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9c94dcfb71aa763c2acab5ffd022db94c20d776
--- /dev/null
+++ b/models/official/nlp/nhnet/evaluation.py
@@ -0,0 +1,185 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Evaluation for Bert2Bert."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+from absl import logging
+import numpy as np
+import tensorflow as tf
+
+from official.nlp.nhnet import input_pipeline
+from official.nlp.nhnet import models
+from official.nlp.transformer import metrics as metrics_v2
+from official.nlp.transformer.utils import metrics
+
+
+def rouge_l_fscore(logits, labels):
+ """ROUGE scores computation between labels and predictions.
+
+ This is an approximate ROUGE scoring method since we do not glue word pieces
+ or decode the ids and tokenize the output.
+
+ Args:
+ logits: tensor, model predictions
+ labels: tensor, gold output.
+
+ Returns:
+ rouge_l_fscore: approx rouge-l f1 score.
+ """
+ predictions = np.argmax(logits, axis=-1)
+ rouge_l_f_score = metrics.rouge_l_sentence_level(predictions, labels)
+ return rouge_l_f_score
+
+
+def rouge_2_fscore(logits, labels):
+ """ROUGE-2 F1 score computation between labels and predictions.
+
+ This is an approximate ROUGE scoring method since we do not glue word pieces
+ or decode the ids and tokenize the output.
+
+ Args:
+ logits: tensor, model predictions
+ labels: tensor, gold output.
+
+ Returns:
+ rouge2_fscore: approx rouge-2 f1 score.
+ """
+ predictions = np.argmax(logits, axis=-1)
+ rouge_2_f_score = metrics.rouge_n(predictions, labels)
+ return rouge_2_f_score
+
+
+def bleu_score(logits, labels):
+ """Approximate BLEU score computation between labels and predictions.
+
+ An approximate BLEU scoring method since we do not glue word pieces or
+ decode the ids and tokenize the output. By default, we use ngram order of 4
+ and use brevity penalty. Also, this does not have beam search.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch-size, length_labels]
+
+ Returns:
+ bleu: int, approx bleu score
+ """
+ predictions = np.argmax(logits, axis=-1)
+ bleu = metrics.compute_bleu(labels, predictions)
+ return bleu
+
+
+def continuous_eval(strategy,
+ params,
+ model_type,
+ eval_file_pattern=None,
+ batch_size=4,
+ eval_steps=None,
+ model_dir=None,
+ timeout=3000):
+ """Continuously evaluate checkpoints on testing data."""
+ test_dataset = input_pipeline.get_input_dataset(
+ eval_file_pattern,
+ batch_size=batch_size,
+ params=params,
+ is_training=False,
+ strategy=strategy)
+
+ with strategy.scope():
+ model = models.create_model(model_type, params)
+ metric_layer = metrics_v2.MetricLayer(params.vocab_size)
+ eval_summary_writer = tf.summary.create_file_writer(
+ os.path.join(model_dir, "summaries/eval"))
+ global_step = tf.Variable(
+ 0,
+ trainable=False,
+ dtype=tf.int64,
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+ shape=[])
+ model.global_step = global_step
+
+ @tf.function
+ def test_step(inputs):
+ """Calculates evaluation metrics on distributed devices."""
+
+ def _test_step_fn(inputs):
+ """Replicated accuracy calculation."""
+ targets = models.remove_sos_from_seq(inputs["target_ids"],
+ params.pad_token_id)
+
+ # Using ground truth sequences as targets to calculate logits for accuracy
+ # and perplexity metrics.
+ logits, _, _ = model(inputs, training=False, mode="train")
+ metric_layer([logits, targets])
+
+ # Get logits from top beam search results for bleu and rouge metrics.
+ logits = model(inputs, training=False, mode="eval")
+
+ return targets, logits
+
+ outputs = strategy.run(_test_step_fn, args=(inputs,))
+
+ return tf.nest.map_structure(strategy.experimental_local_results, outputs)
+
+ metrics_and_funcs = [
+ (tf.keras.metrics.Mean("bleu", dtype=tf.float32), bleu_score),
+ (tf.keras.metrics.Mean("rouge_2_fscore",
+ dtype=tf.float32), rouge_2_fscore),
+ (tf.keras.metrics.Mean("rouge_l_fscore",
+ dtype=tf.float32), rouge_l_fscore),
+ ]
+ eval_results = {}
+ for latest_checkpoint in tf.train.checkpoints_iterator(
+ model_dir, timeout=timeout):
+ checkpoint = tf.train.Checkpoint(model=model)
+ checkpoint.restore(latest_checkpoint).expect_partial()
+ logging.info("Loaded checkpoint %s", latest_checkpoint)
+
+ for i, inputs in enumerate(test_dataset):
+ if eval_steps and i >= eval_steps:
+ break
+ outputs = test_step(inputs)
+ for metric, func in metrics_and_funcs:
+ for targets, logits in zip(outputs[0], outputs[1]):
+ metric.update_state(func(logits.numpy(), targets.numpy()))
+
+ with eval_summary_writer.as_default():
+ step = model.global_step.numpy()
+ for metric, _ in metrics_and_funcs:
+ eval_results[metric.name] = metric.result().numpy().astype(float)
+ tf.summary.scalar(
+ metric.name,
+ eval_results[metric.name],
+ step=step)
+ for metric in metric_layer.metrics:
+ eval_results[metric.name] = metric.result().numpy().astype(float)
+ tf.summary.scalar(
+ metric.name,
+ eval_results[metric.name],
+ step=step)
+ logging.info("Step %d Metrics= %s", step, str(eval_results))
+ eval_summary_writer.flush()
+
+ # Resets metrics.
+ for metric, _ in metrics_and_funcs:
+ metric.reset_states()
+ for metric in metric_layer.metrics:
+ metric.reset_states()
+ return eval_results
diff --git a/models/official/nlp/nhnet/input_pipeline.py b/models/official/nlp/nhnet/input_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..cadf3f085c868e56039679fdb2124b23f33fc19b
--- /dev/null
+++ b/models/official/nlp/nhnet/input_pipeline.py
@@ -0,0 +1,254 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Input pipelines."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.compat.v2 as tf
+
+
+def decode_record(record, name_to_features):
+ """Decodes a record to a TensorFlow example."""
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+
+def process_singledoc_dataset(dataset, batch_size, params):
+ """Parses and batches single-doc dataset."""
+ name_to_features = {
+ "input_ids_a": tf.io.FixedLenFeature([params.len_title], tf.int64),
+ "input_ids_b": tf.io.FixedLenFeature([params.len_passage], tf.int64),
+ "input_mask_b": tf.io.FixedLenFeature([params.len_passage], tf.int64),
+ "segment_ids_b": tf.io.FixedLenFeature([params.len_passage], tf.int64),
+ }
+ decode_fn = lambda record: decode_record(record, name_to_features)
+ dataset = dataset.map(
+ decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ def _select_data_from_record(record):
+ """Filter out features to use for pretraining."""
+ return {
+ "input_ids": record["input_ids_b"],
+ "input_mask": record["input_mask_b"],
+ "segment_ids": record["segment_ids_b"],
+ "target_ids": record["input_ids_a"],
+ }
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ return dataset
+
+
+def decode_sparse_record(record, name_to_features):
+ """Decodes a sparse record to a TensorFlow example."""
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = tf.sparse.to_dense(t)
+
+ return example
+
+
+def _filter_max_length(example, max_title_length=256):
+ """Indicates whether the example's length is lower than the maximum length."""
+ return tf.size(example["targets"]) <= max_title_length
+
+
+def process_singledoc_transformer_dataset(dataset, batch_size, params):
+ """Parses, batches and pads single-doc dataset."""
+ name_to_features = {
+ "inputs": tf.io.VarLenFeature(tf.int64),
+ "targets": tf.io.VarLenFeature(tf.int64),
+ }
+ decode_fn = lambda record: decode_sparse_record(record, name_to_features)
+ dataset = dataset.map(
+ decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ def _select_data_from_record(record):
+ """Filter out features to use for pretraining."""
+ input_ids = record["inputs"][:params.len_passage]
+ target_ids = record["targets"]
+ input_mask = tf.ones_like(input_ids)
+ segment_ids = tf.zeros_like(input_ids)
+ return {
+ "input_ids": input_ids,
+ "input_mask": input_mask,
+ "segment_ids": segment_ids,
+ "target_ids": target_ids,
+ }
+
+ dataset = dataset.filter(lambda x: _filter_max_length(x, params.len_title))
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ dataset = dataset.padded_batch(
+ batch_size, {
+ "input_ids": [params.len_passage],
+ "input_mask": [params.len_passage],
+ "segment_ids": [params.len_passage],
+ "target_ids": [params.len_title],
+ },
+ padding_values={
+ "input_ids": params.pad_token_id,
+ "input_mask": 0,
+ "segment_ids": 0,
+ "target_ids": params.pad_token_id,
+ },
+ drop_remainder=True)
+
+ return dataset
+
+
+def multidoc_parse_spec(params, training=True):
+ """Gets the mutli-doc tf.Example parsing spec."""
+ len_p = params.len_passage
+ name_to_features = {}
+ feature_list = ["input_ids", "input_mask", "segment_ids"]
+ for idx in params.passage_list:
+ for feature in feature_list:
+ name_to_features["%s_%s" % (feature, idx)] = tf.io.FixedLenFeature(
+ [len_p], tf.int64)
+ if training:
+ # Cluster title.
+ name_to_features["input_ids_a"] = tf.io.FixedLenFeature([params.len_title],
+ tf.int64)
+ return name_to_features, feature_list
+
+
+def process_multidoc_dataset(dataset, batch_size, params):
+ """Parses, organizes and batches multi-doc dataset."""
+ name_to_features, feature_list = multidoc_parse_spec(params)
+ decode_fn = lambda record: decode_record(record, name_to_features)
+ dataset = dataset.map(
+ decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ def _select_data_from_record(record):
+ """Filter out features to use for pretraining."""
+ features = {"target_ids": record["input_ids_a"]}
+ for feature in feature_list:
+ tensors = [record["%s_%s" % (feature, i)] for i in params.passage_list]
+ features[feature] = tf.stack(tensors)
+ return features
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ return dataset
+
+
+def create_dataset(file_paths,
+ batch_size,
+ params,
+ is_training=True,
+ input_pipeline_context=None):
+ """Creates input dataset from (tf)records files for pretraining."""
+ dataset = tf.data.Dataset.list_files(file_paths, shuffle=is_training)
+
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ if not is_training or params.input_sharding:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+
+ if is_training:
+ dataset = dataset.repeat()
+ # We set shuffle buffer to exactly match total number of
+ # training files to ensure that training data is well shuffled.
+ dataset = dataset.shuffle(len(file_paths))
+
+ # In parallel, create tf record dataset for each train files.
+ # cycle_length = 8 means that up to 8 files will be read and deserialized in
+ # parallel. You may want to increase this number if you have a large number of
+ # CPU cores.
+ dataset = dataset.interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=8,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if is_training:
+ dataset = dataset.shuffle(100)
+
+ if params.get("multi_channel_cross_attention", value=False):
+ dataset = process_multidoc_dataset(dataset, batch_size, params)
+ else:
+ if not params.input_data_not_padded:
+ dataset = process_singledoc_dataset(dataset, batch_size, params)
+ else:
+ dataset = process_singledoc_transformer_dataset(dataset, batch_size,
+ params)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def get_input_dataset(input_file_pattern,
+ batch_size,
+ params,
+ is_training,
+ strategy=None):
+ """Returns input dataset from input file string."""
+
+ # When using TPU pods, we need to clone dataset across
+ # workers and need to pass in function that returns the dataset rather
+ # than passing dataset instance itself.
+ use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
+ if use_dataset_fn:
+ if batch_size % strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ "Batch size must be divisible by number of replicas : {}".format(
+ strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `experimental_distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ batch_size = int(batch_size / strategy.num_replicas_in_sync)
+
+ def _dataset_fn(ctx=None):
+ """Returns tf.data.Dataset for distributed BERT pretraining."""
+ input_files = []
+ for input_pattern in input_file_pattern.split(","):
+ input_files.extend(tf.io.gfile.glob(input_pattern))
+
+ return create_dataset(
+ input_files,
+ batch_size,
+ params,
+ is_training=is_training,
+ input_pipeline_context=ctx)
+
+ if use_dataset_fn:
+ return strategy.experimental_distribute_datasets_from_function(_dataset_fn)
+ else:
+ return strategy.experimental_distribute_dataset(_dataset_fn())
diff --git a/models/official/nlp/nhnet/models.py b/models/official/nlp/nhnet/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6f70e7f36d8a30ed869c1ca135ef3262fd2150e
--- /dev/null
+++ b/models/official/nlp/nhnet/models.py
@@ -0,0 +1,590 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tf.keras Models for NHNet."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import logging
+import gin
+import tensorflow as tf
+from typing import Optional, Text
+
+from official.modeling import tf_utils
+from official.modeling.hyperparams import params_dict
+from official.nlp.modeling import networks
+from official.nlp.modeling.layers import multi_channel_attention
+from official.nlp.nhnet import configs
+from official.nlp.nhnet import decoder
+from official.nlp.nhnet import utils
+from official.nlp.transformer import beam_search
+
+
+def embedding_linear(embedding_matrix, x):
+ """Uses embeddings as linear transformation weights."""
+ with tf.name_scope("presoftmax_linear"):
+ batch_size = tf.shape(x)[0]
+ length = tf.shape(x)[1]
+ hidden_size = tf.shape(x)[2]
+ vocab_size = tf.shape(embedding_matrix)[0]
+
+ x = tf.reshape(x, [-1, hidden_size])
+ logits = tf.matmul(x, embedding_matrix, transpose_b=True)
+
+ return tf.reshape(logits, [batch_size, length, vocab_size])
+
+
+def _add_sos_to_seq(seq, start_token_id):
+ """Add a start sequence token while keeping seq length."""
+ batch_size = tf.shape(seq)[0]
+ seq_len = tf.shape(seq)[1]
+ sos_ids = tf.ones([batch_size], tf.int32) * start_token_id
+ targets = tf.concat([tf.expand_dims(sos_ids, axis=1), seq], axis=1)
+ targets = targets[:, :-1]
+ tf.assert_equal(tf.shape(targets), (batch_size, seq_len))
+ return targets
+
+
+def remove_sos_from_seq(seq, pad_token_id):
+ """Remove the start sequence token while keeping seq length."""
+ batch_size, seq_len = tf_utils.get_shape_list(seq, expected_rank=2)
+ # remove
+ targets = seq[:, 1:]
+ # pad
+ pad_ids = tf.ones([batch_size], tf.int32) * pad_token_id
+ targets = tf.concat([targets, tf.expand_dims(pad_ids, axis=1)], axis=1)
+ tf.assert_equal(tf.shape(targets), (batch_size, seq_len))
+ return targets
+
+
+class Bert2Bert(tf.keras.Model):
+ """Bert2Bert encoder decoder model for training."""
+
+ def __init__(self, params, bert_layer, decoder_layer, name=None):
+ super(Bert2Bert, self).__init__(name=name)
+ self.params = params
+ if not bert_layer.built:
+ raise ValueError("bert_layer should be built.")
+ if not decoder_layer.built:
+ raise ValueError("decoder_layer should be built.")
+ self.bert_layer = bert_layer
+ self.decoder_layer = decoder_layer
+
+ def get_config(self):
+ return {"params": self.params.as_dict()}
+
+ def get_decode_logits(self,
+ decoder_inputs,
+ ids,
+ decoder_self_attention_bias,
+ step,
+ cache=None):
+ if cache:
+ if self.params.get("padded_decode", False):
+ bias_shape = decoder_self_attention_bias.shape.as_list()
+ self_attention_bias = tf.slice(
+ decoder_self_attention_bias, [0, 0, step, 0],
+ [bias_shape[0], bias_shape[1], 1, bias_shape[3]])
+ else:
+ self_attention_bias = decoder_self_attention_bias[:, :, step:step +
+ 1, :step + 1]
+ # Sets decoder input to the last generated IDs.
+ decoder_input = ids[:, -1:]
+ else:
+ self_attention_bias = decoder_self_attention_bias[:, :, :step + 1, :step +
+ 1]
+ decoder_input = ids
+ decoder_inputs["target_ids"] = decoder_input
+ decoder_inputs["self_attention_bias"] = self_attention_bias
+ if cache:
+ decoder_outputs = self.decoder_layer(
+ decoder_inputs,
+ cache,
+ decode_loop_step=step,
+ padded_decode=self.params.get("padded_decode", False))
+ else:
+ decoder_outputs = self.decoder_layer(decoder_inputs)
+ logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings,
+ decoder_outputs[:, -1:, :])
+ logits = tf.squeeze(logits, axis=[1])
+ return logits
+
+ def _get_symbols_to_logits_fn(self, max_decode_length):
+ """Returns a decoding function that calculates logits of the next tokens."""
+ # Max decode length should be smaller than the positional embedding max
+ # sequence length.
+ decoder_self_attention_bias = decoder.get_attention_bias(
+ input_tensor=None,
+ bias_type="decoder_self",
+ max_length=max_decode_length)
+
+ def _symbols_to_logits_fn(ids, i, cache):
+ """Generate logits for next candidate IDs.
+
+ Args:
+ ids: Current decoded sequences. int tensor with shape [batch_size *
+ beam_size, i + 1]
+ i: Loop index
+ cache: dictionary of values storing the encoder output, encoder-decoder
+ attention bias, and previous decoder attention values.
+
+ Returns:
+ Tuple of
+ (logits with shape [batch_size * beam_size, vocab_size],
+ updated cache values)
+ """
+ decoder_inputs = dict(
+ all_encoder_outputs=cache["all_encoder_outputs"],
+ attention_bias=cache["attention_bias"])
+ logits = self.get_decode_logits(
+ decoder_inputs,
+ ids,
+ decoder_self_attention_bias,
+ step=i,
+ cache=cache if self.params.use_cache else None)
+ return logits, cache
+
+ return _symbols_to_logits_fn
+
+ def train_decode(self, decode_outputs):
+ logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings,
+ decode_outputs)
+ decode_output_ids = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ output_log_probs = tf.nn.log_softmax(logits, axis=-1)
+ return logits, decode_output_ids, output_log_probs
+
+ def predict_decode(self, start_token_ids, cache):
+ symbols_to_logits_fn = self._get_symbols_to_logits_fn(self.params.len_title)
+ # Use beam search to find the top beam_size sequences and scores.
+ decoded_ids, scores = beam_search.sequence_beam_search(
+ symbols_to_logits_fn=symbols_to_logits_fn,
+ initial_ids=start_token_ids,
+ initial_cache=cache,
+ vocab_size=self.params.vocab_size,
+ beam_size=self.params.beam_size,
+ alpha=self.params.alpha,
+ max_decode_length=self.params.len_title,
+ padded_decode=self.params.get("padded_decode", False),
+ eos_id=self.params.end_token_id)
+ return decoded_ids, scores
+
+ def _get_logits_for_decode_ids(self, decoder_inputs, top_decoded_ids):
+ """Returns the log probabilities for ids."""
+ target_ids = _add_sos_to_seq(top_decoded_ids, self.params.start_token_id)
+ decoder_inputs["self_attention_bias"] = decoder.get_attention_bias(
+ target_ids, bias_type="decoder_self")
+ decoder_inputs["target_ids"] = target_ids
+ decoder_outputs = self.decoder_layer(decoder_inputs)
+ logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings,
+ decoder_outputs)
+ return logits
+
+ def _init_cache(self, batch_size):
+ num_heads = self.params.num_decoder_attn_heads
+ dim_per_head = self.params.hidden_size // num_heads
+ init_decode_length = (
+ self.params.len_title if self.params.get("padded_decode", False) else 0)
+ cache = {}
+ for layer in range(self.params.num_decoder_layers):
+ cache[str(layer)] = {
+ "key":
+ tf.zeros(
+ [batch_size, init_decode_length, num_heads, dim_per_head],
+ dtype=tf.float32),
+ "value":
+ tf.zeros(
+ [batch_size, init_decode_length, num_heads, dim_per_head],
+ dtype=tf.float32)
+ }
+ return cache
+
+ def call(self, inputs, mode="train"):
+ """Implements call().
+
+ Args:
+ inputs: a dictionary of tensors.
+ mode: string, an enum for mode, train/eval.
+
+ Returns:
+ logits, decode_output_ids, output_log_probs for training. top_decoded_ids
+ for eval.
+ """
+ input_ids = inputs["input_ids"]
+ input_mask = inputs["input_mask"]
+ segment_ids = inputs["segment_ids"]
+ all_encoder_outputs, _ = self.bert_layer(
+ [input_ids, input_mask, segment_ids])
+
+ if mode not in ("train", "eval", "predict"):
+ raise ValueError("Invalid call mode: %s" % mode)
+ encoder_decoder_attention_bias = decoder.get_attention_bias(
+ input_ids,
+ bias_type="single_cross",
+ padding_value=self.params.pad_token_id)
+ if mode == "train":
+ self_attention_bias = decoder.get_attention_bias(
+ inputs["target_ids"], bias_type="decoder_self")
+ decoder_inputs = dict(
+ attention_bias=encoder_decoder_attention_bias,
+ all_encoder_outputs=all_encoder_outputs,
+ target_ids=inputs["target_ids"],
+ self_attention_bias=self_attention_bias)
+ decoder_outputs = self.decoder_layer(decoder_inputs)
+ return self.train_decode(decoder_outputs)
+
+ batch_size = tf.shape(input_ids)[0]
+ start_token_ids = tf.ones([batch_size],
+ tf.int32) * self.params.start_token_id
+ # Add encoder output and attention bias to the cache.
+ if self.params.use_cache:
+ cache = self._init_cache(batch_size)
+ else:
+ cache = {}
+ cache["all_encoder_outputs"] = all_encoder_outputs
+ cache["attention_bias"] = encoder_decoder_attention_bias
+ decoded_ids, scores = self.predict_decode(start_token_ids, cache)
+ if mode == "predict":
+ return decoded_ids[:, :self.params.beam_size,
+ 1:], scores[:, :self.params.beam_size]
+
+ decoder_inputs = dict(
+ attention_bias=encoder_decoder_attention_bias,
+ all_encoder_outputs=all_encoder_outputs)
+ top_decoded_ids = decoded_ids[:, 0, 1:]
+ return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids)
+
+
+class NHNet(Bert2Bert):
+ """NHNet model which performs multi-doc decoding."""
+
+ def __init__(self, params, bert_layer, decoder_layer, name=None):
+ super(NHNet, self).__init__(params, bert_layer, decoder_layer, name=name)
+ self.doc_attention = multi_channel_attention.VotingAttention(
+ num_heads=params.num_decoder_attn_heads,
+ head_size=params.hidden_size // params.num_decoder_attn_heads)
+
+ def _expand_doc_attention_probs(self, doc_attention_probs, target_length):
+ """Expands doc attention probs to fit the decoding sequence length."""
+ doc_attention_probs = tf.expand_dims(
+ doc_attention_probs, axis=[1]) # [B, 1, A]
+ doc_attention_probs = tf.expand_dims(
+ doc_attention_probs, axis=[2]) # [B, 1, 1, A]
+ return tf.tile(doc_attention_probs,
+ [1, self.params.num_decoder_attn_heads, target_length, 1])
+
+ def _get_symbols_to_logits_fn(self, max_decode_length):
+ """Returns a decoding function that calculates logits of the next tokens."""
+ # Max decode length should be smaller than the positional embedding max
+ # sequence length.
+ decoder_self_attention_bias = decoder.get_attention_bias(
+ input_tensor=None,
+ bias_type="decoder_self",
+ max_length=max_decode_length)
+
+ def _symbols_to_logits_fn(ids, i, cache):
+ """Generate logits for next candidate IDs."""
+ if self.params.use_cache:
+ target_length = 1
+ else:
+ target_length = i + 1
+ decoder_inputs = dict(
+ doc_attention_probs=self._expand_doc_attention_probs(
+ cache["doc_attention_probs"], target_length),
+ all_encoder_outputs=cache["all_encoder_outputs"],
+ attention_bias=cache["attention_bias"])
+ logits = self.get_decode_logits(
+ decoder_inputs,
+ ids,
+ decoder_self_attention_bias,
+ step=i,
+ cache=cache if self.params.use_cache else None)
+ return logits, cache
+
+ return _symbols_to_logits_fn
+
+ def call(self, inputs, mode="training"):
+ input_shape = tf_utils.get_shape_list(inputs["input_ids"], expected_rank=3)
+ batch_size, num_docs, len_passage = (input_shape[0], input_shape[1],
+ input_shape[2])
+ input_ids = tf.reshape(inputs["input_ids"], [-1, len_passage])
+ input_mask = tf.reshape(inputs["input_mask"], [-1, len_passage])
+ segment_ids = tf.reshape(inputs["segment_ids"], [-1, len_passage])
+ all_encoder_outputs, _ = self.bert_layer(
+ [input_ids, input_mask, segment_ids])
+ encoder_outputs = tf.reshape(
+ all_encoder_outputs[-1],
+ [batch_size, num_docs, len_passage, self.params.hidden_size])
+ doc_attention_mask = tf.reshape(
+ tf.cast(
+ tf.math.count_nonzero(input_mask, axis=1, dtype=tf.int32) > 2,
+ tf.int32), [batch_size, num_docs])
+
+ doc_attention_probs = self.doc_attention(encoder_outputs,
+ doc_attention_mask)
+ encoder_decoder_attention_bias = decoder.get_attention_bias(
+ inputs["input_ids"],
+ bias_type="multi_cross",
+ padding_value=self.params.pad_token_id)
+
+ if mode == "train":
+ target_length = tf_utils.get_shape_list(
+ inputs["target_ids"], expected_rank=2)[1]
+ doc_attention_probs = self._expand_doc_attention_probs(
+ doc_attention_probs, target_length)
+ self_attention_bias = decoder.get_attention_bias(
+ inputs["target_ids"], bias_type="decoder_self")
+ decoder_inputs = dict(
+ attention_bias=encoder_decoder_attention_bias,
+ self_attention_bias=self_attention_bias,
+ target_ids=inputs["target_ids"],
+ all_encoder_outputs=encoder_outputs,
+ doc_attention_probs=doc_attention_probs)
+ decoder_outputs = self.decoder_layer(decoder_inputs)
+ return self.train_decode(decoder_outputs)
+
+ # Adds encoder output and attention bias to the cache.
+ if self.params.use_cache:
+ cache = self._init_cache(batch_size)
+ else:
+ cache = {}
+ cache["all_encoder_outputs"] = [encoder_outputs]
+ cache["attention_bias"] = encoder_decoder_attention_bias
+ cache["doc_attention_probs"] = doc_attention_probs
+
+ start_token_ids = tf.ones([batch_size],
+ tf.int32) * self.params.start_token_id
+ decoded_ids, scores = self.predict_decode(start_token_ids, cache)
+ if mode == "predict":
+ return decoded_ids[:, :self.params.beam_size,
+ 1:], scores[:, :self.params.beam_size]
+
+ top_decoded_ids = decoded_ids[:, 0, 1:]
+ target_length = tf_utils.get_shape_list(top_decoded_ids)[-1]
+ decoder_inputs = dict(
+ attention_bias=encoder_decoder_attention_bias,
+ all_encoder_outputs=[encoder_outputs],
+ doc_attention_probs=self._expand_doc_attention_probs(
+ doc_attention_probs, target_length))
+ return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids)
+
+
+def get_bert2bert_layers(params: configs.BERT2BERTConfig):
+ """Creates a Bert2Bert stem model and returns Bert encoder/decoder.
+
+ We use funtional-style to create stem model because we need to make all layers
+ built to restore variables in a customized way. The layers are called with
+ placeholder inputs to make them fully built.
+
+ Args:
+ params: ParamsDict.
+
+ Returns:
+ two keras Layers, bert_model_layer and decoder_layer
+ """
+ input_ids = tf.keras.layers.Input(
+ shape=(None,), name="input_ids", dtype=tf.int32)
+ input_mask = tf.keras.layers.Input(
+ shape=(None,), name="input_mask", dtype=tf.int32)
+ segment_ids = tf.keras.layers.Input(
+ shape=(None,), name="segment_ids", dtype=tf.int32)
+ target_ids = tf.keras.layers.Input(
+ shape=(None,), name="target_ids", dtype=tf.int32)
+ bert_config = utils.get_bert_config_from_params(params)
+ bert_model_layer = networks.TransformerEncoder(
+ vocab_size=bert_config.vocab_size,
+ hidden_size=bert_config.hidden_size,
+ num_layers=bert_config.num_hidden_layers,
+ num_attention_heads=bert_config.num_attention_heads,
+ intermediate_size=bert_config.intermediate_size,
+ activation=tf_utils.get_activation(bert_config.hidden_act),
+ dropout_rate=bert_config.hidden_dropout_prob,
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
+ sequence_length=None,
+ max_sequence_length=bert_config.max_position_embeddings,
+ type_vocab_size=bert_config.type_vocab_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range),
+ return_all_encoder_outputs=True,
+ name="bert_encoder")
+ all_encoder_outputs, _ = bert_model_layer(
+ [input_ids, input_mask, segment_ids])
+ # pylint: disable=protected-access
+ decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer)
+ # pylint: enable=protected-access
+ cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")(
+ input_ids)
+ self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
+ target_ids)
+ decoder_inputs = dict(
+ attention_bias=cross_attention_bias,
+ self_attention_bias=self_attention_bias,
+ target_ids=target_ids,
+ all_encoder_outputs=all_encoder_outputs)
+ _ = decoder_layer(decoder_inputs)
+
+ return bert_model_layer, decoder_layer
+
+
+def get_nhnet_layers(params: configs.NHNetConfig):
+ """Creates a Mult-doc encoder/decoder.
+
+ Args:
+ params: ParamsDict.
+
+ Returns:
+ two keras Layers, bert_model_layer and decoder_layer
+ """
+ input_ids = tf.keras.layers.Input(
+ shape=(None,), name="input_ids", dtype=tf.int32)
+ input_mask = tf.keras.layers.Input(
+ shape=(None,), name="input_mask", dtype=tf.int32)
+ segment_ids = tf.keras.layers.Input(
+ shape=(None,), name="segment_ids", dtype=tf.int32)
+ bert_config = utils.get_bert_config_from_params(params)
+ bert_model_layer = networks.TransformerEncoder(
+ vocab_size=bert_config.vocab_size,
+ hidden_size=bert_config.hidden_size,
+ num_layers=bert_config.num_hidden_layers,
+ num_attention_heads=bert_config.num_attention_heads,
+ intermediate_size=bert_config.intermediate_size,
+ activation=tf_utils.get_activation(bert_config.hidden_act),
+ dropout_rate=bert_config.hidden_dropout_prob,
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
+ sequence_length=None,
+ max_sequence_length=bert_config.max_position_embeddings,
+ type_vocab_size=bert_config.type_vocab_size,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range),
+ return_all_encoder_outputs=True,
+ name="bert_encoder")
+ bert_model_layer([input_ids, input_mask, segment_ids])
+
+ input_ids = tf.keras.layers.Input(
+ shape=(None, None), name="input_ids", dtype=tf.int32)
+ all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size),
+ dtype=tf.float32)
+ target_ids = tf.keras.layers.Input(
+ shape=(None,), name="target_ids", dtype=tf.int32)
+ doc_attention_probs = tf.keras.layers.Input(
+ (params.num_decoder_attn_heads, None, None), dtype=tf.float32)
+ # pylint: disable=protected-access
+ decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer)
+ # pylint: enable=protected-access
+ cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")(
+ input_ids)
+ self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
+ target_ids)
+ decoder_inputs = dict(
+ attention_bias=cross_attention_bias,
+ self_attention_bias=self_attention_bias,
+ target_ids=target_ids,
+ all_encoder_outputs=all_encoder_outputs,
+ doc_attention_probs=doc_attention_probs)
+ _ = decoder_layer(decoder_inputs)
+
+ return bert_model_layer, decoder_layer
+
+
+def create_transformer_model(params,
+ init_checkpoint: Optional[Text] = None
+ ) -> tf.keras.Model:
+ """A helper to create Transformer model."""
+ bert_layer, decoder_layer = get_bert2bert_layers(params=params)
+ model = Bert2Bert(
+ params=params,
+ bert_layer=bert_layer,
+ decoder_layer=decoder_layer,
+ name="transformer")
+
+ if init_checkpoint:
+ logging.info(
+ "Checkpoint file %s found and restoring from "
+ "initial checkpoint.", init_checkpoint)
+ ckpt = tf.train.Checkpoint(model=model)
+ ckpt.restore(init_checkpoint).expect_partial()
+
+ return model
+
+
+def create_bert2bert_model(
+ params: configs.BERT2BERTConfig,
+ cls=Bert2Bert,
+ init_checkpoint: Optional[Text] = None) -> tf.keras.Model:
+ """A helper to create Bert2Bert model."""
+ bert_layer, decoder_layer = get_bert2bert_layers(params=params)
+ if init_checkpoint:
+ utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer,
+ init_checkpoint)
+ return cls(
+ params=params,
+ bert_layer=bert_layer,
+ decoder_layer=decoder_layer,
+ name="bert2bert")
+
+
+def create_nhnet_model(
+ params: configs.NHNetConfig,
+ cls=NHNet,
+ init_checkpoint: Optional[Text] = None) -> tf.keras.Model:
+ """A helper to create NHNet model."""
+ bert_layer, decoder_layer = get_nhnet_layers(params=params)
+ model = cls(
+ params=params,
+ bert_layer=bert_layer,
+ decoder_layer=decoder_layer,
+ name="nhnet")
+ if init_checkpoint:
+ logging.info(
+ "Checkpoint file %s found and restoring from "
+ "initial checkpoint.", init_checkpoint)
+ if params.init_from_bert2bert:
+ ckpt = tf.train.Checkpoint(model=model)
+ ckpt.restore(init_checkpoint).assert_existing_objects_matched()
+ else:
+ utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer,
+ init_checkpoint)
+ return model
+
+
+@gin.configurable
+def get_model_params(model: Optional[Text] = "bert2bert",
+ config_class=None) -> params_dict.ParamsDict:
+ """Helper function to convert config file to ParamsDict."""
+ if model == "bert2bert":
+ return configs.BERT2BERTConfig()
+ elif model == "nhnet":
+ return configs.NHNetConfig()
+ elif config_class:
+ return config_class()
+ else:
+ raise KeyError("The model type is not defined: %s" % model)
+
+
+@gin.configurable
+def create_model(model_type: Text,
+ params,
+ init_checkpoint: Optional[Text] = None):
+ """A factory function to create different types of models."""
+ if model_type == "bert2bert":
+ return create_bert2bert_model(params, init_checkpoint=init_checkpoint)
+ elif model_type == "nhnet":
+ return create_nhnet_model(params, init_checkpoint=init_checkpoint)
+ elif "transformer" in model_type:
+ return create_transformer_model(
+ params, init_checkpoint=init_checkpoint)
+ else:
+ raise KeyError("The model type is not defined: %s" % model_type)
diff --git a/models/official/nlp/nhnet/models_test.py b/models/official/nlp/nhnet/models_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..39676a347d65e2dc19e99a7dec4d22dfb4c60df4
--- /dev/null
+++ b/models/official/nlp/nhnet/models_test.py
@@ -0,0 +1,324 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for nlp.nhnet.models."""
+
+import os
+
+from absl import logging
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+# pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+# pylint: enable=g-direct-tensorflow-import
+from official.nlp.nhnet import configs
+from official.nlp.nhnet import models
+from official.nlp.nhnet import utils
+
+
+def all_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ ],
+ mode="eager",
+ )
+
+
+def distribution_forward_path(strategy,
+ model,
+ inputs,
+ batch_size,
+ mode="train"):
+ dataset = tf.data.Dataset.from_tensor_slices((inputs))
+ dataset = dataset.batch(batch_size)
+ dataset = strategy.experimental_distribute_dataset(dataset)
+
+ @tf.function
+ def test_step(inputs):
+ """Calculates evaluation metrics on distributed devices."""
+
+ def _test_step_fn(inputs):
+ """Replicated accuracy calculation."""
+ return model(inputs, mode=mode, training=False)
+
+ outputs = strategy.run(_test_step_fn, args=(inputs,))
+ return tf.nest.map_structure(strategy.experimental_local_results, outputs)
+
+ return [test_step(inputs) for inputs in dataset]
+
+
+def process_decoded_ids(predictions, end_token_id):
+ """Transforms decoded tensors to lists ending with END_TOKEN_ID."""
+ if isinstance(predictions, tf.Tensor):
+ predictions = predictions.numpy()
+ flatten_ids = predictions.reshape((-1, predictions.shape[-1]))
+ results = []
+ for ids in flatten_ids:
+ ids = list(ids)
+ if end_token_id in ids:
+ ids = ids[:ids.index(end_token_id)]
+ results.append(ids)
+ return results
+
+
+class Bert2BertTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(Bert2BertTest, self).setUp()
+ self._config = utils.get_test_params()
+
+ def test_model_creation(self):
+ model = models.create_bert2bert_model(params=self._config)
+ fake_ids = np.zeros((2, 10), dtype=np.int32)
+ fake_inputs = {
+ "input_ids": fake_ids,
+ "input_mask": fake_ids,
+ "segment_ids": fake_ids,
+ "target_ids": fake_ids,
+ }
+ model(fake_inputs)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_bert2bert_train_forward(self, distribution):
+ seq_length = 10
+ # Defines the model inside distribution strategy scope.
+ with distribution.scope():
+ # Forward path.
+ batch_size = 2
+ batches = 4
+ fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32)
+ fake_inputs = {
+ "input_ids": fake_ids,
+ "input_mask": fake_ids,
+ "segment_ids": fake_ids,
+ "target_ids": fake_ids,
+ }
+ model = models.create_bert2bert_model(params=self._config)
+ results = distribution_forward_path(distribution, model, fake_inputs,
+ batch_size)
+ logging.info("Forward path results: %s", str(results))
+ self.assertLen(results, batches)
+
+ def test_bert2bert_decoding(self):
+ seq_length = 10
+ self._config.override(
+ {
+ "beam_size": 3,
+ "len_title": seq_length,
+ "alpha": 0.6,
+ },
+ is_strict=False)
+
+ batch_size = 2
+ fake_ids = np.zeros((batch_size, seq_length), dtype=np.int32)
+ fake_inputs = {
+ "input_ids": fake_ids,
+ "input_mask": fake_ids,
+ "segment_ids": fake_ids,
+ }
+ self._config.override({
+ "padded_decode": False,
+ "use_cache": False,
+ },
+ is_strict=False)
+ model = models.create_bert2bert_model(params=self._config)
+ ckpt = tf.train.Checkpoint(model=model)
+
+ # Initializes variables from checkpoint to keep outputs deterministic.
+ init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ ckpt.restore(init_checkpoint).assert_existing_objects_matched()
+ top_ids, scores = model(fake_inputs, mode="predict")
+
+ self._config.override({
+ "padded_decode": False,
+ "use_cache": True,
+ },
+ is_strict=False)
+ model = models.create_bert2bert_model(params=self._config)
+ ckpt = tf.train.Checkpoint(model=model)
+ ckpt.restore(init_checkpoint).assert_existing_objects_matched()
+ cached_top_ids, cached_scores = model(fake_inputs, mode="predict")
+ self.assertEqual(
+ process_decoded_ids(top_ids, self._config.end_token_id),
+ process_decoded_ids(cached_top_ids, self._config.end_token_id))
+ self.assertAllClose(scores, cached_scores)
+
+ self._config.override({
+ "padded_decode": True,
+ "use_cache": True,
+ },
+ is_strict=False)
+ model = models.create_bert2bert_model(params=self._config)
+ ckpt = tf.train.Checkpoint(model=model)
+ ckpt.restore(init_checkpoint).assert_existing_objects_matched()
+ padded_top_ids, padded_scores = model(fake_inputs, mode="predict")
+ self.assertEqual(
+ process_decoded_ids(top_ids, self._config.end_token_id),
+ process_decoded_ids(padded_top_ids, self._config.end_token_id))
+ self.assertAllClose(scores, padded_scores)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_bert2bert_eval(self, distribution):
+ seq_length = 10
+ padded_decode = isinstance(distribution,
+ tf.distribute.experimental.TPUStrategy)
+ self._config.override(
+ {
+ "beam_size": 3,
+ "len_title": seq_length,
+ "alpha": 0.6,
+ "padded_decode": padded_decode,
+ },
+ is_strict=False)
+ # Defines the model inside distribution strategy scope.
+ with distribution.scope():
+ # Forward path.
+ batch_size = 2
+ batches = 4
+ fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32)
+ fake_inputs = {
+ "input_ids": fake_ids,
+ "input_mask": fake_ids,
+ "segment_ids": fake_ids,
+ }
+ model = models.create_bert2bert_model(params=self._config)
+ results = distribution_forward_path(
+ distribution, model, fake_inputs, batch_size, mode="predict")
+ self.assertLen(results, batches)
+ results = distribution_forward_path(
+ distribution, model, fake_inputs, batch_size, mode="eval")
+ self.assertLen(results, batches)
+
+
+class NHNetTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(NHNetTest, self).setUp()
+ self._nhnet_config = configs.NHNetConfig()
+ self._nhnet_config.override(utils.get_test_params().as_dict())
+ self._bert2bert_config = configs.BERT2BERTConfig()
+ self._bert2bert_config.override(utils.get_test_params().as_dict())
+
+ def _count_params(self, layer, trainable_only=True):
+ """Returns the count of all model parameters, or just trainable ones."""
+ if not trainable_only:
+ return layer.count_params()
+ else:
+ return int(
+ np.sum([
+ tf.keras.backend.count_params(p) for p in layer.trainable_weights
+ ]))
+
+ def test_create_nhnet_layers(self):
+ single_doc_bert, single_doc_decoder = models.get_bert2bert_layers(
+ self._bert2bert_config)
+ multi_doc_bert, multi_doc_decoder = models.get_nhnet_layers(
+ self._nhnet_config)
+ # Expects multi-doc encoder/decoder have the same number of parameters as
+ # single-doc encoder/decoder.
+ self.assertEqual(
+ self._count_params(multi_doc_bert), self._count_params(single_doc_bert))
+ self.assertEqual(
+ self._count_params(multi_doc_decoder),
+ self._count_params(single_doc_decoder))
+
+ def test_checkpoint_restore(self):
+ bert2bert_model = models.create_bert2bert_model(self._bert2bert_config)
+ ckpt = tf.train.Checkpoint(model=bert2bert_model)
+ init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ nhnet_model = models.create_nhnet_model(
+ params=self._nhnet_config, init_checkpoint=init_checkpoint)
+ source_weights = (
+ bert2bert_model.bert_layer.trainable_weights +
+ bert2bert_model.decoder_layer.trainable_weights)
+ dest_weights = (
+ nhnet_model.bert_layer.trainable_weights +
+ nhnet_model.decoder_layer.trainable_weights)
+ for source_weight, dest_weight in zip(source_weights, dest_weights):
+ self.assertAllClose(source_weight.numpy(), dest_weight.numpy())
+
+ @combinations.generate(all_strategy_combinations())
+ def test_nhnet_train_forward(self, distribution):
+ seq_length = 10
+ # Defines the model inside distribution strategy scope.
+ with distribution.scope():
+ # Forward path.
+ batch_size = 2
+ num_docs = 2
+ batches = 4
+ fake_ids = np.zeros((batch_size * batches, num_docs, seq_length),
+ dtype=np.int32)
+ fake_inputs = {
+ "input_ids":
+ fake_ids,
+ "input_mask":
+ fake_ids,
+ "segment_ids":
+ fake_ids,
+ "target_ids":
+ np.zeros((batch_size * batches, seq_length * 2), dtype=np.int32),
+ }
+ model = models.create_nhnet_model(params=self._nhnet_config)
+ results = distribution_forward_path(distribution, model, fake_inputs,
+ batch_size)
+ logging.info("Forward path results: %s", str(results))
+ self.assertLen(results, batches)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_nhnet_eval(self, distribution):
+ seq_length = 10
+ padded_decode = isinstance(distribution,
+ tf.distribute.experimental.TPUStrategy)
+ self._nhnet_config.override(
+ {
+ "beam_size": 4,
+ "len_title": seq_length,
+ "alpha": 0.6,
+ "multi_channel_cross_attention": True,
+ "padded_decode": padded_decode,
+ },
+ is_strict=False)
+ # Defines the model inside distribution strategy scope.
+ with distribution.scope():
+ # Forward path.
+ batch_size = 2
+ num_docs = 2
+ batches = 4
+ fake_ids = np.zeros((batch_size * batches, num_docs, seq_length),
+ dtype=np.int32)
+ fake_inputs = {
+ "input_ids": fake_ids,
+ "input_mask": fake_ids,
+ "segment_ids": fake_ids,
+ "target_ids": np.zeros((batch_size * batches, 5), dtype=np.int32),
+ }
+ model = models.create_nhnet_model(params=self._nhnet_config)
+ results = distribution_forward_path(
+ distribution, model, fake_inputs, batch_size, mode="predict")
+ self.assertLen(results, batches)
+ results = distribution_forward_path(
+ distribution, model, fake_inputs, batch_size, mode="eval")
+ self.assertLen(results, batches)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/nhnet/optimizer.py b/models/official/nlp/nhnet/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..15c7e248019399f1abc94f64f5acd509db104f38
--- /dev/null
+++ b/models/official/nlp/nhnet/optimizer.py
@@ -0,0 +1,82 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Optimizer and learning rate scheduler."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.modeling.hyperparams import params_dict
+
+
+class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Learning rate schedule."""
+
+ def __init__(self, initial_learning_rate, hidden_size, warmup_steps):
+ """Initialize configuration of the learning rate schedule.
+
+ Args:
+ initial_learning_rate: A float, the initial learning rate.
+ hidden_size: An integer, the model dimension in the hidden layers.
+ warmup_steps: An integer, the number of steps required for linear warmup.
+ """
+ super(LearningRateSchedule, self).__init__()
+ self.initial_learning_rate = initial_learning_rate
+ self.hidden_size = hidden_size
+ self.warmup_steps = tf.cast(warmup_steps, tf.float32)
+
+ def __call__(self, global_step):
+ """Calculate learning rate with linear warmup and rsqrt decay.
+
+ Args:
+ global_step: An integer, the current global step used for learning rate
+ calculation.
+
+ Returns:
+ A float, the learning rate needs to be used for current global step.
+ """
+ with tf.name_scope('learning_rate_schedule'):
+ global_step = tf.cast(global_step, tf.float32)
+ learning_rate = self.initial_learning_rate
+ learning_rate *= (self.hidden_size**-0.5)
+ # Apply linear warmup
+ learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps)
+ # Apply rsqrt decay
+ learning_rate /= tf.sqrt(tf.maximum(global_step, self.warmup_steps))
+ return learning_rate
+
+ def get_config(self):
+ """Get the configuration of the learning rate schedule."""
+ return {
+ 'initial_learning_rate': self.initial_learning_rate,
+ 'hidden_size': self.hidden_size,
+ 'warmup_steps': self.warmup_steps,
+ }
+
+
+def create_optimizer(params: params_dict.ParamsDict):
+ """Creates optimizer."""
+ lr_schedule = LearningRateSchedule(
+ params.learning_rate,
+ params.hidden_size,
+ params.learning_rate_warmup_steps)
+ return tf.keras.optimizers.Adam(
+ learning_rate=lr_schedule,
+ beta_1=params.adam_beta1,
+ beta_2=params.adam_beta2,
+ epsilon=params.adam_epsilon)
diff --git a/models/official/nlp/nhnet/raw_data_process.py b/models/official/nlp/nhnet/raw_data_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..9597043237355f2c4b4399c490e105672e406b62
--- /dev/null
+++ b/models/official/nlp/nhnet/raw_data_process.py
@@ -0,0 +1,91 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Processes crawled content from news URLs by generating tfrecords."""
+
+import os
+from absl import app
+from absl import flags
+from official.nlp.nhnet import raw_data_processor
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("crawled_articles", "/tmp/nhnet/",
+ "Folder path to the crawled articles using news-please.")
+flags.DEFINE_string("vocab", None, "Filepath of the BERT vocabulary.")
+flags.DEFINE_bool("do_lower_case", True,
+ "Whether the vocabulary is uncased or not.")
+flags.DEFINE_integer("len_title", 15,
+ "Maximum number of tokens in story headline.")
+flags.DEFINE_integer("len_passage", 200,
+ "Maximum number of tokens in article passage.")
+flags.DEFINE_integer("max_num_articles", 5,
+ "Maximum number of articles in a story.")
+flags.DEFINE_bool("include_article_title_in_passage", False,
+ "Whether to include article title in article passage.")
+flags.DEFINE_string("data_folder", None,
+ "Folder path to the downloaded data folder (output).")
+flags.DEFINE_integer("num_tfrecords_shards", 20,
+ "Number of shards for train/valid/test.")
+
+
+def transform_as_tfrecords(data_processor, filename):
+ """Transforms story from json to tfrecord (sharded).
+
+ Args:
+ data_processor: Instance of RawDataProcessor.
+ filename: 'train', 'valid', or 'test'.
+ """
+ print("Transforming json to tfrecord for %s..." % filename)
+ story_filepath = os.path.join(FLAGS.data_folder, filename + ".json")
+ output_folder = os.path.join(FLAGS.data_folder, "processed")
+ os.makedirs(output_folder, exist_ok=True)
+ output_filepaths = []
+ for i in range(FLAGS.num_tfrecords_shards):
+ output_filepaths.append(
+ os.path.join(
+ output_folder, "%s.tfrecord-%.5d-of-%.5d" %
+ (filename, i, FLAGS.num_tfrecords_shards)))
+ (total_num_examples,
+ generated_num_examples) = data_processor.generate_examples(
+ story_filepath, output_filepaths)
+ print("For %s, %d examples have been generated from %d stories in json." %
+ (filename, generated_num_examples, total_num_examples))
+
+
+def main(_):
+ if not FLAGS.data_folder:
+ raise ValueError("data_folder must be set as the downloaded folder path.")
+ if not FLAGS.vocab:
+ raise ValueError("vocab must be set as the filepath of BERT vocabulary.")
+ data_processor = raw_data_processor.RawDataProcessor(
+ vocab=FLAGS.vocab,
+ do_lower_case=FLAGS.do_lower_case,
+ len_title=FLAGS.len_title,
+ len_passage=FLAGS.len_passage,
+ max_num_articles=FLAGS.max_num_articles,
+ include_article_title_in_passage=FLAGS.include_article_title_in_passage,
+ include_text_snippet_in_example=True)
+ print("Loading crawled articles...")
+ num_articles = data_processor.read_crawled_articles(FLAGS.crawled_articles)
+ print("Total number of articles loaded: %d" % num_articles)
+ print()
+ transform_as_tfrecords(data_processor, "train")
+ transform_as_tfrecords(data_processor, "valid")
+ transform_as_tfrecords(data_processor, "test")
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/nhnet/raw_data_processor.py b/models/official/nlp/nhnet/raw_data_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a30532f4f401e6f2b29430d353767c6cdea0966
--- /dev/null
+++ b/models/official/nlp/nhnet/raw_data_processor.py
@@ -0,0 +1,228 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library for processing crawled content and generating tfrecords."""
+
+import collections
+import json
+import multiprocessing
+import os
+import urllib.parse
+import tensorflow as tf
+
+from official.nlp.bert import tokenization
+from official.nlp.data import classifier_data_lib
+
+
+class RawDataProcessor(object):
+ """Data converter for story examples."""
+
+ def __init__(self,
+ vocab: str,
+ do_lower_case: bool,
+ len_title: int = 15,
+ len_passage: int = 200,
+ max_num_articles: int = 5,
+ include_article_title_in_passage: bool = False,
+ include_text_snippet_in_example: bool = False):
+ """Constructs a RawDataProcessor.
+
+ Args:
+ vocab: Filepath of the BERT vocabulary.
+ do_lower_case: Whether the vocabulary is uncased or not.
+ len_title: Maximum number of tokens in story headline.
+ len_passage: Maximum number of tokens in article passage.
+ max_num_articles: Maximum number of articles in a story.
+ include_article_title_in_passage: Whether to include article title in
+ article passage.
+ include_text_snippet_in_example: Whether to include text snippet
+ (headline and article content) in generated tensorflow Examples, for
+ debug usage. If include_article_title_in_passage=True, title and body
+ will be separated by [SEP].
+ """
+ self.articles = dict()
+ self.tokenizer = tokenization.FullTokenizer(
+ vocab, do_lower_case=do_lower_case, split_on_punc=False)
+ self.len_title = len_title
+ self.len_passage = len_passage
+ self.max_num_articles = max_num_articles
+ self.include_article_title_in_passage = include_article_title_in_passage
+ self.include_text_snippet_in_example = include_text_snippet_in_example
+ # ex_index=5 deactivates printing inside convert_single_example.
+ self.ex_index = 5
+ # Parameters used in InputExample, not used in NHNet.
+ self.label = 0
+ self.guid = 0
+ self.num_generated_examples = 0
+
+ def read_crawled_articles(self, folder_path):
+ """Reads crawled articles under folder_path."""
+ for path, _, files in os.walk(folder_path):
+ for name in files:
+ if not name.endswith(".json"):
+ continue
+ url, article = self._get_article_content_from_json(
+ os.path.join(path, name))
+ if not article.text_a:
+ continue
+ self.articles[RawDataProcessor.normalize_url(url)] = article
+ if len(self.articles) % 5000 == 0:
+ print("Number of articles loaded: %d\r" % len(self.articles), end="")
+ print()
+ return len(self.articles)
+
+ def generate_examples(self, input_file, output_files):
+ """Loads story from input json file and exports examples in output_files."""
+ writers = []
+ story_partition = []
+ for output_file in output_files:
+ writers.append(tf.io.TFRecordWriter(output_file))
+ story_partition.append(list())
+ with tf.io.gfile.GFile(input_file, "r") as story_json_file:
+ stories = json.load(story_json_file)
+ writer_index = 0
+ for story in stories:
+ articles = []
+ for url in story["urls"]:
+ normalized_url = RawDataProcessor.normalize_url(url)
+ if normalized_url in self.articles:
+ articles.append(self.articles[normalized_url])
+ if not articles:
+ continue
+ story_partition[writer_index].append((story["label"], articles))
+ writer_index = (writer_index + 1) % len(writers)
+ lock = multiprocessing.Lock()
+ pool = multiprocessing.pool.ThreadPool(len(writers))
+ data = [(story_partition[i], writers[i], lock) for i in range(len(writers))]
+ pool.map(self._write_story_partition, data)
+ return len(stories), self.num_generated_examples
+
+ @classmethod
+ def normalize_url(cls, url):
+ """Normalize url for better matching."""
+ url = urllib.parse.unquote(
+ urllib.parse.urlsplit(url)._replace(query=None).geturl())
+ output, part = [], None
+ for part in url.split("//"):
+ if part == "http:" or part == "https:":
+ continue
+ else:
+ output.append(part)
+ return "//".join(output)
+
+ def _get_article_content_from_json(self, file_path):
+ """Returns (url, InputExample) keeping content extracted from file_path."""
+ with tf.io.gfile.GFile(file_path, "r") as article_json_file:
+ article = json.load(article_json_file)
+ if self.include_article_title_in_passage:
+ return article["url"], classifier_data_lib.InputExample(
+ guid=self.guid,
+ text_a=article["title"],
+ text_b=article["maintext"],
+ label=self.label)
+ else:
+ return article["url"], classifier_data_lib.InputExample(
+ guid=self.guid, text_a=article["maintext"], label=self.label)
+
+ def _write_story_partition(self, data):
+ """Writes stories in a partition into file."""
+ for (story_headline, articles) in data[0]:
+ story_example = tf.train.Example(
+ features=tf.train.Features(
+ feature=self._get_single_story_features(story_headline,
+ articles)))
+ data[1].write(story_example.SerializeToString())
+ data[2].acquire()
+ try:
+ self.num_generated_examples += 1
+ if self.num_generated_examples % 1000 == 0:
+ print(
+ "Number of stories written: %d\r" % self.num_generated_examples,
+ end="")
+ finally:
+ data[2].release()
+
+ def _get_single_story_features(self, story_headline, articles):
+ """Converts a list of articles to a tensorflow Example."""
+ def get_text_snippet(article):
+ if article.text_b:
+ return " [SEP] ".join([article.text_a, article.text_b])
+ else:
+ return article.text_a
+
+ story_features = collections.OrderedDict()
+ story_headline_feature = classifier_data_lib.convert_single_example(
+ ex_index=self.ex_index,
+ example=classifier_data_lib.InputExample(
+ guid=self.guid, text_a=story_headline, label=self.label),
+ label_list=[self.label],
+ max_seq_length=self.len_title,
+ tokenizer=self.tokenizer)
+ if self.include_text_snippet_in_example:
+ story_headline_feature.label_id = story_headline
+ self._add_feature_with_suffix(
+ feature=story_headline_feature,
+ suffix="a",
+ story_features=story_features)
+ for (article_index, article) in enumerate(articles):
+ if article_index == self.max_num_articles:
+ break
+ article_feature = classifier_data_lib.convert_single_example(
+ ex_index=self.ex_index,
+ example=article,
+ label_list=[self.label],
+ max_seq_length=self.len_passage,
+ tokenizer=self.tokenizer)
+ if self.include_text_snippet_in_example:
+ article_feature.label_id = get_text_snippet(article)
+ suffix = chr(ord("b") + article_index)
+ self._add_feature_with_suffix(
+ feature=article_feature, suffix=suffix, story_features=story_features)
+
+ # Adds empty features as placeholder.
+ for article_index in range(len(articles), self.max_num_articles):
+ suffix = chr(ord("b") + article_index)
+ empty_article = classifier_data_lib.InputExample(
+ guid=self.guid, text_a="", label=self.label)
+ empty_feature = classifier_data_lib.convert_single_example(
+ ex_index=self.ex_index,
+ example=empty_article,
+ label_list=[self.label],
+ max_seq_length=self.len_passage,
+ tokenizer=self.tokenizer)
+ if self.include_text_snippet_in_example:
+ empty_feature.label_id = ""
+ self._add_feature_with_suffix(
+ feature=empty_feature, suffix=suffix, story_features=story_features)
+ return story_features
+
+ def _add_feature_with_suffix(self, feature, suffix, story_features):
+ """Appends suffix to feature names and fills in the corresponding values."""
+
+ def _create_int_feature(values):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+
+ def _create_string_feature(value):
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+ story_features["input_ids_%c" % suffix] = _create_int_feature(
+ feature.input_ids)
+ story_features["input_mask_%c" % suffix] = _create_int_feature(
+ feature.input_mask)
+ story_features["segment_ids_%c" % suffix] = _create_int_feature(
+ feature.segment_ids)
+ if self.include_text_snippet_in_example:
+ story_features["text_snippet_%c" % suffix] = _create_string_feature(
+ bytes(feature.label_id.encode()))
diff --git a/models/official/nlp/nhnet/testdata/crawled_articles/domain_0.com/url_000.html b/models/official/nlp/nhnet/testdata/crawled_articles/domain_0.com/url_000.html
new file mode 100644
index 0000000000000000000000000000000000000000..0a8549c1d274dc2ba29862860391e65bca391242
--- /dev/null
+++ b/models/official/nlp/nhnet/testdata/crawled_articles/domain_0.com/url_000.html
@@ -0,0 +1,3 @@
+
+
+Page Title 0
diff --git a/models/official/nlp/nhnet/testdata/crawled_articles/domain_0.com/url_000.json b/models/official/nlp/nhnet/testdata/crawled_articles/domain_0.com/url_000.json
new file mode 100644
index 0000000000000000000000000000000000000000..b7308592b77a0d3b6b3534a3ecbf00b717b62d26
--- /dev/null
+++ b/models/official/nlp/nhnet/testdata/crawled_articles/domain_0.com/url_000.json
@@ -0,0 +1,5 @@
+{
+ "title": "title for 0",
+ "maintext": "text snippet for 0",
+ "url": "http://url_000.html"
+}
diff --git a/models/official/nlp/nhnet/testdata/crawled_articles/domain_1.com/url_001.html b/models/official/nlp/nhnet/testdata/crawled_articles/domain_1.com/url_001.html
new file mode 100644
index 0000000000000000000000000000000000000000..7c8bb8d285c3e9da41ea8ca546d6d1503e3a7e51
--- /dev/null
+++ b/models/official/nlp/nhnet/testdata/crawled_articles/domain_1.com/url_001.html
@@ -0,0 +1,3 @@
+
+
+Page Title 1
diff --git a/models/official/nlp/nhnet/testdata/crawled_articles/domain_1.com/url_001.json b/models/official/nlp/nhnet/testdata/crawled_articles/domain_1.com/url_001.json
new file mode 100644
index 0000000000000000000000000000000000000000..dbc2322c7debc7ae695f763a75843bc1ea0a2f22
--- /dev/null
+++ b/models/official/nlp/nhnet/testdata/crawled_articles/domain_1.com/url_001.json
@@ -0,0 +1,5 @@
+{
+ "title": "title for 1",
+ "maintext": "text snippet for 1",
+ "url": "url_001.html"
+}
diff --git a/models/official/nlp/nhnet/testdata/stories.json b/models/official/nlp/nhnet/testdata/stories.json
new file mode 100644
index 0000000000000000000000000000000000000000..0618f3d5c8afdbd7e164d02dd663507da467e8b2
--- /dev/null
+++ b/models/official/nlp/nhnet/testdata/stories.json
@@ -0,0 +1,29 @@
+[
+ {
+ "urls": [
+ "http://url_000.html",
+ "http://url_001.html"
+ ],
+ "label": "headline 0"
+ },
+ {
+ "urls": [
+ "http://url_000.html",
+ "http://url_001.html"
+ ],
+ "label": "headline 1"
+ },
+ {
+ "urls": [
+ "http://url_002.html",
+ "http://url_001.html"
+ ],
+ "label": "headline 2"
+ },
+ {
+ "urls": [
+ "http://url_003.html"
+ ],
+ "label": "headline 3"
+ }
+]
diff --git a/models/official/nlp/nhnet/testdata/vocab.txt b/models/official/nlp/nhnet/testdata/vocab.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dd708d71c2fec475901fe9b2a23c7e6c2b539d95
--- /dev/null
+++ b/models/official/nlp/nhnet/testdata/vocab.txt
@@ -0,0 +1,23 @@
+[UNK]
+[CLS]
+[SEP]
+[MASK]
+0
+1
+this
+is
+a
+title
+snippet
+for
+url
+main
+text
+http
+www
+html
+:
+//
+.
+_
+headline
diff --git a/models/official/nlp/nhnet/trainer.py b/models/official/nlp/nhnet/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fa26a53a0d002247eb656691e8f49fa42bbe80f
--- /dev/null
+++ b/models/official/nlp/nhnet/trainer.py
@@ -0,0 +1,232 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Run NHNet model training and eval."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+from six.moves import zip
+import tensorflow as tf
+from official.modeling.hyperparams import params_dict
+from official.nlp.nhnet import evaluation
+from official.nlp.nhnet import input_pipeline
+from official.nlp.nhnet import models
+from official.nlp.nhnet import optimizer
+from official.nlp.transformer import metrics as transformer_metrics
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+
+
+def define_flags():
+ """Defines command line flags used by NHNet trainer."""
+ ## Required parameters
+ flags.DEFINE_enum("mode", "train", ["train", "eval", "train_and_eval"],
+ "Execution mode.")
+ flags.DEFINE_string("train_file_pattern", "", "Train file pattern.")
+ flags.DEFINE_string("eval_file_pattern", "", "Eval file pattern.")
+ flags.DEFINE_string(
+ "model_dir", None,
+ "The output directory where the model checkpoints will be written.")
+
+ # Model training specific flags.
+ flags.DEFINE_enum(
+ "distribution_strategy", "mirrored", ["tpu", "mirrored"],
+ "Distribution Strategy type to use for training. `tpu` uses TPUStrategy "
+ "for running on TPUs, `mirrored` uses GPUs with single host.")
+ flags.DEFINE_string("tpu", "", "TPU address to connect to.")
+ flags.DEFINE_string(
+ "init_checkpoint", None,
+ "Initial checkpoint (usually from a pre-trained BERT model).")
+ flags.DEFINE_integer("train_steps", 100000, "Max train steps")
+ flags.DEFINE_integer("eval_steps", 32, "Number of eval steps per run.")
+ flags.DEFINE_integer("eval_timeout", 3000, "Timeout waiting for checkpoints.")
+ flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
+ flags.DEFINE_integer("eval_batch_size", 4, "Total batch size for evaluation.")
+ flags.DEFINE_integer(
+ "steps_per_loop", 1000,
+ "Number of steps per graph-mode loop. Only training step "
+ "happens inside the loop.")
+ flags.DEFINE_integer("checkpoint_interval", 2000, "Checkpointing interval.")
+ flags.DEFINE_integer("len_title", 15, "Title length.")
+ flags.DEFINE_integer("len_passage", 200, "Passage length.")
+ flags.DEFINE_integer("num_encoder_layers", 12,
+ "Number of hidden layers of encoder.")
+ flags.DEFINE_integer("num_decoder_layers", 12,
+ "Number of hidden layers of decoder.")
+ flags.DEFINE_string("model_type", "nhnet",
+ "Model type to choose a model configuration.")
+ flags.DEFINE_integer(
+ "num_nhnet_articles", 5,
+ "Maximum number of articles in NHNet, only used when model_type=nhnet")
+ flags.DEFINE_string(
+ "params_override",
+ default=None,
+ help=("a YAML/JSON string or a YAML file which specifies additional "
+ "overrides over the default parameters"))
+
+
+# pylint: disable=protected-access
+
+
+class Trainer(tf.keras.Model):
+ """A training only model."""
+
+ def __init__(self, model, params):
+ super(Trainer, self).__init__()
+ self.model = model
+ self.params = params
+ self._num_replicas_in_sync = tf.distribute.get_strategy(
+ ).num_replicas_in_sync
+
+ def call(self, inputs, mode="train"):
+ return self.model(inputs, mode)
+
+ def train_step(self, inputs):
+ """The logic for one training step."""
+ with tf.GradientTape() as tape:
+ logits, _, _ = self(inputs, mode="train", training=True)
+ targets = models.remove_sos_from_seq(inputs["target_ids"],
+ self.params.pad_token_id)
+ loss = transformer_metrics.transformer_loss(logits, targets,
+ self.params.label_smoothing,
+ self.params.vocab_size)
+ # Scales the loss, which results in using the average loss across all
+ # of the replicas for backprop.
+ scaled_loss = loss / self._num_replicas_in_sync
+
+ tvars = self.trainable_variables
+ grads = tape.gradient(scaled_loss, tvars)
+ self.optimizer.apply_gradients(list(zip(grads, tvars)))
+ return {
+ "training_loss": loss,
+ "learning_rate": self.optimizer._decayed_lr(var_dtype=tf.float32)
+ }
+
+
+def train(params, strategy, dataset=None):
+ """Runs training."""
+
+ if not dataset:
+ dataset = input_pipeline.get_input_dataset(
+ FLAGS.train_file_pattern,
+ FLAGS.train_batch_size,
+ params,
+ is_training=True,
+ strategy=strategy)
+
+ with strategy.scope():
+ model = models.create_model(
+ FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint)
+ opt = optimizer.create_optimizer(params)
+ trainer = Trainer(model, params)
+ model.global_step = opt.iterations
+
+ trainer.compile(
+ optimizer=opt,
+ experimental_steps_per_execution=FLAGS.steps_per_loop)
+ summary_dir = os.path.join(FLAGS.model_dir, "summaries")
+ summary_callback = tf.keras.callbacks.TensorBoard(
+ summary_dir, update_freq=max(100, FLAGS.steps_per_loop))
+ checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ directory=FLAGS.model_dir,
+ max_to_keep=10,
+ step_counter=model.global_step,
+ checkpoint_interval=FLAGS.checkpoint_interval)
+ if checkpoint_manager.restore_or_initialize():
+ logging.info("Training restored from the checkpoints in: %s",
+ FLAGS.model_dir)
+ checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
+
+ # Trains the model.
+ steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval)
+ epochs = FLAGS.train_steps // steps_per_epoch
+ history = trainer.fit(
+ x=dataset,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ callbacks=[summary_callback, checkpoint_callback],
+ verbose=2)
+ train_hist = history.history
+ # Gets final loss from training.
+ stats = dict(training_loss=float(train_hist["training_loss"][-1]))
+ return stats
+
+
+def run():
+ """Runs NHNet using Keras APIs."""
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu)
+ if strategy:
+ logging.info("***** Number of cores used : %d",
+ strategy.num_replicas_in_sync)
+
+ params = models.get_model_params(FLAGS.model_type)
+ params = params_dict.override_params_dict(
+ params, FLAGS.params_override, is_strict=True)
+ params.override(
+ {
+ "len_title":
+ FLAGS.len_title,
+ "len_passage":
+ FLAGS.len_passage,
+ "num_hidden_layers":
+ FLAGS.num_encoder_layers,
+ "num_decoder_layers":
+ FLAGS.num_decoder_layers,
+ "passage_list":
+ [chr(ord("b") + i) for i in range(FLAGS.num_nhnet_articles)],
+ },
+ is_strict=False)
+ stats = {}
+ if "train" in FLAGS.mode:
+ stats = train(params, strategy)
+ if "eval" in FLAGS.mode:
+ timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout
+ # Uses padded decoding for TPU. Always uses cache.
+ padded_decode = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
+ params.override({
+ "padded_decode": padded_decode,
+ }, is_strict=False)
+ stats = evaluation.continuous_eval(
+ strategy,
+ params,
+ model_type=FLAGS.model_type,
+ eval_file_pattern=FLAGS.eval_file_pattern,
+ batch_size=FLAGS.eval_batch_size,
+ eval_steps=FLAGS.eval_steps,
+ model_dir=FLAGS.model_dir,
+ timeout=timeout)
+ return stats
+
+
+def main(_):
+ stats = run()
+ if stats:
+ logging.info("Stats:\n%s", stats)
+
+if __name__ == "__main__":
+ define_flags()
+ app.run(main)
diff --git a/models/official/nlp/nhnet/trainer_test.py b/models/official/nlp/nhnet/trainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..39673dd2c7afe0f7310e556395824b9ba4582262
--- /dev/null
+++ b/models/official/nlp/nhnet/trainer_test.py
@@ -0,0 +1,104 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for official.nlp.nhnet.trainer."""
+
+import os
+
+from absl import flags
+from absl.testing import parameterized
+import tensorflow as tf
+
+# pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+# pylint: enable=g-direct-tensorflow-import
+from official.nlp.nhnet import trainer
+from official.nlp.nhnet import utils
+
+FLAGS = flags.FLAGS
+trainer.define_flags()
+
+
+def all_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.one_device_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ ],
+ mode="eager",
+ )
+
+
+def get_trivial_data(config) -> tf.data.Dataset:
+ """Gets trivial data in the ImageNet size."""
+ batch_size, num_docs = 2, len(config.passage_list),
+ len_passage = config.len_passage
+ len_title = config.len_title
+
+ def generate_data(_) -> tf.data.Dataset:
+ fake_ids = tf.zeros((num_docs, len_passage), dtype=tf.int32)
+ title = tf.zeros((len_title), dtype=tf.int32)
+ return dict(
+ input_ids=fake_ids,
+ input_mask=fake_ids,
+ segment_ids=fake_ids,
+ target_ids=title)
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.prefetch(buffer_size=1).batch(batch_size)
+ return dataset
+
+
+class TrainerTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(TrainerTest, self).setUp()
+ self._config = utils.get_test_params()
+ self._config.override(
+ {
+ "vocab_size": 49911,
+ "max_position_embeddings": 200,
+ "len_title": 15,
+ "len_passage": 20,
+ "beam_size": 5,
+ "alpha": 0.6,
+ "learning_rate": 0.0,
+ "learning_rate_warmup_steps": 0,
+ "multi_channel_cross_attention": True,
+ "passage_list": ["a", "b"],
+ },
+ is_strict=False)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_train(self, distribution):
+ FLAGS.train_steps = 10
+ FLAGS.checkpoint_interval = 5
+ FLAGS.model_dir = self.get_temp_dir()
+ FLAGS.model_type = "nhnet"
+ stats = trainer.train(self._config, distribution,
+ get_trivial_data(self._config))
+ self.assertIn("training_loss", stats)
+ self.assertLen(
+ tf.io.gfile.glob(os.path.join(FLAGS.model_dir, "ckpt*.index")), 2)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/nhnet/utils.py b/models/official/nlp/nhnet/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f588798b7feee95a33b3b003f77570fe48340fe7
--- /dev/null
+++ b/models/official/nlp/nhnet/utils.py
@@ -0,0 +1,90 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility helpers for Bert2Bert."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import logging
+import tensorflow as tf
+from typing import Optional, Text
+from official.modeling.hyperparams import params_dict
+from official.nlp.bert import configs
+from official.nlp.nhnet import configs as nhnet_configs
+
+
+def get_bert_config_from_params(
+ params: params_dict.ParamsDict) -> configs.BertConfig:
+ """Converts a BertConfig to ParamsDict."""
+ return configs.BertConfig.from_dict(params.as_dict())
+
+
+def get_test_params(cls=nhnet_configs.BERT2BERTConfig):
+ return cls.from_args(**nhnet_configs.UNITTEST_CONFIG)
+
+
+# pylint: disable=protected-access
+def encoder_common_layers(transformer_block):
+ return [
+ transformer_block._attention_layer,
+ transformer_block._attention_layer_norm,
+ transformer_block._intermediate_dense, transformer_block._output_dense,
+ transformer_block._output_layer_norm
+ ]
+# pylint: enable=protected-access
+
+
+def initialize_bert2bert_from_pretrained_bert(
+ bert_encoder: tf.keras.layers.Layer,
+ bert_decoder: tf.keras.layers.Layer,
+ init_checkpoint: Optional[Text] = None) -> None:
+ """Helper function to initialze Bert2Bert from Bert pretrained checkpoint."""
+ ckpt = tf.train.Checkpoint(model=bert_encoder)
+ logging.info(
+ "Checkpoint file %s found and restoring from "
+ "initial checkpoint for core model.", init_checkpoint)
+ status = ckpt.restore(init_checkpoint)
+
+ # Expects the bert model is a subset of checkpoint as pooling layer is
+ # not used.
+ status.assert_existing_objects_matched()
+ logging.info("Loading from checkpoint file completed.")
+
+ # Saves a checkpoint with transformer layers.
+ encoder_layers = []
+ for transformer_block in bert_encoder.transformer_layers:
+ encoder_layers.extend(encoder_common_layers(transformer_block))
+
+ # Restores from the checkpoint with encoder layers.
+ decoder_layers_to_initialize = []
+ for decoder_block in bert_decoder.decoder.layers:
+ decoder_layers_to_initialize.extend(
+ decoder_block.common_layers_with_encoder())
+
+ if len(decoder_layers_to_initialize) != len(encoder_layers):
+ raise ValueError(
+ "Source encoder layers with %d objects does not match destination "
+ "decoder layers with %d objects." %
+ (len(decoder_layers_to_initialize), len(encoder_layers)))
+
+ for dest_layer, source_layer in zip(decoder_layers_to_initialize,
+ encoder_layers):
+ try:
+ dest_layer.set_weights(source_layer.get_weights())
+ except ValueError as e:
+ logging.error(
+ "dest_layer: %s failed to set weights from "
+ "source_layer: %s as %s", dest_layer.name, source_layer.name, str(e))
diff --git a/models/official/nlp/optimization.py b/models/official/nlp/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..51289a535b239d5831dd76bae57d6306f604e746
--- /dev/null
+++ b/models/official/nlp/optimization.py
@@ -0,0 +1,228 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions and classes related to optimization (weight updates)."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from absl import logging
+import gin
+import tensorflow as tf
+import tensorflow_addons.optimizers as tfa_optimizers
+
+
+class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Applies a warmup schedule on a given learning rate decay schedule."""
+
+ def __init__(self,
+ initial_learning_rate,
+ decay_schedule_fn,
+ warmup_steps,
+ power=1.0,
+ name=None):
+ super(WarmUp, self).__init__()
+ self.initial_learning_rate = initial_learning_rate
+ self.warmup_steps = warmup_steps
+ self.power = power
+ self.decay_schedule_fn = decay_schedule_fn
+ self.name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self.name or 'WarmUp') as name:
+ # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
+ global_step_float = tf.cast(step, tf.float32)
+ warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
+ warmup_percent_done = global_step_float / warmup_steps_float
+ warmup_learning_rate = (
+ self.initial_learning_rate *
+ tf.math.pow(warmup_percent_done, self.power))
+ return tf.cond(
+ global_step_float < warmup_steps_float,
+ lambda: warmup_learning_rate,
+ lambda: self.decay_schedule_fn(step),
+ name=name)
+
+ def get_config(self):
+ return {
+ 'initial_learning_rate': self.initial_learning_rate,
+ 'decay_schedule_fn': self.decay_schedule_fn,
+ 'warmup_steps': self.warmup_steps,
+ 'power': self.power,
+ 'name': self.name
+ }
+
+
+@gin.configurable
+def create_optimizer(init_lr,
+ num_train_steps,
+ num_warmup_steps,
+ end_lr=0.0,
+ optimizer_type='adamw'):
+ """Creates an optimizer with learning rate schedule."""
+ # Implements linear decay of the learning rate.
+ lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
+ initial_learning_rate=init_lr,
+ decay_steps=num_train_steps,
+ end_learning_rate=end_lr)
+ if num_warmup_steps:
+ lr_schedule = WarmUp(
+ initial_learning_rate=init_lr,
+ decay_schedule_fn=lr_schedule,
+ warmup_steps=num_warmup_steps)
+
+ if optimizer_type == 'adamw':
+ logging.info('using Adamw optimizer')
+ optimizer = AdamWeightDecay(
+ learning_rate=lr_schedule,
+ weight_decay_rate=0.01,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-6,
+ exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
+ elif optimizer_type == 'lamb':
+ logging.info('using Lamb optimizer')
+ optimizer = tfa_optimizers.LAMB(
+ learning_rate=lr_schedule,
+ weight_decay_rate=0.01,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-6,
+ exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
+ else:
+ raise ValueError('Unsupported optimizer type: ', optimizer_type)
+
+ return optimizer
+
+
+class AdamWeightDecay(tf.keras.optimizers.Adam):
+ """Adam enables L2 weight decay and clip_by_global_norm on gradients.
+
+ Just adding the square of the weights to the loss function is *not* the
+ correct way of using L2 regularization/weight decay with Adam, since that will
+ interact with the m and v parameters in strange ways.
+
+ Instead we want ot decay the weights in a manner that doesn't interact with
+ the m/v parameters. This is equivalent to adding the square of the weights to
+ the loss with plain (non-momentum) SGD.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-7,
+ amsgrad=False,
+ weight_decay_rate=0.0,
+ include_in_weight_decay=None,
+ exclude_from_weight_decay=None,
+ name='AdamWeightDecay',
+ **kwargs):
+ super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2,
+ epsilon, amsgrad, name, **kwargs)
+ self.weight_decay_rate = weight_decay_rate
+ self._include_in_weight_decay = include_in_weight_decay
+ self._exclude_from_weight_decay = exclude_from_weight_decay
+
+ @classmethod
+ def from_config(cls, config):
+ """Creates an optimizer from its config with WarmUp custom object."""
+ custom_objects = {'WarmUp': WarmUp}
+ return super(AdamWeightDecay, cls).from_config(
+ config, custom_objects=custom_objects)
+
+ def _prepare_local(self, var_device, var_dtype, apply_state):
+ super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
+ apply_state)
+ apply_state[(var_device, var_dtype)]['weight_decay_rate'] = tf.constant(
+ self.weight_decay_rate, name='adam_weight_decay_rate')
+
+ def _decay_weights_op(self, var, learning_rate, apply_state):
+ do_decay = self._do_use_weight_decay(var.name)
+ if do_decay:
+ return var.assign_sub(
+ learning_rate * var *
+ apply_state[(var.device, var.dtype.base_dtype)]['weight_decay_rate'],
+ use_locking=self._use_locking)
+ return tf.no_op()
+
+ def apply_gradients(self,
+ grads_and_vars,
+ name=None,
+ experimental_aggregate_gradients=True):
+ grads, tvars = list(zip(*grads_and_vars))
+ if experimental_aggregate_gradients:
+ # when experimental_aggregate_gradients = False, apply_gradients() no
+ # longer implicitly allreduce gradients, users manually allreduce gradient
+ # and passed the allreduced grads_and_vars. For now, the
+ # clip_by_global_norm will be moved to before the explicit allreduce to
+ # keep the math the same as TF 1 and pre TF 2.2 implementation.
+ (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
+ return super(AdamWeightDecay, self).apply_gradients(
+ zip(grads, tvars),
+ name=name,
+ experimental_aggregate_gradients=experimental_aggregate_gradients)
+
+ def _get_lr(self, var_device, var_dtype, apply_state):
+ """Retrieves the learning rate with the given state."""
+ if apply_state is None:
+ return self._decayed_lr_t[var_dtype], {}
+
+ apply_state = apply_state or {}
+ coefficients = apply_state.get((var_device, var_dtype))
+ if coefficients is None:
+ coefficients = self._fallback_apply_state(var_device, var_dtype)
+ apply_state[(var_device, var_dtype)] = coefficients
+
+ return coefficients['lr_t'], dict(apply_state=apply_state)
+
+ def _resource_apply_dense(self, grad, var, apply_state=None):
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+ decay = self._decay_weights_op(var, lr_t, apply_state)
+ with tf.control_dependencies([decay]):
+ return super(AdamWeightDecay,
+ self)._resource_apply_dense(grad, var, **kwargs)
+
+ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+ decay = self._decay_weights_op(var, lr_t, apply_state)
+ with tf.control_dependencies([decay]):
+ return super(AdamWeightDecay,
+ self)._resource_apply_sparse(grad, var, indices, **kwargs)
+
+ def get_config(self):
+ config = super(AdamWeightDecay, self).get_config()
+ config.update({
+ 'weight_decay_rate': self.weight_decay_rate,
+ })
+ return config
+
+ def _do_use_weight_decay(self, param_name):
+ """Whether to use L2 weight decay for `param_name`."""
+ if self.weight_decay_rate == 0:
+ return False
+
+ if self._include_in_weight_decay:
+ for r in self._include_in_weight_decay:
+ if re.search(r, param_name) is not None:
+ return True
+
+ if self._exclude_from_weight_decay:
+ for r in self._exclude_from_weight_decay:
+ if re.search(r, param_name) is not None:
+ return False
+ return True
diff --git a/models/official/nlp/tasks/__init__.py b/models/official/nlp/tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/nlp/tasks/masked_lm.py b/models/official/nlp/tasks/masked_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d392ad1117f54f5539c76988495f4e5999eb4ba
--- /dev/null
+++ b/models/official/nlp/tasks/masked_lm.py
@@ -0,0 +1,171 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Masked language task."""
+import dataclasses
+import tensorflow as tf
+
+from official.core import base_task
+from official.modeling.hyperparams import config_definitions as cfg
+from official.nlp.configs import bert
+from official.nlp.data import pretrain_dataloader
+from official.nlp.modeling import losses as loss_lib
+
+
+@dataclasses.dataclass
+class MaskedLMConfig(cfg.TaskConfig):
+ """The model config."""
+ network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
+ ])
+ train_data: cfg.DataConfig = cfg.DataConfig()
+ validation_data: cfg.DataConfig = cfg.DataConfig()
+
+
+@base_task.register_task_cls(MaskedLMConfig)
+class MaskedLMTask(base_task.Task):
+ """Mock task object for testing."""
+
+ def build_model(self):
+ return bert.instantiate_bertpretrainer_from_cfg(self.task_config.network)
+
+ def build_losses(self,
+ labels,
+ model_outputs,
+ metrics,
+ aux_losses=None) -> tf.Tensor:
+ metrics = dict([(metric.name, metric) for metric in metrics])
+ lm_output = tf.nn.log_softmax(
+ tf.cast(model_outputs['lm_output'], tf.float32), axis=-1)
+ mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
+ labels=labels['masked_lm_ids'],
+ predictions=lm_output,
+ weights=labels['masked_lm_weights'])
+ metrics['lm_example_loss'].update_state(mlm_loss)
+ if 'next_sentence_labels' in labels:
+ sentence_labels = labels['next_sentence_labels']
+ sentence_outputs = tf.cast(
+ model_outputs['next_sentence'], dtype=tf.float32)
+ sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
+ labels=sentence_labels,
+ predictions=tf.nn.log_softmax(sentence_outputs, axis=-1))
+ metrics['next_sentence_loss'].update_state(sentence_loss)
+ total_loss = mlm_loss + sentence_loss
+ else:
+ total_loss = mlm_loss
+
+ if aux_losses:
+ total_loss += tf.add_n(aux_losses)
+ return total_loss
+
+ def build_inputs(self, params, input_context=None):
+ """Returns tf.data.Dataset for pretraining."""
+ if params.input_path == 'dummy':
+ def dummy_data(_):
+ dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
+ dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
+ return dict(
+ input_word_ids=dummy_ids,
+ input_mask=dummy_ids,
+ input_type_ids=dummy_ids,
+ masked_lm_positions=dummy_lm,
+ masked_lm_ids=dummy_lm,
+ masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
+ next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ return pretrain_dataloader.BertPretrainDataLoader(params).load(
+ input_context)
+
+ def build_metrics(self, training=None):
+ del training
+ metrics = [
+ tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
+ tf.keras.metrics.Mean(name='lm_example_loss')
+ ]
+ # TODO(hongkuny): rethink how to manage metrics creation with heads.
+ if self.task_config.train_data.use_next_sentence_label:
+ metrics.append(
+ tf.keras.metrics.SparseCategoricalAccuracy(
+ name='next_sentence_accuracy'))
+ metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
+ return metrics
+
+ def process_metrics(self, metrics, labels, model_outputs):
+ metrics = dict([(metric.name, metric) for metric in metrics])
+ if 'masked_lm_accuracy' in metrics:
+ metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
+ model_outputs['lm_output'],
+ labels['masked_lm_weights'])
+ if 'next_sentence_accuracy' in metrics:
+ metrics['next_sentence_accuracy'].update_state(
+ labels['next_sentence_labels'], model_outputs['next_sentence'])
+
+ def train_step(self, inputs, model: tf.keras.Model,
+ optimizer: tf.keras.optimizers.Optimizer, metrics):
+ """Does forward and backward.
+
+ Args:
+ inputs: a dictionary of input tensors.
+ model: the model, forward pass definition.
+ optimizer: the optimizer for this training step.
+ metrics: a nested structure of metrics objects.
+
+ Returns:
+ A dictionary of logs.
+ """
+ with tf.GradientTape() as tape:
+ outputs = model(inputs, training=True)
+ # Computes per-replica loss.
+ loss = self.build_losses(
+ labels=inputs,
+ model_outputs=outputs,
+ metrics=metrics,
+ aux_losses=model.losses)
+ # Scales loss as the default gradients allreduce performs sum inside the
+ # optimizer.
+ # TODO(b/154564893): enable loss scaling.
+ # scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
+ tvars = model.trainable_variables
+ grads = tape.gradient(loss, tvars)
+ optimizer.apply_gradients(list(zip(grads, tvars)))
+ self.process_metrics(metrics, inputs, outputs)
+ return {self.loss: loss}
+
+ def validation_step(self, inputs, model: tf.keras.Model, metrics):
+ """Validatation step.
+
+ Args:
+ inputs: a dictionary of input tensors.
+ model: the keras.Model.
+ metrics: a nested structure of metrics objects.
+
+ Returns:
+ A dictionary of logs.
+ """
+ outputs = self.inference_step(inputs, model)
+ loss = self.build_losses(
+ labels=inputs,
+ model_outputs=outputs,
+ metrics=metrics,
+ aux_losses=model.losses)
+ self.process_metrics(metrics, inputs, outputs)
+ return {self.loss: loss}
diff --git a/models/official/nlp/tasks/masked_lm_test.py b/models/official/nlp/tasks/masked_lm_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0124165ed097d80d31d83ad82c5fac256dfddc5d
--- /dev/null
+++ b/models/official/nlp/tasks/masked_lm_test.py
@@ -0,0 +1,53 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for official.nlp.tasks.masked_lm."""
+
+import tensorflow as tf
+
+from official.nlp.configs import bert
+from official.nlp.configs import encoders
+from official.nlp.tasks import masked_lm
+
+
+class MLMTaskTest(tf.test.TestCase):
+
+ def test_task(self):
+ config = masked_lm.MaskedLMConfig(
+ network=bert.BertPretrainerConfig(
+ encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
+ num_masked_tokens=20,
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=10, num_classes=2, name="next_sentence")
+ ]),
+ train_data=bert.BertPretrainDataConfig(
+ input_path="dummy",
+ max_predictions_per_seq=20,
+ seq_length=128,
+ global_batch_size=1))
+ task = masked_lm.MaskedLMTask(config)
+ model = task.build_model()
+ metrics = task.build_metrics()
+ dataset = task.build_inputs(config.train_data)
+
+ iterator = iter(dataset)
+ optimizer = tf.keras.optimizers.SGD(lr=0.1)
+ task.train_step(next(iterator), model, optimizer, metrics=metrics)
+ task.validation_step(next(iterator), model, metrics=metrics)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/tasks/question_answering.py b/models/official/nlp/tasks/question_answering.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3cb7f4e6d78c0f5b9d758ac76768100ad703f9
--- /dev/null
+++ b/models/official/nlp/tasks/question_answering.py
@@ -0,0 +1,156 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Question answering task."""
+import logging
+import dataclasses
+import tensorflow as tf
+import tensorflow_hub as hub
+
+from official.core import base_task
+from official.modeling.hyperparams import config_definitions as cfg
+from official.nlp.bert import input_pipeline
+from official.nlp.configs import encoders
+from official.nlp.modeling import models
+from official.nlp.tasks import utils
+
+
+@dataclasses.dataclass
+class QuestionAnsweringConfig(cfg.TaskConfig):
+ """The model config."""
+ # At most one of `init_checkpoint` and `hub_module_url` can be specified.
+ init_checkpoint: str = ''
+ hub_module_url: str = ''
+ network: encoders.TransformerEncoderConfig = (
+ encoders.TransformerEncoderConfig())
+ train_data: cfg.DataConfig = cfg.DataConfig()
+ validation_data: cfg.DataConfig = cfg.DataConfig()
+
+
+@base_task.register_task_cls(QuestionAnsweringConfig)
+class QuestionAnsweringTask(base_task.Task):
+ """Task object for question answering.
+
+ TODO(lehou): Add post-processing.
+ """
+
+ def __init__(self, params=cfg.TaskConfig):
+ super(QuestionAnsweringTask, self).__init__(params)
+ if params.hub_module_url and params.init_checkpoint:
+ raise ValueError('At most one of `hub_module_url` and '
+ '`init_checkpoint` can be specified.')
+ if params.hub_module_url:
+ self._hub_module = hub.load(params.hub_module_url)
+ else:
+ self._hub_module = None
+
+ def build_model(self):
+ if self._hub_module:
+ encoder_network = utils.get_encoder_from_hub(self._hub_module)
+ else:
+ encoder_network = encoders.instantiate_encoder_from_cfg(
+ self.task_config.network)
+
+ return models.BertSpanLabeler(
+ network=encoder_network,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=self.task_config.network.initializer_range))
+
+ def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
+ start_positions = labels['start_positions']
+ end_positions = labels['end_positions']
+ start_logits, end_logits = model_outputs
+
+ start_loss = tf.keras.losses.sparse_categorical_crossentropy(
+ start_positions,
+ tf.cast(start_logits, dtype=tf.float32),
+ from_logits=True)
+ end_loss = tf.keras.losses.sparse_categorical_crossentropy(
+ end_positions,
+ tf.cast(end_logits, dtype=tf.float32),
+ from_logits=True)
+
+ loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
+ return loss
+
+ def build_inputs(self, params, input_context=None):
+ """Returns tf.data.Dataset for sentence_prediction task."""
+ if params.input_path == 'dummy':
+ def dummy_data(_):
+ dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
+ x = dict(
+ input_word_ids=dummy_ids,
+ input_mask=dummy_ids,
+ input_type_ids=dummy_ids)
+ y = dict(
+ start_positions=tf.constant(0, dtype=tf.int32),
+ end_positions=tf.constant(1, dtype=tf.int32))
+ return (x, y)
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ batch_size = input_context.get_per_replica_batch_size(
+ params.global_batch_size) if input_context else params.global_batch_size
+ # TODO(chendouble): add and use nlp.data.question_answering_dataloader.
+ dataset = input_pipeline.create_squad_dataset(
+ params.input_path,
+ params.seq_length,
+ batch_size,
+ is_training=params.is_training,
+ input_pipeline_context=input_context)
+ return dataset
+
+ def build_metrics(self, training=None):
+ del training
+ # TODO(lehou): a list of metrics doesn't work the same as in compile/fit.
+ metrics = [
+ tf.keras.metrics.SparseCategoricalAccuracy(
+ name='start_position_accuracy'),
+ tf.keras.metrics.SparseCategoricalAccuracy(
+ name='end_position_accuracy'),
+ ]
+ return metrics
+
+ def process_metrics(self, metrics, labels, model_outputs):
+ metrics = dict([(metric.name, metric) for metric in metrics])
+ start_logits, end_logits = model_outputs
+ metrics['start_position_accuracy'].update_state(
+ labels['start_positions'], start_logits)
+ metrics['end_position_accuracy'].update_state(
+ labels['end_positions'], end_logits)
+
+ def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
+ start_logits, end_logits = model_outputs
+ compiled_metrics.update_state(
+ y_true=labels, # labels has keys 'start_positions' and 'end_positions'.
+ y_pred={'start_positions': start_logits, 'end_positions': end_logits})
+
+ def initialize(self, model):
+ """Load a pretrained checkpoint (if exists) and then train from iter 0."""
+ ckpt_dir_or_file = self.task_config.init_checkpoint
+ if tf.io.gfile.isdir(ckpt_dir_or_file):
+ ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
+ if not ckpt_dir_or_file:
+ return
+
+ ckpt = tf.train.Checkpoint(**model.checkpoint_items)
+ status = ckpt.restore(ckpt_dir_or_file)
+ status.expect_partial().assert_existing_objects_matched()
+ logging.info('finished loading pretrained checkpoint from %s',
+ ckpt_dir_or_file)
diff --git a/models/official/nlp/tasks/question_answering_test.py b/models/official/nlp/tasks/question_answering_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e0f3f10bcc28e324b73528465cdcebda5633b56
--- /dev/null
+++ b/models/official/nlp/tasks/question_answering_test.py
@@ -0,0 +1,130 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for official.nlp.tasks.question_answering."""
+import functools
+import os
+import tensorflow as tf
+
+from official.nlp.bert import configs
+from official.nlp.bert import export_tfhub
+from official.nlp.configs import bert
+from official.nlp.configs import encoders
+from official.nlp.tasks import question_answering
+
+
+class QuestionAnsweringTaskTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(QuestionAnsweringTaskTest, self).setUp()
+ self._encoder_config = encoders.TransformerEncoderConfig(
+ vocab_size=30522, num_layers=1)
+ self._train_data_config = bert.QADataConfig(
+ input_path="dummy", seq_length=128, global_batch_size=1)
+
+ def _run_task(self, config):
+ task = question_answering.QuestionAnsweringTask(config)
+ model = task.build_model()
+ metrics = task.build_metrics()
+
+ strategy = tf.distribute.get_strategy()
+ dataset = strategy.experimental_distribute_datasets_from_function(
+ functools.partial(task.build_inputs, config.train_data))
+
+ iterator = iter(dataset)
+ optimizer = tf.keras.optimizers.SGD(lr=0.1)
+ task.train_step(next(iterator), model, optimizer, metrics=metrics)
+ task.validation_step(next(iterator), model, metrics=metrics)
+
+ def test_task(self):
+ # Saves a checkpoint.
+ pretrain_cfg = bert.BertPretrainerConfig(
+ encoder=self._encoder_config,
+ num_masked_tokens=20,
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=10, num_classes=3, name="next_sentence")
+ ])
+ pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
+ ckpt = tf.train.Checkpoint(
+ model=pretrain_model, **pretrain_model.checkpoint_items)
+ saved_path = ckpt.save(self.get_temp_dir())
+
+ config = question_answering.QuestionAnsweringConfig(
+ init_checkpoint=saved_path,
+ network=self._encoder_config,
+ train_data=self._train_data_config)
+ task = question_answering.QuestionAnsweringTask(config)
+ model = task.build_model()
+ metrics = task.build_metrics()
+ dataset = task.build_inputs(config.train_data)
+
+ iterator = iter(dataset)
+ optimizer = tf.keras.optimizers.SGD(lr=0.1)
+ task.train_step(next(iterator), model, optimizer, metrics=metrics)
+ task.validation_step(next(iterator), model, metrics=metrics)
+ task.initialize(model)
+
+ def test_task_with_fit(self):
+ config = question_answering.QuestionAnsweringConfig(
+ network=self._encoder_config,
+ train_data=self._train_data_config)
+ task = question_answering.QuestionAnsweringTask(config)
+ model = task.build_model()
+ model = task.compile_model(
+ model,
+ optimizer=tf.keras.optimizers.SGD(lr=0.1),
+ train_step=task.train_step,
+ metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
+ dataset = task.build_inputs(config.train_data)
+ logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
+ self.assertIn("loss", logs.history)
+ self.assertIn("start_positions_accuracy", logs.history)
+ self.assertIn("end_positions_accuracy", logs.history)
+
+ def _export_bert_tfhub(self):
+ bert_config = configs.BertConfig(
+ vocab_size=30522,
+ hidden_size=16,
+ intermediate_size=32,
+ max_position_embeddings=128,
+ num_attention_heads=2,
+ num_hidden_layers=1)
+ _, encoder = export_tfhub.create_bert_model(bert_config)
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
+ checkpoint = tf.train.Checkpoint(model=encoder)
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
+
+ vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
+ with tf.io.gfile.GFile(vocab_file, "w") as f:
+ f.write("dummy content")
+
+ hub_destination = os.path.join(self.get_temp_dir(), "hub")
+ export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
+ hub_destination, vocab_file)
+ return hub_destination
+
+ def test_task_with_hub(self):
+ hub_module_url = self._export_bert_tfhub()
+ config = question_answering.QuestionAnsweringConfig(
+ hub_module_url=hub_module_url,
+ network=self._encoder_config,
+ train_data=self._train_data_config)
+ self._run_task(config)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/tasks/sentence_prediction.py b/models/official/nlp/tasks/sentence_prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2eb0bf47de273408459e35cf45ff01ac69a9d2c
--- /dev/null
+++ b/models/official/nlp/tasks/sentence_prediction.py
@@ -0,0 +1,190 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sentence prediction (classification) task."""
+from absl import logging
+import dataclasses
+import numpy as np
+from scipy import stats
+from sklearn import metrics as sklearn_metrics
+import tensorflow as tf
+import tensorflow_hub as hub
+
+from official.core import base_task
+from official.modeling.hyperparams import config_definitions as cfg
+from official.nlp.configs import bert
+from official.nlp.data import sentence_prediction_dataloader
+from official.nlp.modeling import losses as loss_lib
+from official.nlp.tasks import utils
+
+
+@dataclasses.dataclass
+class SentencePredictionConfig(cfg.TaskConfig):
+ """The model config."""
+ # At most one of `init_checkpoint` and `hub_module_url` can
+ # be specified.
+ init_checkpoint: str = ''
+ hub_module_url: str = ''
+ metric_type: str = 'accuracy'
+ network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
+ num_masked_tokens=0, # No masked language modeling head.
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=768,
+ num_classes=3,
+ dropout_rate=0.1,
+ name='sentence_prediction')
+ ])
+ train_data: cfg.DataConfig = cfg.DataConfig()
+ validation_data: cfg.DataConfig = cfg.DataConfig()
+
+
+@base_task.register_task_cls(SentencePredictionConfig)
+class SentencePredictionTask(base_task.Task):
+ """Task object for sentence_prediction."""
+
+ def __init__(self, params=cfg.TaskConfig):
+ super(SentencePredictionTask, self).__init__(params)
+ if params.hub_module_url and params.init_checkpoint:
+ raise ValueError('At most one of `hub_module_url` and '
+ '`pretrain_checkpoint_dir` can be specified.')
+ if params.hub_module_url:
+ self._hub_module = hub.load(params.hub_module_url)
+ else:
+ self._hub_module = None
+ self.metric_type = params.metric_type
+
+ def build_model(self):
+ if self._hub_module:
+ encoder_from_hub = utils.get_encoder_from_hub(self._hub_module)
+ return bert.instantiate_bertpretrainer_from_cfg(
+ self.task_config.network, encoder_network=encoder_from_hub)
+ else:
+ return bert.instantiate_bertpretrainer_from_cfg(self.task_config.network)
+
+ def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
+ loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
+ labels=labels,
+ predictions=tf.nn.log_softmax(
+ tf.cast(model_outputs['sentence_prediction'], tf.float32), axis=-1))
+
+ if aux_losses:
+ loss += tf.add_n(aux_losses)
+ return loss
+
+ def build_inputs(self, params, input_context=None):
+ """Returns tf.data.Dataset for sentence_prediction task."""
+ if params.input_path == 'dummy':
+
+ def dummy_data(_):
+ dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
+ x = dict(
+ input_word_ids=dummy_ids,
+ input_mask=dummy_ids,
+ input_type_ids=dummy_ids)
+ y = tf.ones((1, 1), dtype=tf.int32)
+ return (x, y)
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ return sentence_prediction_dataloader.SentencePredictionDataLoader(
+ params).load(input_context)
+
+ def build_metrics(self, training=None):
+ del training
+ metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
+ return metrics
+
+ def process_metrics(self, metrics, labels, model_outputs):
+ for metric in metrics:
+ metric.update_state(labels, model_outputs['sentence_prediction'])
+
+ def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
+ compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
+
+ def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
+ if self.metric_type == 'accuracy':
+ return super(SentencePredictionTask,
+ self).validation_step(inputs, model, metrics)
+ features, labels = inputs
+ outputs = self.inference_step(features, model)
+ loss = self.build_losses(
+ labels=labels, model_outputs=outputs, aux_losses=model.losses)
+ if self.metric_type == 'matthews_corrcoef':
+ return {
+ self.loss:
+ loss,
+ 'sentence_prediction':
+ tf.expand_dims(
+ tf.math.argmax(outputs['sentence_prediction'], axis=1),
+ axis=0),
+ 'labels':
+ labels,
+ }
+ if self.metric_type == 'pearson_spearman_corr':
+ return {
+ self.loss: loss,
+ 'sentence_prediction': outputs['sentence_prediction'],
+ 'labels': labels,
+ }
+
+ def aggregate_logs(self, state=None, step_outputs=None):
+ if state is None:
+ state = {'sentence_prediction': [], 'labels': []}
+ state['sentence_prediction'].append(
+ np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
+ axis=0))
+ state['labels'].append(
+ np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
+ return state
+
+ def reduce_aggregated_logs(self, aggregated_logs):
+ if self.metric_type == 'matthews_corrcoef':
+ preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
+ labels = np.concatenate(aggregated_logs['labels'], axis=0)
+ return {
+ self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
+ }
+ if self.metric_type == 'pearson_spearman_corr':
+ preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
+ labels = np.concatenate(aggregated_logs['labels'], axis=0)
+ pearson_corr = stats.pearsonr(preds, labels)[0]
+ spearman_corr = stats.spearmanr(preds, labels)[0]
+ corr_metric = (pearson_corr + spearman_corr) / 2
+ return {self.metric_type: corr_metric}
+
+ def initialize(self, model):
+ """Load a pretrained checkpoint (if exists) and then train from iter 0."""
+ ckpt_dir_or_file = self.task_config.init_checkpoint
+ if tf.io.gfile.isdir(ckpt_dir_or_file):
+ ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
+ if not ckpt_dir_or_file:
+ return
+
+ pretrain2finetune_mapping = {
+ 'encoder':
+ model.checkpoint_items['encoder'],
+ 'next_sentence.pooler_dense':
+ model.checkpoint_items['sentence_prediction.pooler_dense'],
+ }
+ ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
+ status = ckpt.restore(ckpt_dir_or_file)
+ status.expect_partial().assert_existing_objects_matched()
+ logging.info('finished loading pretrained checkpoint from %s',
+ ckpt_dir_or_file)
diff --git a/models/official/nlp/tasks/sentence_prediction_test.py b/models/official/nlp/tasks/sentence_prediction_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..09419f54c4642f08ca37e2588103c45d0847b7bc
--- /dev/null
+++ b/models/official/nlp/tasks/sentence_prediction_test.py
@@ -0,0 +1,163 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for official.nlp.tasks.sentence_prediction."""
+import functools
+import os
+
+from absl.testing import parameterized
+import tensorflow as tf
+
+from official.nlp.bert import configs
+from official.nlp.bert import export_tfhub
+from official.nlp.configs import bert
+from official.nlp.configs import encoders
+from official.nlp.tasks import sentence_prediction
+
+
+class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(SentencePredictionTaskTest, self).setUp()
+ self._train_data_config = bert.SentencePredictionDataConfig(
+ input_path="dummy", seq_length=128, global_batch_size=1)
+
+ def get_network_config(self, num_classes):
+ return bert.BertPretrainerConfig(
+ encoder=encoders.TransformerEncoderConfig(
+ vocab_size=30522, num_layers=1),
+ num_masked_tokens=0,
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=10,
+ num_classes=num_classes,
+ name="sentence_prediction")
+ ])
+
+ def _run_task(self, config):
+ task = sentence_prediction.SentencePredictionTask(config)
+ model = task.build_model()
+ metrics = task.build_metrics()
+
+ strategy = tf.distribute.get_strategy()
+ dataset = strategy.experimental_distribute_datasets_from_function(
+ functools.partial(task.build_inputs, config.train_data))
+
+ iterator = iter(dataset)
+ optimizer = tf.keras.optimizers.SGD(lr=0.1)
+ task.train_step(next(iterator), model, optimizer, metrics=metrics)
+ task.validation_step(next(iterator), model, metrics=metrics)
+
+ def test_task(self):
+ config = sentence_prediction.SentencePredictionConfig(
+ init_checkpoint=self.get_temp_dir(),
+ network=self.get_network_config(2),
+ train_data=self._train_data_config)
+ task = sentence_prediction.SentencePredictionTask(config)
+ model = task.build_model()
+ metrics = task.build_metrics()
+ dataset = task.build_inputs(config.train_data)
+
+ iterator = iter(dataset)
+ optimizer = tf.keras.optimizers.SGD(lr=0.1)
+ task.train_step(next(iterator), model, optimizer, metrics=metrics)
+ task.validation_step(next(iterator), model, metrics=metrics)
+
+ # Saves a checkpoint.
+ pretrain_cfg = bert.BertPretrainerConfig(
+ encoder=encoders.TransformerEncoderConfig(
+ vocab_size=30522, num_layers=1),
+ num_masked_tokens=20,
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=10, num_classes=3, name="next_sentence")
+ ])
+ pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
+ ckpt = tf.train.Checkpoint(
+ model=pretrain_model, **pretrain_model.checkpoint_items)
+ ckpt.save(config.init_checkpoint)
+ task.initialize(model)
+
+ @parameterized.parameters(("matthews_corrcoef", 2),
+ ("pearson_spearman_corr", 1))
+ def test_np_metrics(self, metric_type, num_classes):
+ config = sentence_prediction.SentencePredictionConfig(
+ metric_type=metric_type,
+ init_checkpoint=self.get_temp_dir(),
+ network=self.get_network_config(num_classes),
+ train_data=self._train_data_config)
+ task = sentence_prediction.SentencePredictionTask(config)
+ model = task.build_model()
+ dataset = task.build_inputs(config.train_data)
+
+ iterator = iter(dataset)
+ strategy = tf.distribute.get_strategy()
+ distributed_outputs = strategy.run(
+ functools.partial(task.validation_step, model=model),
+ args=(next(iterator),))
+ outputs = tf.nest.map_structure(strategy.experimental_local_results,
+ distributed_outputs)
+ aggregated = task.aggregate_logs(step_outputs=outputs)
+ aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs)
+ self.assertIn(metric_type, task.reduce_aggregated_logs(aggregated))
+
+ def test_task_with_fit(self):
+ config = sentence_prediction.SentencePredictionConfig(
+ network=self.get_network_config(2), train_data=self._train_data_config)
+ task = sentence_prediction.SentencePredictionTask(config)
+ model = task.build_model()
+ model = task.compile_model(
+ model,
+ optimizer=tf.keras.optimizers.SGD(lr=0.1),
+ train_step=task.train_step,
+ metrics=task.build_metrics())
+ dataset = task.build_inputs(config.train_data)
+ logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
+ self.assertIn("loss", logs.history)
+
+ def _export_bert_tfhub(self):
+ bert_config = configs.BertConfig(
+ vocab_size=30522,
+ hidden_size=16,
+ intermediate_size=32,
+ max_position_embeddings=128,
+ num_attention_heads=2,
+ num_hidden_layers=1)
+ _, encoder = export_tfhub.create_bert_model(bert_config)
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
+ checkpoint = tf.train.Checkpoint(model=encoder)
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
+
+ vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
+ with tf.io.gfile.GFile(vocab_file, "w") as f:
+ f.write("dummy content")
+
+ hub_destination = os.path.join(self.get_temp_dir(), "hub")
+ export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
+ hub_destination, vocab_file)
+ return hub_destination
+
+ def test_task_with_hub(self):
+ hub_module_url = self._export_bert_tfhub()
+ config = sentence_prediction.SentencePredictionConfig(
+ hub_module_url=hub_module_url,
+ network=self.get_network_config(2),
+ train_data=self._train_data_config)
+ self._run_task(config)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/tasks/tagging.py b/models/official/nlp/tasks/tagging.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1f20b1360a952b0a5c6fabc2a3ee252c2ef5137
--- /dev/null
+++ b/models/official/nlp/tasks/tagging.py
@@ -0,0 +1,147 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tagging (e.g., NER/POS) task."""
+import logging
+import dataclasses
+import tensorflow as tf
+import tensorflow_hub as hub
+
+from official.core import base_task
+from official.modeling.hyperparams import config_definitions as cfg
+from official.nlp.configs import encoders
+from official.nlp.data import tagging_data_loader
+from official.nlp.modeling import models
+from official.nlp.tasks import utils
+
+
+@dataclasses.dataclass
+class TaggingConfig(cfg.TaskConfig):
+ """The model config."""
+ # At most one of `init_checkpoint` and `hub_module_url` can be specified.
+ init_checkpoint: str = ''
+ hub_module_url: str = ''
+ network: encoders.TransformerEncoderConfig = (
+ encoders.TransformerEncoderConfig())
+ num_classes: int = 0
+ # The ignored label id will not contribute to loss.
+ # A word may be tokenized into multiple word_pieces tokens, and we usually
+ # assign the real label id for the first token of the word, and
+ # `ignore_label_id` for the remaining tokens.
+ ignore_label_id: int = 0
+ train_data: cfg.DataConfig = cfg.DataConfig()
+ validation_data: cfg.DataConfig = cfg.DataConfig()
+
+
+@base_task.register_task_cls(TaggingConfig)
+class TaggingTask(base_task.Task):
+ """Task object for tagging (e.g., NER or POS)."""
+
+ def __init__(self, params=cfg.TaskConfig):
+ super(TaggingTask, self).__init__(params)
+ if params.hub_module_url and params.init_checkpoint:
+ raise ValueError('At most one of `hub_module_url` and '
+ '`init_checkpoint` can be specified.')
+ if params.num_classes == 0:
+ raise ValueError('TaggingConfig.num_classes cannot be 0.')
+
+ if params.hub_module_url:
+ self._hub_module = hub.load(params.hub_module_url)
+ else:
+ self._hub_module = None
+
+ def build_model(self):
+ if self._hub_module:
+ encoder_network = utils.get_encoder_from_hub(self._hub_module)
+ else:
+ encoder_network = encoders.instantiate_encoder_from_cfg(
+ self.task_config.network)
+
+ return models.BertTokenClassifier(
+ network=encoder_network,
+ num_classes=self.task_config.num_classes,
+ initializer=tf.keras.initializers.TruncatedNormal(
+ stddev=self.task_config.network.initializer_range),
+ dropout_rate=self.task_config.network.dropout_rate,
+ output='logits')
+
+ def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
+ model_outputs = tf.cast(model_outputs, tf.float32)
+ loss = tf.keras.losses.sparse_categorical_crossentropy(
+ labels, model_outputs, from_logits=True)
+ # `ignore_label_id` will not contribute to loss.
+ label_weights = tf.cast(
+ tf.not_equal(labels, self.task_config.ignore_label_id),
+ dtype=tf.float32)
+ numerator_loss = tf.reduce_sum(loss * label_weights)
+ denominator_loss = tf.reduce_sum(label_weights)
+ loss = tf.math.divide_no_nan(numerator_loss, denominator_loss)
+ return loss
+
+ def build_inputs(self, params, input_context=None):
+ """Returns tf.data.Dataset for sentence_prediction task."""
+ if params.input_path == 'dummy':
+
+ def dummy_data(_):
+ dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
+ x = dict(
+ input_word_ids=dummy_ids,
+ input_mask=dummy_ids,
+ input_type_ids=dummy_ids)
+ y = tf.ones((1, params.seq_length), dtype=tf.int32)
+ return (x, y)
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ dataset = tagging_data_loader.TaggingDataLoader(params).load(input_context)
+ return dataset
+
+ def build_metrics(self, training=None):
+ del training
+ # TODO(chendouble): evaluate using seqeval's f1/precision/recall.
+ return [tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
+
+ def process_metrics(self, metrics, labels, model_outputs):
+ # `ignore_label_id` will not contribute to metrics.
+ sample_weight = tf.cast(
+ tf.not_equal(labels, self.task_config.ignore_label_id),
+ dtype=tf.float32)
+ for metric in metrics:
+ metric.update_state(labels, model_outputs, sample_weight)
+
+ def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
+ # `ignore_label_id` will not contribute to metrics.
+ sample_weight = tf.cast(
+ tf.not_equal(labels, self.task_config.ignore_label_id),
+ dtype=tf.float32)
+ compiled_metrics.update_state(labels, model_outputs, sample_weight)
+
+ def initialize(self, model):
+ """Load a pretrained checkpoint (if exists) and then train from iter 0."""
+ ckpt_dir_or_file = self.task_config.init_checkpoint
+ if tf.io.gfile.isdir(ckpt_dir_or_file):
+ ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
+ if not ckpt_dir_or_file:
+ return
+
+ ckpt = tf.train.Checkpoint(**model.checkpoint_items)
+ status = ckpt.restore(ckpt_dir_or_file)
+ status.expect_partial().assert_existing_objects_matched()
+ logging.info('finished loading pretrained checkpoint from %s',
+ ckpt_dir_or_file)
diff --git a/models/official/nlp/tasks/tagging_test.py b/models/official/nlp/tasks/tagging_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6707a50a824246062be9387c9c83f49c1a3309f5
--- /dev/null
+++ b/models/official/nlp/tasks/tagging_test.py
@@ -0,0 +1,125 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for official.nlp.tasks.tagging."""
+import functools
+import os
+import tensorflow as tf
+
+from official.nlp.bert import configs
+from official.nlp.bert import export_tfhub
+from official.nlp.configs import bert
+from official.nlp.configs import encoders
+from official.nlp.tasks import tagging
+
+
+class TaggingTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(TaggingTest, self).setUp()
+ self._encoder_config = encoders.TransformerEncoderConfig(
+ vocab_size=30522, num_layers=1)
+ self._train_data_config = bert.TaggingDataConfig(
+ input_path="dummy", seq_length=128, global_batch_size=1)
+
+ def _run_task(self, config):
+ task = tagging.TaggingTask(config)
+ model = task.build_model()
+ metrics = task.build_metrics()
+
+ strategy = tf.distribute.get_strategy()
+ dataset = strategy.experimental_distribute_datasets_from_function(
+ functools.partial(task.build_inputs, config.train_data))
+
+ iterator = iter(dataset)
+ optimizer = tf.keras.optimizers.SGD(lr=0.1)
+ task.train_step(next(iterator), model, optimizer, metrics=metrics)
+ task.validation_step(next(iterator), model, metrics=metrics)
+
+ def test_task(self):
+ # Saves a checkpoint.
+ encoder = encoders.instantiate_encoder_from_cfg(self._encoder_config)
+ ckpt = tf.train.Checkpoint(encoder=encoder)
+ saved_path = ckpt.save(self.get_temp_dir())
+
+ config = tagging.TaggingConfig(
+ init_checkpoint=saved_path,
+ network=self._encoder_config,
+ train_data=self._train_data_config,
+ num_classes=3)
+ task = tagging.TaggingTask(config)
+ model = task.build_model()
+ metrics = task.build_metrics()
+ dataset = task.build_inputs(config.train_data)
+
+ iterator = iter(dataset)
+ optimizer = tf.keras.optimizers.SGD(lr=0.1)
+ task.train_step(next(iterator), model, optimizer, metrics=metrics)
+ task.validation_step(next(iterator), model, metrics=metrics)
+ task.initialize(model)
+
+ def test_task_with_fit(self):
+ config = tagging.TaggingConfig(
+ network=self._encoder_config,
+ train_data=self._train_data_config,
+ num_classes=3)
+
+ task = tagging.TaggingTask(config)
+ model = task.build_model()
+ model = task.compile_model(
+ model,
+ optimizer=tf.keras.optimizers.SGD(lr=0.1),
+ train_step=task.train_step,
+ metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
+ dataset = task.build_inputs(config.train_data)
+ logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
+ self.assertIn("loss", logs.history)
+ self.assertIn("accuracy", logs.history)
+
+ def _export_bert_tfhub(self):
+ bert_config = configs.BertConfig(
+ vocab_size=30522,
+ hidden_size=16,
+ intermediate_size=32,
+ max_position_embeddings=128,
+ num_attention_heads=2,
+ num_hidden_layers=1)
+ _, encoder = export_tfhub.create_bert_model(bert_config)
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
+ checkpoint = tf.train.Checkpoint(model=encoder)
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
+
+ vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
+ with tf.io.gfile.GFile(vocab_file, "w") as f:
+ f.write("dummy content")
+
+ hub_destination = os.path.join(self.get_temp_dir(), "hub")
+ export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
+ hub_destination, vocab_file)
+ return hub_destination
+
+ def test_task_with_hub(self):
+ hub_module_url = self._export_bert_tfhub()
+ config = tagging.TaggingConfig(
+ hub_module_url=hub_module_url,
+ network=self._encoder_config,
+ num_classes=4,
+ train_data=self._train_data_config)
+ self._run_task(config)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/tasks/utils.py b/models/official/nlp/tasks/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..467dafe31f813779b7af5ea0209aadccb6d1bdf8
--- /dev/null
+++ b/models/official/nlp/tasks/utils.py
@@ -0,0 +1,34 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common utils for tasks."""
+import tensorflow as tf
+import tensorflow_hub as hub
+
+
+def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
+ """Gets an encoder from hub."""
+ input_word_ids = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name='input_word_ids')
+ input_mask = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name='input_mask')
+ input_type_ids = tf.keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name='input_type_ids')
+ hub_layer = hub.KerasLayer(hub_module, trainable=True)
+ pooled_output, sequence_output = hub_layer(
+ [input_word_ids, input_mask, input_type_ids])
+ return tf.keras.Model(
+ inputs=[input_word_ids, input_mask, input_type_ids],
+ outputs=[sequence_output, pooled_output])
diff --git a/models/official/nlp/transformer/README.md b/models/official/nlp/transformer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1215ed574b316030f69713de8dc3000ea64e3df6
--- /dev/null
+++ b/models/official/nlp/transformer/README.md
@@ -0,0 +1,218 @@
+# Transformer Translation Model
+This is an implementation of the Transformer translation model as described in
+the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. The
+implementation leverages tf.keras and makes sure it is compatible with TF 2.x.
+
+**Note: this transformer folder is subject to be integrated into official/nlp
+folder. Due to its dependencies, we will finish the refactoring after the model
+garden 2.1 release.**
+
+## Contents
+ * [Contents](#contents)
+ * [Walkthrough](#walkthrough)
+ * [Detailed instructions](#detailed-instructions)
+ * [Environment preparation](#environment-preparation)
+ * [Download and preprocess datasets](#download-and-preprocess-datasets)
+ * [Model training and evaluation](#model-training-and-evaluation)
+ * [Implementation overview](#implementation-overview)
+ * [Model Definition](#model-definition)
+ * [Model Trainer](#model-trainer)
+ * [Test dataset](#test-dataset)
+
+## Walkthrough
+
+Below are the commands for running the Transformer model. See the
+[Detailed instructions](#detailed-instructions) for more details on running the
+model.
+
+```
+# Ensure that PYTHONPATH is correctly defined as described in
+# https://github.com/tensorflow/models/tree/master/official#requirements
+export PYTHONPATH="$PYTHONPATH:/path/to/models"
+
+cd /path/to/models/official/nlp/transformer
+
+# Export variables
+PARAM_SET=big
+DATA_DIR=$HOME/transformer/data
+MODEL_DIR=$HOME/transformer/model_$PARAM_SET
+VOCAB_FILE=$DATA_DIR/vocab.ende.32768
+
+# Download training/evaluation/test datasets
+python3 data_download.py --data_dir=$DATA_DIR
+
+# Train the model for 100000 steps and evaluate every 5000 steps on a single GPU.
+# Each train step, takes 4096 tokens as a batch budget with 64 as sequence
+# maximal length.
+python3 transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
+ --vocab_file=$VOCAB_FILE --param_set=$PARAM_SET \
+ --train_steps=100000 --steps_between_evals=5000 \
+ --batch_size=4096 --max_length=64 \
+ --bleu_source=$DATA_DIR/newstest2014.en \
+ --bleu_ref=$DATA_DIR/newstest2014.de \
+ --num_gpus=1 \
+ --enable_time_history=false
+
+# Run during training in a separate process to get continuous updates,
+# or after training is complete.
+tensorboard --logdir=$MODEL_DIR
+```
+
+## Detailed instructions
+
+
+0. ### Environment preparation
+
+ #### Add models repo to PYTHONPATH
+ Follow the instructions described in the [Requirements](https://github.com/tensorflow/models/tree/master/official#requirements) section to add the models folder to the python path.
+
+ #### Export variables (optional)
+
+ Export the following variables, or modify the values in each of the snippets below:
+
+ ```shell
+ PARAM_SET=big
+ DATA_DIR=$HOME/transformer/data
+ MODEL_DIR=$HOME/transformer/model_$PARAM_SET
+ VOCAB_FILE=$DATA_DIR/vocab.ende.32768
+ ```
+
+1. ### Download and preprocess datasets
+
+ [data_download.py](data_download.py) downloads and preprocesses the training and evaluation WMT datasets. After the data is downloaded and extracted, the training data is used to generate a vocabulary of subtokens. The evaluation and training strings are tokenized, and the resulting data is sharded, shuffled, and saved as TFRecords.
+
+ 1.75GB of compressed data will be downloaded. In total, the raw files (compressed, extracted, and combined files) take up 8.4GB of disk space. The resulting TFRecord and vocabulary files are 722MB. The script takes around 40 minutes to run, with the bulk of the time spent downloading and ~15 minutes spent on preprocessing.
+
+ Command to run:
+ ```
+ python3 data_download.py --data_dir=$DATA_DIR
+ ```
+
+ Arguments:
+ * `--data_dir`: Path where the preprocessed TFRecord data, and vocab file will be saved.
+ * Use the `--help` or `-h` flag to get a full list of possible arguments.
+
+2. ### Model training and evaluation
+
+ [transformer_main.py](transformer_main.py) creates a Transformer keras model,
+ and trains it uses keras model.fit().
+
+ Users need to adjust `batch_size` and `num_gpus` to get good performance
+ running multiple GPUs.
+
+ **Note that:**
+ when using multiple GPUs or TPUs, this is the global batch size for all
+ devices. For example, if the batch size is `4096*4` and there are 4 devices,
+ each device will take 4096 tokens as a batch budget.
+
+ Command to run:
+ ```
+ python3 transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
+ --vocab_file=$VOCAB_FILE --param_set=$PARAM_SET
+ ```
+
+ Arguments:
+ * `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
+ * `--model_dir`: Directory to save Transformer model training checkpoints.
+ * `--vocab_file`: Path to subtoken vocabulary file. If data_download was used, you may find the file in `data_dir`.
+ * `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
+ * `--enable_time_history`: Whether add TimeHistory call. If so, --log_steps must be specified.
+ * `--batch_size`: The number of tokens to consider in a batch. Combining with
+ `--max_length`, they decide how many sequences are used per batch.
+ * Use the `--help` or `-h` flag to get a full list of possible arguments.
+
+ #### Using multiple GPUs
+ You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
+ You can read more about them in this
+ [guide](https://www.tensorflow.org/guide/distribute_strategy).
+
+ In this example, we have made it easier to use is with just a command line flag
+ `--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
+ and 0 otherwise.
+
+ - --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
+ - --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
+ - --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
+ distributed training across the GPUs.
+
+ #### Using Cloud TPUs
+
+ You can train the Transformer model on Cloud TPUs using
+ `tf.distribute.TPUStrategy`. If you are not familiar with Cloud TPUs, it is
+ strongly recommended that you go through the
+ [quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
+ create a TPU and GCE VM.
+
+ To run the Transformer model on a TPU, you must set
+ `--distribution_strategy=tpu`, `--tpu=$TPU_NAME`, and `--use_ctl=True` where
+ `$TPU_NAME` the name of your TPU in the Cloud Console.
+
+ An example command to run Transformer on a v2-8 or v3-8 TPU would be:
+
+ ```bash
+ python transformer_main.py \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --vocab_file=$DATA_DIR/vocab.ende.32768 \
+ --bleu_source=$DATA_DIR/newstest2014.en \
+ --bleu_ref=$DATA_DIR/newstest2014.end \
+ --batch_size=6144 \
+ --train_steps=2000 \
+ --static_batch=true \
+ --use_ctl=true \
+ --param_set=big \
+ --max_length=64 \
+ --decode_batch_size=32 \
+ --decode_max_length=97 \
+ --padded_decode=true \
+ --distribution_strategy=tpu
+ ```
+ Note: `$MODEL_DIR` and `$DATA_DIR` must be GCS paths.
+
+ #### Customizing training schedule
+
+ By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
+
+ * Training with steps:
+ * `--train_steps`: sets the total number of training steps to run.
+ * `--steps_between_evals`: Number of training steps to run between evaluations.
+
+ #### Compute BLEU score during model evaluation
+
+ Use these flags to compute the BLEU when the model evaluates:
+
+ * `--bleu_source`: Path to file containing text to translate.
+ * `--bleu_ref`: Path to file containing the reference translation.
+
+ When running `transformer_main.py`, use the flags: `--bleu_source=$DATA_DIR/newstest2014.en --bleu_ref=$DATA_DIR/newstest2014.de`
+
+ #### Tensorboard
+ Training and evaluation metrics (loss, accuracy, approximate BLEU score, etc.) are logged, and can be displayed in the browser using Tensorboard.
+ ```
+ tensorboard --logdir=$MODEL_DIR
+ ```
+ The values are displayed at [localhost:6006](localhost:6006).
+
+## Implementation overview
+
+A brief look at each component in the code:
+
+### Model Definition
+* [transformer.py](transformer.py): Defines a tf.keras.Model: `Transformer`.
+* [embedding_layer.py](embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
+* [attention_layer.py](attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
+* [ffn_layer.py](ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
+
+Other files:
+* [beam_search.py](beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
+
+### Model Trainer
+[transformer_main.py](transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras.
+
+### Test dataset
+The [newstest2014 files](https://storage.googleapis.com/tf-perf-public/official_transformer/test_data/newstest2014.tgz)
+are extracted from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
+The raw text files are converted from the SGM format of the
+[WMT 2016](http://www.statmt.org/wmt16/translation-task.html) test sets. The
+newstest2014 files are put into the `$DATA_DIR` when executing `data_download.py`
diff --git a/models/official/nlp/transformer/__init__.py b/models/official/nlp/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/nlp/transformer/attention_layer.py b/models/official/nlp/transformer/attention_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..114bd5fadc3064f0ff3c895245d8676e47a0bad4
--- /dev/null
+++ b/models/official/nlp/transformer/attention_layer.py
@@ -0,0 +1,170 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of multiheaded attention and self-attention layers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import tensorflow as tf
+from official.nlp.modeling import layers
+
+
+class Attention(tf.keras.layers.Layer):
+ """Multi-headed attention layer."""
+
+ def __init__(self, hidden_size, num_heads, attention_dropout):
+ """Initialize Attention.
+
+ Args:
+ hidden_size: int, output dim of hidden layer.
+ num_heads: int, number of heads to repeat the same attention structure.
+ attention_dropout: float, dropout rate inside attention for training.
+ """
+ if hidden_size % num_heads:
+ raise ValueError(
+ "Hidden size ({}) must be divisible by the number of heads ({})."
+ .format(hidden_size, num_heads))
+
+ super(Attention, self).__init__()
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.attention_dropout = attention_dropout
+
+ def build(self, input_shape):
+ """Builds the layer."""
+ # Layers for linearly projecting the queries, keys, and values.
+ size_per_head = self.hidden_size // self.num_heads
+
+ def _glorot_initializer(fan_in, fan_out):
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
+ return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
+
+ attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
+ self.hidden_size)
+ self.query_dense_layer = layers.DenseEinsum(
+ output_shape=(self.num_heads, size_per_head),
+ kernel_initializer=attention_initializer,
+ use_bias=False,
+ name="query")
+ self.key_dense_layer = layers.DenseEinsum(
+ output_shape=(self.num_heads, size_per_head),
+ kernel_initializer=attention_initializer,
+ use_bias=False,
+ name="key")
+ self.value_dense_layer = layers.DenseEinsum(
+ output_shape=(self.num_heads, size_per_head),
+ kernel_initializer=attention_initializer,
+ use_bias=False,
+ name="value")
+
+ output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
+ self.output_dense_layer = layers.DenseEinsum(
+ output_shape=self.hidden_size,
+ num_summed_dimensions=2,
+ kernel_initializer=output_initializer,
+ use_bias=False,
+ name="output_transform")
+ super(Attention, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "hidden_size": self.hidden_size,
+ "num_heads": self.num_heads,
+ "attention_dropout": self.attention_dropout,
+ }
+
+ def call(self, query_input, source_input, bias, training, cache=None,
+ decode_loop_step=None):
+ """Apply attention mechanism to query_input and source_input.
+
+ Args:
+ query_input: A tensor with shape [batch_size, length_query, hidden_size].
+ source_input: A tensor with shape [batch_size, length_source,
+ hidden_size].
+ bias: A tensor with shape [batch_size, 1, length_query, length_source],
+ the attention bias that will be added to the result of the dot product.
+ training: A bool, whether in training mode or not.
+ cache: (Used during prediction) A dictionary with tensors containing
+ results of previous attentions. The dictionary must have the items:
+ {"k": tensor with shape [batch_size, i, heads, dim_per_head],
+ "v": tensor with shape [batch_size, i, heads, dim_per_head]}
+ where i is the current decoded length for non-padded decode, or max
+ sequence length for padded decode.
+ decode_loop_step: An integer, step number of the decoding loop. Used only
+ for autoregressive inference on TPU.
+
+ Returns:
+ Attention layer output with shape [batch_size, length_query, hidden_size]
+ """
+ # Linearly project the query, key and value using different learned
+ # projections. Splitting heads is automatically done during the linear
+ # projections --> [batch_size, length, num_heads, dim_per_head].
+ query = self.query_dense_layer(query_input)
+ key = self.key_dense_layer(source_input)
+ value = self.value_dense_layer(source_input)
+
+ if cache is not None:
+ # Combine cached keys and values with new keys and values.
+ if decode_loop_step is not None:
+ cache_k_shape = cache["k"].shape.as_list()
+ indices = tf.reshape(
+ tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
+ [1, cache_k_shape[1], 1, 1])
+ key = cache["k"] + key * indices
+ cache_v_shape = cache["v"].shape.as_list()
+ indices = tf.reshape(
+ tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
+ [1, cache_v_shape[1], 1, 1])
+ value = cache["v"] + value * indices
+ else:
+ key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
+ value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
+
+ # Update cache
+ cache["k"] = key
+ cache["v"] = value
+
+ # Scale query to prevent the dot product between query and key from growing
+ # too large.
+ depth = (self.hidden_size // self.num_heads)
+ query *= depth ** -0.5
+
+ # Calculate dot product attention
+ logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
+ logits += bias
+ # Note that softmax internally performs math operations using float32
+ # for numeric stability. When training with float16, we keep the input
+ # and output in float16 for better performance.
+ weights = tf.nn.softmax(logits, name="attention_weights")
+ if training:
+ weights = tf.nn.dropout(weights, rate=self.attention_dropout)
+ attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value)
+
+ # Run the outputs through another linear projection layer. Recombining heads
+ # is automatically done --> [batch_size, length, hidden_size]
+ attention_output = self.output_dense_layer(attention_output)
+ return attention_output
+
+
+class SelfAttention(Attention):
+ """Multiheaded self-attention layer."""
+
+ def call(self, query_input, bias, training, cache=None,
+ decode_loop_step=None):
+ return super(SelfAttention, self).call(
+ query_input, query_input, bias, training, cache, decode_loop_step)
diff --git a/models/official/nlp/transformer/beam_search.py b/models/official/nlp/transformer/beam_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4c1127535e6ae805f6619819737c379cadca6f2
--- /dev/null
+++ b/models/official/nlp/transformer/beam_search.py
@@ -0,0 +1,132 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Beam search in TF v2."""
+
+import tensorflow as tf
+
+from official.nlp.transformer import beam_search_v1 as v1
+
+_StateKeys = v1._StateKeys # pylint: disable=protected-access
+
+
+class SequenceBeamSearchV2(v1.SequenceBeamSearch):
+ """Implementation of beam search loop in v2."""
+
+ def search(self, initial_ids, initial_cache):
+ """Beam search for sequences with highest scores."""
+ state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
+
+ finished_state = tf.nest.map_structure(
+ tf.stop_gradient,
+ tf.while_loop(self._continue_search,
+ self._search_step,
+ loop_vars=[state],
+ shape_invariants=[state_shapes],
+ parallel_iterations=1))
+ finished_state = finished_state[0]
+
+ alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
+ alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
+ finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
+ finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
+ finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
+
+ # 2.0 changes tf.where behavior. Should make parameters broadcastable.
+ finished_cond = tf.reduce_any(finished_flags, 1, name="finished_cond")
+ seq_cond = _expand_to_same_rank(finished_cond, finished_seq)
+ score_cond = _expand_to_same_rank(finished_cond, finished_scores)
+
+ # Account for corner case where there are no finished sequences for a
+ # particular batch item. In that case, return alive sequences for that batch
+ # item.
+ finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
+ finished_scores = tf.where(
+ score_cond, finished_scores, alive_log_probs)
+ return finished_seq, finished_scores
+
+
+def sequence_beam_search(symbols_to_logits_fn,
+ initial_ids,
+ initial_cache,
+ vocab_size,
+ beam_size,
+ alpha,
+ max_decode_length,
+ eos_id,
+ padded_decode=False,
+ dtype="float32"):
+ """Search for sequence of subtoken ids with the largest probability.
+
+ Args:
+ symbols_to_logits_fn: A function that takes in ids, index, and cache as
+ arguments. The passed in arguments will have shape:
+ ids -> A tensor with shape [batch_size * beam_size, index].
+ index -> A scalar.
+ cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
+ The function must return a tuple of logits and new cache:
+ logits -> A tensor with shape [batch * beam_size, vocab_size].
+ new cache -> A nested dictionary with the same shape/structure as the
+ inputted cache.
+ initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
+ each batch item.
+ initial_cache: A dictionary, containing starting decoder variables
+ information.
+ vocab_size: An integer, the size of tokens.
+ beam_size: An integer, the number of beams.
+ alpha: A float, defining the strength of length normalization.
+ max_decode_length: An integer, the maximum length to decoded a sequence.
+ eos_id: An integer, ID of eos token, used to determine when a sequence has
+ finished.
+ padded_decode: A bool, indicating if max_sequence_length padding is used
+ for beam search.
+ dtype: A tensorflow data type used for score computation. The default is
+ tf.float32.
+
+ Returns:
+ Top decoded sequences [batch_size, beam_size, max_decode_length]
+ sequence scores [batch_size, beam_size]
+ """
+ batch_size = (
+ initial_ids.shape.as_list()[0] if padded_decode else
+ tf.shape(initial_ids)[0])
+ sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
+ beam_size, alpha, max_decode_length, eos_id,
+ padded_decode, dtype)
+ return sbs.search(initial_ids, initial_cache)
+
+
+def _expand_to_same_rank(tensor, target):
+ """Expands a given tensor to target's rank to be broadcastable.
+
+ Args:
+ tensor: input tensor to tile. Shape: [b, d1, ..., da]
+ target: target tensor. Shape: [b, d1, ..., da, ..., dn]
+
+ Returns:
+ Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target.
+
+ Raises:
+ ValueError, if the shape rank of rank tensor/target is None.
+ """
+ if tensor.shape.rank is None:
+ raise ValueError("Expect rank for tensor shape, but got None.")
+ if target.shape.rank is None:
+ raise ValueError("Expect rank for target shape, but got None.")
+
+ with tf.name_scope("expand_rank"):
+ diff_rank = target.shape.rank - tensor.shape.rank
+ for _ in range(diff_rank):
+ tensor = tf.expand_dims(tensor, -1)
+ return tensor
diff --git a/models/official/nlp/transformer/beam_search_v1.py b/models/official/nlp/transformer/beam_search_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b143b1b30ef462f6187850b12a5ca9dfe3ab39b
--- /dev/null
+++ b/models/official/nlp/transformer/beam_search_v1.py
@@ -0,0 +1,675 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Beam search to find the translated sequence with the highest probability.
+
+Source implementation from Tensor2Tensor:
+https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py
+"""
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+from tensorflow.python.util import nest
+
+
+def inf(dtype):
+ """Returns a value close to infinity, but is still finite in `dtype`.
+
+ This is useful to get a very large value that is still zero when multiplied by
+ zero. The floating-point "Inf" value is NaN when multiplied by zero.
+
+ Args:
+ dtype: A dtype. The returned value will be finite when casted to this dtype.
+
+ Returns:
+ A very large value.
+ """
+ if dtype == "float32" or dtype == "bfloat16":
+ return 1e7
+ elif dtype == "float16":
+ # Disable no-member lint error, as the linter thinks np.float16 does not
+ # exist for some reason.
+ return np.finfo(np.float16).max # pylint: disable=no-member
+ else:
+ raise AssertionError('Invalid dtype: %s' % dtype)
+
+
+class _StateKeys(object):
+ """Keys to dictionary storing the state of the beam search loop."""
+
+ # Variable storing the loop index.
+ CUR_INDEX = "CUR_INDEX"
+
+ # Top sequences that are alive for each batch item. Alive sequences are ones
+ # that have not generated an EOS token. Sequences that reach EOS are marked as
+ # finished and moved to the FINISHED_SEQ tensor.
+ # Has shape [batch_size, beam_size, CUR_INDEX + 1]
+ ALIVE_SEQ = "ALIVE_SEQ"
+ # Log probabilities of each alive sequence. Shape [batch_size, beam_size]
+ ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS"
+ # Dictionary of cached values for each alive sequence. The cache stores
+ # the encoder output, attention bias, and the decoder attention output from
+ # the previous iteration.
+ ALIVE_CACHE = "ALIVE_CACHE"
+
+ # Top finished sequences for each batch item.
+ # Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
+ # shorter than CUR_INDEX + 1 are padded with 0s.
+ FINISHED_SEQ = "FINISHED_SEQ"
+ # Scores for each finished sequence. Score = log probability / length norm
+ # Shape [batch_size, beam_size]
+ FINISHED_SCORES = "FINISHED_SCORES"
+ # Flags indicating which sequences in the finished sequences are finished.
+ # At the beginning, all of the sequences in FINISHED_SEQ are filler values.
+ # True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
+ FINISHED_FLAGS = "FINISHED_FLAGS"
+
+
+class SequenceBeamSearch(object):
+ """Implementation of beam search loop."""
+
+ def __init__(self,
+ symbols_to_logits_fn,
+ vocab_size,
+ batch_size,
+ beam_size,
+ alpha,
+ max_decode_length,
+ eos_id,
+ padded_decode,
+ dtype=tf.float32):
+ """Initialize sequence beam search.
+
+ Args:
+ symbols_to_logits_fn: A function to provide logits, which is the
+ interface to the Transformer model. The passed in arguments are:
+ ids -> A tensor with shape [batch_size * beam_size, index].
+ index -> A scalar.
+ cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
+ The function must return a tuple of logits and the updated cache:
+ logits -> A tensor with shape [batch * beam_size, vocab_size].
+ updated cache -> A nested dictionary with the same structure as the
+ input cache.
+ vocab_size: An integer, the size of the vocabulary, used for topk
+ computation.
+ batch_size: An integer, the decode batch size.
+ beam_size: An integer, number of beams for beam search.
+ alpha: A float, defining the strength of length normalization.
+ max_decode_length: An integer, the maximum number of steps to decode
+ a sequence.
+ eos_id: An integer. ID of end of sentence token.
+ padded_decode: A bool, indicating if max_sequence_length padding is used
+ for beam search.
+ dtype: A tensorflow data type used for score computation. The default is
+ tf.float32.
+ """
+ self.symbols_to_logits_fn = symbols_to_logits_fn
+ self.vocab_size = vocab_size
+ self.batch_size = batch_size
+ self.beam_size = beam_size
+ self.alpha = alpha
+ self.max_decode_length = max_decode_length
+ self.eos_id = eos_id
+ self.padded_decode = padded_decode
+ self.dtype = tf.as_dtype(dtype)
+
+ def search(self, initial_ids, initial_cache):
+ """Beam search for sequences with highest scores."""
+ state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
+
+ finished_state = tf.while_loop(
+ self._continue_search, self._search_step, loop_vars=[state],
+ shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False)
+ finished_state = finished_state[0]
+
+ alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
+ alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
+ finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
+ finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
+ finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
+
+ # Account for corner case where there are no finished sequences for a
+ # particular batch item. In that case, return alive sequences for that batch
+ # item.
+ finished_seq = tf.where(
+ tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
+ finished_scores = tf.where(
+ tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
+ return finished_seq, finished_scores
+
+ def _create_initial_state(self, initial_ids, initial_cache):
+ """Return initial state dictionary and its shape invariants.
+
+ Args:
+ initial_ids: initial ids to pass into the symbols_to_logits_fn.
+ int tensor with shape [batch_size, 1]
+ initial_cache: dictionary storing values to be passed into the
+ symbols_to_logits_fn.
+
+ Returns:
+ state and shape invariant dictionaries with keys from _StateKeys
+ """
+ for key, value in initial_cache.items():
+ for inner_value in nest.flatten(value):
+ if inner_value.dtype != self.dtype:
+ raise TypeError(
+ "initial_cache element for key '%s' has dtype %s that does not "
+ "match SequenceBeamSearch's dtype of %s. Value: %s" %
+ (key, value.dtype.name, self.dtype.name, inner_value))
+
+ # Current loop index (starts at 0)
+ cur_index = tf.constant(0)
+
+ # Create alive sequence with shape [batch_size, beam_size, 1]
+ alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
+ alive_seq = tf.expand_dims(alive_seq, axis=2)
+ if self.padded_decode:
+ alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])
+
+ # Create tensor for storing initial log probabilities.
+ # Assume initial_ids are prob 1.0
+ initial_log_probs = tf.constant(
+ [[0.] + [-float("inf")] * (self.beam_size - 1)], dtype=self.dtype)
+ alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1])
+
+ # Expand all values stored in the dictionary to the beam size, so that each
+ # beam has a separate cache.
+ alive_cache = nest.map_structure(
+ lambda t: _expand_to_beam_size(t, self.beam_size), initial_cache)
+
+ # Initialize tensor storing finished sequences with filler values.
+ finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)
+
+ # Set scores of the initial finished seqs to negative infinity.
+ finished_scores = tf.ones([self.batch_size, self.beam_size],
+ dtype=self.dtype) * -inf(self.dtype)
+
+ # Initialize finished flags with all False values.
+ finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool)
+
+ # Create state dictionary
+ state = {
+ _StateKeys.CUR_INDEX: cur_index,
+ _StateKeys.ALIVE_SEQ: alive_seq,
+ _StateKeys.ALIVE_LOG_PROBS: alive_log_probs,
+ _StateKeys.ALIVE_CACHE: alive_cache,
+ _StateKeys.FINISHED_SEQ: finished_seq,
+ _StateKeys.FINISHED_SCORES: finished_scores,
+ _StateKeys.FINISHED_FLAGS: finished_flags
+ }
+
+ # Create state invariants for each value in the state dictionary. Each
+ # dimension must be a constant or None. A None dimension means either:
+ # 1) the dimension's value is a tensor that remains the same but may
+ # depend on the input sequence to the model (e.g. batch size).
+ # 2) the dimension may have different values on different iterations.
+ if self.padded_decode:
+ state_shape_invariants = {
+ _StateKeys.CUR_INDEX:
+ tf.TensorShape([]),
+ _StateKeys.ALIVE_SEQ:
+ tf.TensorShape(
+ [self.batch_size, self.beam_size,
+ self.max_decode_length + 1]),
+ _StateKeys.ALIVE_LOG_PROBS:
+ tf.TensorShape([self.batch_size, self.beam_size]),
+ _StateKeys.ALIVE_CACHE:
+ nest.map_structure(_get_shape, alive_cache),
+ _StateKeys.FINISHED_SEQ:
+ tf.TensorShape(
+ [self.batch_size, self.beam_size,
+ self.max_decode_length + 1]),
+ _StateKeys.FINISHED_SCORES:
+ tf.TensorShape([self.batch_size, self.beam_size]),
+ _StateKeys.FINISHED_FLAGS:
+ tf.TensorShape([self.batch_size, self.beam_size])
+ }
+ else:
+ state_shape_invariants = {
+ _StateKeys.CUR_INDEX:
+ tf.TensorShape([]),
+ _StateKeys.ALIVE_SEQ:
+ tf.TensorShape([None, self.beam_size, None]),
+ _StateKeys.ALIVE_LOG_PROBS:
+ tf.TensorShape([None, self.beam_size]),
+ _StateKeys.ALIVE_CACHE:
+ nest.map_structure(_get_shape_keep_last_dim, alive_cache),
+ _StateKeys.FINISHED_SEQ:
+ tf.TensorShape([None, self.beam_size, None]),
+ _StateKeys.FINISHED_SCORES:
+ tf.TensorShape([None, self.beam_size]),
+ _StateKeys.FINISHED_FLAGS:
+ tf.TensorShape([None, self.beam_size])
+ }
+
+ return state, state_shape_invariants
+
+ def _continue_search(self, state):
+ """Return whether to continue the search loop.
+
+ The loops should terminate when
+ 1) when decode length has been reached, or
+ 2) when the worst score in the finished sequences is better than the best
+ score in the alive sequences (i.e. the finished sequences are provably
+ unchanging)
+
+ Args:
+ state: A dictionary with the current loop state.
+
+ Returns:
+ Bool tensor with value True if loop should continue, False if loop should
+ terminate.
+ """
+ i = state[_StateKeys.CUR_INDEX]
+ alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS]
+ finished_scores = state[_StateKeys.FINISHED_SCORES]
+ finished_flags = state[_StateKeys.FINISHED_FLAGS]
+
+ not_at_max_decode_length = tf.less(i, self.max_decode_length)
+
+ # Calculate largest length penalty (the larger penalty, the better score).
+ max_length_norm = _length_normalization(self.alpha, self.max_decode_length,
+ dtype=self.dtype)
+ # Get the best possible scores from alive sequences.
+ best_alive_scores = alive_log_probs[:, 0] / max_length_norm
+
+ # Compute worst score in finished sequences for each batch element
+ finished_scores *= tf.cast(finished_flags,
+ self.dtype) # set filler scores to zero
+ lowest_finished_scores = tf.reduce_min(finished_scores, axis=1)
+
+ # If there are no finished sequences in a batch element, then set the lowest
+ # finished score to -INF for that element.
+ finished_batches = tf.reduce_any(finished_flags, 1)
+ lowest_finished_scores += ((1.0 -
+ tf.cast(finished_batches, self.dtype)) *
+ -inf(self.dtype))
+
+ worst_finished_score_better_than_best_alive_score = tf.reduce_all(
+ tf.greater(lowest_finished_scores, best_alive_scores)
+ )
+
+ return tf.logical_and(
+ not_at_max_decode_length,
+ tf.logical_not(worst_finished_score_better_than_best_alive_score)
+ )
+
+ def _search_step(self, state):
+ """Beam search loop body.
+
+ Grow alive sequences by a single ID. Sequences that have reached the EOS
+ token are marked as finished. The alive and finished sequences with the
+ highest log probabilities and scores are returned.
+
+ A sequence's finished score is calculating by dividing the log probability
+ by the length normalization factor. Without length normalization, the
+ search is more likely to return shorter sequences.
+
+ Args:
+ state: A dictionary with the current loop state.
+
+ Returns:
+ new state dictionary.
+ """
+ # Grow alive sequences by one token.
+ new_seq, new_log_probs, topk_ids, new_cache = self._grow_alive_seq(state)
+ new_finished_flags = tf.equal(topk_ids, self.eos_id)
+ # Collect top beam_size alive sequences
+ alive_state = self._get_new_alive_state(new_seq, new_log_probs,
+ new_finished_flags, new_cache)
+
+ # Combine newly finished sequences with existing finished sequences, and
+ # collect the top k scoring sequences.
+ finished_state = self._get_new_finished_state(state, new_seq, new_log_probs,
+ new_finished_flags)
+
+ # Increment loop index and create new state dictionary
+ new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
+ new_state.update(alive_state)
+ new_state.update(finished_state)
+ return [new_state]
+
+ def _grow_alive_seq(self, state):
+ """Grow alive sequences by one token, and collect top 2*beam_size sequences.
+
+ 2*beam_size sequences are collected because some sequences may have reached
+ the EOS token. 2*beam_size ensures that at least beam_size sequences are
+ still alive.
+
+ Args:
+ state: A dictionary with the current loop state.
+ Returns:
+ Tuple of
+ (Top 2*beam_size sequences [batch_size, 2 * beam_size, cur_index + 1],
+ Scores of returned sequences [batch_size, 2 * beam_size],
+ New alive cache, for each of the 2 * beam_size sequences)
+ """
+ i = state[_StateKeys.CUR_INDEX]
+ alive_seq = state[_StateKeys.ALIVE_SEQ]
+ alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS]
+ alive_cache = state[_StateKeys.ALIVE_CACHE]
+
+ beams_to_keep = 2 * self.beam_size
+
+ # Get logits for the next candidate IDs for the alive sequences. Get the new
+ # cache values at the same time.
+ if self.padded_decode:
+ flat_ids = tf.reshape(
+ tf.slice(alive_seq, [0, 0, i], [self.batch_size, self.beam_size, 1]),
+ [self.batch_size * self.beam_size, -1])
+ else:
+ flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
+ flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache)
+
+ flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache)
+
+ # Unflatten logits to shape [batch_size, beam_size, vocab_size]
+ logits = _unflatten_beam_dim(flat_logits, self.batch_size, self.beam_size)
+ new_cache = nest.map_structure(
+ lambda t: _unflatten_beam_dim(t, self.batch_size, self.beam_size),
+ flat_cache)
+
+ # Convert logits to normalized log probs
+ candidate_log_probs = _log_prob_from_logits(logits)
+
+ # Calculate new log probabilities if each of the alive sequences were
+ # extended # by the the candidate IDs.
+ # Shape [batch_size, beam_size, vocab_size]
+ log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)
+
+ # Each batch item has beam_size * vocab_size candidate sequences. For each
+ # batch item, get the k candidates with the highest log probabilities.
+ flat_log_probs = tf.reshape(log_probs,
+ [-1, self.beam_size * self.vocab_size])
+ topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep)
+
+ # Extract the alive sequences that generate the highest log probabilities
+ # after being extended.
+ topk_beam_indices = topk_indices // self.vocab_size
+ topk_seq, new_cache = _gather_beams(
+ [alive_seq, new_cache], topk_beam_indices, self.batch_size,
+ beams_to_keep)
+
+ # Append the most probable IDs to the topk sequences
+ topk_ids = topk_indices % self.vocab_size
+ if self.padded_decode:
+ topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
+ # TODO(b/145533236, hongkuny): Reverts once TF fix the validation.
+ topk_seq = tf.tensor_scatter_nd_update(topk_seq, [[i + 1]],
+ tf.expand_dims(topk_ids, axis=0))
+ topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
+ else:
+ topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
+ return topk_seq, topk_log_probs, topk_ids, new_cache
+
+ def _get_new_alive_state(self, new_seq, new_log_probs, new_finished_flags,
+ new_cache):
+ """Gather the top k sequences that are still alive.
+
+ Args:
+ new_seq: New sequences generated by growing the current alive sequences
+ int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
+ new_log_probs: Log probabilities of new sequences float32 tensor with
+ shape [batch_size, beam_size]
+ new_finished_flags: A boolean Tensor indicates which sequences are live
+ inside the beam.
+ new_cache: Dict of cached values for each sequence.
+
+ Returns:
+ Dictionary with alive keys from _StateKeys:
+ {Top beam_size sequences that are still alive (don't end with eos_id)
+ Log probabilities of top alive sequences
+ Dict cache storing decoder states for top alive sequences}
+ """
+ # To prevent finished sequences from being considered, set log probs to -inf
+ new_log_probs += tf.cast(new_finished_flags, self.dtype) * -inf(self.dtype)
+
+ top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
+ [new_seq, new_log_probs, new_cache], new_log_probs, self.batch_size,
+ self.beam_size)
+
+ return {
+ _StateKeys.ALIVE_SEQ: top_alive_seq,
+ _StateKeys.ALIVE_LOG_PROBS: top_alive_log_probs,
+ _StateKeys.ALIVE_CACHE: top_alive_cache
+ }
+
+ def _get_new_finished_state(self, state, new_seq, new_log_probs,
+ new_finished_flags):
+ """Combine new and old finished sequences, and gather the top k sequences.
+
+ Args:
+ state: A dictionary with the current loop state.
+ new_seq: New sequences generated by growing the current alive sequences
+ int32 tensor with shape [batch_size, beam_size, i + 1]
+ new_log_probs: Log probabilities of new sequences float32 tensor with
+ shape [batch_size, beam_size]
+ new_finished_flags: A boolean Tensor indicates which sequences are live
+ inside the beam.
+
+ Returns:
+ Dictionary with finished keys from _StateKeys:
+ {Top beam_size finished sequences based on score,
+ Scores of finished sequences,
+ Finished flags of finished sequences}
+ """
+ i = state[_StateKeys.CUR_INDEX]
+ finished_seq = state[_StateKeys.FINISHED_SEQ]
+ finished_scores = state[_StateKeys.FINISHED_SCORES]
+ finished_flags = state[_StateKeys.FINISHED_FLAGS]
+
+ # First append a column of 0-ids to finished_seq to increment the length.
+ # New shape of finished_seq: [batch_size, beam_size, i + 1]
+ if not self.padded_decode:
+ finished_seq = tf.concat([
+ finished_seq,
+ tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)
+ ],
+ axis=2)
+
+ # Calculate new seq scores from log probabilities.
+ length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype)
+ new_scores = new_log_probs / length_norm
+
+ # Set the scores of the still-alive seq in new_seq to large negative values.
+ new_scores += ((1. - tf.cast(new_finished_flags, self.dtype)) *
+ -inf(self.dtype))
+
+ # Combine sequences, scores, and flags.
+ finished_seq = tf.concat([finished_seq, new_seq], axis=1)
+ finished_scores = tf.concat([finished_scores, new_scores], axis=1)
+ finished_flags = tf.concat([finished_flags, new_finished_flags], axis=1)
+
+ # Return the finished sequences with the best scores.
+ top_finished_seq, top_finished_scores, top_finished_flags = (
+ _gather_topk_beams([finished_seq, finished_scores, finished_flags],
+ finished_scores, self.batch_size, self.beam_size))
+
+ return {
+ _StateKeys.FINISHED_SEQ: top_finished_seq,
+ _StateKeys.FINISHED_SCORES: top_finished_scores,
+ _StateKeys.FINISHED_FLAGS: top_finished_flags
+ }
+
+
+def sequence_beam_search(
+ symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
+ alpha, max_decode_length, eos_id, padded_decode=False):
+ """Search for sequence of subtoken ids with the largest probability.
+
+ Args:
+ symbols_to_logits_fn: A function that takes in ids, index, and cache as
+ arguments. The passed in arguments will have shape:
+ ids -> A tensor with shape [batch_size * beam_size, index].
+ index -> A scalar.
+ cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
+ The function must return a tuple of logits and new cache:
+ logits -> A tensor with shape [batch * beam_size, vocab_size].
+ new cache -> A nested dictionary with the same shape/structure as the
+ inputted cache.
+ initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
+ each batch item.
+ initial_cache: A dictionary, containing starting decoder variables
+ information.
+ vocab_size: An integer, the size of the vocabulary, used for topk
+ computation.
+ beam_size: An integer, the number of beams.
+ alpha: A float, defining the strength of length normalization.
+ max_decode_length: An integer, the maximum length to decoded a sequence.
+ eos_id: An integer, ID of eos token, used to determine when a sequence has
+ finished.
+ padded_decode: A bool, indicating if max_sequence_length padding is used
+ for beam search.
+
+ Returns:
+ Top decoded sequences [batch_size, beam_size, max_decode_length]
+ sequence scores [batch_size, beam_size]
+ """
+ batch_size = (
+ initial_ids.shape.as_list()[0] if padded_decode else
+ tf.shape(initial_ids)[0])
+ sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
+ beam_size, alpha, max_decode_length, eos_id,
+ padded_decode)
+ return sbs.search(initial_ids, initial_cache)
+
+
+def _log_prob_from_logits(logits):
+ return logits - tf.reduce_logsumexp(logits, axis=2, keepdims=True)
+
+
+def _length_normalization(alpha, length, dtype=tf.float32):
+ """Return length normalization factor."""
+ return tf.pow(((5. + tf.cast(length, dtype)) / 6.), alpha)
+
+
+def _expand_to_beam_size(tensor, beam_size):
+ """Tiles a given tensor by beam_size.
+
+ Args:
+ tensor: tensor to tile [batch_size, ...]
+ beam_size: How much to tile the tensor by.
+
+ Returns:
+ Tiled tensor [batch_size, beam_size, ...]
+ """
+ tensor = tf.expand_dims(tensor, axis=1)
+ tile_dims = [1] * tensor.shape.ndims
+ tile_dims[1] = beam_size
+
+ return tf.tile(tensor, tile_dims)
+
+
+def _shape_list(tensor):
+ """Return a list of the tensor's shape, and ensure no None values in list."""
+ # Get statically known shape (may contain None's for unknown dimensions)
+ shape = tensor.get_shape().as_list()
+
+ # Ensure that the shape values are not None
+ dynamic_shape = tf.shape(tensor)
+ for i in range(len(shape)): # pylint: disable=consider-using-enumerate
+ if shape[i] is None:
+ shape[i] = dynamic_shape[i]
+ return shape
+
+
+def _get_shape_keep_last_dim(tensor):
+ shape_list = _shape_list(tensor)
+
+ # Only the last
+ for i in range(len(shape_list) - 1):
+ shape_list[i] = None
+
+ if isinstance(shape_list[-1], tf.Tensor):
+ shape_list[-1] = None
+ return tf.TensorShape(shape_list)
+
+
+def _get_shape(tensor):
+ """Return the shape of the input tensor."""
+ return tf.TensorShape(_shape_list(tensor))
+
+
+def _flatten_beam_dim(tensor):
+ """Reshapes first two dimensions in to single dimension.
+
+ Args:
+ tensor: Tensor to reshape of shape [A, B, ...]
+
+ Returns:
+ Reshaped tensor of shape [A*B, ...]
+ """
+ shape = _shape_list(tensor)
+ shape[0] *= shape[1]
+ shape.pop(1) # Remove beam dim
+ return tf.reshape(tensor, shape)
+
+
+def _unflatten_beam_dim(tensor, batch_size, beam_size):
+ """Reshapes first dimension back to [batch_size, beam_size].
+
+ Args:
+ tensor: Tensor to reshape of shape [batch_size*beam_size, ...]
+ batch_size: Tensor, original batch size.
+ beam_size: int, original beam size.
+
+ Returns:
+ Reshaped tensor of shape [batch_size, beam_size, ...]
+ """
+ shape = _shape_list(tensor)
+ new_shape = [batch_size, beam_size] + shape[1:]
+ return tf.reshape(tensor, new_shape)
+
+
+def _gather_beams(nested, beam_indices, batch_size, new_beam_size):
+ """Gather beams from nested structure of tensors.
+
+ Each tensor in nested represents a batch of beams, where beam refers to a
+ single search state (beam search involves searching through multiple states
+ in parallel).
+
+ This function is used to gather the top beams, specified by
+ beam_indices, from the nested tensors.
+
+ Args:
+ nested: Nested structure (tensor, list, tuple or dict) containing tensors
+ with shape [batch_size, beam_size, ...].
+ beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
+ value in beam_indices must be between [0, beam_size), and are not
+ necessarily unique.
+ batch_size: int size of batch
+ new_beam_size: int number of beams to be pulled from the nested tensors.
+
+ Returns:
+ Nested structure containing tensors with shape
+ [batch_size, new_beam_size, ...]
+ """
+ # Computes the i'th coodinate that contains the batch index for gather_nd.
+ # Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..].
+ batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size
+ batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size])
+
+ # Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor
+ # with shape [batch_size, beam_size, 2], where the last dimension contains
+ # the (i, j) gathering coordinates.
+ coordinates = tf.stack([batch_pos, beam_indices], axis=2)
+
+ return nest.map_structure(
+ lambda state: tf.gather_nd(state, coordinates), nested)
+
+
+def _gather_topk_beams(nested, score_or_log_prob, batch_size, beam_size):
+ """Gather top beams from nested structure."""
+ _, topk_indexes = tf.nn.top_k(score_or_log_prob, k=beam_size)
+ return _gather_beams(nested, topk_indexes, batch_size, beam_size)
diff --git a/models/official/nlp/transformer/beam_search_v1_test.py b/models/official/nlp/transformer/beam_search_v1_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cf921fb90e93950a05e999807fc497390674a1
--- /dev/null
+++ b/models/official/nlp/transformer/beam_search_v1_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test beam search helper methods."""
+
+import tensorflow.compat.v1 as tf
+
+from official.nlp.transformer import beam_search_v1 as beam_search
+
+
+class BeamSearchHelperTests(tf.test.TestCase):
+
+ def setUp(self):
+ super(BeamSearchHelperTests, self).setUp()
+ tf.compat.v1.disable_eager_execution()
+
+ def test_expand_to_beam_size(self):
+ x = tf.ones([7, 4, 2, 5])
+ x = beam_search._expand_to_beam_size(x, 3)
+ with self.session() as sess:
+ shape = sess.run(tf.shape(x))
+ self.assertAllEqual([7, 3, 4, 2, 5], shape)
+
+ def test_shape_list(self):
+ y = tf.compat.v1.placeholder(dtype=tf.int32, shape=[])
+ x = tf.ones([7, y, 2, 5])
+ shape = beam_search._shape_list(x)
+ self.assertIsInstance(shape[0], int)
+ self.assertIsInstance(shape[1], tf.Tensor)
+ self.assertIsInstance(shape[2], int)
+ self.assertIsInstance(shape[3], int)
+
+ def test_get_shape_keep_last_dim(self):
+ y = tf.constant(4.0)
+ x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5])
+ shape = beam_search._get_shape_keep_last_dim(x)
+ self.assertAllEqual([None, None, None, 5],
+ shape.as_list())
+
+ def test_flatten_beam_dim(self):
+ x = tf.ones([7, 4, 2, 5])
+ x = beam_search._flatten_beam_dim(x)
+ with self.session() as sess:
+ shape = sess.run(tf.shape(x))
+ self.assertAllEqual([28, 2, 5], shape)
+
+ def test_unflatten_beam_dim(self):
+ x = tf.ones([28, 2, 5])
+ x = beam_search._unflatten_beam_dim(x, 7, 4)
+ with self.session() as sess:
+ shape = sess.run(tf.shape(x))
+ self.assertAllEqual([7, 4, 2, 5], shape)
+
+ def test_gather_beams(self):
+ x = tf.reshape(tf.range(24), [2, 3, 4])
+ # x looks like: [[[ 0 1 2 3]
+ # [ 4 5 6 7]
+ # [ 8 9 10 11]]
+ #
+ # [[12 13 14 15]
+ # [16 17 18 19]
+ # [20 21 22 23]]]
+
+ y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2)
+ with self.session() as sess:
+ y = sess.run(y)
+
+ self.assertAllEqual([[[4, 5, 6, 7],
+ [8, 9, 10, 11]],
+ [[12, 13, 14, 15],
+ [20, 21, 22, 23]]],
+ y)
+
+ def test_gather_topk_beams(self):
+ x = tf.reshape(tf.range(24), [2, 3, 4])
+ x_scores = [[0, 1, 1], [1, 0, 1]]
+
+ y = beam_search._gather_topk_beams(x, x_scores, 2, 2)
+ with self.session() as sess:
+ y = sess.run(y)
+
+ self.assertAllEqual([[[4, 5, 6, 7],
+ [8, 9, 10, 11]],
+ [[12, 13, 14, 15],
+ [20, 21, 22, 23]]],
+ y)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/transformer/compute_bleu.py b/models/official/nlp/transformer/compute_bleu.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d54c30ecbc844d271b49f49ed19abc09098abf
--- /dev/null
+++ b/models/official/nlp/transformer/compute_bleu.py
@@ -0,0 +1,148 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to compute official BLEU score.
+
+Source:
+https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import sys
+import unicodedata
+
+from absl import app as absl_app
+from absl import flags
+import six
+from six.moves import range
+import tensorflow as tf
+
+from official.nlp.transformer.utils import metrics
+from official.nlp.transformer.utils import tokenizer
+from official.utils.flags import core as flags_core
+
+
+class UnicodeRegex(object):
+ """Ad-hoc hack to recognize all punctuation and symbols."""
+
+ def __init__(self):
+ punctuation = self.property_chars("P")
+ self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
+ self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
+ self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
+
+ def property_chars(self, prefix):
+ return "".join(
+ six.unichr(x)
+ for x in range(sys.maxunicode)
+ if unicodedata.category(six.unichr(x)).startswith(prefix))
+
+
+uregex = UnicodeRegex()
+
+
+def bleu_tokenize(string):
+ r"""Tokenize a string following the official BLEU implementation.
+
+ See https://github.com/moses-smt/mosesdecoder/'
+ 'blob/master/scripts/generic/mteval-v14.pl#L954-L983
+ In our case, the input string is expected to be just one line
+ and no HTML entities de-escaping is needed.
+ So we just tokenize on punctuation and symbols,
+ except when a punctuation is preceded and followed by a digit
+ (e.g. a comma/dot as a thousand/decimal separator).
+
+ Note that a numer (e.g. a year) followed by a dot at the end of sentence
+ is NOT tokenized,
+ i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
+ does not match this case (unless we add a space after each sentence).
+ However, this error is already in the original mteval-v14.pl
+ and we want to be consistent with it.
+
+ Args:
+ string: the input string
+
+ Returns:
+ a list of tokens
+ """
+ string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
+ string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
+ string = uregex.symbol_re.sub(r" \1 ", string)
+ return string.split()
+
+
+def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
+ """Compute BLEU for two files (reference and hypothesis translation)."""
+ ref_lines = tokenizer.native_to_unicode(
+ tf.io.gfile.GFile(ref_filename).read()).strip().splitlines()
+ hyp_lines = tokenizer.native_to_unicode(
+ tf.io.gfile.GFile(hyp_filename).read()).strip().splitlines()
+
+ if len(ref_lines) != len(hyp_lines):
+ raise ValueError(
+ "Reference and translation files have different number of "
+ "lines (%d VS %d). If training only a few steps (100-200), the "
+ "translation may be empty." % (len(ref_lines), len(hyp_lines)))
+ if not case_sensitive:
+ ref_lines = [x.lower() for x in ref_lines]
+ hyp_lines = [x.lower() for x in hyp_lines]
+ ref_tokens = [bleu_tokenize(x) for x in ref_lines]
+ hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
+ return metrics.compute_bleu(ref_tokens, hyp_tokens) * 100
+
+
+def main(unused_argv):
+ if FLAGS.bleu_variant in ("both", "uncased"):
+ score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
+ tf.logging.info("Case-insensitive results: %f" % score)
+
+ if FLAGS.bleu_variant in ("both", "cased"):
+ score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
+ tf.logging.info("Case-sensitive results: %f" % score)
+
+
+def define_compute_bleu_flags():
+ """Add flags for computing BLEU score."""
+ flags.DEFINE_string(
+ name="translation",
+ default=None,
+ help=flags_core.help_wrap("File containing translated text."))
+ flags.mark_flag_as_required("translation")
+
+ flags.DEFINE_string(
+ name="reference",
+ default=None,
+ help=flags_core.help_wrap("File containing reference translation."))
+ flags.mark_flag_as_required("reference")
+
+ flags.DEFINE_enum(
+ name="bleu_variant",
+ short_name="bv",
+ default="both",
+ enum_values=["both", "uncased", "cased"],
+ case_sensitive=False,
+ help=flags_core.help_wrap(
+ "Specify one or more BLEU variants to calculate. Variants: \"cased\""
+ ", \"uncased\", or \"both\"."))
+
+
+if __name__ == "__main__":
+ tf.logging.set_verbosity(tf.logging.INFO)
+ define_compute_bleu_flags()
+ FLAGS = flags.FLAGS
+ absl_app.run(main)
diff --git a/models/official/nlp/transformer/compute_bleu_test.py b/models/official/nlp/transformer/compute_bleu_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c578e3698a7e6cc2d7170433d4565cd3d8091ed
--- /dev/null
+++ b/models/official/nlp/transformer/compute_bleu_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test functions in compute_blue.py."""
+
+import tempfile
+
+import tensorflow as tf
+
+from official.nlp.transformer import compute_bleu
+
+
+class ComputeBleuTest(tf.test.TestCase):
+
+ def _create_temp_file(self, text):
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ with tf.io.gfile.GFile(temp_file.name, "w") as w:
+ w.write(text)
+ return temp_file.name
+
+ def test_bleu_same(self):
+ ref = self._create_temp_file("test 1 two 3\nmore tests!")
+ hyp = self._create_temp_file("test 1 two 3\nmore tests!")
+
+ uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
+ cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
+ self.assertEqual(100, uncased_score)
+ self.assertEqual(100, cased_score)
+
+ def test_bleu_same_different_case(self):
+ ref = self._create_temp_file("Test 1 two 3\nmore tests!")
+ hyp = self._create_temp_file("test 1 two 3\nMore tests!")
+ uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
+ cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
+ self.assertEqual(100, uncased_score)
+ self.assertLess(cased_score, 100)
+
+ def test_bleu_different(self):
+ ref = self._create_temp_file("Testing\nmore tests!")
+ hyp = self._create_temp_file("Dog\nCat")
+ uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
+ cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
+ self.assertLess(uncased_score, 100)
+ self.assertLess(cased_score, 100)
+
+ def test_bleu_tokenize(self):
+ s = "Test0, 1 two, 3"
+ tokenized = compute_bleu.bleu_tokenize(s)
+ self.assertEqual(["Test0", ",", "1", "two", ",", "3"], tokenized)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/transformer/data_download.py b/models/official/nlp/transformer/data_download.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5f66685611e1ad379d05dcf321a679527914b19
--- /dev/null
+++ b/models/official/nlp/transformer/data_download.py
@@ -0,0 +1,439 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Download and preprocess WMT17 ende training and evaluation datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+import tarfile
+
+# pylint: disable=g-bad-import-order
+from absl import app as absl_app
+from absl import flags
+from absl import logging
+import six
+from six.moves import range
+from six.moves import urllib
+from six.moves import zip
+import tensorflow.compat.v1 as tf
+
+from official.nlp.transformer.utils import tokenizer
+from official.utils.flags import core as flags_core
+# pylint: enable=g-bad-import-order
+
+# Data sources for training/evaluating the transformer translation model.
+# If any of the training sources are changed, then either:
+# 1) use the flag `--search` to find the best min count or
+# 2) update the _TRAIN_DATA_MIN_COUNT constant.
+# min_count is the minimum number of times a token must appear in the data
+# before it is added to the vocabulary. "Best min count" refers to the value
+# that generates a vocabulary set that is closest in size to _TARGET_VOCAB_SIZE.
+_TRAIN_DATA_SOURCES = [
+ {
+ "url": "http://data.statmt.org/wmt17/translation-task/"
+ "training-parallel-nc-v12.tgz",
+ "input": "news-commentary-v12.de-en.en",
+ "target": "news-commentary-v12.de-en.de",
+ },
+ {
+ "url": "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
+ "input": "commoncrawl.de-en.en",
+ "target": "commoncrawl.de-en.de",
+ },
+ {
+ "url": "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
+ "input": "europarl-v7.de-en.en",
+ "target": "europarl-v7.de-en.de",
+ },
+]
+# Use pre-defined minimum count to generate subtoken vocabulary.
+_TRAIN_DATA_MIN_COUNT = 6
+
+_EVAL_DATA_SOURCES = [
+ {
+ "url": "http://data.statmt.org/wmt17/translation-task/dev.tgz",
+ "input": "newstest2013.en",
+ "target": "newstest2013.de",
+ }
+]
+
+_TEST_DATA_SOURCES = [
+ {
+ "url": ("https://storage.googleapis.com/tf-perf-public/"
+ "official_transformer/test_data/newstest2014.tgz"),
+ "input": "newstest2014.en",
+ "target": "newstest2014.de",
+ }
+]
+
+# Vocabulary constants
+_TARGET_VOCAB_SIZE = 32768 # Number of subtokens in the vocabulary list.
+_TARGET_THRESHOLD = 327 # Accept vocabulary if size is within this threshold
+VOCAB_FILE = "vocab.ende.%d" % _TARGET_VOCAB_SIZE
+
+# Strings to inclue in the generated files.
+_PREFIX = "wmt32k"
+_TRAIN_TAG = "train"
+_EVAL_TAG = "dev" # Following WMT and Tensor2Tensor conventions, in which the
+# evaluation datasets are tagged as "dev" for development.
+
+# Number of files to split train and evaluation data
+_TRAIN_SHARDS = 100
+_EVAL_SHARDS = 1
+
+
+def find_file(path, filename, max_depth=5):
+ """Returns full filepath if the file is in path or a subdirectory."""
+ for root, dirs, files in os.walk(path):
+ if filename in files:
+ return os.path.join(root, filename)
+
+ # Don't search past max_depth
+ depth = root[len(path) + 1:].count(os.sep)
+ if depth > max_depth:
+ del dirs[:] # Clear dirs
+ return None
+
+
+###############################################################################
+# Download and extraction functions
+###############################################################################
+def get_raw_files(raw_dir, data_source):
+ """Return raw files from source. Downloads/extracts if needed.
+
+ Args:
+ raw_dir: string directory to store raw files
+ data_source: dictionary with
+ {"url": url of compressed dataset containing input and target files
+ "input": file with data in input language
+ "target": file with data in target language}
+
+ Returns:
+ dictionary with
+ {"inputs": list of files containing data in input language
+ "targets": list of files containing corresponding data in target language
+ }
+ """
+ raw_files = {
+ "inputs": [],
+ "targets": [],
+ } # keys
+ for d in data_source:
+ input_file, target_file = download_and_extract(
+ raw_dir, d["url"], d["input"], d["target"])
+ raw_files["inputs"].append(input_file)
+ raw_files["targets"].append(target_file)
+ return raw_files
+
+
+def download_report_hook(count, block_size, total_size):
+ """Report hook for download progress.
+
+ Args:
+ count: current block number
+ block_size: block size
+ total_size: total size
+ """
+ percent = int(count * block_size * 100 / total_size)
+ print(six.ensure_str("\r%d%%" % percent) + " completed", end="\r")
+
+
+def download_from_url(path, url):
+ """Download content from a url.
+
+ Args:
+ path: string directory where file will be downloaded
+ url: string url
+
+ Returns:
+ Full path to downloaded file
+ """
+ filename = six.ensure_str(url).split("/")[-1]
+ found_file = find_file(path, filename, max_depth=0)
+ if found_file is None:
+ filename = os.path.join(path, filename)
+ logging.info("Downloading from %s to %s." % (url, filename))
+ inprogress_filepath = six.ensure_str(filename) + ".incomplete"
+ inprogress_filepath, _ = urllib.request.urlretrieve(
+ url, inprogress_filepath, reporthook=download_report_hook)
+ # Print newline to clear the carriage return from the download progress.
+ print()
+ tf.gfile.Rename(inprogress_filepath, filename)
+ return filename
+ else:
+ logging.info("Already downloaded: %s (at %s)." % (url, found_file))
+ return found_file
+
+
+def download_and_extract(path, url, input_filename, target_filename):
+ """Extract files from downloaded compressed archive file.
+
+ Args:
+ path: string directory where the files will be downloaded
+ url: url containing the compressed input and target files
+ input_filename: name of file containing data in source language
+ target_filename: name of file containing data in target language
+
+ Returns:
+ Full paths to extracted input and target files.
+
+ Raises:
+ OSError: if the the download/extraction fails.
+ """
+ # Check if extracted files already exist in path
+ input_file = find_file(path, input_filename)
+ target_file = find_file(path, target_filename)
+ if input_file and target_file:
+ logging.info("Already downloaded and extracted %s." % url)
+ return input_file, target_file
+
+ # Download archive file if it doesn't already exist.
+ compressed_file = download_from_url(path, url)
+
+ # Extract compressed files
+ logging.info("Extracting %s." % compressed_file)
+ with tarfile.open(compressed_file, "r:gz") as corpus_tar:
+ corpus_tar.extractall(path)
+
+ # Return file paths of the requested files.
+ input_file = find_file(path, input_filename)
+ target_file = find_file(path, target_filename)
+
+ if input_file and target_file:
+ return input_file, target_file
+
+ raise OSError("Download/extraction failed for url %s to path %s" %
+ (url, path))
+
+
+def txt_line_iterator(path):
+ """Iterate through lines of file."""
+ with tf.io.gfile.GFile(path) as f:
+ for line in f:
+ yield line.strip()
+
+
+def compile_files(raw_dir, raw_files, tag):
+ """Compile raw files into a single file for each language.
+
+ Args:
+ raw_dir: Directory containing downloaded raw files.
+ raw_files: Dict containing filenames of input and target data.
+ {"inputs": list of files containing data in input language
+ "targets": list of files containing corresponding data in target language
+ }
+ tag: String to append to the compiled filename.
+
+ Returns:
+ Full path of compiled input and target files.
+ """
+ logging.info("Compiling files with tag %s." % tag)
+ filename = "%s-%s" % (_PREFIX, tag)
+ input_compiled_file = os.path.join(raw_dir,
+ six.ensure_str(filename) + ".lang1")
+ target_compiled_file = os.path.join(raw_dir,
+ six.ensure_str(filename) + ".lang2")
+
+ with tf.io.gfile.GFile(input_compiled_file, mode="w") as input_writer:
+ with tf.io.gfile.GFile(target_compiled_file, mode="w") as target_writer:
+ for i in range(len(raw_files["inputs"])):
+ input_file = raw_files["inputs"][i]
+ target_file = raw_files["targets"][i]
+
+ logging.info("Reading files %s and %s." % (input_file, target_file))
+ write_file(input_writer, input_file)
+ write_file(target_writer, target_file)
+ return input_compiled_file, target_compiled_file
+
+
+def write_file(writer, filename):
+ """Write all of lines from file using the writer."""
+ for line in txt_line_iterator(filename):
+ writer.write(line)
+ writer.write("\n")
+
+
+###############################################################################
+# Data preprocessing
+###############################################################################
+def encode_and_save_files(
+ subtokenizer, data_dir, raw_files, tag, total_shards):
+ """Save data from files as encoded Examples in TFrecord format.
+
+ Args:
+ subtokenizer: Subtokenizer object that will be used to encode the strings.
+ data_dir: The directory in which to write the examples
+ raw_files: A tuple of (input, target) data files. Each line in the input and
+ the corresponding line in target file will be saved in a tf.Example.
+ tag: String that will be added onto the file names.
+ total_shards: Number of files to divide the data into.
+
+ Returns:
+ List of all files produced.
+ """
+ # Create a file for each shard.
+ filepaths = [shard_filename(data_dir, tag, n + 1, total_shards)
+ for n in range(total_shards)]
+
+ if all_exist(filepaths):
+ logging.info("Files with tag %s already exist." % tag)
+ return filepaths
+
+ logging.info("Saving files with tag %s." % tag)
+ input_file = raw_files[0]
+ target_file = raw_files[1]
+
+ # Write examples to each shard in round robin order.
+ tmp_filepaths = [six.ensure_str(fname) + ".incomplete" for fname in filepaths]
+ writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
+ counter, shard = 0, 0
+ for counter, (input_line, target_line) in enumerate(zip(
+ txt_line_iterator(input_file), txt_line_iterator(target_file))):
+ if counter > 0 and counter % 100000 == 0:
+ logging.info("\tSaving case %d." % counter)
+ example = dict_to_example(
+ {"inputs": subtokenizer.encode(input_line, add_eos=True),
+ "targets": subtokenizer.encode(target_line, add_eos=True)})
+ writers[shard].write(example.SerializeToString())
+ shard = (shard + 1) % total_shards
+ for writer in writers:
+ writer.close()
+
+ for tmp_name, final_name in zip(tmp_filepaths, filepaths):
+ tf.gfile.Rename(tmp_name, final_name)
+
+ logging.info("Saved %d Examples", counter + 1)
+ return filepaths
+
+
+def shard_filename(path, tag, shard_num, total_shards):
+ """Create filename for data shard."""
+ return os.path.join(
+ path, "%s-%s-%.5d-of-%.5d" % (_PREFIX, tag, shard_num, total_shards))
+
+
+def shuffle_records(fname):
+ """Shuffle records in a single file."""
+ logging.info("Shuffling records in file %s" % fname)
+
+ # Rename file prior to shuffling
+ tmp_fname = six.ensure_str(fname) + ".unshuffled"
+ tf.gfile.Rename(fname, tmp_fname)
+
+ reader = tf.io.tf_record_iterator(tmp_fname)
+ records = []
+ for record in reader:
+ records.append(record)
+ if len(records) % 100000 == 0:
+ logging.info("\tRead: %d", len(records))
+
+ random.shuffle(records)
+
+ # Write shuffled records to original file name
+ with tf.python_io.TFRecordWriter(fname) as w:
+ for count, record in enumerate(records):
+ w.write(record)
+ if count > 0 and count % 100000 == 0:
+ logging.info("\tWriting record: %d" % count)
+
+ tf.gfile.Remove(tmp_fname)
+
+
+def dict_to_example(dictionary):
+ """Converts a dictionary of string->int to a tf.Example."""
+ features = {}
+ for k, v in six.iteritems(dictionary):
+ features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
+ return tf.train.Example(features=tf.train.Features(feature=features))
+
+
+def all_exist(filepaths):
+ """Returns true if all files in the list exist."""
+ for fname in filepaths:
+ if not tf.gfile.Exists(fname):
+ return False
+ return True
+
+
+def make_dir(path):
+ if not tf.gfile.Exists(path):
+ logging.info("Creating directory %s" % path)
+ tf.gfile.MakeDirs(path)
+
+
+def main(unused_argv):
+ """Obtain training and evaluation data for the Transformer model."""
+ make_dir(FLAGS.raw_dir)
+ make_dir(FLAGS.data_dir)
+
+ # Download test_data
+ logging.info("Step 1/5: Downloading test data")
+ get_raw_files(FLAGS.data_dir, _TEST_DATA_SOURCES)
+
+ # Get paths of download/extracted training and evaluation files.
+ logging.info("Step 2/5: Downloading data from source")
+ train_files = get_raw_files(FLAGS.raw_dir, _TRAIN_DATA_SOURCES)
+ eval_files = get_raw_files(FLAGS.raw_dir, _EVAL_DATA_SOURCES)
+
+ # Create subtokenizer based on the training files.
+ logging.info("Step 3/5: Creating subtokenizer and building vocabulary")
+ train_files_flat = train_files["inputs"] + train_files["targets"]
+ vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE)
+ subtokenizer = tokenizer.Subtokenizer.init_from_files(
+ vocab_file, train_files_flat, _TARGET_VOCAB_SIZE, _TARGET_THRESHOLD,
+ min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT)
+
+ logging.info("Step 4/5: Compiling training and evaluation data")
+ compiled_train_files = compile_files(FLAGS.raw_dir, train_files, _TRAIN_TAG)
+ compiled_eval_files = compile_files(FLAGS.raw_dir, eval_files, _EVAL_TAG)
+
+ # Tokenize and save data as Examples in the TFRecord format.
+ logging.info("Step 5/5: Preprocessing and saving data")
+ train_tfrecord_files = encode_and_save_files(
+ subtokenizer, FLAGS.data_dir, compiled_train_files, _TRAIN_TAG,
+ _TRAIN_SHARDS)
+ encode_and_save_files(
+ subtokenizer, FLAGS.data_dir, compiled_eval_files, _EVAL_TAG,
+ _EVAL_SHARDS)
+
+ for fname in train_tfrecord_files:
+ shuffle_records(fname)
+
+
+def define_data_download_flags():
+ """Add flags specifying data download arguments."""
+ flags.DEFINE_string(
+ name="data_dir", short_name="dd", default="/tmp/translate_ende",
+ help=flags_core.help_wrap(
+ "Directory for where the translate_ende_wmt32k dataset is saved."))
+ flags.DEFINE_string(
+ name="raw_dir", short_name="rd", default="/tmp/translate_ende_raw",
+ help=flags_core.help_wrap(
+ "Path where the raw data will be downloaded and extracted."))
+ flags.DEFINE_bool(
+ name="search", default=False,
+ help=flags_core.help_wrap(
+ "If set, use binary search to find the vocabulary set with size"
+ "closest to the target size (%d)." % _TARGET_VOCAB_SIZE))
+
+
+if __name__ == "__main__":
+ logging.set_verbosity(logging.INFO)
+ define_data_download_flags()
+ FLAGS = flags.FLAGS
+ absl_app.run(main)
diff --git a/models/official/nlp/transformer/data_pipeline.py b/models/official/nlp/transformer/data_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedd2c309d3194a07841610f8f1039a1a1e7ac51
--- /dev/null
+++ b/models/official/nlp/transformer/data_pipeline.py
@@ -0,0 +1,316 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Input pipeline for the transformer model to read, filter, and batch examples.
+
+Two things to note in the pipeline:
+
+1. Batching scheme
+
+ The examples encoded in the TFRecord files contain data in the format:
+ {"inputs": [variable length array of integers],
+ "targets": [variable length array of integers]}
+ Where integers in the arrays refer to tokens in the English and German vocab
+ file (named `vocab.ende.32768`).
+
+ Prior to batching, elements in the dataset are grouped by length (max between
+ "inputs" and "targets" length). Each group is then batched such that:
+ group_batch_size * length <= batch_size.
+
+ Another way to view batch_size is the maximum number of tokens in each batch.
+
+ Once batched, each element in the dataset will have the shape:
+ {"inputs": [group_batch_size, padded_input_length],
+ "targets": [group_batch_size, padded_target_length]}
+ Lengths are padded to the longest "inputs" or "targets" sequence in the batch
+ (padded_input_length and padded_target_length can be different).
+
+ This batching scheme decreases the fraction of padding tokens per training
+ batch, thus improving the training speed significantly.
+
+2. Shuffling
+
+ While training, the dataset is shuffled in two places in the code. The first
+ is the list of training files. Second, while reading records using
+ `parallel_interleave`, the `sloppy` argument is used to generate randomness
+ in the order of the examples.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import logging
+import tensorflow as tf
+
+from official.utils.misc import model_helpers
+
+# Buffer size for reading records from a TFRecord file. Each training file is
+# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
+_READ_RECORD_BUFFER = 8 * 1000 * 1000
+
+# Example grouping constants. Defines length boundaries for each group.
+# These values are the defaults used in Tensor2Tensor.
+_MIN_BOUNDARY = 8
+_BOUNDARY_SCALE = 1.1
+
+
+def _load_records(filename):
+ """Read file and return a dataset of tf.Examples."""
+ return tf.data.TFRecordDataset(filename, buffer_size=_READ_RECORD_BUFFER)
+
+
+def _parse_example(serialized_example):
+ """Return inputs and targets Tensors from a serialized tf.Example."""
+ data_fields = {
+ "inputs": tf.io.VarLenFeature(tf.int64),
+ "targets": tf.io.VarLenFeature(tf.int64)
+ }
+ parsed = tf.io.parse_single_example(serialized_example, data_fields)
+ inputs = tf.sparse.to_dense(parsed["inputs"])
+ targets = tf.sparse.to_dense(parsed["targets"])
+ return inputs, targets
+
+
+def _filter_max_length(example, max_length=256):
+ """Indicates whether the example's length is lower than the maximum length."""
+ return tf.logical_and(tf.size(example[0]) <= max_length,
+ tf.size(example[1]) <= max_length)
+
+
+def _get_example_length(example):
+ """Returns the maximum length between the example inputs and targets."""
+ length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
+ return length
+
+
+def _create_min_max_boundaries(
+ max_length, min_boundary=_MIN_BOUNDARY, boundary_scale=_BOUNDARY_SCALE):
+ """Create min and max boundary lists up to max_length.
+
+ For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
+ returned values will be:
+ buckets_min = [0, 4, 8, 16, 24]
+ buckets_max = [4, 8, 16, 24, 25]
+
+ Args:
+ max_length: The maximum length of example in dataset.
+ min_boundary: Minimum length in boundary.
+ boundary_scale: Amount to scale consecutive boundaries in the list.
+
+ Returns:
+ min and max boundary lists
+
+ """
+ # Create bucket boundaries list by scaling the previous boundary or adding 1
+ # (to ensure increasing boundary sizes).
+ bucket_boundaries = []
+ x = min_boundary
+ while x < max_length:
+ bucket_boundaries.append(x)
+ x = max(x + 1, int(x * boundary_scale))
+
+ # Create min and max boundary lists from the initial list.
+ buckets_min = [0] + bucket_boundaries
+ buckets_max = bucket_boundaries + [max_length + 1]
+ return buckets_min, buckets_max
+
+
+def _batch_examples(dataset, batch_size, max_length):
+ """Group examples by similar lengths, and return batched dataset.
+
+ Each batch of similar-length examples are padded to the same length, and may
+ have different number of elements in each batch, such that:
+ group_batch_size * padded_length <= batch_size.
+
+ This decreases the number of padding tokens per batch, which improves the
+ training speed.
+
+ Args:
+ dataset: Dataset of unbatched examples.
+ batch_size: Max number of tokens per batch of examples.
+ max_length: Max number of tokens in an example input or target sequence.
+
+ Returns:
+ Dataset of batched examples with similar lengths.
+ """
+ # Get min and max boundary lists for each example. These are used to calculate
+ # the `bucket_id`, which is the index at which:
+ # buckets_min[bucket_id] <= len(example) < buckets_max[bucket_id]
+ # Note that using both min and max lists improves the performance.
+ buckets_min, buckets_max = _create_min_max_boundaries(max_length)
+
+ # Create list of batch sizes for each bucket_id, so that
+ # bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
+ bucket_batch_sizes = [int(batch_size) // x for x in buckets_max]
+ # bucket_id will be a tensor, so convert this list to a tensor as well.
+ bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
+
+ def example_to_bucket_id(example_input, example_target):
+ """Return int64 bucket id for this example, calculated based on length."""
+ seq_length = _get_example_length((example_input, example_target))
+
+ # TODO(xunkai): investigate if removing code branching improves performance.
+ conditions_c = tf.logical_and(
+ tf.less_equal(buckets_min, seq_length),
+ tf.less(seq_length, buckets_max))
+ bucket_id = tf.reduce_min(tf.where(conditions_c))
+ return bucket_id
+
+ def window_size_fn(bucket_id):
+ """Return number of examples to be grouped when given a bucket id."""
+ return bucket_batch_sizes[bucket_id]
+
+ def batching_fn(bucket_id, grouped_dataset):
+ """Batch and add padding to a dataset of elements with similar lengths."""
+ bucket_batch_size = window_size_fn(bucket_id)
+
+ # Batch the dataset and add padding so that all input sequences in the
+ # examples have the same length, and all target sequences have the same
+ # lengths as well. Resulting lengths of inputs and targets can differ.
+ return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
+
+ return dataset.apply(tf.data.experimental.group_by_window(
+ key_func=example_to_bucket_id,
+ reduce_func=batching_fn,
+ window_size=None,
+ window_size_func=window_size_fn))
+
+
+def _read_and_batch_from_files(
+ file_pattern, batch_size, max_length, max_io_parallelism, shuffle, repeat,
+ static_batch=False, num_replicas=1, ctx=None):
+ """Create dataset where each item is a dict of "inputs" and "targets".
+
+ Args:
+ file_pattern: String used to match the input TFRecord files.
+ batch_size: Maximum number of tokens per global batch of examples.
+ max_length: Maximum number of tokens per example
+ max_io_parallelism: Max number of cpu cores for parallel input processing.
+ shuffle: If true, randomizes order of elements.
+ repeat: Number of times to repeat the dataset. If None, the dataset is
+ repeated forever.
+ static_batch: Whether the batches in the dataset should have static shapes.
+ If True, the input is batched so that every batch has the
+ shape [batch_size // max_length, max_length]. If False, the input is
+ grouped by length, and batched so that batches may have different
+ shapes [N, M], where:
+ N * M <= batch_size
+ M <= max_length
+ In general, this setting should be False. Dynamic shapes allow the inputs
+ to be grouped so that the number of padding tokens is minimized, and helps
+ model training. In cases where the input shape must be static
+ (e.g. running on TPU), this setting should be set to True.
+ num_replicas: Number of GPUs or other workers. We will generate global
+ batches, and each global batch is equally divisible by number of replicas.
+ Currently it is only effective when static_batch==True. TODO: make it
+ effective when static_batch=False.
+ ctx: Input context.
+
+ Returns:
+ tf.data.Dataset object containing examples loaded from the files.
+ """
+ dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)
+
+ if ctx and ctx.num_input_pipelines > 1:
+ logging.info("Shard %d of the dataset.", ctx.input_pipeline_id)
+ dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
+
+ # Read files and interleave results. When training, the order of the examples
+ # will be non-deterministic.
+ options = tf.data.Options()
+ options.experimental_deterministic = False
+ dataset = dataset.interleave(
+ _load_records,
+ cycle_length=max_io_parallelism,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options)
+
+ # Parse each tf.Example into a dictionary
+ # TODO: Look into prefetch_input_elements for performance optimization.
+ dataset = dataset.map(_parse_example,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ # Remove examples where the input or target length exceeds the maximum length,
+ dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
+
+ if static_batch:
+ dataset = dataset.padded_batch(
+ # First calculate batch size (token number) per worker, then divide it
+ # into sentences, and finally expand to a global batch. It could prove
+ # the global batch divisble for distribution strategy.
+ int(batch_size // num_replicas // max_length * num_replicas),
+ ([max_length], [max_length]), drop_remainder=True)
+ else:
+ # Group and batch such that each batch has examples of similar length.
+ # TODO(xunkai): _batch_examples might need to do something special for
+ # num_replicas.
+ dataset = _batch_examples(dataset, batch_size, max_length)
+
+ dataset = dataset.repeat(repeat)
+
+ # Prefetch the next element to improve speed of input pipeline.
+ dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def _generate_synthetic_data(params):
+ """Create synthetic data based on the parameter batch size."""
+ batch_size = int(params["batch_size"] // params["max_length"])
+ length = params["max_length"]
+ dataset = model_helpers.generate_synthetic_data(
+ input_shape=tf.TensorShape([length]),
+ input_value=1,
+ input_dtype=tf.int64,
+ label_shape=tf.TensorShape([length]),
+ label_value=1,
+ label_dtype=tf.int64,
+ )
+ if params["static_batch"]:
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ else:
+ dataset = dataset.padded_batch(batch_size, ([None], [None]))
+ return dataset
+
+
+def train_input_fn(params, ctx=None):
+ """Load and return dataset of batched examples for use during training."""
+ file_pattern = os.path.join(params["data_dir"] or "", "*train*")
+ if params["use_synthetic_data"]:
+ return _generate_synthetic_data(params)
+ return _read_and_batch_from_files(
+ file_pattern, params["batch_size"], params["max_length"],
+ params["max_io_parallelism"], shuffle=True,
+ repeat=params["repeat_dataset"], static_batch=params["static_batch"],
+ num_replicas=params["num_gpus"], ctx=ctx)
+
+
+def eval_input_fn(params, ctx=None):
+ """Load and return dataset of batched examples for use during evaluation."""
+ file_pattern = os.path.join(params["data_dir"] or "", "*dev*")
+ if params["use_synthetic_data"]:
+ return _generate_synthetic_data(params)
+ return _read_and_batch_from_files(
+ file_pattern, params["batch_size"], params["max_length"],
+ params["max_io_parallelism"], shuffle=False, repeat=1,
+ static_batch=params["static_batch"], num_replicas=params["num_gpus"],
+ ctx=ctx)
+
+
+def map_data_for_transformer_fn(x, y):
+ """Maps data for training, and handles weried behaviors for different vers."""
+ # Will transform input x and targets y into tuple(x, y) as new model inputs.
+ # For TF v2, the 2nd parameter is omitted to make Keras training work.
+ return ((x, y),)
diff --git a/models/official/nlp/transformer/embedding_layer.py b/models/official/nlp/transformer/embedding_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6694e2b42af47673ee3ce0b9572ec5867d69cb7d
--- /dev/null
+++ b/models/official/nlp/transformer/embedding_layer.py
@@ -0,0 +1,103 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of embedding layer with shared weights."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class EmbeddingSharedWeights(tf.keras.layers.Layer):
+ """Calculates input embeddings and pre-softmax linear with shared weights."""
+
+ def __init__(self, vocab_size, hidden_size):
+ """Specify characteristic parameters of embedding layer.
+
+ Args:
+ vocab_size: Number of tokens in the embedding. (Typically ~32,000)
+ hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
+ """
+ super(EmbeddingSharedWeights, self).__init__()
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+
+ def build(self, input_shape):
+ """Build embedding layer."""
+ with tf.name_scope("embedding_and_softmax"):
+ # Create and initialize weights. The random normal initializer was chosen
+ # arbitrarily, and works well.
+ self.shared_weights = self.add_weight(
+ "weights",
+ shape=[self.vocab_size, self.hidden_size],
+ initializer=tf.random_normal_initializer(
+ mean=0., stddev=self.hidden_size**-0.5))
+ super(EmbeddingSharedWeights, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "vocab_size": self.vocab_size,
+ "hidden_size": self.hidden_size,
+ }
+
+ def call(self, inputs, mode="embedding"):
+ """Get token embeddings of inputs.
+
+ Args:
+ inputs: An int64 tensor with shape [batch_size, length]
+ mode: string, a valid value is one of "embedding" and "linear".
+ Returns:
+ outputs: (1) If mode == "embedding", output embedding tensor, float32 with
+ shape [batch_size, length, embedding_size]; (2) mode == "linear", output
+ linear tensor, float32 with shape [batch_size, length, vocab_size].
+ Raises:
+ ValueError: if mode is not valid.
+ """
+ if mode == "embedding":
+ return self._embedding(inputs)
+ elif mode == "linear":
+ return self._linear(inputs)
+ else:
+ raise ValueError("mode {} is not valid.".format(mode))
+
+ def _embedding(self, inputs):
+ """Applies embedding based on inputs tensor."""
+ with tf.name_scope("embedding"):
+ # Create binary mask of size [batch_size, length]
+ embeddings = tf.gather(self.shared_weights, inputs)
+ mask = tf.cast(tf.not_equal(inputs, 0), embeddings.dtype)
+ embeddings *= tf.expand_dims(mask, -1)
+ # Scale embedding by the sqrt of the hidden size
+ embeddings *= self.hidden_size ** 0.5
+
+ return embeddings
+
+ def _linear(self, inputs):
+ """Computes logits by running inputs through a linear layer.
+
+ Args:
+ inputs: A float32 tensor with shape [batch_size, length, hidden_size]
+ Returns:
+ float32 tensor with shape [batch_size, length, vocab_size].
+ """
+ with tf.name_scope("presoftmax_linear"):
+ batch_size = tf.shape(inputs)[0]
+ length = tf.shape(inputs)[1]
+
+ x = tf.reshape(inputs, [-1, self.hidden_size])
+ logits = tf.matmul(x, self.shared_weights, transpose_b=True)
+
+ return tf.reshape(logits, [batch_size, length, self.vocab_size])
diff --git a/models/official/nlp/transformer/ffn_layer.py b/models/official/nlp/transformer/ffn_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7785f27dd0c3fed01c514d052749dcafd163605
--- /dev/null
+++ b/models/official/nlp/transformer/ffn_layer.py
@@ -0,0 +1,77 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of fully connected network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class FeedForwardNetwork(tf.keras.layers.Layer):
+ """Fully connected feedforward network."""
+
+ def __init__(self, hidden_size, filter_size, relu_dropout):
+ """Initialize FeedForwardNetwork.
+
+ Args:
+ hidden_size: int, output dim of hidden layer.
+ filter_size: int, filter size for the inner (first) dense layer.
+ relu_dropout: float, dropout rate for training.
+ """
+ super(FeedForwardNetwork, self).__init__()
+ self.hidden_size = hidden_size
+ self.filter_size = filter_size
+ self.relu_dropout = relu_dropout
+
+ def build(self, input_shape):
+ self.filter_dense_layer = tf.keras.layers.Dense(
+ self.filter_size,
+ use_bias=True,
+ activation=tf.nn.relu,
+ name="filter_layer")
+ self.output_dense_layer = tf.keras.layers.Dense(
+ self.hidden_size, use_bias=True, name="output_layer")
+ super(FeedForwardNetwork, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "hidden_size": self.hidden_size,
+ "filter_size": self.filter_size,
+ "relu_dropout": self.relu_dropout,
+ }
+
+ def call(self, x, training):
+ """Return outputs of the feedforward network.
+
+ Args:
+ x: tensor with shape [batch_size, length, hidden_size]
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ Output of the feedforward network.
+ tensor with shape [batch_size, length, hidden_size]
+ """
+ # Retrieve dynamically known shapes
+ batch_size = tf.shape(x)[0]
+ length = tf.shape(x)[1]
+
+ output = self.filter_dense_layer(x)
+ if training:
+ output = tf.nn.dropout(output, rate=self.relu_dropout)
+ output = self.output_dense_layer(output)
+
+ return output
diff --git a/models/official/nlp/transformer/metrics.py b/models/official/nlp/transformer/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bd6bba6e6862d643c6cb9bb9fb857b70b3cc00f
--- /dev/null
+++ b/models/official/nlp/transformer/metrics.py
@@ -0,0 +1,183 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions for calculating loss, accuracy, and other model metrics.
+
+Metrics:
+ - Padded loss, accuracy, and negative log perplexity. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/metrics.py
+ - BLEU approximation. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
+ - ROUGE score. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf
+
+
+def _pad_tensors_to_same_length(x, y):
+ """Pad x and y so that the results have the same length (second dimension)."""
+ with tf.name_scope("pad_to_same_length"):
+ x_length = tf.shape(x)[1]
+ y_length = tf.shape(y)[1]
+
+ max_length = tf.maximum(x_length, y_length)
+
+ x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
+ y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
+ return x, y
+
+
+def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
+ """Calculate cross entropy loss while ignoring padding.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch_size, length_labels]
+ smoothing: Label smoothing constant, used to determine the on and off values
+ vocab_size: int size of the vocabulary
+
+ Returns:
+ Returns the cross entropy loss and weight tensors: float32 tensors with
+ shape [batch_size, max(length_logits, length_labels)]
+ """
+ with tf.name_scope("loss"):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+
+ # Calculate smoothing cross entropy
+ with tf.name_scope("smoothing_cross_entropy"):
+ confidence = 1.0 - smoothing
+ low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32)
+ soft_targets = tf.one_hot(
+ tf.cast(labels, tf.int32),
+ depth=vocab_size,
+ on_value=confidence,
+ off_value=low_confidence)
+ xentropy = tf.nn.softmax_cross_entropy_with_logits(
+ logits=logits, labels=soft_targets)
+
+ # Calculate the best (lowest) possible value of cross entropy, and
+ # subtract from the cross entropy loss.
+ normalizing_constant = -(
+ confidence * tf.math.log(confidence) +
+ tf.cast(vocab_size - 1, tf.float32) * low_confidence *
+ tf.math.log(low_confidence + 1e-20))
+ xentropy -= normalizing_constant
+
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ return xentropy * weights, weights
+
+
+def padded_accuracy(logits, labels):
+ """Percentage of times that predictions matches labels on non-0s."""
+ with tf.name_scope("padded_accuracy"):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ return tf.cast(tf.equal(outputs, padded_labels), tf.float32), weights
+
+
+def padded_accuracy_topk(logits, labels, k):
+ """Percentage of times that top-k predictions matches labels on non-0s."""
+ with tf.name_scope("padded_accuracy_topk"):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ effective_k = tf.minimum(k, tf.shape(logits)[-1])
+ _, outputs = tf.nn.top_k(logits, k=effective_k)
+ outputs = tf.cast(outputs, tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ padded_labels = tf.expand_dims(padded_labels, axis=-1)
+ padded_labels += tf.zeros_like(outputs) # Pad to same shape.
+ same = tf.cast(tf.equal(outputs, padded_labels), tf.float32)
+ same_topk = tf.reduce_sum(same, axis=-1)
+ return same_topk, weights
+
+
+def padded_accuracy_top5(logits, labels):
+ return padded_accuracy_topk(logits, labels, 5)
+
+
+def padded_sequence_accuracy(logits, labels):
+ """Percentage of times that predictions matches labels everywhere (non-0)."""
+ with tf.name_scope("padded_sequence_accuracy"):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ not_correct = tf.cast(tf.not_equal(outputs, padded_labels),
+ tf.float32) * weights
+ axis = list(range(1, len(outputs.get_shape())))
+ correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
+ return correct_seq, tf.constant(1.0)
+
+
+def padded_neg_log_perplexity(logits, labels, vocab_size):
+ """Average log-perplexity excluding padding 0s. No smoothing."""
+ num, den = padded_cross_entropy_loss(logits, labels, 0, vocab_size)
+ return -num, den
+
+
+class MetricLayer(tf.keras.layers.Layer):
+ """Custom a layer of metrics for Transformer model."""
+
+ def __init__(self, vocab_size):
+ super(MetricLayer, self).__init__()
+ self.vocab_size = vocab_size
+ self.metric_mean_fns = []
+
+ def build(self, input_shape):
+ """"Builds metric layer."""
+ neg_log_perplexity = functools.partial(
+ padded_neg_log_perplexity, vocab_size=self.vocab_size)
+ self.metric_mean_fns = [
+ (tf.keras.metrics.Mean("accuracy"), padded_accuracy),
+ (tf.keras.metrics.Mean("accuracy_top5"), padded_accuracy_top5),
+ (tf.keras.metrics.Mean("accuracy_per_sequence"),
+ padded_sequence_accuracy),
+ (tf.keras.metrics.Mean("neg_log_perplexity"), neg_log_perplexity),
+ ]
+ super(MetricLayer, self).build(input_shape)
+
+ def get_config(self):
+ return {"vocab_size": self.vocab_size}
+
+ def call(self, inputs):
+ logits, targets = inputs[0], inputs[1]
+ for mean, fn in self.metric_mean_fns:
+ m = mean(*fn(logits, targets))
+ self.add_metric(m)
+ return logits
+
+
+def transformer_loss(logits, labels, smoothing, vocab_size):
+ """Calculates total loss containing cross entropy with padding ignored.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch_size, length_labels]
+ smoothing: Label smoothing constant, used to determine the on and off values
+ vocab_size: int size of the vocabulary
+
+ Returns:
+ A scalar float tensor for loss.
+ """
+ xentropy, weights = padded_cross_entropy_loss(logits, labels, smoothing,
+ vocab_size)
+ return tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
diff --git a/models/official/nlp/transformer/misc.py b/models/official/nlp/transformer/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b351ae652b7f644c8d598aef67b188ced01d68
--- /dev/null
+++ b/models/official/nlp/transformer/misc.py
@@ -0,0 +1,260 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Misc for Transformer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=g-bad-import-order
+from absl import flags
+import tensorflow as tf
+
+from official.nlp.transformer import model_params
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+
+PARAMS_MAP = {
+ 'tiny': model_params.TINY_PARAMS,
+ 'base': model_params.BASE_PARAMS,
+ 'big': model_params.BIG_PARAMS,
+}
+
+
+def get_model_params(param_set, num_gpus):
+ """Gets predefined model params."""
+ if num_gpus > 1:
+ if param_set == 'big':
+ return model_params.BIG_MULTI_GPU_PARAMS.copy()
+ elif param_set == 'base':
+ return model_params.BASE_MULTI_GPU_PARAMS.copy()
+ else:
+ raise ValueError('Not valid params: param_set={} num_gpus={}'.format(
+ param_set, num_gpus))
+
+ return PARAMS_MAP[param_set].copy()
+
+
+def define_transformer_flags():
+ """Add flags and flag validators for running transformer_main."""
+ # Add common flags (data_dir, model_dir, etc.).
+ flags_core.define_base(num_gpu=True, distribution_strategy=True)
+ flags_core.define_performance(
+ num_parallel_calls=True,
+ inter_op=False,
+ intra_op=False,
+ synthetic_data=True,
+ max_train_steps=False,
+ dtype=True,
+ loss_scale=True,
+ all_reduce_alg=True,
+ num_packs=True,
+ tf_gpu_thread_mode=True,
+ datasets_num_private_threads=True,
+ enable_xla=True,
+ fp16_implementation=True
+ )
+
+ flags_core.define_benchmark()
+ flags_core.define_device(tpu=True)
+
+ flags.DEFINE_integer(
+ name='train_steps', short_name='ts', default=300000,
+ help=flags_core.help_wrap('The number of steps used to train.'))
+ flags.DEFINE_integer(
+ name='steps_between_evals', short_name='sbe', default=5000,
+ help=flags_core.help_wrap(
+ 'The Number of training steps to run between evaluations. This is '
+ 'used if --train_steps is defined.'))
+ flags.DEFINE_boolean(
+ name='enable_time_history', default=True,
+ help='Whether to enable TimeHistory callback.')
+ flags.DEFINE_boolean(
+ name='enable_tensorboard', default=False,
+ help='Whether to enable Tensorboard callback.')
+ flags.DEFINE_boolean(
+ name='enable_metrics_in_training', default=False,
+ help='Whether to enable metrics during training.')
+ flags.DEFINE_boolean(
+ name='enable_mlir_bridge',
+ default=False,
+ help='Whether to enable the TF to XLA bridge.')
+ # Set flags from the flags_core module as 'key flags' so they're listed when
+ # the '-h' flag is used. Without this line, the flags defined above are
+ # only shown in the full `--helpful` help text.
+ flags.adopt_module_key_flags(flags_core)
+
+ # Add transformer-specific flags
+ flags.DEFINE_enum(
+ name='param_set', short_name='mp', default='big',
+ enum_values=PARAMS_MAP.keys(),
+ help=flags_core.help_wrap(
+ 'Parameter set to use when creating and training the model. The '
+ 'parameters define the input shape (batch size and max length), '
+ 'model configuration (size of embedding, # of hidden layers, etc.), '
+ 'and various other settings. The big parameter set increases the '
+ 'default batch size, embedding/hidden size, and filter size. For a '
+ 'complete list of parameters, please see model/model_params.py.'))
+
+ flags.DEFINE_bool(
+ name='static_batch', short_name='sb', default=False,
+ help=flags_core.help_wrap(
+ 'Whether the batches in the dataset should have static shapes. In '
+ 'general, this setting should be False. Dynamic shapes allow the '
+ 'inputs to be grouped so that the number of padding tokens is '
+ 'minimized, and helps model training. In cases where the input shape '
+ 'must be static (e.g. running on TPU), this setting will be ignored '
+ 'and static batching will always be used.'))
+ flags.DEFINE_integer(
+ name='max_length', short_name='ml', default=256,
+ help=flags_core.help_wrap(
+ 'Max sentence length for Transformer. Default is 256. Note: Usually '
+ 'it is more effective to use a smaller max length if static_batch is '
+ 'enabled, e.g. 64.'))
+
+ # Flags for training with steps (may be used for debugging)
+ flags.DEFINE_integer(
+ name='validation_steps', short_name='vs', default=64,
+ help=flags_core.help_wrap('The number of steps used in validation.'))
+
+ # BLEU score computation
+ flags.DEFINE_string(
+ name='bleu_source', short_name='bls', default=None,
+ help=flags_core.help_wrap(
+ 'Path to source file containing text translate when calculating the '
+ 'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
+ ))
+ flags.DEFINE_string(
+ name='bleu_ref', short_name='blr', default=None,
+ help=flags_core.help_wrap(
+ 'Path to source file containing text translate when calculating the '
+ 'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
+ ))
+ flags.DEFINE_string(
+ name='vocab_file', short_name='vf', default=None,
+ help=flags_core.help_wrap(
+ 'Path to subtoken vocabulary file. If data_download.py was used to '
+ 'download and encode the training data, look in the data_dir to find '
+ 'the vocab file.'))
+ flags.DEFINE_string(
+ name='mode', default='train',
+ help=flags_core.help_wrap('mode: train, eval, or predict'))
+ flags.DEFINE_bool(
+ name='use_ctl',
+ default=False,
+ help=flags_core.help_wrap(
+ 'Whether the model runs with custom training loop.'))
+ flags.DEFINE_integer(
+ name='decode_batch_size',
+ default=32,
+ help=flags_core.help_wrap(
+ 'Global batch size used for Transformer autoregressive decoding on '
+ 'TPU.'))
+ flags.DEFINE_integer(
+ name='decode_max_length',
+ default=97,
+ help=flags_core.help_wrap(
+ 'Max sequence length of the decode/eval data. This is used by '
+ 'Transformer autoregressive decoding on TPU to have minimum '
+ 'paddings.'))
+ flags.DEFINE_bool(
+ name='padded_decode',
+ default=False,
+ help=flags_core.help_wrap(
+ 'Whether the autoregressive decoding runs with input data padded to '
+ 'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
+ 'set due the static shape requirement. Although CPU/GPU could also '
+ 'use padded_decode, it has not been tested. In addition, this method '
+ 'will introduce unnecessary overheads which grow quadratically with '
+ 'the max sequence length.'))
+ flags.DEFINE_bool(
+ name='enable_checkpointing',
+ default=True,
+ help=flags_core.help_wrap(
+ 'Whether to do checkpointing during training. When running under '
+ 'benchmark harness, we will avoid checkpointing.'))
+
+ flags_core.set_defaults(data_dir='/tmp/translate_ende',
+ model_dir='/tmp/transformer_model',
+ batch_size=None)
+
+ # pylint: disable=unused-variable
+ @flags.multi_flags_validator(
+ ['bleu_source', 'bleu_ref'],
+ message='Both or neither --bleu_source and --bleu_ref must be defined.')
+ def _check_bleu_files(flags_dict):
+ return (flags_dict['bleu_source'] is None) == (
+ flags_dict['bleu_ref'] is None)
+
+ @flags.multi_flags_validator(
+ ['bleu_source', 'bleu_ref', 'vocab_file'],
+ message='--vocab_file must be defined if --bleu_source and --bleu_ref '
+ 'are defined.')
+ def _check_bleu_vocab_file(flags_dict):
+ if flags_dict['bleu_source'] and flags_dict['bleu_ref']:
+ return flags_dict['vocab_file'] is not None
+ return True
+ # pylint: enable=unused-variable
+
+
+def get_callbacks():
+ """Returns common callbacks."""
+ callbacks = []
+ if FLAGS.enable_time_history:
+ time_callback = keras_utils.TimeHistory(
+ FLAGS.batch_size,
+ FLAGS.log_steps,
+ logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
+ callbacks.append(time_callback)
+
+ if FLAGS.enable_tensorboard:
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(
+ log_dir=FLAGS.model_dir)
+ callbacks.append(tensorboard_callback)
+
+ return callbacks
+
+
+def update_stats(history, stats, callbacks):
+ """Normalizes and updates dictionary of stats.
+
+ Args:
+ history: Results of the training step.
+ stats: Dict with pre-existing training stats.
+ callbacks: a list of callbacks which might include a time history callback
+ used during keras.fit.
+ """
+
+ if history and history.history:
+ train_hist = history.history
+ # Gets final loss from training.
+ stats['loss'] = float(train_hist['loss'][-1])
+
+ if not callbacks:
+ return
+
+ # Look for the time history callback which was used during keras.fit
+ for callback in callbacks:
+ if isinstance(callback, keras_utils.TimeHistory):
+ timestamp_log = callback.timestamp_log
+ stats['step_timestamp_log'] = timestamp_log
+ stats['train_finish_time'] = callback.train_finish_time
+ if len(timestamp_log) > 1:
+ stats['avg_exp_per_second'] = (
+ callback.batch_size * callback.log_steps *
+ (len(callback.timestamp_log)-1) /
+ (timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
diff --git a/models/official/nlp/transformer/model_params.py b/models/official/nlp/transformer/model_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..e978abeafca5a627c698f291432f24119ae3fa68
--- /dev/null
+++ b/models/official/nlp/transformer/model_params.py
@@ -0,0 +1,96 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines Transformer model parameters."""
+
+from collections import defaultdict
+
+
+BASE_PARAMS = defaultdict(
+ lambda: None, # Set default value to None.
+
+ # Input params
+ default_batch_size=2048, # Maximum number of tokens per batch of examples.
+ default_batch_size_tpu=32768,
+ max_length=256, # Maximum number of tokens per example.
+
+ # Model params
+ initializer_gain=1.0, # Used in trainable variable initialization.
+ vocab_size=33708, # Number of tokens defined in the vocabulary file.
+ hidden_size=512, # Model dimension in the hidden layers.
+ num_hidden_layers=6, # Number of layers in the encoder and decoder stacks.
+ num_heads=8, # Number of heads to use in multi-headed attention.
+ filter_size=2048, # Inner layer dimension in the feedforward network.
+
+ # Dropout values (only used when training)
+ layer_postprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+
+ # Training params
+ label_smoothing=0.1,
+ learning_rate=2.0,
+ learning_rate_decay_rate=1.0,
+ learning_rate_warmup_steps=16000,
+
+ # Optimizer params
+ optimizer_adam_beta1=0.9,
+ optimizer_adam_beta2=0.997,
+ optimizer_adam_epsilon=1e-09,
+
+ # Default prediction params
+ extra_decode_length=50,
+ beam_size=4,
+ alpha=0.6, # used to calculate length normalization in beam search
+
+ # TPU specific parameters
+ use_tpu=False,
+ static_batch=False,
+ allow_ffn_pad=True,
+)
+
+BIG_PARAMS = BASE_PARAMS.copy()
+BIG_PARAMS.update(
+ default_batch_size=4096,
+
+ # default batch size is smaller than for BASE_PARAMS due to memory limits.
+ default_batch_size_tpu=16384,
+
+ hidden_size=1024,
+ filter_size=4096,
+ num_heads=16,
+)
+
+# Parameters for running the model in multi gpu. These should not change the
+# params that modify the model shape (such as the hidden_size or num_heads).
+BASE_MULTI_GPU_PARAMS = BASE_PARAMS.copy()
+BASE_MULTI_GPU_PARAMS.update(
+ learning_rate_warmup_steps=8000
+)
+
+BIG_MULTI_GPU_PARAMS = BIG_PARAMS.copy()
+BIG_MULTI_GPU_PARAMS.update(
+ layer_postprocess_dropout=0.3,
+ learning_rate_warmup_steps=8000
+)
+
+# Parameters for testing the model
+TINY_PARAMS = BASE_PARAMS.copy()
+TINY_PARAMS.update(
+ default_batch_size=1024,
+ default_batch_size_tpu=1024,
+ hidden_size=32,
+ num_heads=4,
+ filter_size=256,
+)
diff --git a/models/official/nlp/transformer/model_utils.py b/models/official/nlp/transformer/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f860f049cd0bcf0467913c91ee6312356f3ad23
--- /dev/null
+++ b/models/official/nlp/transformer/model_utils.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Transformer model helper methods."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+import tensorflow as tf
+
+# Very low numbers to represent -infinity. We do not actually use -Inf, since we
+# want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN)
+_NEG_INF_FP32 = -1e9
+_NEG_INF_FP16 = np.finfo(np.float16).min
+
+
+def get_position_encoding(
+ length, hidden_size, min_timescale=1.0, max_timescale=1.0e4):
+ """Return positional encoding.
+
+ Calculates the position encoding as a mix of sine and cosine functions with
+ geometrically increasing wavelengths.
+ Defined and formulized in Attention is All You Need, section 3.5.
+
+ Args:
+ length: Sequence length.
+ hidden_size: Size of the
+ min_timescale: Minimum scale that will be applied at each position
+ max_timescale: Maximum scale that will be applied at each position
+
+ Returns:
+ Tensor with shape [length, hidden_size]
+ """
+ # We compute the positional encoding in float32 even if the model uses
+ # float16, as many of the ops used, like log and exp, are numerically unstable
+ # in float16.
+ position = tf.cast(tf.range(length), tf.float32)
+ num_timescales = hidden_size // 2
+ log_timescale_increment = (
+ math.log(float(max_timescale) / float(min_timescale)) /
+ (tf.cast(num_timescales, tf.float32) - 1))
+ inv_timescales = min_timescale * tf.exp(
+ tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment)
+ scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
+ signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
+ return signal
+
+
+def get_decoder_self_attention_bias(length, dtype=tf.float32):
+ """Calculate bias for decoder that maintains model's autoregressive property.
+
+ Creates a tensor that masks out locations that correspond to illegal
+ connections, so prediction at position i cannot draw information from future
+ positions.
+
+ Args:
+ length: int length of sequences in batch.
+ dtype: The dtype of the return value.
+
+ Returns:
+ float tensor of shape [1, 1, length, length]
+ """
+ neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32
+ with tf.name_scope("decoder_self_attention_bias"):
+ valid_locs = tf.linalg.band_part(tf.ones([length, length], dtype=dtype),
+ -1, 0)
+ valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
+ decoder_bias = neg_inf * (1.0 - valid_locs)
+ return decoder_bias
+
+
+def get_padding(x, padding_value=0, dtype=tf.float32):
+ """Return float tensor representing the padding values in x.
+
+ Args:
+ x: int tensor with any shape
+ padding_value: int which represents padded values in input
+ dtype: The dtype of the return value.
+
+ Returns:
+ float tensor with same shape as x containing values 0 or 1.
+ 0 -> non-padding, 1 -> padding
+ """
+ with tf.name_scope("padding"):
+ return tf.cast(tf.equal(x, padding_value), dtype)
+
+
+def get_padding_bias(x, padding_value=0, dtype=tf.float32):
+ """Calculate bias tensor from padding values in tensor.
+
+ Bias tensor that is added to the pre-softmax multi-headed attention logits,
+ which has shape [batch_size, num_heads, length, length]. The tensor is zero at
+ non-padding locations, and -1e9 (negative infinity) at padding locations.
+
+ Args:
+ x: int tensor with shape [batch_size, length]
+ padding_value: int which represents padded values in input
+ dtype: The dtype of the return value
+
+ Returns:
+ Attention bias tensor of shape [batch_size, 1, 1, length].
+ """
+ with tf.name_scope("attention_bias"):
+ padding = get_padding(x, padding_value, dtype)
+ attention_bias = padding * _NEG_INF_FP32
+ attention_bias = tf.expand_dims(
+ tf.expand_dims(attention_bias, axis=1), axis=1)
+ return attention_bias
diff --git a/models/official/nlp/transformer/model_utils_test.py b/models/official/nlp/transformer/model_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8c4a15c9aba8dbff043088a392fe415f22206ca
--- /dev/null
+++ b/models/official/nlp/transformer/model_utils_test.py
@@ -0,0 +1,62 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test Transformer model helper methods."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.nlp.transformer import model_utils
+
+NEG_INF = -1e9
+
+
+class ModelUtilsTest(tf.test.TestCase):
+
+ def test_get_padding(self):
+ x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
+ padding = model_utils.get_padding(x, padding_value=0)
+
+ self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]],
+ padding)
+
+ def test_get_padding_bias(self):
+ x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
+ bias = model_utils.get_padding_bias(x)
+ bias_shape = tf.shape(bias)
+ flattened_bias = tf.reshape(bias, [3, 5])
+
+ self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0],
+ [0, 0, NEG_INF, NEG_INF, NEG_INF],
+ [NEG_INF, 0, 0, NEG_INF, 0]],
+ flattened_bias)
+ self.assertAllEqual([3, 1, 1, 5], bias_shape)
+
+ def test_get_decoder_self_attention_bias(self):
+ length = 5
+ bias = model_utils.get_decoder_self_attention_bias(length)
+
+ self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
+ [0, 0, NEG_INF, NEG_INF, NEG_INF],
+ [0, 0, 0, NEG_INF, NEG_INF],
+ [0, 0, 0, 0, NEG_INF],
+ [0, 0, 0, 0, 0]]]],
+ bias)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/transformer/optimizer.py b/models/official/nlp/transformer/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..176b5eb8c6ffcea8a9bccbad5fbdef1d2106e106
--- /dev/null
+++ b/models/official/nlp/transformer/optimizer.py
@@ -0,0 +1,137 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Optimizer from addons and learning rate scheduler."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+K = tf.keras.backend
+
+
+class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Learning rate schedule."""
+
+ def __init__(self, initial_learning_rate, hidden_size, warmup_steps):
+ """Initialize configuration of the learning rate schedule.
+
+ Args:
+ initial_learning_rate: A float, the initial learning rate.
+ hidden_size: An integer, the model dimension in the hidden layers.
+ warmup_steps: An integer, the number of steps required for linear warmup.
+ """
+ super(LearningRateSchedule, self).__init__()
+ self.initial_learning_rate = initial_learning_rate
+ self.hidden_size = hidden_size
+ self.warmup_steps = tf.cast(warmup_steps, tf.float32)
+
+ def __call__(self, global_step):
+ """Calculate learning rate with linear warmup and rsqrt decay.
+
+ Args:
+ global_step: An integer, the current global step used for learning rate
+ calculation.
+
+ Returns:
+ A float, the learning rate needs to be used for current global step.
+ """
+ with tf.name_scope('learning_rate_schedule'):
+ global_step = tf.cast(global_step, tf.float32)
+ learning_rate = self.initial_learning_rate
+ learning_rate *= (self.hidden_size**-0.5)
+ # Apply linear warmup
+ learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps)
+ # Apply rsqrt decay
+ learning_rate /= tf.sqrt(tf.maximum(global_step, self.warmup_steps))
+ return learning_rate
+
+ def get_config(self):
+ """Get the configuration of the learning rate schedule."""
+ return {
+ 'initial_learning_rate': self.initial_learning_rate,
+ 'hidden_size': self.hidden_size,
+ 'warmup_steps': self.warmup_steps,
+ }
+
+
+class LearningRateFn(object):
+ """Creates learning rate function."""
+
+ def __init__(self, learning_rate, hidden_size, warmup_steps):
+ self.learning_rate = learning_rate
+ self.hidden_size = hidden_size
+ self.warmup_steps = float(warmup_steps)
+
+ def __call__(self, global_step):
+ """Calculate learning rate with linear warmup and rsqrt decay."""
+ step = float(global_step)
+ learning_rate = self.learning_rate
+ learning_rate *= (self.hidden_size ** -0.5)
+ # Apply linear warmup
+ learning_rate *= np.minimum(1.0, step / self.warmup_steps)
+ # Apply rsqrt decay
+ learning_rate /= np.sqrt(np.maximum(step, self.warmup_steps))
+ return learning_rate
+
+
+class LearningRateScheduler(tf.keras.callbacks.Callback):
+ """Keras callback to schedule learning rate.
+
+ TODO(tianlin): Refactor this scheduler and LearningRateBatchScheduler in
+ official/resnet/keras/keras_common.py.
+ """
+
+ def __init__(self, schedule, init_steps=None, verbose=False):
+ super(LearningRateScheduler, self).__init__()
+ self.schedule = schedule
+ self.verbose = verbose
+ if init_steps is None:
+ init_steps = 0.0
+ self.steps = float(init_steps) # Total steps during training.
+
+ def on_epoch_begin(self, epoch, logs=None):
+ if not hasattr(self.model.optimizer, 'lr'):
+ raise ValueError('Optimizer must have a "lr" attribute.')
+ if not hasattr(self.model.optimizer, 'iterations'):
+ raise ValueError('Optimizer must have a "iterations" attribute.')
+
+ def on_train_batch_begin(self, batch, logs=None):
+ """Adjusts learning rate for each train batch."""
+ if self.verbose > 0:
+ iterations = K.get_value(self.model.optimizer.iterations)
+ print('Original iteration %d' % iterations)
+
+ self.steps += 1.0
+ try: # new API
+ lr = float(K.get_value(self.model.optimizer.lr))
+ lr = self.schedule(self.steps, lr)
+ except TypeError: # Support for old API for backward compatibility
+ lr = self.schedule(self.steps)
+ if not isinstance(lr, (float, np.float32, np.float64)):
+ raise ValueError('The output of the "schedule" function '
+ 'should be float.')
+ K.set_value(self.model.optimizer.lr, lr)
+ K.set_value(self.model.optimizer.iterations, self.steps)
+
+ if self.verbose > 0:
+ print('Batch %05d Step %05d: LearningRateScheduler setting learning '
+ 'rate to %s.' % (batch + 1, self.steps, lr))
+
+ def on_epoch_end(self, epoch, logs=None):
+ logs = logs or {}
+ logs['lr'] = K.get_value(self.model.optimizer.lr)
+ logs['steps'] = self.steps
diff --git a/models/official/nlp/transformer/transformer.py b/models/official/nlp/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..773e79449cdc493a96a5078ce85e801f8f9da250
--- /dev/null
+++ b/models/official/nlp/transformer/transformer.py
@@ -0,0 +1,565 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines the Transformer model in TF 2.0.
+
+Model paper: https://arxiv.org/pdf/1706.03762.pdf
+Transformer model code source: https://github.com/tensorflow/tensor2tensor
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from official.nlp.modeling.layers import position_embedding
+from official.nlp.transformer import attention_layer
+from official.nlp.transformer import beam_search
+from official.nlp.transformer import embedding_layer
+from official.nlp.transformer import ffn_layer
+from official.nlp.transformer import metrics
+from official.nlp.transformer import model_utils
+from official.nlp.transformer.utils.tokenizer import EOS_ID
+
+
+# Disable the not-callable lint error, since it claims many objects are not
+# callable when they actually are.
+# pylint: disable=not-callable
+
+
+def create_model(params, is_train):
+ """Creates transformer model."""
+ with tf.name_scope("model"):
+ if is_train:
+ inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
+ targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
+ internal_model = Transformer(params, name="transformer_v2")
+ logits = internal_model([inputs, targets], training=is_train)
+ vocab_size = params["vocab_size"]
+ label_smoothing = params["label_smoothing"]
+ if params["enable_metrics_in_training"]:
+ logits = metrics.MetricLayer(vocab_size)([logits, targets])
+ logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
+ dtype=tf.float32)(logits)
+ model = tf.keras.Model([inputs, targets], logits)
+ # TODO(reedwm): Can we do this loss in float16 instead of float32?
+ loss = metrics.transformer_loss(
+ logits, targets, label_smoothing, vocab_size)
+ model.add_loss(loss)
+ return model
+
+ else:
+ inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
+ internal_model = Transformer(params, name="transformer_v2")
+ ret = internal_model([inputs], training=is_train)
+ outputs, scores = ret["outputs"], ret["scores"]
+ return tf.keras.Model(inputs, [outputs, scores])
+
+
+class Transformer(tf.keras.Model):
+ """Transformer model with Keras.
+
+ Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf
+
+ The Transformer model consists of an encoder and decoder. The input is an int
+ sequence (or a batch of sequences). The encoder produces a continuous
+ representation, and the decoder uses the encoder output to generate
+ probabilities for the output sequence.
+ """
+
+ def __init__(self, params, name=None):
+ """Initialize layers to build Transformer model.
+
+ Args:
+ params: hyperparameter object defining layer sizes, dropout values, etc.
+ name: name of the model.
+ """
+ super(Transformer, self).__init__(name=name)
+ self.params = params
+ self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
+ params["vocab_size"], params["hidden_size"])
+ self.encoder_stack = EncoderStack(params)
+ self.decoder_stack = DecoderStack(params)
+ self.position_embedding = position_embedding.RelativePositionEmbedding(
+ hidden_size=self.params["hidden_size"])
+
+ def get_config(self):
+ return {
+ "params": self.params,
+ }
+
+ def call(self, inputs, training):
+ """Calculate target logits or inferred target sequences.
+
+ Args:
+ inputs: input tensor list of size 1 or 2.
+ First item, inputs: int tensor with shape [batch_size, input_length].
+ Second item (optional), targets: None or int tensor with shape
+ [batch_size, target_length].
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ If targets is defined, then return logits for each word in the target
+ sequence. float tensor with shape [batch_size, target_length, vocab_size]
+ If target is none, then generate output sequence one token at a time.
+ returns a dictionary {
+ outputs: [batch_size, decoded length]
+ scores: [batch_size, float]}
+ Even when float16 is used, the output tensor(s) are always float32.
+
+ Raises:
+ NotImplementedError: If try to use padded decode method on CPU/GPUs.
+ """
+ if len(inputs) == 2:
+ inputs, targets = inputs[0], inputs[1]
+ else:
+ # Decoding path.
+ inputs, targets = inputs[0], None
+ if self.params["padded_decode"]:
+ if not self.params["num_replicas"]:
+ raise NotImplementedError(
+ "Padded decoding on CPU/GPUs is not supported.")
+ decode_batch_size = int(self.params["decode_batch_size"] /
+ self.params["num_replicas"])
+ inputs.set_shape([
+ decode_batch_size, self.params["decode_max_length"]
+ ])
+
+ # Variance scaling is used here because it seems to work in many problems.
+ # Other reasonable initializers may also work just as well.
+ with tf.name_scope("Transformer"):
+ # Calculate attention bias for encoder self-attention and decoder
+ # multi-headed attention layers.
+ attention_bias = model_utils.get_padding_bias(inputs)
+
+ # Run the inputs through the encoder layer to map the symbol
+ # representations to continuous representations.
+ encoder_outputs = self.encode(inputs, attention_bias, training)
+ # Generate output sequence if targets is None, or return logits if target
+ # sequence is known.
+ if targets is None:
+ return self.predict(encoder_outputs, attention_bias, training)
+ else:
+ logits = self.decode(targets, encoder_outputs, attention_bias, training)
+ return logits
+
+ def encode(self, inputs, attention_bias, training):
+ """Generate continuous representation for inputs.
+
+ Args:
+ inputs: int tensor with shape [batch_size, input_length].
+ attention_bias: float tensor with shape [batch_size, 1, 1, input_length].
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ float tensor with shape [batch_size, input_length, hidden_size]
+ """
+ with tf.name_scope("encode"):
+ # Prepare inputs to the layer stack by adding positional encodings and
+ # applying dropout.
+ embedded_inputs = self.embedding_softmax_layer(inputs)
+ embedded_inputs = tf.cast(embedded_inputs, self.params["dtype"])
+ inputs_padding = model_utils.get_padding(inputs)
+ attention_bias = tf.cast(attention_bias, self.params["dtype"])
+
+ with tf.name_scope("add_pos_encoding"):
+ pos_encoding = self.position_embedding(inputs=embedded_inputs)
+ pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
+ encoder_inputs = embedded_inputs + pos_encoding
+
+ if training:
+ encoder_inputs = tf.nn.dropout(
+ encoder_inputs, rate=self.params["layer_postprocess_dropout"])
+
+ return self.encoder_stack(
+ encoder_inputs, attention_bias, inputs_padding, training=training)
+
+ def decode(self, targets, encoder_outputs, attention_bias, training):
+ """Generate logits for each value in the target sequence.
+
+ Args:
+ targets: target values for the output sequence. int tensor with shape
+ [batch_size, target_length]
+ encoder_outputs: continuous representation of input sequence. float tensor
+ with shape [batch_size, input_length, hidden_size]
+ attention_bias: float tensor with shape [batch_size, 1, 1, input_length]
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ float32 tensor with shape [batch_size, target_length, vocab_size]
+ """
+ with tf.name_scope("decode"):
+ # Prepare inputs to decoder layers by shifting targets, adding positional
+ # encoding and applying dropout.
+ decoder_inputs = self.embedding_softmax_layer(targets)
+ decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"])
+ attention_bias = tf.cast(attention_bias, self.params["dtype"])
+ with tf.name_scope("shift_targets"):
+ # Shift targets to the right, and remove the last element
+ decoder_inputs = tf.pad(decoder_inputs,
+ [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
+ with tf.name_scope("add_pos_encoding"):
+ length = tf.shape(decoder_inputs)[1]
+ pos_encoding = self.position_embedding(decoder_inputs)
+ pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
+ decoder_inputs += pos_encoding
+ if training:
+ decoder_inputs = tf.nn.dropout(
+ decoder_inputs, rate=self.params["layer_postprocess_dropout"])
+
+ # Run values
+ decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
+ length, dtype=self.params["dtype"])
+ outputs = self.decoder_stack(
+ decoder_inputs,
+ encoder_outputs,
+ decoder_self_attention_bias,
+ attention_bias,
+ training=training)
+ logits = self.embedding_softmax_layer(outputs, mode="linear")
+ logits = tf.cast(logits, tf.float32)
+ return logits
+
+ def _get_symbols_to_logits_fn(self, max_decode_length, training):
+ """Returns a decoding function that calculates logits of the next tokens."""
+ timing_signal = self.position_embedding(
+ inputs=None, length=max_decode_length + 1)
+ timing_signal = tf.cast(timing_signal, self.params["dtype"])
+ decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
+ max_decode_length, dtype=self.params["dtype"])
+
+ # TODO(b/139770046): Refactor code with better naming of i.
+ def symbols_to_logits_fn(ids, i, cache):
+ """Generate logits for next potential IDs.
+
+ Args:
+ ids: Current decoded sequences. int tensor with shape [batch_size *
+ beam_size, i + 1].
+ i: Loop index.
+ cache: dictionary of values storing the encoder output, encoder-decoder
+ attention bias, and previous decoder attention values.
+
+ Returns:
+ Tuple of
+ (logits with shape [batch_size * beam_size, vocab_size],
+ updated cache values)
+ """
+ # Set decoder input to the last generated IDs
+ decoder_input = ids[:, -1:]
+
+ # Preprocess decoder input by getting embeddings and adding timing signal.
+ decoder_input = self.embedding_softmax_layer(decoder_input)
+
+ if self.params["padded_decode"]:
+ timing_signal_shape = timing_signal.shape.as_list()
+ decoder_input += tf.slice(timing_signal, [i, 0],
+ [1, timing_signal_shape[1]])
+
+ bias_shape = decoder_self_attention_bias.shape.as_list()
+ self_attention_bias = tf.slice(
+ decoder_self_attention_bias, [0, 0, i, 0],
+ [bias_shape[0], bias_shape[1], 1, bias_shape[3]])
+ else:
+ decoder_input += timing_signal[i:i + 1]
+
+ self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
+
+ decoder_outputs = self.decoder_stack(
+ decoder_input,
+ cache.get("encoder_outputs"),
+ self_attention_bias,
+ cache.get("encoder_decoder_attention_bias"),
+ training=training,
+ cache=cache,
+ decode_loop_step=i if self.params["padded_decode"] else None)
+ logits = self.embedding_softmax_layer(decoder_outputs, mode="linear")
+ logits = tf.squeeze(logits, axis=[1])
+ return logits, cache
+
+ return symbols_to_logits_fn
+
+ def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
+ """Return predicted sequence."""
+ encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
+ if self.params["padded_decode"]:
+ batch_size = encoder_outputs.shape.as_list()[0]
+ input_length = encoder_outputs.shape.as_list()[1]
+ else:
+ batch_size = tf.shape(encoder_outputs)[0]
+ input_length = tf.shape(encoder_outputs)[1]
+ max_decode_length = input_length + self.params["extra_decode_length"]
+ encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
+ self.params["dtype"])
+
+ symbols_to_logits_fn = self._get_symbols_to_logits_fn(
+ max_decode_length, training)
+
+ # Create initial set of IDs that will be passed into symbols_to_logits_fn.
+ initial_ids = tf.zeros([batch_size], dtype=tf.int32)
+
+ # Create cache storing decoder attention values for each layer.
+ # pylint: disable=g-complex-comprehension
+ init_decode_length = (
+ max_decode_length if self.params["padded_decode"] else 0)
+ num_heads = self.params["num_heads"]
+ dim_per_head = self.params["hidden_size"] // num_heads
+ cache = {
+ "layer_%d" % layer: {
+ "k":
+ tf.zeros([
+ batch_size, init_decode_length, num_heads, dim_per_head
+ ],
+ dtype=self.params["dtype"]),
+ "v":
+ tf.zeros([
+ batch_size, init_decode_length, num_heads, dim_per_head
+ ],
+ dtype=self.params["dtype"])
+ } for layer in range(self.params["num_hidden_layers"])
+ }
+ # pylint: enable=g-complex-comprehension
+
+ # Add encoder output and attention bias to the cache.
+ cache["encoder_outputs"] = encoder_outputs
+ cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
+
+ # Use beam search to find the top beam_size sequences and scores.
+ decoded_ids, scores = beam_search.sequence_beam_search(
+ symbols_to_logits_fn=symbols_to_logits_fn,
+ initial_ids=initial_ids,
+ initial_cache=cache,
+ vocab_size=self.params["vocab_size"],
+ beam_size=self.params["beam_size"],
+ alpha=self.params["alpha"],
+ max_decode_length=max_decode_length,
+ eos_id=EOS_ID,
+ padded_decode=self.params["padded_decode"],
+ dtype=self.params["dtype"])
+
+ # Get the top sequence for each batch element
+ top_decoded_ids = decoded_ids[:, 0, 1:]
+ top_scores = scores[:, 0]
+
+ return {"outputs": top_decoded_ids, "scores": top_scores}
+
+
+class PrePostProcessingWrapper(tf.keras.layers.Layer):
+ """Wrapper class that applies layer pre-processing and post-processing."""
+
+ def __init__(self, layer, params):
+ super(PrePostProcessingWrapper, self).__init__()
+ self.layer = layer
+ self.params = params
+ self.postprocess_dropout = params["layer_postprocess_dropout"]
+
+ def build(self, input_shape):
+ # Create normalization layer
+ self.layer_norm = tf.keras.layers.LayerNormalization(
+ epsilon=1e-6, dtype="float32")
+ super(PrePostProcessingWrapper, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "params": self.params,
+ }
+
+ def call(self, x, *args, **kwargs):
+ """Calls wrapped layer with same parameters."""
+ # Preprocessing: apply layer normalization
+ training = kwargs["training"]
+
+ y = self.layer_norm(x)
+
+ # Get layer output
+ y = self.layer(y, *args, **kwargs)
+
+ # Postprocessing: apply dropout and residual connection
+ if training:
+ y = tf.nn.dropout(y, rate=self.postprocess_dropout)
+ return x + y
+
+
+class EncoderStack(tf.keras.layers.Layer):
+ """Transformer encoder stack.
+
+ The encoder stack is made up of N identical layers. Each layer is composed
+ of the sublayers:
+ 1. Self-attention layer
+ 2. Feedforward network (which is 2 fully-connected layers)
+ """
+
+ def __init__(self, params):
+ super(EncoderStack, self).__init__()
+ self.params = params
+ self.layers = []
+
+ def build(self, input_shape):
+ """Builds the encoder stack."""
+ params = self.params
+ for _ in range(params["num_hidden_layers"]):
+ # Create sublayers for each layer.
+ self_attention_layer = attention_layer.SelfAttention(
+ params["hidden_size"], params["num_heads"],
+ params["attention_dropout"])
+ feed_forward_network = ffn_layer.FeedForwardNetwork(
+ params["hidden_size"], params["filter_size"], params["relu_dropout"])
+
+ self.layers.append([
+ PrePostProcessingWrapper(self_attention_layer, params),
+ PrePostProcessingWrapper(feed_forward_network, params)
+ ])
+
+ # Create final layer normalization layer.
+ self.output_normalization = tf.keras.layers.LayerNormalization(
+ epsilon=1e-6, dtype="float32")
+ super(EncoderStack, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "params": self.params,
+ }
+
+ def call(self, encoder_inputs, attention_bias, inputs_padding, training):
+ """Return the output of the encoder layer stacks.
+
+ Args:
+ encoder_inputs: tensor with shape [batch_size, input_length, hidden_size]
+ attention_bias: bias for the encoder self-attention layer. [batch_size, 1,
+ 1, input_length]
+ inputs_padding: tensor with shape [batch_size, input_length], inputs with
+ zero paddings.
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ Output of encoder layer stack.
+ float32 tensor with shape [batch_size, input_length, hidden_size]
+ """
+ for n, layer in enumerate(self.layers):
+ # Run inputs through the sublayers.
+ self_attention_layer = layer[0]
+ feed_forward_network = layer[1]
+
+ with tf.name_scope("layer_%d" % n):
+ with tf.name_scope("self_attention"):
+ encoder_inputs = self_attention_layer(
+ encoder_inputs, attention_bias, training=training)
+ with tf.name_scope("ffn"):
+ encoder_inputs = feed_forward_network(
+ encoder_inputs, training=training)
+
+ return self.output_normalization(encoder_inputs)
+
+
+class DecoderStack(tf.keras.layers.Layer):
+ """Transformer decoder stack.
+
+ Like the encoder stack, the decoder stack is made up of N identical layers.
+ Each layer is composed of the sublayers:
+ 1. Self-attention layer
+ 2. Multi-headed attention layer combining encoder outputs with results from
+ the previous self-attention layer.
+ 3. Feedforward network (2 fully-connected layers)
+ """
+
+ def __init__(self, params):
+ super(DecoderStack, self).__init__()
+ self.params = params
+ self.layers = []
+
+ def build(self, input_shape):
+ """Builds the decoder stack."""
+ params = self.params
+ for _ in range(params["num_hidden_layers"]):
+ self_attention_layer = attention_layer.SelfAttention(
+ params["hidden_size"], params["num_heads"],
+ params["attention_dropout"])
+ enc_dec_attention_layer = attention_layer.Attention(
+ params["hidden_size"], params["num_heads"],
+ params["attention_dropout"])
+ feed_forward_network = ffn_layer.FeedForwardNetwork(
+ params["hidden_size"], params["filter_size"], params["relu_dropout"])
+
+ self.layers.append([
+ PrePostProcessingWrapper(self_attention_layer, params),
+ PrePostProcessingWrapper(enc_dec_attention_layer, params),
+ PrePostProcessingWrapper(feed_forward_network, params)
+ ])
+ self.output_normalization = tf.keras.layers.LayerNormalization(
+ epsilon=1e-6, dtype="float32")
+ super(DecoderStack, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "params": self.params,
+ }
+
+ def call(self,
+ decoder_inputs,
+ encoder_outputs,
+ decoder_self_attention_bias,
+ attention_bias,
+ training,
+ cache=None,
+ decode_loop_step=None):
+ """Return the output of the decoder layer stacks.
+
+ Args:
+ decoder_inputs: A tensor with shape
+ [batch_size, target_length, hidden_size].
+ encoder_outputs: A tensor with shape
+ [batch_size, input_length, hidden_size]
+ decoder_self_attention_bias: A tensor with shape
+ [1, 1, target_len, target_length], the bias for decoder self-attention
+ layer.
+ attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
+ the bias for encoder-decoder attention layer.
+ training: A bool, whether in training mode or not.
+ cache: (Used for fast decoding) A nested dictionary storing previous
+ decoder self-attention values. The items are:
+ {layer_n: {"k": A tensor with shape [batch_size, i, key_channels],
+ "v": A tensor with shape [batch_size, i, value_channels]},
+ ...}
+ decode_loop_step: An integer, the step number of the decoding loop. Used
+ only for autoregressive inference on TPU.
+
+ Returns:
+ Output of decoder layer stack.
+ float32 tensor with shape [batch_size, target_length, hidden_size]
+ """
+ for n, layer in enumerate(self.layers):
+ self_attention_layer = layer[0]
+ enc_dec_attention_layer = layer[1]
+ feed_forward_network = layer[2]
+
+ # Run inputs through the sublayers.
+ layer_name = "layer_%d" % n
+ layer_cache = cache[layer_name] if cache is not None else None
+ with tf.name_scope(layer_name):
+ with tf.name_scope("self_attention"):
+ decoder_inputs = self_attention_layer(
+ decoder_inputs,
+ decoder_self_attention_bias,
+ training=training,
+ cache=layer_cache,
+ decode_loop_step=decode_loop_step)
+ with tf.name_scope("encdec_attention"):
+ decoder_inputs = enc_dec_attention_layer(
+ decoder_inputs,
+ encoder_outputs,
+ attention_bias,
+ training=training)
+ with tf.name_scope("ffn"):
+ decoder_inputs = feed_forward_network(
+ decoder_inputs, training=training)
+
+ return self.output_normalization(decoder_inputs)
diff --git a/models/official/nlp/transformer/transformer_layers_test.py b/models/official/nlp/transformer/transformer_layers_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..82d37259da2854fb83e086749fe7a8df2c22e955
--- /dev/null
+++ b/models/official/nlp/transformer/transformer_layers_test.py
@@ -0,0 +1,97 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for layers in Transformer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.nlp.transformer import attention_layer
+from official.nlp.transformer import embedding_layer
+from official.nlp.transformer import ffn_layer
+from official.nlp.transformer import metrics
+
+
+class TransformerLayersTest(tf.test.TestCase):
+
+ def test_attention_layer(self):
+ hidden_size = 64
+ num_heads = 4
+ dropout = 0.5
+ dim_per_head = hidden_size // num_heads
+ layer = attention_layer.SelfAttention(hidden_size, num_heads, dropout)
+ self.assertDictEqual(layer.get_config(), {
+ "hidden_size": hidden_size,
+ "num_heads": num_heads,
+ "attention_dropout": dropout,
+ })
+ length = 2
+ x = tf.ones([1, length, hidden_size])
+ bias = tf.ones([1])
+ cache = {
+ "k": tf.zeros([1, 0, num_heads, dim_per_head]),
+ "v": tf.zeros([1, 0, num_heads, dim_per_head]),
+ }
+ y = layer(x, bias, training=True, cache=cache)
+ self.assertEqual(y.shape, (1, length, 64,))
+ self.assertEqual(cache["k"].shape, (1, length, num_heads, dim_per_head,))
+ self.assertEqual(cache["v"].shape, (1, length, num_heads, dim_per_head,))
+
+ def test_embedding_shared_weights(self):
+ vocab_size = 50
+ hidden_size = 64
+ length = 2
+ layer = embedding_layer.EmbeddingSharedWeights(vocab_size, hidden_size)
+ self.assertDictEqual(layer.get_config(), {
+ "vocab_size": 50,
+ "hidden_size": 64,
+ })
+
+ idx = tf.ones([1, length], dtype="int32")
+ y = layer(idx)
+ self.assertEqual(y.shape, (1, length, hidden_size,))
+ x = tf.ones([1, length, hidden_size])
+ output = layer(x, "linear")
+ self.assertEqual(output.shape, (1, length, vocab_size,))
+
+ def test_feed_forward_network(self):
+ hidden_size = 64
+ filter_size = 32
+ relu_dropout = 0.5
+ layer = ffn_layer.FeedForwardNetwork(hidden_size, filter_size, relu_dropout)
+ self.assertDictEqual(layer.get_config(), {
+ "hidden_size": hidden_size,
+ "filter_size": filter_size,
+ "relu_dropout": relu_dropout,
+ })
+ length = 2
+ x = tf.ones([1, length, hidden_size])
+ y = layer(x, training=True)
+ self.assertEqual(y.shape, (1, length, hidden_size,))
+
+ def test_metric_layer(self):
+ vocab_size = 50
+ logits = tf.keras.layers.Input((None, vocab_size),
+ dtype="float32",
+ name="logits")
+ targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
+ output_logits = metrics.MetricLayer(vocab_size)([logits, targets])
+ self.assertEqual(output_logits.shape.as_list(), [None, None, vocab_size,])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/transformer/transformer_main.py b/models/official/nlp/transformer/transformer_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..14177d8562b6ec4b190fe5d773998368ffc0b881
--- /dev/null
+++ b/models/official/nlp/transformer/transformer_main.py
@@ -0,0 +1,496 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Train and evaluate the Transformer model.
+
+See README for description of setting the training schedule and evaluating the
+BLEU score.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+
+from official.modeling import performance
+from official.nlp.transformer import compute_bleu
+from official.nlp.transformer import data_pipeline
+from official.nlp.transformer import metrics
+from official.nlp.transformer import misc
+from official.nlp.transformer import optimizer
+from official.nlp.transformer import transformer
+from official.nlp.transformer import translate
+from official.nlp.transformer.utils import tokenizer
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+
+INF = int(1e9)
+BLEU_DIR = "bleu"
+_SINGLE_SAMPLE = 1
+
+
+def translate_and_compute_bleu(model,
+ params,
+ subtokenizer,
+ bleu_source,
+ bleu_ref,
+ distribution_strategy=None):
+ """Translate file and report the cased and uncased bleu scores.
+
+ Args:
+ model: A Keras model, used to generate the translations.
+ params: A dictionary, containing the translation related parameters.
+ subtokenizer: A subtokenizer object, used for encoding and decoding source
+ and translated lines.
+ bleu_source: A file containing source sentences for translation.
+ bleu_ref: A file containing the reference for the translated sentences.
+ distribution_strategy: A platform distribution strategy, used for TPU based
+ translation.
+
+ Returns:
+ uncased_score: A float, the case insensitive BLEU score.
+ cased_score: A float, the case sensitive BLEU score.
+ """
+ # Create temporary file to store translation.
+ tmp = tempfile.NamedTemporaryFile(delete=False)
+ tmp_filename = tmp.name
+
+ translate.translate_file(
+ model,
+ params,
+ subtokenizer,
+ bleu_source,
+ output_file=tmp_filename,
+ print_all_translations=False,
+ distribution_strategy=distribution_strategy)
+
+ # Compute uncased and cased bleu scores.
+ uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
+ cased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, True)
+ os.remove(tmp_filename)
+ return uncased_score, cased_score
+
+
+def evaluate_and_log_bleu(model,
+ params,
+ bleu_source,
+ bleu_ref,
+ vocab_file,
+ distribution_strategy=None):
+ """Calculate and record the BLEU score.
+
+ Args:
+ model: A Keras model, used to generate the translations.
+ params: A dictionary, containing the translation related parameters.
+ bleu_source: A file containing source sentences for translation.
+ bleu_ref: A file containing the reference for the translated sentences.
+ vocab_file: A file containing the vocabulary for translation.
+ distribution_strategy: A platform distribution strategy, used for TPU based
+ translation.
+
+ Returns:
+ uncased_score: A float, the case insensitive BLEU score.
+ cased_score: A float, the case sensitive BLEU score.
+ """
+ subtokenizer = tokenizer.Subtokenizer(vocab_file)
+
+ uncased_score, cased_score = translate_and_compute_bleu(
+ model, params, subtokenizer, bleu_source, bleu_ref, distribution_strategy)
+
+ logging.info("Bleu score (uncased): %s", uncased_score)
+ logging.info("Bleu score (cased): %s", cased_score)
+ return uncased_score, cased_score
+
+
+class TransformerTask(object):
+ """Main entry of Transformer model."""
+
+ def __init__(self, flags_obj):
+ """Init function of TransformerMain.
+
+ Args:
+ flags_obj: Object containing parsed flag values, i.e., FLAGS.
+
+ Raises:
+ ValueError: if not using static batch for input data on TPU.
+ """
+ self.flags_obj = flags_obj
+ self.predict_model = None
+
+ # Add flag-defined parameters to params object
+ num_gpus = flags_core.get_num_gpus(flags_obj)
+ self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
+
+ params["num_gpus"] = num_gpus
+ params["use_ctl"] = flags_obj.use_ctl
+ params["data_dir"] = flags_obj.data_dir
+ params["model_dir"] = flags_obj.model_dir
+ params["static_batch"] = flags_obj.static_batch
+ params["max_length"] = flags_obj.max_length
+ params["decode_batch_size"] = flags_obj.decode_batch_size
+ params["decode_max_length"] = flags_obj.decode_max_length
+ params["padded_decode"] = flags_obj.padded_decode
+ params["max_io_parallelism"] = (
+ flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
+
+ params["use_synthetic_data"] = flags_obj.use_synthetic_data
+ params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
+ params["repeat_dataset"] = None
+ params["dtype"] = flags_core.get_tf_dtype(flags_obj)
+ params["enable_tensorboard"] = flags_obj.enable_tensorboard
+ params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
+ params["steps_between_evals"] = flags_obj.steps_between_evals
+ params["enable_checkpointing"] = flags_obj.enable_checkpointing
+
+ self.distribution_strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=num_gpus,
+ all_reduce_alg=flags_obj.all_reduce_alg,
+ num_packs=flags_obj.num_packs,
+ tpu_address=flags_obj.tpu or "")
+ if self.use_tpu:
+ params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
+ else:
+ logging.info("Running transformer with num_gpus = %d", num_gpus)
+
+ if self.distribution_strategy:
+ logging.info("For training, using distribution strategy: %s",
+ self.distribution_strategy)
+ else:
+ logging.info("Not using any distribution strategy.")
+
+ performance.set_mixed_precision_policy(
+ params["dtype"],
+ flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic"))
+
+ @property
+ def use_tpu(self):
+ if self.distribution_strategy:
+ return isinstance(self.distribution_strategy,
+ tf.distribute.experimental.TPUStrategy)
+ return False
+
+ def train(self):
+ """Trains the model."""
+ params = self.params
+ flags_obj = self.flags_obj
+ # Sets config options.
+ keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
+
+ _ensure_dir(flags_obj.model_dir)
+ with distribution_utils.get_strategy_scope(self.distribution_strategy):
+ model = transformer.create_model(params, is_train=True)
+ opt = self._create_optimizer()
+
+ current_step = 0
+ checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
+ latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
+ if latest_checkpoint:
+ checkpoint.restore(latest_checkpoint)
+ logging.info("Loaded checkpoint %s", latest_checkpoint)
+ current_step = opt.iterations.numpy()
+
+ if params["use_ctl"]:
+ train_loss_metric = tf.keras.metrics.Mean(
+ "training_loss", dtype=tf.float32)
+ if params["enable_tensorboard"]:
+ summary_writer = tf.compat.v2.summary.create_file_writer(
+ flags_obj.model_dir)
+ else:
+ summary_writer = tf.compat.v2.summary.create_noop_writer()
+ train_metrics = [train_loss_metric]
+ if params["enable_metrics_in_training"]:
+ train_metrics = train_metrics + model.metrics
+ else:
+ model.compile(opt)
+
+ model.summary()
+
+ if self.use_tpu:
+ # Different from experimental_distribute_dataset,
+ # experimental_distribute_datasets_from_function requires
+ # per-replica/local batch size.
+ params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
+ train_ds = (
+ self.distribution_strategy
+ .experimental_distribute_datasets_from_function(
+ lambda ctx: data_pipeline.train_input_fn(params, ctx)))
+ else:
+ train_ds = data_pipeline.train_input_fn(params)
+ map_data_fn = data_pipeline.map_data_for_transformer_fn
+ train_ds = train_ds.map(
+ map_data_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ if params["use_ctl"]:
+ train_ds_iterator = iter(train_ds)
+
+ callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
+
+ # Only TimeHistory callback is supported for CTL
+ if params["use_ctl"]:
+ callbacks = [cb for cb in callbacks
+ if isinstance(cb, keras_utils.TimeHistory)]
+
+ # TODO(b/139418525): Refactor the custom training loop logic.
+ @tf.function
+ def train_steps(iterator, steps):
+ """Training steps function for TPU runs.
+
+ Args:
+ iterator: The input iterator of the training dataset.
+ steps: An integer, the number of training steps.
+
+ Returns:
+ A float, the loss value.
+ """
+
+ def _step_fn(inputs):
+ """Per-replica step function."""
+ inputs, targets = inputs
+ with tf.GradientTape() as tape:
+ logits = model([inputs, targets], training=True)
+ loss = metrics.transformer_loss(logits, targets,
+ params["label_smoothing"],
+ params["vocab_size"])
+ # Scales the loss, which results in using the average loss across all
+ # of the replicas for backprop.
+ scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync
+
+ # De-dupes variables due to keras tracking issues.
+ tvars = list({id(v): v for v in model.trainable_variables}.values())
+ grads = tape.gradient(scaled_loss, tvars)
+ opt.apply_gradients(zip(grads, tvars))
+ # For reporting, the metric takes the mean of losses.
+ train_loss_metric.update_state(loss)
+
+ for _ in tf.range(steps):
+ train_loss_metric.reset_states()
+ self.distribution_strategy.run(
+ _step_fn, args=(next(iterator),))
+
+ cased_score, uncased_score = None, None
+ cased_score_history, uncased_score_history = [], []
+ while current_step < flags_obj.train_steps:
+ remaining_steps = flags_obj.train_steps - current_step
+ train_steps_per_eval = (
+ remaining_steps if remaining_steps < flags_obj.steps_between_evals
+ else flags_obj.steps_between_evals)
+ current_iteration = current_step // flags_obj.steps_between_evals
+
+ logging.info(
+ "Start train iteration at global step:{}".format(current_step))
+ history = None
+ if params["use_ctl"]:
+ if not self.use_tpu:
+ raise NotImplementedError(
+ "Custom training loop on GPUs is not implemented.")
+
+ # Runs training steps.
+ with summary_writer.as_default():
+ for cb in callbacks:
+ cb.on_epoch_begin(current_iteration)
+ cb.on_batch_begin(0)
+
+ train_steps(
+ train_ds_iterator,
+ tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
+ current_step += train_steps_per_eval
+ train_loss = train_loss_metric.result().numpy().astype(float)
+ logging.info("Train Step: %d/%d / loss = %s", current_step,
+ flags_obj.train_steps, train_loss)
+
+ for cb in callbacks:
+ cb.on_batch_end(train_steps_per_eval - 1)
+ cb.on_epoch_end(current_iteration)
+
+ if params["enable_tensorboard"]:
+ for metric_obj in train_metrics:
+ tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
+ current_step)
+ summary_writer.flush()
+
+ for cb in callbacks:
+ cb.on_train_end()
+
+ if flags_obj.enable_checkpointing:
+ # avoid check-pointing when running for benchmarking.
+ checkpoint_name = checkpoint.save(
+ os.path.join(flags_obj.model_dir,
+ "ctl_step_{}.ckpt".format(current_step)))
+ logging.info("Saved checkpoint to %s", checkpoint_name)
+ else:
+ if self.use_tpu:
+ raise NotImplementedError(
+ "Keras model.fit on TPUs is not implemented.")
+ history = model.fit(
+ train_ds,
+ initial_epoch=current_iteration,
+ epochs=current_iteration + 1,
+ steps_per_epoch=train_steps_per_eval,
+ callbacks=callbacks,
+ # If TimeHistory is enabled, progress bar would be messy. Increase
+ # the verbose level to get rid of it.
+ verbose=(2 if flags_obj.enable_time_history else 1))
+ current_step += train_steps_per_eval
+ logging.info("Train history: {}".format(history.history))
+
+ logging.info("End train iteration at global step:{}".format(current_step))
+
+ if (flags_obj.bleu_source and flags_obj.bleu_ref):
+ uncased_score, cased_score = self.eval()
+ cased_score_history.append([current_iteration + 1, cased_score])
+ uncased_score_history.append([current_iteration + 1, uncased_score])
+
+ stats = ({
+ "loss": train_loss
+ } if history is None else {})
+ misc.update_stats(history, stats, callbacks)
+ if uncased_score and cased_score:
+ stats["bleu_uncased"] = uncased_score
+ stats["bleu_cased"] = cased_score
+ stats["bleu_uncased_history"] = uncased_score_history
+ stats["bleu_cased_history"] = cased_score_history
+ return stats
+
+ def eval(self):
+ """Evaluates the model."""
+ distribution_strategy = self.distribution_strategy if self.use_tpu else None
+
+ # We only want to create the model under DS scope for TPU case.
+ # When 'distribution_strategy' is None, a no-op DummyContextManager will
+ # be used.
+ with distribution_utils.get_strategy_scope(distribution_strategy):
+ if not self.predict_model:
+ self.predict_model = transformer.create_model(self.params, False)
+ self._load_weights_if_possible(
+ self.predict_model,
+ tf.train.latest_checkpoint(self.flags_obj.model_dir))
+ self.predict_model.summary()
+ return evaluate_and_log_bleu(
+ self.predict_model, self.params, self.flags_obj.bleu_source,
+ self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
+ distribution_strategy)
+
+ def predict(self):
+ """Predicts result from the model."""
+ params = self.params
+ flags_obj = self.flags_obj
+
+ with tf.name_scope("model"):
+ model = transformer.create_model(params, is_train=False)
+ self._load_weights_if_possible(
+ model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
+ model.summary()
+ subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file)
+
+ ds = data_pipeline.eval_input_fn(params)
+ ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE)
+ ret = model.predict(ds)
+ val_outputs, _ = ret
+ length = len(val_outputs)
+ for i in range(length):
+ translate.translate_from_input(val_outputs[i], subtokenizer)
+
+ def _create_callbacks(self, cur_log_dir, init_steps, params):
+ """Creates a list of callbacks."""
+ sfunc = optimizer.LearningRateFn(params["learning_rate"],
+ params["hidden_size"],
+ params["learning_rate_warmup_steps"])
+ scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps)
+ callbacks = misc.get_callbacks()
+ callbacks.append(scheduler_callback)
+ if params["enable_checkpointing"]:
+ ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
+ callbacks.append(
+ tf.keras.callbacks.ModelCheckpoint(
+ ckpt_full_path, save_weights_only=True))
+ return callbacks
+
+ def _load_weights_if_possible(self, model, init_weight_path=None):
+ """Loads model weights when it is provided."""
+ if init_weight_path:
+ logging.info("Load weights: {}".format(init_weight_path))
+ # TODO(b/139414977): Having the same variable restoring method for both
+ # TPU and GPU.
+ if self.use_tpu:
+ checkpoint = tf.train.Checkpoint(
+ model=model, optimizer=self._create_optimizer())
+ checkpoint.restore(init_weight_path)
+ else:
+ model.load_weights(init_weight_path)
+ else:
+ logging.info("Weights not loaded from path:{}".format(init_weight_path))
+
+ def _create_optimizer(self):
+ """Creates optimizer."""
+ params = self.params
+ lr_schedule = optimizer.LearningRateSchedule(
+ params["learning_rate"], params["hidden_size"],
+ params["learning_rate_warmup_steps"])
+ opt = tf.keras.optimizers.Adam(
+ lr_schedule if self.use_tpu else params["learning_rate"],
+ params["optimizer_adam_beta1"],
+ params["optimizer_adam_beta2"],
+ epsilon=params["optimizer_adam_epsilon"])
+
+ opt = performance.configure_optimizer(
+ opt,
+ use_float16=params["dtype"] == tf.float16,
+ use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
+ loss_scale=flags_core.get_loss_scale(
+ self.flags_obj, default_for_fp16="dynamic"))
+
+ return opt
+
+
+def _ensure_dir(log_dir):
+ """Makes log dir if not existed."""
+ if not tf.io.gfile.exists(log_dir):
+ tf.io.gfile.makedirs(log_dir)
+
+
+def main(_):
+ flags_obj = flags.FLAGS
+ if flags_obj.enable_mlir_bridge:
+ tf.config.experimental.enable_mlir_bridge()
+ task = TransformerTask(flags_obj)
+
+ # Execute flag override logic for better model performance
+ if flags_obj.tf_gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=flags_obj.per_gpu_thread_count,
+ gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
+ num_gpus=flags_obj.num_gpus,
+ datasets_num_private_threads=flags_obj.datasets_num_private_threads)
+
+ if flags_obj.mode == "train":
+ task.train()
+ elif flags_obj.mode == "predict":
+ task.predict()
+ elif flags_obj.mode == "eval":
+ task.eval()
+ else:
+ raise ValueError("Invalid mode {}".format(flags_obj.mode))
+
+
+if __name__ == "__main__":
+ logging.set_verbosity(logging.INFO)
+ misc.define_transformer_flags()
+ app.run(main)
diff --git a/models/official/nlp/transformer/transformer_main_test.py b/models/official/nlp/transformer/transformer_main_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a65cc4bcbf3a1c4281a36730a1ab60c496f3c7aa
--- /dev/null
+++ b/models/official/nlp/transformer/transformer_main_test.py
@@ -0,0 +1,191 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test Transformer model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import re
+import sys
+import unittest
+
+from absl import flags
+from absl.testing import flagsaver
+import tensorflow as tf
+from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
+from official.nlp.transformer import misc
+from official.nlp.transformer import transformer_main
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+FIXED_TIMESTAMP = 'my_time_stamp'
+WEIGHT_PATTERN = re.compile(r'weights-epoch-.+\.hdf5')
+
+
+def _generate_file(filepath, lines):
+ with open(filepath, 'w') as f:
+ for l in lines:
+ f.write('{}\n'.format(l))
+
+
+class TransformerTaskTest(tf.test.TestCase):
+ local_flags = None
+
+ def setUp(self):
+ temp_dir = self.get_temp_dir()
+ if TransformerTaskTest.local_flags is None:
+ misc.define_transformer_flags()
+ # Loads flags, array cannot be blank.
+ flags.FLAGS(['foo'])
+ TransformerTaskTest.local_flags = flagsaver.save_flag_values()
+ else:
+ flagsaver.restore_flag_values(TransformerTaskTest.local_flags)
+ FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
+ FLAGS.param_set = 'tiny'
+ FLAGS.use_synthetic_data = True
+ FLAGS.steps_between_evals = 1
+ FLAGS.train_steps = 2
+ FLAGS.validation_steps = 1
+ FLAGS.batch_size = 8
+ FLAGS.max_length = 1
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.dtype = 'fp32'
+ self.model_dir = FLAGS.model_dir
+ self.temp_dir = temp_dir
+ self.vocab_file = os.path.join(temp_dir, 'vocab')
+ self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size']
+ self.bleu_source = os.path.join(temp_dir, 'bleu_source')
+ self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
+ self.orig_policy = (
+ tf.compat.v2.keras.mixed_precision.experimental.global_policy())
+
+ def tearDown(self):
+ tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.orig_policy)
+
+ def _assert_exists(self, filepath):
+ self.assertTrue(os.path.exists(filepath))
+
+ def test_train_no_dist_strat(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ def test_train_static_batch(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ FLAGS.distribution_strategy = 'one_device'
+ if tf.test.is_built_with_cuda():
+ FLAGS.num_gpus = 1
+ else:
+ FLAGS.num_gpus = 0
+ FLAGS.static_batch = True
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_train_1_gpu_with_dist_strat(self):
+ FLAGS.distribution_strategy = 'one_device'
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_train_fp16(self):
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.dtype = 'fp16'
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_train_2_gpu(self):
+ if context.num_gpus() < 2:
+ self.skipTest(
+ '{} GPUs are not available for this test. {} GPUs are available'
+ .format(2, context.num_gpus()))
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.num_gpus = 2
+ FLAGS.param_set = 'base'
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_train_2_gpu_fp16(self):
+ if context.num_gpus() < 2:
+ self.skipTest(
+ '{} GPUs are not available for this test. {} GPUs are available'
+ .format(2, context.num_gpus()))
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.num_gpus = 2
+ FLAGS.param_set = 'base'
+ FLAGS.dtype = 'fp16'
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ def _prepare_files_and_flags(self, *extra_flags):
+ # Make log dir.
+ if not os.path.exists(self.temp_dir):
+ os.makedirs(self.temp_dir)
+
+ # Fake vocab, bleu_source and bleu_ref.
+ tokens = [
+ "''", "''", "'_'", "'a'", "'b'", "'c'", "'d'", "'a_'", "'b_'",
+ "'c_'", "'d_'"
+ ]
+ tokens += ["'{}'".format(i) for i in range(self.vocab_size - len(tokens))]
+ _generate_file(self.vocab_file, tokens)
+ _generate_file(self.bleu_source, ['a b', 'c d'])
+ _generate_file(self.bleu_ref, ['a b', 'd c'])
+
+ # Update flags.
+ update_flags = [
+ 'ignored_program_name',
+ '--vocab_file={}'.format(self.vocab_file),
+ '--bleu_source={}'.format(self.bleu_source),
+ '--bleu_ref={}'.format(self.bleu_ref),
+ ]
+ if extra_flags:
+ update_flags.extend(extra_flags)
+ FLAGS(update_flags)
+
+ def test_predict(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ self._prepare_files_and_flags()
+ t = transformer_main.TransformerTask(FLAGS)
+ t.predict()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_predict_fp16(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ self._prepare_files_and_flags('--dtype=fp16')
+ t = transformer_main.TransformerTask(FLAGS)
+ t.predict()
+
+ def test_eval(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ if 'test_xla' in sys.argv[0]:
+ self.skipTest('TODO(xla): Make this test faster under XLA.')
+ self._prepare_files_and_flags()
+ t = transformer_main.TransformerTask(FLAGS)
+ t.eval()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/nlp/transformer/transformer_test.py b/models/official/nlp/transformer/transformer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..227b43dc6ff194ab74effc37214ae9253823310d
--- /dev/null
+++ b/models/official/nlp/transformer/transformer_test.py
@@ -0,0 +1,68 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test Transformer model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.nlp.transformer import model_params
+from official.nlp.transformer import transformer
+
+
+class TransformerV2Test(tf.test.TestCase):
+
+ def setUp(self):
+ self.params = params = model_params.TINY_PARAMS
+ params["batch_size"] = params["default_batch_size"] = 16
+ params["use_synthetic_data"] = True
+ params["hidden_size"] = 12
+ params["num_hidden_layers"] = 2
+ params["filter_size"] = 14
+ params["num_heads"] = 2
+ params["vocab_size"] = 41
+ params["extra_decode_length"] = 2
+ params["beam_size"] = 3
+ params["dtype"] = tf.float32
+
+ def test_create_model_train(self):
+ model = transformer.create_model(self.params, True)
+ inputs, outputs = model.inputs, model.outputs
+ self.assertEqual(len(inputs), 2)
+ self.assertEqual(len(outputs), 1)
+ self.assertEqual(inputs[0].shape.as_list(), [None, None])
+ self.assertEqual(inputs[0].dtype, tf.int64)
+ self.assertEqual(inputs[1].shape.as_list(), [None, None])
+ self.assertEqual(inputs[1].dtype, tf.int64)
+ self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
+ self.assertEqual(outputs[0].dtype, tf.float32)
+
+ def test_create_model_not_train(self):
+ model = transformer.create_model(self.params, False)
+ inputs, outputs = model.inputs, model.outputs
+ self.assertEqual(len(inputs), 1)
+ self.assertEqual(len(outputs), 2)
+ self.assertEqual(inputs[0].shape.as_list(), [None, None])
+ self.assertEqual(inputs[0].dtype, tf.int64)
+ self.assertEqual(outputs[0].shape.as_list(), [None, None])
+ self.assertEqual(outputs[0].dtype, tf.int32)
+ self.assertEqual(outputs[1].shape.as_list(), [None])
+ self.assertEqual(outputs[1].dtype, tf.float32)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/transformer/translate.py b/models/official/nlp/transformer/translate.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f92504142e08918a972dff10c422a58fcfbbd04
--- /dev/null
+++ b/models/official/nlp/transformer/translate.py
@@ -0,0 +1,199 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Translate text or files using trained transformer model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+import numpy as np
+import tensorflow as tf
+
+from official.nlp.transformer.utils import tokenizer
+
+_EXTRA_DECODE_LENGTH = 100
+_BEAM_SIZE = 4
+_ALPHA = 0.6
+
+
+def _get_sorted_inputs(filename):
+ """Read and sort lines from the file sorted by decreasing length.
+
+ Args:
+ filename: String name of file to read inputs from.
+ Returns:
+ Sorted list of inputs, and dictionary mapping original index->sorted index
+ of each element.
+ """
+ with tf.io.gfile.GFile(filename) as f:
+ records = f.read().split("\n")
+ inputs = [record.strip() for record in records]
+ if not inputs[-1]:
+ inputs.pop()
+
+ input_lens = [(i, len(line.split())) for i, line in enumerate(inputs)]
+ sorted_input_lens = sorted(input_lens, key=lambda x: x[1], reverse=True)
+
+ sorted_inputs = [None] * len(sorted_input_lens)
+ sorted_keys = [0] * len(sorted_input_lens)
+ for i, (index, _) in enumerate(sorted_input_lens):
+ sorted_inputs[i] = inputs[index]
+ sorted_keys[index] = i
+ return sorted_inputs, sorted_keys
+
+
+def _encode_and_add_eos(line, subtokenizer):
+ """Encode line with subtokenizer, and add EOS id to the end."""
+ return subtokenizer.encode(line) + [tokenizer.EOS_ID]
+
+
+def _trim_and_decode(ids, subtokenizer):
+ """Trim EOS and PAD tokens from ids, and decode to return a string."""
+ try:
+ index = list(ids).index(tokenizer.EOS_ID)
+ return subtokenizer.decode(ids[:index])
+ except ValueError: # No EOS found in sequence
+ return subtokenizer.decode(ids)
+
+
+def translate_file(model,
+ params,
+ subtokenizer,
+ input_file,
+ output_file=None,
+ print_all_translations=True,
+ distribution_strategy=None):
+ """Translate lines in file, and save to output file if specified.
+
+ Args:
+ model: A Keras model, used to generate the translations.
+ params: A dictionary, containing the translation related parameters.
+ subtokenizer: A subtokenizer object, used for encoding and decoding source
+ and translated lines.
+ input_file: A file containing lines to translate.
+ output_file: A file that stores the generated translations.
+ print_all_translations: A bool. If true, all translations are printed to
+ stdout.
+ distribution_strategy: A distribution strategy, used to perform inference
+ directly with tf.function instead of Keras model.predict().
+
+ Raises:
+ ValueError: if output file is invalid.
+ """
+ batch_size = params["decode_batch_size"]
+
+ # Read and sort inputs by length. Keep dictionary (original index-->new index
+ # in sorted list) to write translations in the original order.
+ sorted_inputs, sorted_keys = _get_sorted_inputs(input_file)
+ total_samples = len(sorted_inputs)
+ num_decode_batches = (total_samples - 1) // batch_size + 1
+
+ def input_generator():
+ """Yield encoded strings from sorted_inputs."""
+ for i in range(num_decode_batches):
+ lines = [
+ sorted_inputs[j + i * batch_size]
+ for j in range(batch_size)
+ if j + i * batch_size < total_samples
+ ]
+ lines = [_encode_and_add_eos(l, subtokenizer) for l in lines]
+ if distribution_strategy:
+ for j in range(batch_size - len(lines)):
+ lines.append([tokenizer.EOS_ID])
+ batch = tf.keras.preprocessing.sequence.pad_sequences(
+ lines,
+ maxlen=params["decode_max_length"],
+ dtype="int32",
+ padding="post")
+ logging.info("Decoding batch %d out of %d.", i, num_decode_batches)
+ yield batch
+
+ @tf.function
+ def predict_step(inputs):
+ """Decoding step function for TPU runs."""
+
+ def _step_fn(inputs):
+ """Per replica step function."""
+ tag = inputs[0]
+ val_inputs = inputs[1]
+ val_outputs, _ = model([val_inputs], training=False)
+ return tag, val_outputs
+
+ return distribution_strategy.run(_step_fn, args=(inputs,))
+
+ translations = []
+ if distribution_strategy:
+ num_replicas = distribution_strategy.num_replicas_in_sync
+ local_batch_size = params["decode_batch_size"] // num_replicas
+ for i, text in enumerate(input_generator()):
+ if distribution_strategy:
+ text = np.reshape(text, [num_replicas, local_batch_size, -1])
+ # Add tag to the input of each replica with the reordering logic after
+ # outputs, to ensure the output order matches the input order.
+ text = tf.constant(text)
+
+ @tf.function
+ def text_as_per_replica():
+ replica_context = tf.distribute.get_replica_context()
+ replica_id = replica_context.replica_id_in_sync_group
+ return replica_id, text[replica_id]
+
+ text = distribution_strategy.run(text_as_per_replica)
+ outputs = distribution_strategy.experimental_local_results(
+ predict_step(text))
+ tags, unordered_val_outputs = outputs[0]
+ tags = [tag.numpy() for tag in tags._values]
+ unordered_val_outputs = [
+ val_output.numpy() for val_output in unordered_val_outputs._values]
+ # pylint: enable=protected-access
+ val_outputs = [None] * len(tags)
+ for k in range(len(tags)):
+ val_outputs[tags[k]] = unordered_val_outputs[k]
+ val_outputs = np.reshape(val_outputs, [params["decode_batch_size"], -1])
+ else:
+ val_outputs, _ = model.predict(text)
+
+ length = len(val_outputs)
+ for j in range(length):
+ if j + i * batch_size < total_samples:
+ translation = _trim_and_decode(val_outputs[j], subtokenizer)
+ translations.append(translation)
+ if print_all_translations:
+ logging.info("Translating:\n\tInput: %s\n\tOutput: %s",
+ sorted_inputs[j + i * batch_size], translation)
+
+ # Write translations in the order they appeared in the original file.
+ if output_file is not None:
+ if tf.io.gfile.isdir(output_file):
+ raise ValueError("File output is a directory, will not save outputs to "
+ "file.")
+ logging.info("Writing to file %s", output_file)
+ with tf.compat.v1.gfile.Open(output_file, "w") as f:
+ for i in sorted_keys:
+ f.write("%s\n" % translations[i])
+
+
+def translate_from_text(model, subtokenizer, txt):
+ encoded_txt = _encode_and_add_eos(txt, subtokenizer)
+ result = model.predict(encoded_txt)
+ outputs = result["outputs"]
+ logging.info("Original: \"%s\"", txt)
+ translate_from_input(outputs, subtokenizer)
+
+
+def translate_from_input(outputs, subtokenizer):
+ translation = _trim_and_decode(outputs, subtokenizer)
+ logging.info("Translation: \"%s\"", translation)
diff --git a/models/official/nlp/transformer/utils/__init__.py b/models/official/nlp/transformer/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/nlp/transformer/utils/metrics.py b/models/official/nlp/transformer/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..7900cf807768f81af7a8afeee1f467074b04189f
--- /dev/null
+++ b/models/official/nlp/transformer/utils/metrics.py
@@ -0,0 +1,490 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions for calculating loss, accuracy, and other model metrics.
+
+Metrics:
+ - Padded loss, accuracy, and negative log perplexity. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/metrics.py
+ - BLEU approximation. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
+ - ROUGE score. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import math
+
+import numpy as np
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow.compat.v1 as tf
+
+
+def _pad_tensors_to_same_length(x, y):
+ """Pad x and y so that the results have the same length (second dimension)."""
+ with tf.name_scope("pad_to_same_length"):
+ x_length = tf.shape(x)[1]
+ y_length = tf.shape(y)[1]
+
+ max_length = tf.maximum(x_length, y_length)
+
+ x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
+ y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
+ return x, y
+
+
+def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
+ """Calculate cross entropy loss while ignoring padding.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch_size, length_labels]
+ smoothing: Label smoothing constant, used to determine the on and off values
+ vocab_size: int size of the vocabulary
+ Returns:
+ Returns the cross entropy loss and weight tensors: float32 tensors with
+ shape [batch_size, max(length_logits, length_labels)]
+ """
+ with tf.name_scope("loss", values=[logits, labels]):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+
+ # Calculate smoothing cross entropy
+ with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
+ confidence = 1.0 - smoothing
+ low_confidence = (1.0 - confidence) / tf.to_float(vocab_size - 1)
+ soft_targets = tf.one_hot(
+ tf.cast(labels, tf.int32),
+ depth=vocab_size,
+ on_value=confidence,
+ off_value=low_confidence)
+ xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
+ logits=logits, labels=soft_targets)
+
+ # Calculate the best (lowest) possible value of cross entropy, and
+ # subtract from the cross entropy loss.
+ normalizing_constant = -(
+ confidence * tf.log(confidence) + tf.to_float(vocab_size - 1) *
+ low_confidence * tf.log(low_confidence + 1e-20))
+ xentropy -= normalizing_constant
+
+ weights = tf.to_float(tf.not_equal(labels, 0))
+ return xentropy * weights, weights
+
+
+def _convert_to_eval_metric(metric_fn):
+ """Wrap a metric fn that returns scores and weights as an eval metric fn.
+
+ The input metric_fn returns values for the current batch. The wrapper
+ aggregates the return values collected over all of the batches evaluated.
+
+ Args:
+ metric_fn: function that returns scores and weights for the current batch's
+ logits and predicted labels.
+
+ Returns:
+ function that aggregates the scores and weights from metric_fn.
+ """
+ def problem_metric_fn(*args):
+ """Returns an aggregation of the metric_fn's returned values."""
+ (scores, weights) = metric_fn(*args)
+
+ # The tf.metrics.mean function assures correct aggregation.
+ return tf.metrics.mean(scores, weights)
+ return problem_metric_fn
+
+
+def get_eval_metrics(logits, labels, params):
+ """Return dictionary of model evaluation metrics."""
+ metrics = {
+ "accuracy": _convert_to_eval_metric(padded_accuracy)(logits, labels),
+ "accuracy_top5": _convert_to_eval_metric(padded_accuracy_top5)(
+ logits, labels),
+ "accuracy_per_sequence": _convert_to_eval_metric(
+ padded_sequence_accuracy)(logits, labels),
+ "neg_log_perplexity": _convert_to_eval_metric(padded_neg_log_perplexity)(
+ logits, labels, params["vocab_size"]),
+ }
+
+ if not params["use_tpu"]:
+ # TPU does not support tf.py_func
+ metrics.update({
+ "approx_bleu_score": _convert_to_eval_metric(
+ bleu_score)(logits, labels),
+ "rouge_2_fscore": _convert_to_eval_metric(
+ rouge_2_fscore)(logits, labels),
+ "rouge_L_fscore": _convert_to_eval_metric(
+ rouge_l_fscore)(logits, labels),
+ })
+
+ # Prefix each of the metric names with "metrics/". This allows the metric
+ # graphs to display under the "metrics" category in TensorBoard.
+ metrics = {"metrics/%s" % k: v for k, v in six.iteritems(metrics)}
+ return metrics
+
+
+def padded_accuracy(logits, labels):
+ """Percentage of times that predictions matches labels on non-0s."""
+ with tf.variable_scope("padded_accuracy", values=[logits, labels]):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.to_float(tf.not_equal(labels, 0))
+ outputs = tf.to_int32(tf.argmax(logits, axis=-1))
+ padded_labels = tf.to_int32(labels)
+ return tf.to_float(tf.equal(outputs, padded_labels)), weights
+
+
+def padded_accuracy_topk(logits, labels, k):
+ """Percentage of times that top-k predictions matches labels on non-0s."""
+ with tf.variable_scope("padded_accuracy_topk", values=[logits, labels]):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.to_float(tf.not_equal(labels, 0))
+ effective_k = tf.minimum(k, tf.shape(logits)[-1])
+ _, outputs = tf.nn.top_k(logits, k=effective_k)
+ outputs = tf.to_int32(outputs)
+ padded_labels = tf.to_int32(labels)
+ padded_labels = tf.expand_dims(padded_labels, axis=-1)
+ padded_labels += tf.zeros_like(outputs) # Pad to same shape.
+ same = tf.to_float(tf.equal(outputs, padded_labels))
+ same_topk = tf.reduce_sum(same, axis=-1)
+ return same_topk, weights
+
+
+def padded_accuracy_top5(logits, labels):
+ return padded_accuracy_topk(logits, labels, 5)
+
+
+def padded_sequence_accuracy(logits, labels):
+ """Percentage of times that predictions matches labels everywhere (non-0)."""
+ with tf.variable_scope("padded_sequence_accuracy", values=[logits, labels]):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.to_float(tf.not_equal(labels, 0))
+ outputs = tf.to_int32(tf.argmax(logits, axis=-1))
+ padded_labels = tf.to_int32(labels)
+ not_correct = tf.to_float(tf.not_equal(outputs, padded_labels)) * weights
+ axis = list(range(1, len(outputs.get_shape())))
+ correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
+ return correct_seq, tf.constant(1.0)
+
+
+def padded_neg_log_perplexity(logits, labels, vocab_size):
+ """Average log-perplexity excluding padding 0s. No smoothing."""
+ num, den = padded_cross_entropy_loss(logits, labels, 0, vocab_size)
+ return -num, den
+
+
+def bleu_score(logits, labels):
+ """Approximate BLEU score computation between labels and predictions.
+
+ An approximate BLEU scoring method since we do not glue word pieces or
+ decode the ids and tokenize the output. By default, we use ngram order of 4
+ and use brevity penalty. Also, this does not have beam search.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch-size, length_labels]
+
+ Returns:
+ bleu: int, approx bleu score
+ """
+ predictions = tf.to_int32(tf.argmax(logits, axis=-1))
+ # TODO: Look into removing use of py_func
+ bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32)
+ return bleu, tf.constant(1.0)
+
+
+def _get_ngrams_with_counter(segment, max_order):
+ """Extracts all n-grams up to a given maximum order from an input segment.
+
+ Args:
+ segment: text segment from which n-grams will be extracted.
+ max_order: maximum length in tokens of the n-grams returned by this
+ methods.
+
+ Returns:
+ The Counter containing all n-grams upto max_order in segment
+ with a count of how many times each n-gram occurred.
+ """
+ ngram_counts = collections.Counter()
+ for order in xrange(1, max_order + 1):
+ for i in xrange(0, len(segment) - order + 1):
+ ngram = tuple(segment[i:i + order])
+ ngram_counts[ngram] += 1
+ return ngram_counts
+
+
+def compute_bleu(reference_corpus, translation_corpus, max_order=4,
+ use_bp=True):
+ """Computes BLEU score of translated segments against one or more references.
+
+ Args:
+ reference_corpus: list of references for each translation. Each
+ reference should be tokenized into a list of tokens.
+ translation_corpus: list of translations to score. Each translation
+ should be tokenized into a list of tokens.
+ max_order: Maximum n-gram order to use when computing BLEU score.
+ use_bp: boolean, whether to apply brevity penalty.
+
+ Returns:
+ BLEU score.
+ """
+ reference_length = 0
+ translation_length = 0
+ bp = 1.0
+ geo_mean = 0
+
+ matches_by_order = [0] * max_order
+ possible_matches_by_order = [0] * max_order
+ precisions = []
+
+ for (references, translations) in zip(reference_corpus, translation_corpus):
+ reference_length += len(references)
+ translation_length += len(translations)
+ ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
+ translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
+
+ overlap = dict((ngram,
+ min(count, translation_ngram_counts[ngram]))
+ for ngram, count in ref_ngram_counts.items())
+
+ for ngram in overlap:
+ matches_by_order[len(ngram) - 1] += overlap[ngram]
+ for ngram in translation_ngram_counts:
+ possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[
+ ngram]
+
+ precisions = [0] * max_order
+ smooth = 1.0
+
+ for i in xrange(0, max_order):
+ if possible_matches_by_order[i] > 0:
+ precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
+ if matches_by_order[i] > 0:
+ precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[
+ i]
+ else:
+ smooth *= 2
+ precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
+ else:
+ precisions[i] = 0.0
+
+ if max(precisions) > 0:
+ p_log_sum = sum(math.log(p) for p in precisions if p)
+ geo_mean = math.exp(p_log_sum / max_order)
+
+ if use_bp:
+ ratio = translation_length / reference_length
+ bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
+ bleu = geo_mean * bp
+ return np.float32(bleu)
+
+
+def rouge_2_fscore(logits, labels):
+ """ROUGE-2 F1 score computation between labels and predictions.
+
+ This is an approximate ROUGE scoring method since we do not glue word pieces
+ or decode the ids and tokenize the output.
+
+ Args:
+ logits: tensor, model predictions
+ labels: tensor, gold output.
+
+ Returns:
+ rouge2_fscore: approx rouge-2 f1 score.
+ """
+ predictions = tf.to_int32(tf.argmax(logits, axis=-1))
+ # TODO: Look into removing use of py_func
+ rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32)
+ return rouge_2_f_score, tf.constant(1.0)
+
+
+def _get_ngrams(n, text):
+ """Calculates n-grams.
+
+ Args:
+ n: which n-grams to calculate
+ text: An array of tokens
+
+ Returns:
+ A set of n-grams
+ """
+ ngram_set = set()
+ text_length = len(text)
+ max_index_ngram_start = text_length - n
+ for i in range(max_index_ngram_start + 1):
+ ngram_set.add(tuple(text[i:i + n]))
+ return ngram_set
+
+
+def rouge_n(eval_sentences, ref_sentences, n=2):
+ """Computes ROUGE-N f1 score of two text collections of sentences.
+
+ Source: https://www.microsoft.com/en-us/research/publication/
+ rouge-a-package-for-automatic-evaluation-of-summaries/
+
+ Args:
+ eval_sentences: Predicted sentences.
+ ref_sentences: Sentences from the reference set
+ n: Size of ngram. Defaults to 2.
+
+ Returns:
+ f1 score for ROUGE-N
+ """
+ f1_scores = []
+ for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
+ eval_ngrams = _get_ngrams(n, eval_sentence)
+ ref_ngrams = _get_ngrams(n, ref_sentence)
+ ref_count = len(ref_ngrams)
+ eval_count = len(eval_ngrams)
+
+ # Count the overlapping ngrams between evaluated and reference
+ overlapping_ngrams = eval_ngrams.intersection(ref_ngrams)
+ overlapping_count = len(overlapping_ngrams)
+
+ # Handle edge case. This isn't mathematically correct, but it's good enough
+ if eval_count == 0:
+ precision = 0.0
+ else:
+ precision = float(overlapping_count) / eval_count
+ if ref_count == 0:
+ recall = 0.0
+ else:
+ recall = float(overlapping_count) / ref_count
+ f1_scores.append(2.0 * ((precision * recall) / (precision + recall + 1e-8)))
+
+ # return overlapping_count / reference_count
+ return np.mean(f1_scores, dtype=np.float32)
+
+
+def rouge_l_fscore(predictions, labels):
+ """ROUGE scores computation between labels and predictions.
+
+ This is an approximate ROUGE scoring method since we do not glue word pieces
+ or decode the ids and tokenize the output.
+
+ Args:
+ predictions: tensor, model predictions
+ labels: tensor, gold output.
+
+ Returns:
+ rouge_l_fscore: approx rouge-l f1 score.
+ """
+ outputs = tf.to_int32(tf.argmax(predictions, axis=-1))
+ rouge_l_f_score = tf.py_func(rouge_l_sentence_level, (outputs, labels),
+ tf.float32)
+ return rouge_l_f_score, tf.constant(1.0)
+
+
+def rouge_l_sentence_level(eval_sentences, ref_sentences):
+ """Computes ROUGE-L (sentence level) of two collections of sentences.
+
+ Source: https://www.microsoft.com/en-us/research/publication/
+ rouge-a-package-for-automatic-evaluation-of-summaries/
+
+ Calculated according to:
+ R_lcs = LCS(X,Y)/m
+ P_lcs = LCS(X,Y)/n
+ F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
+
+ where:
+ X = reference summary
+ Y = Candidate summary
+ m = length of reference summary
+ n = length of candidate summary
+
+ Args:
+ eval_sentences: The sentences that have been picked by the summarizer
+ ref_sentences: The sentences from the reference set
+
+ Returns:
+ A float: F_lcs
+ """
+
+ f1_scores = []
+ for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
+ m = float(len(ref_sentence))
+ n = float(len(eval_sentence))
+ lcs = _len_lcs(eval_sentence, ref_sentence)
+ f1_scores.append(_f_lcs(lcs, m, n))
+ return np.mean(f1_scores, dtype=np.float32)
+
+
+def _len_lcs(x, y):
+ """Returns the length of the Longest Common Subsequence between two seqs.
+
+ Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
+
+ Args:
+ x: sequence of words
+ y: sequence of words
+
+ Returns
+ integer: Length of LCS between x and y
+ """
+ table = _lcs(x, y)
+ n, m = len(x), len(y)
+ return table[n, m]
+
+
+def _lcs(x, y):
+ """Computes the length of the LCS between two seqs.
+
+ The implementation below uses a DP programming algorithm and runs
+ in O(nm) time where n = len(x) and m = len(y).
+ Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
+
+ Args:
+ x: collection of words
+ y: collection of words
+
+ Returns:
+ Table of dictionary of coord and len lcs
+ """
+ n, m = len(x), len(y)
+ table = dict()
+ for i in range(n + 1):
+ for j in range(m + 1):
+ if i == 0 or j == 0:
+ table[i, j] = 0
+ elif x[i - 1] == y[j - 1]:
+ table[i, j] = table[i - 1, j - 1] + 1
+ else:
+ table[i, j] = max(table[i - 1, j], table[i, j - 1])
+ return table
+
+
+def _f_lcs(llcs, m, n):
+ """Computes the LCS-based F-measure score.
+
+ Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/
+ rouge-working-note-v1.3.1.pdf
+
+ Args:
+ llcs: Length of LCS
+ m: number of words in reference summary
+ n: number of words in candidate summary
+
+ Returns:
+ Float. LCS-based F-measure score
+ """
+ r_lcs = llcs / m
+ p_lcs = llcs / n
+ beta = p_lcs / (r_lcs + 1e-12)
+ num = (1 + (beta ** 2)) * r_lcs * p_lcs
+ denom = r_lcs + ((beta ** 2) * p_lcs)
+ f_lcs = num / (denom + 1e-12)
+ return f_lcs
diff --git a/models/official/nlp/transformer/utils/tokenizer.py b/models/official/nlp/transformer/utils/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3749dfe9de6263a4cc185928b7f8967c56250216
--- /dev/null
+++ b/models/official/nlp/transformer/utils/tokenizer.py
@@ -0,0 +1,660 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines Subtokenizer class to encode and decode strings."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+import sys
+import unicodedata
+from absl import logging
+
+import numpy as np
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+
+# pylint: disable=g-complex-comprehension
+PAD = ""
+PAD_ID = 0
+EOS = ""
+EOS_ID = 1
+RESERVED_TOKENS = [PAD, EOS]
+
+# Set of characters that will be used in the function _escape_token() (see func
+# docstring for more details).
+# This set is added to the alphabet list to ensure that all escaped tokens can
+# be encoded.
+_ESCAPE_CHARS = set(u"\\_u;0123456789")
+# Regex for the function _unescape_token(), the inverse of _escape_token().
+# This is used to find "\u", "\\", and "\###;" substrings in the token.
+_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
+
+_UNDEFINED_UNICODE = u"\u3013"
+
+
+def alphanumeric_char_set():
+ return set(
+ six.unichr(i)
+ for i in xrange(sys.maxunicode)
+ if (unicodedata.category(six.unichr(i)).startswith("L") or
+ unicodedata.category(six.unichr(i)).startswith("N")))
+
+
+# Set contains all letter and number characters.
+_ALPHANUMERIC_CHAR_SET = alphanumeric_char_set()
+
+# min_count is the minimum number of times a subtoken must appear in the data
+# before before it is added to the vocabulary. The value is found using binary
+# search to obtain the target vocabulary size.
+_MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count
+_MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count
+
+
+class Subtokenizer(object):
+ """Encodes and decodes strings to/from integer IDs."""
+
+ def __init__(self, vocab_file, reserved_tokens=None, master_char_set=None):
+ """Initializes class, creating a vocab file if data_files is provided."""
+ logging.info("Initializing Subtokenizer from file %s.", vocab_file)
+
+ if master_char_set is None:
+ master_char_set = _ALPHANUMERIC_CHAR_SET
+
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ self.subtoken_list = _load_vocab_file(vocab_file, reserved_tokens)
+ self.alphabet = _generate_alphabet_dict(self.subtoken_list)
+ self.subtoken_to_id_dict = _list_to_index_dict(self.subtoken_list)
+
+ self.max_subtoken_length = 0
+ for subtoken in self.subtoken_list:
+ self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken))
+
+ # Create cache to speed up subtokenization
+ self._cache_size = 2**20
+ self._cache = [(None, None)] * self._cache_size
+ self._master_char_set = master_char_set
+
+ @staticmethod
+ def init_from_files(vocab_file,
+ files,
+ target_vocab_size,
+ threshold,
+ min_count=None,
+ file_byte_limit=1e6,
+ reserved_tokens=None,
+ correct_strip=True,
+ master_char_set=None):
+ """Create subtoken vocabulary based on files, and save vocab to file.
+
+ Args:
+ vocab_file: String name of vocab file to store subtoken vocabulary.
+ files: List of file paths that will be used to generate vocabulary.
+ target_vocab_size: target vocabulary size to generate.
+ threshold: int threshold of vocabulary size to accept.
+ min_count: int minimum count to use for generating the vocabulary. The min
+ count is the minimum number of times a subtoken should appear in the
+ files before it is added to the vocabulary. If set to none, this value
+ is found using binary search.
+ file_byte_limit: (Default 1e6) Maximum number of bytes of sample text that
+ will be drawn from the files.
+ reserved_tokens: List of string tokens that are guaranteed to be at the
+ beginning of the subtoken vocabulary list.
+ correct_strip: Whether to convert text to unicode before strip.
+ master_char_set: the char set.
+
+ Returns:
+ Subtokenizer object
+ """
+ if master_char_set is None:
+ master_char_set = _ALPHANUMERIC_CHAR_SET
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ if tf.io.gfile.exists(vocab_file):
+ logging.info("Vocab file already exists (%s)", vocab_file)
+ else:
+ logging.info("Begin steps to create subtoken vocabulary...")
+ token_counts = _count_tokens(files, file_byte_limit, correct_strip,
+ master_char_set)
+ alphabet = _generate_alphabet_dict(token_counts)
+ subtoken_list = _generate_subtokens_with_target_vocab_size(
+ token_counts, alphabet, target_vocab_size, threshold, min_count,
+ reserved_tokens)
+ logging.info("Generated vocabulary with %d subtokens.",
+ len(subtoken_list))
+ _save_vocab_file(vocab_file, subtoken_list)
+ return Subtokenizer(vocab_file, master_char_set=master_char_set)
+
+ def encode(self, raw_string, add_eos=False):
+ """Encodes a string into a list of int subtoken ids."""
+ ret = []
+ tokens = _split_string_to_tokens(
+ native_to_unicode(raw_string), self._master_char_set)
+ for token in tokens:
+ ret.extend(self._token_to_subtoken_ids(token))
+ if add_eos:
+ assert EOS in self.subtoken_list, \
+ "Can't append 'EOS' because it is not in list of known subtokens."
+ ret.append(EOS_ID)
+ return ret
+
+ def _token_to_subtoken_ids(self, token):
+ """Encode a single token into a list of subtoken ids."""
+ cache_location = hash(token) % self._cache_size
+ cache_key, cache_value = self._cache[cache_location]
+ if cache_key == token:
+ return cache_value
+
+ ret = _split_token_to_subtokens(
+ _escape_token(token, self.alphabet), self.subtoken_to_id_dict,
+ self.max_subtoken_length)
+ ret = [self.subtoken_to_id_dict[subtoken_id] for subtoken_id in ret]
+
+ self._cache[cache_location] = (token, ret)
+ return ret
+
+ def decode(self, subtokens):
+ """Converts list of int subtokens ids into a string."""
+ if isinstance(subtokens, np.ndarray):
+ # Note that list(subtokens) converts subtokens to a python list, but the
+ # items remain as np.int32. This converts both the array and its items.
+ subtokens = subtokens.tolist()
+
+ if not subtokens:
+ return ""
+
+ assert isinstance(subtokens, list) and isinstance(subtokens[0], int), (
+ "Subtokens argument passed into decode() must be a list of integers.")
+
+ return _unicode_to_native(
+ _join_tokens_to_string(
+ self._subtoken_ids_to_tokens(subtokens), self._master_char_set))
+
+ def _subtoken_ids_to_tokens(self, subtokens):
+ """Convert list of int subtoken ids to a list of string tokens."""
+ escaped_tokens = "".join([
+ self.subtoken_list[s] for s in subtokens if s < len(self.subtoken_list)
+ ])
+ escaped_tokens = escaped_tokens.split("_")
+
+ # All tokens in the vocabulary list have been escaped (see _escape_token())
+ # so each token must be unescaped when decoding.
+ ret = []
+ for token in escaped_tokens:
+ if token:
+ ret.append(_unescape_token(token))
+ return ret
+
+
+def _save_vocab_file(vocab_file, subtoken_list):
+ """Save subtokens to file."""
+ with tf.io.gfile.GFile(vocab_file, mode="w") as f:
+ for subtoken in subtoken_list:
+ f.write("'%s'\n" % _unicode_to_native(subtoken))
+
+
+def _load_vocab_file(vocab_file, reserved_tokens=None):
+ """Load vocabulary while ensuring reserved tokens are at the top."""
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ subtoken_list = []
+ with tf.io.gfile.GFile(vocab_file, mode="r") as f:
+ for line in f:
+ subtoken = native_to_unicode(line.strip())
+ subtoken = subtoken[1:-1] # Remove surrounding single-quotes
+ if subtoken in reserved_tokens:
+ continue
+ subtoken_list.append(native_to_unicode(subtoken))
+ return reserved_tokens + subtoken_list
+
+
+def native_to_unicode(s):
+ """Convert string to unicode (required in Python 2)."""
+ try: # Python 2
+ return s if isinstance(s, unicode) else s.decode("utf-8")
+ except NameError: # Python 3
+ return s
+
+
+def _unicode_to_native(s):
+ """Convert string from unicode to native format (required in Python 2)."""
+ try: # Python 2
+ return s.encode("utf-8") if isinstance(s, unicode) else s
+ except NameError: # Python 3
+ return s
+
+
+def _split_string_to_tokens(text, master_char_set):
+ """Splits text to a list of string tokens."""
+ if not text:
+ return []
+ ret = []
+ token_start = 0
+ # Classify each character in the input string
+ is_master = [c in master_char_set for c in text]
+ for pos in xrange(1, len(text)):
+ if is_master[pos] != is_master[pos - 1]:
+ token = text[token_start:pos]
+ if token != u" " or token_start == 0:
+ ret.append(token)
+ token_start = pos
+ final_token = text[token_start:]
+ ret.append(final_token)
+ return ret
+
+
+def _join_tokens_to_string(tokens, master_char_set):
+ """Join a list of string tokens into a single string."""
+ token_is_master = [t[0] in master_char_set for t in tokens]
+ ret = []
+ for i, token in enumerate(tokens):
+ if i > 0 and token_is_master[i - 1] and token_is_master[i]:
+ ret.append(u" ")
+ ret.append(token)
+ return "".join(ret)
+
+
+def _escape_token(token, alphabet):
+ r"""Replace characters that aren't in the alphabet and append "_" to token.
+
+ Apply three transformations to the token:
+ 1. Replace underline character "_" with "\u", and backslash "\" with "\\".
+ 2. Replace characters outside of the alphabet with "\###;", where ### is the
+ character's Unicode code point.
+ 3. Appends "_" to mark the end of a token.
+
+ Args:
+ token: unicode string to be escaped
+ alphabet: list of all known characters
+
+ Returns:
+ escaped string
+ """
+ token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u")
+ ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token]
+ return u"".join(ret) + "_"
+
+
+def _unescape_token(token):
+ r"""Replaces escaped characters in the token with their unescaped versions.
+
+ Applies inverse transformations as _escape_token():
+ 1. Replace "\u" with "_", and "\\" with "\".
+ 2. Replace "\###;" with the unicode character the ### refers to.
+
+ Args:
+ token: escaped string
+
+ Returns:
+ unescaped string
+ """
+
+ def match(m):
+ r"""Returns replacement string for matched object.
+
+ Matched objects contain one of the strings that matches the regex pattern:
+ r"\\u|\\\\|\\([0-9]+);"
+ The strings can be '\u', '\\', or '\###;' (### is any digit number).
+
+ m.group(0) refers to the entire matched string ('\u', '\\', or '\###;').
+ m.group(1) refers to the first parenthesized subgroup ('###').
+
+ m.group(0) exists for all match objects, while m.group(1) exists only for
+ the string '\###;'.
+
+ This function looks to see if m.group(1) exists. If it doesn't, then the
+ matched string must be '\u' or '\\' . In this case, the corresponding
+ replacement ('_' and '\') are returned. Note that in python, a single
+ backslash is written as '\\', and double backslash as '\\\\'.
+
+ If m.goup(1) exists, then use the integer in m.group(1) to return a
+ unicode character.
+
+ Args:
+ m: match object
+
+ Returns:
+ String to replace matched object with.
+ """
+ # Check if the matched strings are '\u' or '\\'.
+ if m.group(1) is None:
+ return u"_" if m.group(0) == u"\\u" else u"\\"
+
+ # If m.group(1) exists, try and return unicode character.
+ try:
+ return six.unichr(int(m.group(1)))
+ except (ValueError, OverflowError) as _:
+ return _UNDEFINED_UNICODE
+
+ # Use match function to replace escaped substrings in the token.
+ return _UNESCAPE_REGEX.sub(match, token)
+
+
+def _count_tokens(files,
+ file_byte_limit=1e6,
+ correct_strip=True,
+ master_char_set=None):
+ """Return token counts of words in the files.
+
+ Samples file_byte_limit bytes from each file, and counts the words that appear
+ in the samples. The samples are semi-evenly distributed across the file.
+
+ Args:
+ files: List of filepaths
+ file_byte_limit: Max number of bytes that will be read from each file.
+ correct_strip: Whether to convert text to unicode before strip. This affects
+ vocabulary generation for PY2. Sets correct_strip to False in PY2 to
+ reproduce previous common public result. Sets correct_strip to True will
+ let PY2 and PY3 get a consistent vocabulary.
+ master_char_set: the char set.
+
+ Returns:
+ Dictionary mapping tokens to the number of times they appear in the sampled
+ lines from the files.
+ """
+ if master_char_set is None:
+ master_char_set = _ALPHANUMERIC_CHAR_SET
+
+ token_counts = collections.defaultdict(int)
+
+ for filepath in files:
+ with tf.io.gfile.GFile(filepath, mode="r") as reader:
+ file_byte_budget = file_byte_limit
+ counter = 0
+ lines_to_skip = int(reader.size() / (file_byte_budget * 2))
+ for line in reader:
+ if counter < lines_to_skip:
+ counter += 1
+ else:
+ if file_byte_budget < 0:
+ break
+ if correct_strip:
+ line = native_to_unicode(line)
+ line = line.strip()
+ file_byte_budget -= len(line)
+ counter = 0
+
+ # Add words to token counts
+ for token in _split_string_to_tokens(
+ native_to_unicode(line), master_char_set):
+ token_counts[token] += 1
+ return token_counts
+
+
+def _list_to_index_dict(lst):
+ """Create dictionary mapping list items to their indices in the list."""
+ return {item: n for n, item in enumerate(lst)}
+
+
+def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
+ """Splits a token into subtokens defined in the subtoken dict."""
+ ret = []
+ start = 0
+ token_len = len(token)
+ while start < token_len:
+ # Find the longest subtoken, so iterate backwards.
+ for end in xrange(min(token_len, start + max_subtoken_length), start, -1):
+ subtoken = token[start:end]
+ if subtoken in subtoken_dict:
+ ret.append(subtoken)
+ start = end
+ break
+ else: # Did not break
+ # If there is no possible encoding of the escaped token then one of the
+ # characters in the token is not in the alphabet. This should be
+ # impossible and would be indicative of a bug.
+ raise ValueError("Was unable to split token \"%s\" into subtokens." %
+ token)
+ return ret
+
+
+def _generate_subtokens_with_target_vocab_size(token_counts,
+ alphabet,
+ target_size,
+ threshold,
+ min_count=None,
+ reserved_tokens=None):
+ """Generate subtoken vocabulary close to the target size."""
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ if min_count is not None:
+ logging.info("Using min_count=%d to generate vocab with target size %d",
+ min_count, target_size)
+ return _generate_subtokens(
+ token_counts, alphabet, min_count, reserved_tokens=reserved_tokens)
+
+ def bisect(min_val, max_val):
+ """Recursive function to binary search for subtoken vocabulary."""
+ cur_count = (min_val + max_val) // 2
+ logging.info("Binary search: trying min_count=%d (%d %d)", cur_count,
+ min_val, max_val)
+ subtoken_list = _generate_subtokens(
+ token_counts, alphabet, cur_count, reserved_tokens=reserved_tokens)
+
+ val = len(subtoken_list)
+ logging.info("Binary search: min_count=%d resulted in %d tokens", cur_count,
+ val)
+
+ within_threshold = abs(val - target_size) < threshold
+ if within_threshold or min_val >= max_val or cur_count < 2:
+ return subtoken_list
+ if val > target_size:
+ other_subtoken_list = bisect(cur_count + 1, max_val)
+ else:
+ other_subtoken_list = bisect(min_val, cur_count - 1)
+
+ # Return vocabulary dictionary with the closest number of tokens.
+ other_val = len(other_subtoken_list)
+ if abs(other_val - target_size) < abs(val - target_size):
+ return other_subtoken_list
+ return subtoken_list
+
+ logging.info("Finding best min_count to get target size of %d", target_size)
+ return bisect(_MIN_MIN_COUNT, _MAX_MIN_COUNT)
+
+
+def _generate_alphabet_dict(iterable, reserved_tokens=None):
+ """Create set of characters that appear in any element in the iterable."""
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+ alphabet = {c for token in iterable for c in token}
+ alphabet |= {c for token in reserved_tokens for c in token}
+ alphabet |= _ESCAPE_CHARS # Add escape characters to alphabet set.
+ return alphabet
+
+
+def _count_and_gen_subtokens(token_counts, alphabet, subtoken_dict,
+ max_subtoken_length):
+ """Count number of times subtokens appear, and generate new subtokens.
+
+ Args:
+ token_counts: dict mapping tokens to the number of times they appear in the
+ original files.
+ alphabet: list of allowed characters. Used to escape the tokens, which
+ guarantees that all tokens can be split into subtokens.
+ subtoken_dict: dict mapping subtokens to ids.
+ max_subtoken_length: maximum length of subtoken in subtoken_dict.
+
+ Returns:
+ A defaultdict mapping subtokens to the number of times they appear in the
+ tokens. The dict may contain new subtokens.
+ """
+ subtoken_counts = collections.defaultdict(int)
+ for token, count in six.iteritems(token_counts):
+ token = _escape_token(token, alphabet)
+ subtokens = _split_token_to_subtokens(token, subtoken_dict,
+ max_subtoken_length)
+
+ # Generate new subtokens by taking substrings from token.
+ start = 0
+ for subtoken in subtokens:
+ for end in xrange(start + 1, len(token) + 1):
+ new_subtoken = token[start:end]
+ subtoken_counts[new_subtoken] += count
+ start += len(subtoken)
+
+ return subtoken_counts
+
+
+def _filter_and_bucket_subtokens(subtoken_counts, min_count):
+ """Return a bucketed list of subtokens that are filtered by count.
+
+ Args:
+ subtoken_counts: defaultdict mapping subtokens to their counts
+ min_count: int count used to filter subtokens
+
+ Returns:
+ List of subtoken sets, where subtokens in set i have the same length=i.
+ """
+ # Create list of buckets, where subtokens in bucket i have length i.
+ subtoken_buckets = []
+ for subtoken, count in six.iteritems(subtoken_counts):
+ if count < min_count: # Filter out subtokens that don't appear enough
+ continue
+ while len(subtoken_buckets) <= len(subtoken):
+ subtoken_buckets.append(set())
+ subtoken_buckets[len(subtoken)].add(subtoken)
+ return subtoken_buckets
+
+
+def _gen_new_subtoken_list(subtoken_counts,
+ min_count,
+ alphabet,
+ reserved_tokens=None):
+ """Generate candidate subtokens ordered by count, and new max subtoken length.
+
+ Add subtokens to the candiate list in order of length (longest subtokens
+ first). When a subtoken is added, the counts of each of its prefixes are
+ decreased. Prefixes that don't appear much outside the subtoken are not added
+ to the candidate list.
+
+ For example:
+ subtoken being added to candidate list: 'translate'
+ subtoken_counts: {'translate':10, 't':40, 'tr':16, 'tra':12, ...}
+ min_count: 5
+
+ When 'translate' is added, subtoken_counts is updated to:
+ {'translate':0, 't':30, 'tr':6, 'tra': 2, ...}
+
+ The subtoken 'tra' will not be added to the candidate list, because it appears
+ twice (less than min_count) outside of 'translate'.
+
+ Args:
+ subtoken_counts: defaultdict mapping str subtokens to int counts
+ min_count: int minumum count requirement for subtokens
+ alphabet: set of characters. Each character is added to the subtoken list to
+ guarantee that all tokens can be encoded.
+ reserved_tokens: list of tokens that will be added to the beginning of the
+ returned subtoken list.
+
+ Returns:
+ List of candidate subtokens in decreasing count order, and maximum subtoken
+ length
+ """
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ # Create a list of (count, subtoken) for each candidate subtoken.
+ subtoken_candidates = []
+
+ # Use bucketted list to iterate through subtokens in order of length.
+ # subtoken_buckets[i] = set(subtokens), where each subtoken has length i.
+ subtoken_buckets = _filter_and_bucket_subtokens(subtoken_counts, min_count)
+ max_subtoken_length = len(subtoken_buckets) - 1
+
+ # Go through the list in reverse order to consider longer subtokens first.
+ for subtoken_len in xrange(max_subtoken_length, 0, -1):
+ for subtoken in subtoken_buckets[subtoken_len]:
+ count = subtoken_counts[subtoken]
+
+ # Possible if this subtoken is a prefix of another token.
+ if count < min_count:
+ continue
+
+ # Ignore alphabet/reserved tokens, which will be added manually later.
+ if subtoken not in alphabet and subtoken not in reserved_tokens:
+ subtoken_candidates.append((count, subtoken))
+
+ # Decrement count of the subtoken's prefixes (if a longer subtoken is
+ # added, its prefixes lose priority to be added).
+ for end in xrange(1, subtoken_len):
+ subtoken_counts[subtoken[:end]] -= count
+
+ # Add alphabet subtokens (guarantees that all strings are encodable).
+ subtoken_candidates.extend((subtoken_counts.get(a, 0), a) for a in alphabet)
+
+ # Order subtoken candidates by decreasing count.
+ subtoken_list = [t for _, t in sorted(subtoken_candidates, reverse=True)]
+
+ # Add reserved tokens to beginning of the list.
+ subtoken_list = reserved_tokens + subtoken_list
+ return subtoken_list, max_subtoken_length
+
+
+def _generate_subtokens(token_counts,
+ alphabet,
+ min_count,
+ num_iterations=4,
+ reserved_tokens=None):
+ """Create a list of subtokens in decreasing order of frequency.
+
+ Args:
+ token_counts: dict mapping str tokens -> int count
+ alphabet: set of characters
+ min_count: int minimum number of times a subtoken must appear before it is
+ added to the vocabulary.
+ num_iterations: int number of iterations to generate new tokens.
+ reserved_tokens: list of tokens that will be added to the beginning to the
+ returned subtoken list.
+
+ Returns:
+ Sorted list of subtokens (most frequent first)
+ """
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ # Use alphabet set to create initial list of subtokens
+ subtoken_list = reserved_tokens + list(alphabet)
+ max_subtoken_length = 1
+
+ # On each iteration, segment all words using the subtokens defined in
+ # subtoken_dict, count how often the resulting subtokens appear, and update
+ # the dictionary with subtokens w/ high enough counts.
+ for i in xrange(num_iterations):
+ logging.info("\tGenerating subtokens: iteration %d", i)
+ # Generate new subtoken->id dictionary using the new subtoken list.
+ subtoken_dict = _list_to_index_dict(subtoken_list)
+
+ # Create dict mapping subtoken->count, with additional subtokens created
+ # from substrings taken from the tokens.
+ subtoken_counts = _count_and_gen_subtokens(token_counts, alphabet,
+ subtoken_dict,
+ max_subtoken_length)
+
+ # Generate new list of subtokens sorted by subtoken count.
+ subtoken_list, max_subtoken_length = _gen_new_subtoken_list(
+ subtoken_counts, min_count, alphabet, reserved_tokens)
+
+ logging.info("\tVocab size: %d", len(subtoken_list))
+ return subtoken_list
diff --git a/models/official/nlp/transformer/utils/tokenizer_test.py b/models/official/nlp/transformer/utils/tokenizer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..307398fd3aeaf55a5bec495006a1fb65ebadd639
--- /dev/null
+++ b/models/official/nlp/transformer/utils/tokenizer_test.py
@@ -0,0 +1,204 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test Subtokenizer and string helper methods."""
+
+import collections
+import tempfile
+
+import tensorflow as tf
+
+from official.nlp.transformer.utils import tokenizer
+
+
+class SubtokenizerTest(tf.test.TestCase):
+
+ def _init_subtokenizer(self, vocab_list):
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ with tf.io.gfile.GFile(temp_file.name, "w") as w:
+ for subtoken in vocab_list:
+ w.write("'%s'" % subtoken)
+ w.write("\n")
+ return tokenizer.Subtokenizer(temp_file.name, reserved_tokens=[])
+
+ def test_encode(self):
+ vocab_list = ["123_", "test", "ing_"]
+ subtokenizer = self._init_subtokenizer(vocab_list)
+ s = "testing 123"
+ encoded_list = subtokenizer.encode(s)
+ self.assertEqual([1, 2, 0], encoded_list)
+
+ def test_decode(self):
+ vocab_list = ["123_", "test", "ing_"]
+ subtokenizer = self._init_subtokenizer(vocab_list)
+ encoded_list = [1, 2, 0] # testing 123
+ decoded_str = subtokenizer.decode(encoded_list)
+ self.assertEqual("testing 123", decoded_str)
+
+ def test_subtoken_ids_to_tokens(self):
+ vocab_list = ["123_", "test", "ing_"]
+ subtokenizer = self._init_subtokenizer(vocab_list)
+ encoded_list = [1, 2, 0] # testing 123
+ token_list = subtokenizer._subtoken_ids_to_tokens(encoded_list)
+ self.assertEqual([u"testing", u"123"], token_list)
+
+
+class StringHelperTest(tf.test.TestCase):
+
+ def test_split_string_to_tokens(self):
+ text = "test? testing 123."
+
+ tokens = tokenizer._split_string_to_tokens(text,
+ tokenizer._ALPHANUMERIC_CHAR_SET)
+ self.assertEqual(["test", "? ", "testing", "123", "."], tokens)
+
+ def test_join_tokens_to_string(self):
+ tokens = ["test", "? ", "testing", "123", "."]
+
+ s = tokenizer._join_tokens_to_string(tokens,
+ tokenizer._ALPHANUMERIC_CHAR_SET)
+ self.assertEqual("test? testing 123.", s)
+
+ def test_escape_token(self):
+ token = u"abc_\\4"
+ alphabet = set("abc_\\u;")
+
+ escaped_token = tokenizer._escape_token(token, alphabet)
+ self.assertEqual("abc\\u\\\\\\52;_", escaped_token)
+
+ def test_unescape_token(self):
+ escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;"
+
+ unescaped_token = tokenizer._unescape_token(escaped_token)
+ self.assertEqual("Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
+
+ def test_list_to_index_dict(self):
+ lst = ["test", "strings"]
+
+ d = tokenizer._list_to_index_dict(lst)
+ self.assertDictEqual({"test": 0, "strings": 1}, d)
+
+ def test_split_token_to_subtokens(self):
+ token = "abc"
+ subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3}
+ max_subtoken_length = 2
+
+ subtokens = tokenizer._split_token_to_subtokens(token, subtoken_dict,
+ max_subtoken_length)
+ self.assertEqual(["ab", "c"], subtokens)
+
+ def test_generate_alphabet_dict(self):
+ s = ["testing", "123"]
+ reserved_tokens = ["???"]
+
+ alphabet = tokenizer._generate_alphabet_dict(s, reserved_tokens)
+ self.assertIn("?", alphabet)
+ self.assertIn("t", alphabet)
+ self.assertIn("e", alphabet)
+ self.assertIn("s", alphabet)
+ self.assertIn("i", alphabet)
+ self.assertIn("n", alphabet)
+ self.assertIn("g", alphabet)
+ self.assertIn("1", alphabet)
+ self.assertIn("2", alphabet)
+ self.assertIn("3", alphabet)
+
+ def test_count_and_gen_subtokens(self):
+ token_counts = {"abc": 5}
+ alphabet = set("abc_")
+ subtoken_dict = {"a": 0, "b": 1, "c": 2, "_": 3}
+ max_subtoken_length = 2
+
+ subtoken_counts = tokenizer._count_and_gen_subtokens(
+ token_counts, alphabet, subtoken_dict, max_subtoken_length)
+
+ self.assertIsInstance(subtoken_counts, collections.defaultdict)
+ self.assertDictEqual(
+ {
+ "a": 5,
+ "b": 5,
+ "c": 5,
+ "_": 5,
+ "ab": 5,
+ "bc": 5,
+ "c_": 5,
+ "abc": 5,
+ "bc_": 5,
+ "abc_": 5
+ }, subtoken_counts)
+
+ def test_filter_and_bucket_subtokens(self):
+ subtoken_counts = collections.defaultdict(int, {
+ "a": 2,
+ "b": 4,
+ "c": 1,
+ "ab": 6,
+ "ac": 3,
+ "abbc": 5
+ })
+ min_count = 3
+
+ subtoken_buckets = tokenizer._filter_and_bucket_subtokens(
+ subtoken_counts, min_count)
+
+ self.assertEqual(len(subtoken_buckets[0]), 0)
+ self.assertEqual(set("b"), subtoken_buckets[1])
+ self.assertEqual(set(["ab", "ac"]), subtoken_buckets[2])
+ self.assertEqual(len(subtoken_buckets[3]), 0)
+ self.assertEqual(set(["abbc"]), subtoken_buckets[4])
+
+ def test_gen_new_subtoken_list(self):
+ subtoken_counts = collections.defaultdict(int, {
+ "translate": 10,
+ "t": 40,
+ "tr": 16,
+ "tra": 12
+ })
+ min_count = 5
+ alphabet = set("translate")
+ reserved_tokens = ["reserved", "tokens"]
+
+ subtoken_list, max_token_length = tokenizer._gen_new_subtoken_list(
+ subtoken_counts, min_count, alphabet, reserved_tokens)
+
+ # Check that "tra" isn"t in the list (its count should be decremented to 2,
+ # so it should not be added to the canddiate list).
+ self.assertNotIn("tra", subtoken_list)
+
+ self.assertIn("tr", subtoken_list)
+ self.assertIn("t", subtoken_list)
+
+ self.assertEqual(len("translate"), max_token_length)
+
+ def test_generate_subtokens(self):
+ token_counts = {"ab": 1, "bc": 3, "abc": 5}
+ alphabet = set("abc_")
+ min_count = 100
+ num_iterations = 1
+ reserved_tokens = ["reserved", "tokens"]
+
+ vocab_list = tokenizer._generate_subtokens(token_counts, alphabet,
+ min_count, num_iterations,
+ reserved_tokens)
+
+ # Check that reserved tokens are at the front of the list
+ self.assertEqual(vocab_list[:2], reserved_tokens)
+
+ # Check that each character in alphabet is in the vocab list
+ for c in alphabet:
+ self.assertIn(c, vocab_list)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/nlp/xlnet/README.md b/models/official/nlp/xlnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9675f01a57fd26a83ed5103e116257b3664396cb
--- /dev/null
+++ b/models/official/nlp/xlnet/README.md
@@ -0,0 +1,16 @@
+# XLNet: Generalized Autoregressive Pretraining for Language Understanding
+
+The academic paper which describes XLNet in detail and provides full results on
+a number of tasks can be found here: https://arxiv.org/abs/1906.08237.
+
+**Instructions and user guide will be added soon.**
+
+XLNet is a generalized autoregressive BERT-like pretraining language model that
+enables learning bidirectional contexts by maximizing the expected likelihood
+over all permutations of the factorization order. It can learn dependency beyond
+a fixed length without disrupting temporal coherence by using segment-level
+recurrence mechanism and relative positional encoding scheme introduced in
+[Transformer-XL](https://arxiv.org/pdf/1901.02860.pdf). XLNet outperforms BERT
+on 20 NLP benchmark tasks and achieves state-of-the-art results on 18 tasks
+including question answering, natural language inference, sentiment analysis,
+and document ranking.
diff --git a/models/official/nlp/xlnet/__init__.py b/models/official/nlp/xlnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/official/nlp/xlnet/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/official/nlp/xlnet/classifier_utils.py b/models/official/nlp/xlnet/classifier_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..64363e322633f7ae43d6ffc65c99ee1beff36827
--- /dev/null
+++ b/models/official/nlp/xlnet/classifier_utils.py
@@ -0,0 +1,162 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for pre-processing classification data."""
+from absl import logging
+
+from official.nlp.xlnet import data_utils
+
+SEG_ID_A = 0
+SEG_ID_B = 1
+
+
+class PaddingInputExample(object):
+ """Fake example so the num input examples is a multiple of the batch size.
+
+ When running eval/predict on the TPU, we need to pad the number of examples
+ to be a multiple of the batch size, because the TPU requires a fixed batch
+ size. The alternative is to drop the last batch, which is bad because it means
+ the entire output data won't be generated.
+ We use this class instead of `None` because treating `None` as padding
+ battches could cause silent errors.
+ """
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ input_ids,
+ input_mask,
+ segment_ids,
+ label_id,
+ is_real_example=True):
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.segment_ids = segment_ids
+ self.label_id = label_id
+ self.is_real_example = is_real_example
+
+
+def _truncate_seq_pair(tokens_a, tokens_b, max_length):
+ """Truncates a sequence pair in place to the maximum length."""
+
+ # This is a simple heuristic which will always truncate the longer sequence
+ # one token at a time. This makes more sense than truncating an equal percent
+ # of tokens from each, since if one sequence is very short then each token
+ # that's truncated likely contains more information than a longer sequence.
+ while True:
+ total_length = len(tokens_a) + len(tokens_b)
+ if total_length <= max_length:
+ break
+ if len(tokens_a) > len(tokens_b):
+ tokens_a.pop()
+ else:
+ tokens_b.pop()
+
+
+def convert_single_example(example_index, example, label_list, max_seq_length,
+ tokenize_fn, use_bert_format):
+ """Converts a single `InputExample` into a single `InputFeatures`."""
+
+ if isinstance(example, PaddingInputExample):
+ return InputFeatures(
+ input_ids=[0] * max_seq_length,
+ input_mask=[1] * max_seq_length,
+ segment_ids=[0] * max_seq_length,
+ label_id=0,
+ is_real_example=False)
+
+ if label_list is not None:
+ label_map = {}
+ for (i, label) in enumerate(label_list):
+ label_map[label] = i
+
+ tokens_a = tokenize_fn(example.text_a)
+ tokens_b = None
+ if example.text_b:
+ tokens_b = tokenize_fn(example.text_b)
+
+ if tokens_b:
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
+ # length is less than the specified length.
+ # Account for two [SEP] & one [CLS] with "- 3"
+ _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
+ else:
+ # Account for one [SEP] & one [CLS] with "- 2"
+ if len(tokens_a) > max_seq_length - 2:
+ tokens_a = tokens_a[:max_seq_length - 2]
+
+ tokens = []
+ segment_ids = []
+ for token in tokens_a:
+ tokens.append(token)
+ segment_ids.append(SEG_ID_A)
+ tokens.append(data_utils.SEP_ID)
+ segment_ids.append(SEG_ID_A)
+
+ if tokens_b:
+ for token in tokens_b:
+ tokens.append(token)
+ segment_ids.append(SEG_ID_B)
+ tokens.append(data_utils.SEP_ID)
+ segment_ids.append(SEG_ID_B)
+
+ if use_bert_format:
+ tokens.insert(0, data_utils.CLS_ID)
+ segment_ids.insert(0, data_utils.SEG_ID_CLS)
+ else:
+ tokens.append(data_utils.CLS_ID)
+ segment_ids.append(data_utils.SEG_ID_CLS)
+
+ input_ids = tokens
+
+ # The mask has 0 for real tokens and 1 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [0] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ if len(input_ids) < max_seq_length:
+ delta_len = max_seq_length - len(input_ids)
+ if use_bert_format:
+ input_ids = input_ids + [0] * delta_len
+ input_mask = input_mask + [1] * delta_len
+ segment_ids = segment_ids + [data_utils.SEG_ID_PAD] * delta_len
+ else:
+ input_ids = [0] * delta_len + input_ids
+ input_mask = [1] * delta_len + input_mask
+ segment_ids = [data_utils.SEG_ID_PAD] * delta_len + segment_ids
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ if label_list is not None:
+ label_id = label_map[example.label]
+ else:
+ label_id = example.label
+ if example_index < 5:
+ logging.info("*** Example ***")
+ logging.info("guid: %s", (example.guid))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ logging.info("label: %d (id = %d)", example.label, label_id)
+
+ feature = InputFeatures(
+ input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ label_id=label_id)
+ return feature
diff --git a/models/official/nlp/xlnet/common_flags.py b/models/official/nlp/xlnet/common_flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..93d9499f19475b96095c409fb20a5efb35f3f9b5
--- /dev/null
+++ b/models/official/nlp/xlnet/common_flags.py
@@ -0,0 +1,146 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common flags used in XLNet model."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import flags
+
+flags.DEFINE_string("master", default=None, help="master")
+flags.DEFINE_string(
+ "tpu",
+ default=None,
+ help="The Cloud TPU to use for training. This should be "
+ "either the name used when creating the Cloud TPU, or a "
+ "url like grpc://ip.address.of.tpu:8470.")
+flags.DEFINE_bool(
+ "use_tpu", default=True, help="Use TPUs rather than plain CPUs.")
+flags.DEFINE_string("tpu_topology", "2x2", help="TPU topology.")
+flags.DEFINE_integer(
+ "num_core_per_host", default=8, help="number of cores per host")
+
+flags.DEFINE_string("model_dir", default=None, help="Estimator model_dir.")
+flags.DEFINE_string(
+ "init_checkpoint",
+ default=None,
+ help="Checkpoint path for initializing the model.")
+flags.DEFINE_bool(
+ "init_from_transformerxl",
+ default=False,
+ help="Init from a transformerxl model checkpoint. Otherwise, init from the "
+ "entire model checkpoint.")
+
+# Optimization config
+flags.DEFINE_float("learning_rate", default=1e-4, help="Maximum learning rate.")
+flags.DEFINE_float("clip", default=1.0, help="Gradient clipping value.")
+flags.DEFINE_float("weight_decay_rate", default=0.0, help="Weight decay rate.")
+
+# lr decay
+flags.DEFINE_integer(
+ "warmup_steps", default=0, help="Number of steps for linear lr warmup.")
+flags.DEFINE_float("adam_epsilon", default=1e-8, help="Adam epsilon.")
+flags.DEFINE_float(
+ "lr_layer_decay_rate",
+ default=1.0,
+ help="Top layer: lr[L] = FLAGS.learning_rate."
+ "Lower layers: lr[l-1] = lr[l] * lr_layer_decay_rate.")
+flags.DEFINE_float(
+ "min_lr_ratio", default=0.0, help="Minimum ratio learning rate.")
+
+# Training config
+flags.DEFINE_integer(
+ "train_batch_size",
+ default=16,
+ help="Size of the train batch across all hosts.")
+flags.DEFINE_integer(
+ "train_steps", default=100000, help="Total number of training steps.")
+flags.DEFINE_integer(
+ "iterations", default=1000, help="Number of iterations per repeat loop.")
+
+# Data config
+flags.DEFINE_integer(
+ "seq_len", default=0, help="Sequence length for pretraining.")
+flags.DEFINE_integer(
+ "reuse_len",
+ default=0,
+ help="How many tokens to be reused in the next batch. "
+ "Could be half of `seq_len`.")
+flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
+flags.DEFINE_bool(
+ "bi_data",
+ default=False,
+ help="Use bidirectional data streams, "
+ "i.e., forward & backward.")
+flags.DEFINE_integer("n_token", 32000, help="Vocab size")
+
+# Model config
+flags.DEFINE_integer("mem_len", default=0, help="Number of steps to cache")
+flags.DEFINE_bool("same_length", default=False, help="Same length attention")
+flags.DEFINE_integer("clamp_len", default=-1, help="Clamp length")
+
+flags.DEFINE_integer("n_layer", default=6, help="Number of layers.")
+flags.DEFINE_integer("d_model", default=32, help="Dimension of the model.")
+flags.DEFINE_integer("d_embed", default=32, help="Dimension of the embeddings.")
+flags.DEFINE_integer("n_head", default=4, help="Number of attention heads.")
+flags.DEFINE_integer(
+ "d_head", default=8, help="Dimension of each attention head.")
+flags.DEFINE_integer(
+ "d_inner",
+ default=32,
+ help="Dimension of inner hidden size in positionwise "
+ "feed-forward.")
+flags.DEFINE_float("dropout", default=0.1, help="Dropout rate.")
+flags.DEFINE_float("dropout_att", default=0.1, help="Attention dropout rate.")
+flags.DEFINE_bool("untie_r", default=False, help="Untie r_w_bias and r_r_bias")
+flags.DEFINE_string(
+ "ff_activation",
+ default="relu",
+ help="Activation type used in position-wise feed-forward.")
+flags.DEFINE_string(
+ "strategy_type",
+ default="tpu",
+ help="Activation type used in position-wise feed-forward.")
+flags.DEFINE_bool("use_bfloat16", False, help="Whether to use bfloat16.")
+
+# Parameter initialization
+flags.DEFINE_enum(
+ "init_method",
+ default="normal",
+ enum_values=["normal", "uniform"],
+ help="Initialization method.")
+flags.DEFINE_float(
+ "init_std", default=0.02, help="Initialization std when init is normal.")
+flags.DEFINE_float(
+ "init_range", default=0.1, help="Initialization std when init is uniform.")
+
+flags.DEFINE_integer(
+ "test_data_size", default=12048, help="Number of test data samples.")
+flags.DEFINE_string(
+ "train_tfrecord_path",
+ default=None,
+ help="Path to preprocessed training set tfrecord.")
+flags.DEFINE_string(
+ "test_tfrecord_path",
+ default=None,
+ help="Path to preprocessed test set tfrecord.")
+flags.DEFINE_integer(
+ "test_batch_size",
+ default=16,
+ help="Size of the test batch across all hosts.")
+flags.DEFINE_integer(
+ "save_steps", default=1000, help="Number of steps for saving checkpoint.")
+FLAGS = flags.FLAGS
diff --git a/models/official/nlp/xlnet/data_utils.py b/models/official/nlp/xlnet/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1dfe5e7ffb06ff8d38c11271b5758db48c4c4cb
--- /dev/null
+++ b/models/official/nlp/xlnet/data_utils.py
@@ -0,0 +1,816 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities used for data preparation."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import collections
+import json
+import os
+from absl import logging
+
+import numpy as np
+import tensorflow as tf
+
+
+special_symbols = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "": 3,
+ "": 4,
+ "": 5,
+ "": 6,
+ "": 7,
+ "": 8,
+}
+
+VOCAB_SIZE = 32000
+UNK_ID = special_symbols[""]
+CLS_ID = special_symbols[""]
+SEP_ID = special_symbols[""]
+MASK_ID = special_symbols[""]
+EOD_ID = special_symbols[""]
+SEG_ID_P = 0
+SEG_ID_Q = 1
+SEG_ID_CLS = 2
+SEG_ID_PAD = 3
+
+
+OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [
+ "sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words",
+ "min_num_words"])
+
+
+def file_based_input_fn_builder(input_file, name_to_features, batch_size,
+ is_training):
+ """Creates an `input_fn` closure."""
+
+ logging.info("Input tfrecord file %s", input_file)
+
+ def _decode_record(record, name_to_features):
+ """Decodes a record to a TensorFlow example."""
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def input_fn():
+ """Returns dataset for training/evaluation."""
+ num_threads = 8
+ if isinstance(input_file, str):
+ d = tf.data.TFRecordDataset(input_file)
+ # For training, we want a lot of parallel reading and shuffling.
+ # For eval, we want no shuffling and parallel reading doesn't matter.
+ if is_training:
+ d = d.shuffle(2048)
+ d = d.repeat()
+ else:
+ cycle_length = min(num_threads, len(input_file))
+ d = tf.data.Dataset.from_tensor_slices(input_file)
+ # file level shuffle
+ d = d.shuffle(len(input_file)).repeat()
+
+ d = d.interleave(
+ tf.data.TFRecordDataset,
+ sloppy=is_training,
+ cycle_length=cycle_length)
+
+ if is_training:
+ # sample level shuffle
+ d = d.shuffle(buffer_size=2048)
+ d = d.map(
+ lambda record: _decode_record(record, name_to_features),
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ d = d.batch(batch_size, drop_remainder=is_training)
+
+ # When `input_file` is a path to a single file or a list
+ # containing a single path, disable auto sharding so that
+ # same input file is sent to all workers.
+ if isinstance(input_file, str) or len(input_file) == 1:
+ options = tf.data.Options()
+ options.experimental_distribute.auto_shard_policy = (
+ tf.data.experimental.AutoShardPolicy.OFF)
+ d = d.with_options(options)
+
+ d = d.prefetch(tf.data.experimental.AUTOTUNE)
+ return d
+
+ return input_fn
+
+
+def create_classification_dataset(file_path, seq_length, batch_size,
+ is_training):
+ """Creates input dataset from (tf)records files for pretraining."""
+ name_to_features = {
+ "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
+ "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "label_ids": tf.io.FixedLenFeature([], tf.int64),
+ "is_real_example": tf.io.FixedLenFeature([], tf.int64),
+ }
+
+ input_fn = file_based_input_fn_builder(file_path, name_to_features,
+ batch_size, is_training)
+ dataset = input_fn()
+ return dataset
+
+
+def create_squad_dataset(file_path, seq_length, batch_size, is_training):
+ """Creates input dataset from (tf)records files for pretraining."""
+ name_to_features = {
+ "unique_ids": tf.io.FixedLenFeature([], tf.int64),
+ "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
+ "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "cls_index": tf.io.FixedLenFeature([], tf.int64),
+ "p_mask": tf.io.FixedLenFeature([seq_length], tf.float32)
+ }
+
+ if is_training:
+ name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
+ name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
+ name_to_features["is_impossible"] = tf.io.FixedLenFeature([], tf.float32)
+
+ input_fn = file_based_input_fn_builder(file_path, name_to_features,
+ batch_size, is_training)
+ dataset = input_fn()
+ return dataset
+
+
+def get_input_iterator(input_fn, strategy):
+ """Returns distributed dataset iterator."""
+
+ # When training with TPU pods, datasets needs to be cloned across
+ # workers. Since Dataset instance cannot be cloned in eager mode, we instead
+ # pass callable that returns a dataset.
+ input_data = input_fn()
+ if callable(input_data):
+ iterator = iter(
+ strategy.experimental_distribute_datasets_from_function(input_data))
+ else:
+ iterator = iter(strategy.experimental_distribute_dataset(input_data))
+ return iterator
+
+
+def get_classification_input_data(batch_size, seq_len, strategy, is_training,
+ file_path):
+ """Returns input dataset from input file string."""
+
+ # When using TPU pods, we need to clone dataset across
+ # workers and need to pass in function that returns the dataset rather
+ # than passing dataset instance itself.
+ use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
+ if use_dataset_fn:
+ if batch_size % strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ "Batch size must be divisible by number of replicas : {}".format(
+ strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `experimental_distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ batch_size = int(batch_size / strategy.num_replicas_in_sync)
+
+ def _dataset_fn(ctx=None):
+ del ctx
+
+ train_dataset = create_classification_dataset(
+ file_path=file_path,
+ seq_length=seq_len,
+ batch_size=batch_size,
+ is_training=is_training)
+ return train_dataset
+
+ return _dataset_fn if use_dataset_fn else _dataset_fn()
+
+
+def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
+ file_path):
+ """Returns input dataset from input file string."""
+
+ # When using TPU pods, we need to clone dataset across
+ # workers and need to pass in function that returns the dataset rather
+ # than passing dataset instance itself.
+ use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
+ if use_dataset_fn:
+ if batch_size % strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ "Batch size must be divisible by number of replicas : {}".format(
+ strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `experimental_distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ batch_size = int(batch_size / strategy.num_replicas_in_sync)
+
+ if is_training:
+ input_glob = os.path.join(
+ file_path,
+ "spiece.model.*.slen-{}.qlen-{}.train.tf_record".format(seq_len, q_len))
+
+ global_input_paths = tf.io.gfile.glob(input_glob)
+ else:
+ global_input_paths = file_path
+
+ def _dataset_fn(ctx=None):
+ del ctx
+
+ train_dataset = create_squad_dataset(
+ file_path=global_input_paths,
+ seq_length=seq_len,
+ batch_size=batch_size,
+ is_training=is_training)
+ return train_dataset
+
+ return _dataset_fn if use_dataset_fn else _dataset_fn()
+
+
+def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict):
+ """Turn beg and end indices into actual mask."""
+ non_func_mask = tf.logical_and(
+ tf.not_equal(inputs, SEP_ID),
+ tf.not_equal(inputs, CLS_ID))
+ all_indices = tf.where(
+ non_func_mask,
+ tf.range(tgt_len, dtype=tf.int64),
+ tf.constant(-1, shape=[tgt_len], dtype=tf.int64))
+ candidate_matrix = tf.cast(
+ tf.logical_and(
+ all_indices[None, :] >= beg_indices[:, None],
+ all_indices[None, :] < end_indices[:, None]),
+ tf.float32)
+ cumsum_matrix = tf.reshape(
+ tf.cumsum(tf.reshape(candidate_matrix, [-1])),
+ [-1, tgt_len])
+ masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32)
+ target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
+ is_masked = tf.cast(target_mask, tf.bool)
+
+ return is_masked, target_mask
+
+
+def _word_span_mask(inputs, tgt_len, num_predict, min_num_words,
+ max_num_words, boundary):
+ """Sample whole word spans as prediction targets."""
+ # Note: 1.2 is the token-to-word ratio
+ mask_alpha = tgt_len / num_predict / 1.2
+ round_to_int = lambda x: tf.cast(tf.round(x), tf.int64)
+
+ # Sample span lengths from a zipf distribution
+ span_len_seq = np.arange(min_num_words, max_num_words + 1)
+ probs = np.array([1.0 / (i + 1) for i in span_len_seq])
+ probs /= np.sum(probs)
+ logits = tf.constant(np.log(probs), dtype=tf.float32)
+
+ # Sample `num_predict` words here: note that this is over sampling
+ span_lens = tf.random.categorical(
+ logits=logits[None],
+ num_samples=num_predict,
+ dtype=tf.int64,
+ )[0] + min_num_words
+
+ # Sample the ratio [0.0, 1.0) of left context lengths
+ span_lens_float = tf.cast(span_lens, tf.float32)
+ left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0)
+ left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
+
+ left_ctx_len = round_to_int(left_ctx_len)
+ right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
+
+ beg_indices = (tf.cumsum(left_ctx_len) +
+ tf.cumsum(right_offset, exclusive=True))
+ end_indices = beg_indices + span_lens
+
+ # Remove out of range indices
+ max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int64)
+ valid_idx_mask = end_indices < max_boundary_index
+ beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask)
+ end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
+
+ beg_indices = tf.gather(boundary, beg_indices)
+ end_indices = tf.gather(boundary, end_indices)
+
+ # Shuffle valid indices
+ num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64)
+ order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64))
+ beg_indices = tf.gather(beg_indices, order)
+ end_indices = tf.gather(end_indices, order)
+
+ return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
+ num_predict)
+
+
+def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens,
+ max_num_tokens):
+ """Sample token spans as prediction targets."""
+ mask_alpha = tgt_len / num_predict
+ round_to_int = lambda x: tf.cast(tf.round(x), tf.int64)
+
+ # Sample span lengths from a zipf distribution
+ span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
+ probs = np.array([1.0 / (i + 1) for i in span_len_seq])
+
+ probs /= np.sum(probs)
+ logits = tf.constant(np.log(probs), dtype=tf.float32)
+ span_lens = tf.random.categorical(
+ logits=logits[None],
+ num_samples=num_predict,
+ dtype=tf.int64,
+ )[0] + min_num_tokens
+
+ # Sample the ratio [0.0, 1.0) of left context lengths
+ span_lens_float = tf.cast(span_lens, tf.float32)
+ left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0)
+ left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
+ left_ctx_len = round_to_int(left_ctx_len)
+
+ # Compute the offset from left start to the right end
+ right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
+
+ # Get the actual begin and end indices
+ beg_indices = (tf.cumsum(left_ctx_len) +
+ tf.cumsum(right_offset, exclusive=True))
+ end_indices = beg_indices + span_lens
+
+ # Remove out of range indices
+ valid_idx_mask = end_indices < tgt_len
+ beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask)
+ end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
+
+ # Shuffle valid indices
+ num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64)
+ order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64))
+ beg_indices = tf.gather(beg_indices, order)
+ end_indices = tf.gather(end_indices, order)
+
+ return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
+ num_predict)
+
+
+def _whole_word_mask(inputs, tgt_len, num_predict, boundary):
+ """Sample whole words as prediction targets."""
+ pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
+ cand_pair_indices = tf.random.shuffle(pair_indices)[:num_predict]
+ beg_indices = cand_pair_indices[:, 0]
+ end_indices = cand_pair_indices[:, 1]
+
+ return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
+ num_predict)
+
+
+def _single_token_mask(inputs, tgt_len, num_predict):
+ """Sample individual tokens as prediction targets."""
+ all_indices = tf.range(tgt_len, dtype=tf.int64)
+ non_func_mask = tf.logical_and(
+ tf.not_equal(inputs, SEP_ID),
+ tf.not_equal(inputs, CLS_ID))
+ non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
+
+ masked_pos = tf.random.shuffle(non_func_indices)
+ masked_pos = tf.sort(masked_pos[:num_predict])
+ target_mask = tf.sparse_to_dense(
+ sparse_indices=masked_pos,
+ output_shape=[tgt_len],
+ sparse_values=1.0,
+ default_value=0.0)
+
+ is_masked = tf.cast(target_mask, tf.bool)
+
+ return is_masked, target_mask
+
+
+def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config,
+ boundary=None):
+ """Sample target positions to predict."""
+ logging.info("Online sample with strategy: `%s`.",
+ online_masking_config.sample_strategy)
+ if online_masking_config.sample_strategy == "single_token":
+ return _single_token_mask(inputs, tgt_len, num_predict)
+ elif online_masking_config.sample_strategy == "whole_word":
+ assert boundary is not None, "whole word sampling requires `boundary`"
+ return _whole_word_mask(inputs, tgt_len, num_predict, boundary)
+ elif online_masking_config.sample_strategy == "token_span":
+ return _token_span_mask(inputs, tgt_len, num_predict,
+ online_masking_config.min_num_tokens,
+ online_masking_config.max_num_tokens)
+ elif online_masking_config.sample_strategy == "word_span":
+ assert boundary is not None, "word span sampling requires `boundary`"
+ return _word_span_mask(inputs, tgt_len, num_predict,
+ online_masking_config.min_num_words,
+ online_masking_config.max_num_words,
+ boundary)
+ else:
+ raise NotImplementedError
+
+
+def create_pretrain_dataset(file_names,
+ bsz_per_core,
+ seq_len,
+ reuse_len,
+ perm_size,
+ leak_ratio,
+ online_masking_config,
+ num_predict=None,
+ input_pipeline_context=None):
+ """Creates pretrain dataset."""
+
+ def parser(record):
+ """Function used to parse tfrecord."""
+
+ record_spec = {
+ "input": tf.io.FixedLenFeature([seq_len], tf.int64),
+ "seg_id": tf.io.FixedLenFeature([seq_len], tf.int64),
+ "label": tf.io.FixedLenFeature([1], tf.int64),
+ }
+
+ if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
+ logging.info("Add `boundary` spec for %s",
+ online_masking_config.sample_strategy)
+ record_spec["boundary"] = tf.io.VarLenFeature(tf.int64)
+
+ # retrieve serialized example
+ example = tf.io.parse_single_example(
+ serialized=record, features=record_spec)
+
+ inputs = example.pop("input")
+ if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
+ boundary = tf.sparse.to_dense(example.pop("boundary"))
+ else:
+ boundary = None
+ is_masked, _ = _online_sample_masks(
+ inputs, seq_len, num_predict, online_masking_config, boundary=boundary)
+
+ if reuse_len > 0:
+ ##### Use memory
+ # permutate the reuse and non-reuse parts separately
+ non_reuse_len = seq_len - reuse_len
+ assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0
+
+ # Creates permutation mask and target mask for the first reuse_len tokens.
+ # The tokens in this part are reused from the last sequence.
+ perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
+ inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len,
+ leak_ratio)
+
+ # Creates permutation mask and target mask for the rest of tokens in
+ # current example, which are concatentation of two new segments.
+ perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
+ inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len,
+ leak_ratio)
+
+ perm_mask_0 = tf.concat(
+ [perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1)
+ perm_mask_1 = tf.concat(
+ [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1)
+ perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
+ target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
+ input_k = tf.concat([input_k_0, input_k_1], axis=0)
+ input_q = tf.concat([input_q_0, input_q_1], axis=0)
+ else:
+ ##### Do not use memory
+ assert seq_len % perm_size == 0
+ # permutate the entire sequence together
+ perm_mask, target_mask, input_k, input_q = _local_perm(
+ inputs, is_masked, perm_size, seq_len, leak_ratio)
+
+ # reshape back to fixed shape
+ example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
+ example["input_k"] = tf.reshape(input_k, [seq_len])
+ example["input_q"] = tf.reshape(input_q, [seq_len])
+
+ # Directly use raw inputs as the target
+ target = inputs
+
+ if num_predict is not None:
+ indices = tf.range(seq_len, dtype=tf.int64)
+ bool_target_mask = tf.cast(target_mask, tf.bool)
+ indices = tf.boolean_mask(indices, bool_target_mask)
+
+ ##### extra padding due to CLS/SEP introduced after prepro
+ actual_num_predict = tf.shape(indices)[0]
+ pad_len = num_predict - actual_num_predict
+
+ ##### target_mapping
+ target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
+ paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
+ target_mapping = tf.concat([target_mapping, paddings], axis=0)
+ example["target_mapping"] = tf.reshape(target_mapping,
+ [num_predict, seq_len])
+
+ ##### target
+ target = tf.boolean_mask(target, bool_target_mask)
+ paddings = tf.zeros([pad_len], dtype=target.dtype)
+ target = tf.concat([target, paddings], axis=0)
+ example["target"] = tf.reshape(target, [num_predict])
+
+ ##### target mask
+ target_mask = tf.concat(
+ [tf.ones([actual_num_predict], dtype=tf.float32),
+ tf.zeros([pad_len], dtype=tf.float32)],
+ axis=0)
+ example["target_mask"] = tf.reshape(target_mask, [num_predict])
+ else:
+ example["target"] = tf.reshape(target, [seq_len])
+ example["target_mask"] = tf.reshape(target_mask, [seq_len])
+
+ for key in list(example.keys()):
+ val = example[key]
+ if tf.keras.backend.is_sparse(val):
+ val = tf.sparse.to_dense(val)
+ if val.dtype == tf.int64:
+ val = tf.cast(val, tf.int32)
+
+ example[key] = val
+
+ for k, v in example.items():
+ logging.info("%s: %s", k, v)
+
+ return example
+
+ dataset = parse_files_to_dataset(
+ parser=parser,
+ file_paths=file_names,
+ bsz_per_core=bsz_per_core,
+ sequential=reuse_len > 0,
+ input_pipeline_context=input_pipeline_context)
+
+ return dataset
+
+
+def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None,
+ uncased=False):
+ """Generates input file name pattern."""
+ if reuse_len is not None and reuse_len > 0:
+ reuse_str = "reuse-{}.".format(reuse_len)
+ bsz_str = "hostbsz-{}.".format(bsz_per_host)
+ else:
+ reuse_str = ""
+ bsz_str = ""
+
+ if not uncased:
+ case_str = ""
+ else:
+ case_str = "uncased."
+
+ file_name = "{}.seq-{}.{}{}{}{}".format(
+ prefix, seq_len, reuse_str, bsz_str, case_str, suffix)
+
+ return file_name
+
+
+def get_pretrain_input_data(batch_size,
+ seq_len,
+ strategy,
+ file_path,
+ reuse_len,
+ perm_size,
+ leak_ratio,
+ num_predict,
+ uncased,
+ online_masking_config,
+ num_hosts=1):
+ """Returns input dataset from input file string."""
+
+ # When using TPU pods, we need to clone dataset across
+ # workers and need to pass in function that returns the dataset rather
+ # than passing dataset instance itself.
+ use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
+ split = "train"
+ bsz_per_host = int(batch_size / num_hosts)
+ record_glob_base = format_filename(
+ prefix="meta.{}.pass-*".format(split),
+ suffix="json*",
+ bsz_per_host=bsz_per_host,
+ seq_len=seq_len,
+ reuse_len=reuse_len,
+ uncased=uncased)
+
+ def _get_num_batch(info):
+ if "num_batch" in info:
+ return info["num_batch"]
+ elif "num_example" in info:
+ return info["num_example"] / bsz_per_host
+ else:
+ raise ValueError("Do not have sample info.")
+
+ if use_dataset_fn:
+ if batch_size % strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ "Batch size must be divisible by number of replicas : {}".format(
+ strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `experimental_distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ batch_size = int(batch_size / strategy.num_replicas_in_sync)
+
+ record_info = {"num_batch": 0, "filenames": []}
+
+ tfrecord_dirs = file_path.split(",")
+ logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
+
+ for idx, record_dir in enumerate(tfrecord_dirs):
+ record_glob = os.path.join(record_dir, record_glob_base)
+ logging.info("[%d] Record glob: %s", idx, record_glob)
+
+ record_paths = sorted(tf.io.gfile.glob(record_glob))
+ logging.info("[%d] Num of record info path: %d", idx, len(record_paths))
+
+ cur_record_info = {"num_batch": 0, "filenames": []}
+
+ for record_info_path in record_paths:
+ with tf.io.gfile.GFile(record_info_path, "r") as fp:
+ info = json.load(fp)
+ cur_record_info["num_batch"] += int(_get_num_batch(info))
+ cur_record_info["filenames"] += info["filenames"]
+
+ # overwrite directory for `cur_record_info`
+ new_filenames = []
+ for filename in cur_record_info["filenames"]:
+ basename = os.path.basename(filename)
+ new_filename = os.path.join(record_dir, basename)
+ new_filenames.append(new_filename)
+ cur_record_info["filenames"] = new_filenames
+
+ logging.info("[Dir %d] Number of chosen batches: %s", idx,
+ cur_record_info["num_batch"])
+ logging.info("[Dir %d] Number of chosen files: %s", idx,
+ len(cur_record_info["filenames"]))
+ logging.info(cur_record_info["filenames"])
+
+ # add `cur_record_info` to global `record_info`
+ record_info["num_batch"] += cur_record_info["num_batch"]
+ record_info["filenames"] += cur_record_info["filenames"]
+
+ logging.info("Total number of batches: %d", record_info["num_batch"])
+ logging.info("Total number of files: %d", len(record_info["filenames"]))
+ logging.info(record_info["filenames"])
+
+ def _dataset_fn(ctx=None):
+ """Function that can create a pretrain dataset."""
+
+ train_dataset = create_pretrain_dataset(
+ file_names=record_info["filenames"],
+ bsz_per_core=batch_size,
+ seq_len=seq_len,
+ reuse_len=reuse_len,
+ perm_size=perm_size,
+ leak_ratio=leak_ratio,
+ online_masking_config=online_masking_config,
+ num_predict=num_predict,
+ input_pipeline_context=ctx)
+ return train_dataset
+
+ return _dataset_fn if use_dataset_fn else _dataset_fn()
+
+
+def parse_files_to_dataset(parser,
+ file_paths,
+ bsz_per_core,
+ sequential,
+ input_pipeline_context=None):
+ """Creates the dataset given file paths."""
+
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ # Note: we cannot perform sample-level shuffle here because this will violate
+ # the consecutive requirement of data stream.
+
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+ # file-level shuffle
+ if len(file_paths) > 1:
+ dataset = dataset.shuffle(len(file_paths))
+
+ if sequential:
+ # Note: cannot perform sample-level shuffle here because this will violate
+ # the consecutive requirement of data stream.
+ dataset = tf.data.TFRecordDataset(dataset)
+ else:
+ # `cycle_length` is the number of parallel files that get read.
+ cycle_length = min(8, len(file_paths))
+ logging.info("Interleave %d files", cycle_length)
+
+ # `sloppy` mode means that the interleaving is not exact. This adds
+ # even more randomness to the training pipeline.
+ dataset = dataset.apply(
+ tf.data.experimental.parallel_interleave(
+ tf.data.TFRecordDataset,
+ sloppy=True,
+ cycle_length=cycle_length))
+ buffer_size = 2048
+ logging.info("Perform sample-level shuffle with size %d", buffer_size)
+ dataset = dataset.shuffle(buffer_size=buffer_size)
+
+ dataset = dataset.cache().repeat().map(parser)
+ dataset = dataset.batch(bsz_per_core, drop_remainder=True)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+
+ return dataset
+
+
+def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio):
+ """Samples a permutation of the factorization order.
+
+ Creates perm_mask and target_mask accordingly.
+
+ Args:
+ inputs: int64 Tensor in shape [seq_len], input ids.
+ is_masked: bool Tensor in shape [seq_len]. True means being selected for
+ partial prediction.
+ perm_size: the length of longest permutation. Could be set to be reuse_len.
+ Should not be larger than reuse_len or there will be data leaks.
+ seq_len: int, sequence length.
+ leak_ratio: float, percent of masked tokens that are leaked.
+
+ Returns:
+ perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
+ If perm_mask[i][j] == 1, it means the ith token (in original order) cannot
+ attend to the jth token
+ (in original order). This case will happen only when the ith token's
+ permutated position <= the jth token's permutated position,
+ and the jth token is masked or is func token. If perm_mask[i][j] == 0, it
+ means the ith token (in original order) can attend to the jth token
+ (in original order). Note that non-masked tokens can be attended by all
+ other tokens, which is different from the description in original paper.
+ target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
+ target_mask[i] == 1,
+ the ith token needs to be predicted and mask will be used as input. This
+ token will count for loss.
+ If target_mask[i] == 0, token (or [SEP], [CLS]) will be used as input. This
+ token will not count for loss.
+ inputs_k: int64 Tensor in shape [seq_len], input ids.
+ inputs_q: float32 Tensor in shape [seq_len], the same as target_mask.
+
+ """
+
+ # Generate permutation indices
+ index = tf.range(seq_len, dtype=tf.int64)
+ index = tf.transpose(tf.reshape(index, [-1, perm_size]))
+ index = tf.random.shuffle(index)
+ index = tf.reshape(tf.transpose(index), [-1])
+
+ # non-functional tokens
+ non_func_tokens = tf.logical_not(tf.logical_or(
+ tf.equal(inputs, SEP_ID),
+ tf.equal(inputs, CLS_ID)))
+ masked_tokens = tf.logical_and(is_masked, non_func_tokens)
+ non_masked_or_func_tokens = tf.logical_not(masked_tokens)
+
+ smallest_index = -2 * tf.ones([seq_len], dtype=tf.int64)
+
+ # Similar to BERT, randomly leak some masked tokens
+ if leak_ratio > 0:
+ leak_tokens = tf.logical_and(
+ masked_tokens,
+ tf.random.uniform([seq_len], maxval=1.0) < leak_ratio)
+ can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
+ else:
+ can_attend_self = non_masked_or_func_tokens
+ to_index = tf.where(can_attend_self, smallest_index, index)
+ from_index = tf.where(can_attend_self, to_index + 1, to_index)
+
+ # For masked tokens, can attend if i > j
+ # For context tokens, always can attend each other
+ can_attend = from_index[:, None] > to_index[None, :]
+
+ # In modeling, 1 indicates cannot attend. Hence, reverse the value here.
+ perm_mask = 1.0 - tf.cast(can_attend, tf.float32)
+
+ # Only masked tokens are included in the loss
+ target_mask = tf.cast(masked_tokens, tf.float32)
+
+ # construct inputs_k
+ inputs_k = inputs
+
+ # construct inputs_q
+ inputs_q = masked_tokens
+
+ return perm_mask, target_mask, inputs_k, inputs_q
diff --git a/models/official/nlp/xlnet/optimization.py b/models/official/nlp/xlnet/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d9031647faef79c7e4f722dfeca7e3c1fd7712f
--- /dev/null
+++ b/models/official/nlp/xlnet/optimization.py
@@ -0,0 +1,102 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions and classes related to optimization (weight updates)."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import logging
+import tensorflow as tf
+from official.nlp import optimization
+
+
+class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Applys a warmup schedule on a given learning rate decay schedule."""
+
+ def __init__(self,
+ initial_learning_rate,
+ decay_schedule_fn,
+ warmup_steps,
+ power=1.0,
+ name=None):
+ super(WarmUp, self).__init__()
+ self.initial_learning_rate = initial_learning_rate
+ self.warmup_steps = warmup_steps
+ self.power = power
+ self.decay_schedule_fn = decay_schedule_fn
+ self.name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self.name or "WarmUp") as name:
+ # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
+ global_step_float = tf.cast(step, tf.float32)
+ warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
+ warmup_percent_done = global_step_float / warmup_steps_float
+ warmup_learning_rate = (
+ self.initial_learning_rate *
+ tf.math.pow(warmup_percent_done, self.power))
+ return tf.cond(
+ global_step_float < warmup_steps_float,
+ lambda: warmup_learning_rate,
+ lambda: self.decay_schedule_fn(step - self.warmup_steps),
+ name=name)
+
+ def get_config(self):
+ return {
+ "initial_learning_rate": self.initial_learning_rate,
+ "decay_schedule_fn": self.decay_schedule_fn,
+ "warmup_steps": self.warmup_steps,
+ "power": self.power,
+ "name": self.name
+ }
+
+
+def create_optimizer(init_lr,
+ num_train_steps,
+ num_warmup_steps,
+ min_lr_ratio=0.0,
+ adam_epsilon=1e-8,
+ weight_decay_rate=0.0):
+ """Creates an optimizer with learning rate schedule."""
+ # Implements linear decay of the learning rate.
+ learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
+ initial_learning_rate=init_lr,
+ decay_steps=num_train_steps - num_warmup_steps,
+ end_learning_rate=init_lr * min_lr_ratio)
+ if num_warmup_steps:
+ learning_rate_fn = WarmUp(
+ initial_learning_rate=init_lr,
+ decay_schedule_fn=learning_rate_fn,
+ warmup_steps=num_warmup_steps)
+ if weight_decay_rate > 0.0:
+ logging.info(
+ "Using AdamWeightDecay with adam_epsilon=%.9f weight_decay_rate=%.3f",
+ adam_epsilon, weight_decay_rate)
+ optimizer = optimization.AdamWeightDecay(
+ learning_rate=learning_rate_fn,
+ weight_decay_rate=weight_decay_rate,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=adam_epsilon,
+ exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
+ include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"])
+ else:
+ logging.info("Using Adam with adam_epsilon=%.9f", (adam_epsilon))
+ optimizer = tf.keras.optimizers.Adam(
+ learning_rate=learning_rate_fn, epsilon=adam_epsilon)
+
+ return optimizer, learning_rate_fn
diff --git a/models/official/nlp/xlnet/preprocess_classification_data.py b/models/official/nlp/xlnet/preprocess_classification_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b34ffef7c7ed66a87b8386e1675e14c11b0791d
--- /dev/null
+++ b/models/official/nlp/xlnet/preprocess_classification_data.py
@@ -0,0 +1,457 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to pre-process classification data into tfrecords."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import csv
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+import tensorflow as tf
+
+import sentencepiece as spm
+from official.nlp.xlnet import classifier_utils
+from official.nlp.xlnet import preprocess_utils
+
+
+flags.DEFINE_bool(
+ "overwrite_data",
+ default=False,
+ help="If False, will use cached data if available.")
+flags.DEFINE_string("output_dir", default="", help="Output dir for TF records.")
+flags.DEFINE_string(
+ "spiece_model_file", default="", help="Sentence Piece model path.")
+flags.DEFINE_string("data_dir", default="", help="Directory for input data.")
+
+# task specific
+flags.DEFINE_string("eval_split", default="dev", help="could be dev or test")
+flags.DEFINE_string("task_name", default=None, help="Task name")
+flags.DEFINE_integer(
+ "eval_batch_size", default=64, help="batch size for evaluation")
+flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length")
+flags.DEFINE_integer(
+ "num_passes",
+ default=1,
+ help="Num passes for processing training data. "
+ "This is use to batch data without loss for TPUs.")
+flags.DEFINE_bool("uncased", default=False, help="Use uncased.")
+flags.DEFINE_bool(
+ "is_regression", default=False, help="Whether it's a regression task.")
+flags.DEFINE_bool(
+ "use_bert_format",
+ default=False,
+ help="Whether to use BERT format to arrange input data.")
+
+FLAGS = flags.FLAGS
+
+
+class InputExample(object):
+ """A single training/test example for simple sequence classification."""
+
+ def __init__(self, guid, text_a, text_b=None, label=None):
+ """Constructs a InputExample.
+
+ Args:
+ guid: Unique id for the example.
+ text_a: string. The untokenized text of the first sequence. For single
+ sequence tasks, only this sequence must be specified.
+ text_b: (Optional) string. The untokenized text of the second sequence.
+ Only must be specified for sequence pair tasks.
+ label: (Optional) string. The label of the example. This should be
+ specified for train and dev examples, but not for test examples.
+ """
+ self.guid = guid
+ self.text_a = text_a
+ self.text_b = text_b
+ self.label = label
+
+
+class DataProcessor(object):
+ """Base class for data converters for sequence classification data sets."""
+
+ def get_train_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the train set."""
+ raise NotImplementedError()
+
+ def get_dev_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the dev set."""
+ raise NotImplementedError()
+
+ def get_test_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for prediction."""
+ raise NotImplementedError()
+
+ def get_labels(self):
+ """Gets the list of labels for this data set."""
+ raise NotImplementedError()
+
+ @classmethod
+ def _read_tsv(cls, input_file, quotechar=None):
+ """Reads a tab separated value file."""
+ with tf.io.gfile.GFile(input_file, "r") as f:
+ reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
+ lines = []
+ for line in reader:
+ # pylint: disable=g-explicit-length-test
+ if len(line) == 0:
+ continue
+ lines.append(line)
+ return lines
+
+
+class GLUEProcessor(DataProcessor):
+ """GLUEProcessor."""
+
+ def __init__(self):
+ self.train_file = "train.tsv"
+ self.dev_file = "dev.tsv"
+ self.test_file = "test.tsv"
+ self.label_column = None
+ self.text_a_column = None
+ self.text_b_column = None
+ self.contains_header = True
+ self.test_text_a_column = None
+ self.test_text_b_column = None
+ self.test_contains_header = True
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, self.train_file)), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ if self.test_text_a_column is None:
+ self.test_text_a_column = self.text_a_column
+ if self.test_text_b_column is None:
+ self.test_text_b_column = self.text_b_column
+
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, self.test_file)), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0 and self.contains_header and set_type != "test":
+ continue
+ if i == 0 and self.test_contains_header and set_type == "test":
+ continue
+ guid = "%s-%s" % (set_type, i)
+
+ a_column = (
+ self.text_a_column if set_type != "test" else self.test_text_a_column)
+ b_column = (
+ self.text_b_column if set_type != "test" else self.test_text_b_column)
+
+ # there are some incomplete lines in QNLI
+ if len(line) <= a_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ text_a = line[a_column]
+
+ if b_column is not None:
+ if len(line) <= b_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ text_b = line[b_column]
+ else:
+ text_b = None
+
+ if set_type == "test":
+ label = self.get_labels()[0]
+ else:
+ if len(line) <= self.label_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ label = line[self.label_column]
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class Yelp5Processor(DataProcessor):
+ """Yelp5Processor."""
+
+ def get_train_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "train.csv"))
+
+ def get_dev_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "test.csv"))
+
+ def get_labels(self):
+ """See base class."""
+ return ["1", "2", "3", "4", "5"]
+
+ def _create_examples(self, input_file):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ with tf.io.gfile.GFile(input_file) as f:
+ reader = csv.reader(f)
+ for i, line in enumerate(reader):
+
+ label = line[0]
+ text_a = line[1].replace('""', '"').replace('\\"', '"')
+ examples.append(
+ InputExample(guid=str(i), text_a=text_a, text_b=None, label=label))
+ return examples
+
+
+class ImdbProcessor(DataProcessor):
+ """ImdbProcessor."""
+
+ def get_labels(self):
+ return ["neg", "pos"]
+
+ def get_train_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "train"))
+
+ def get_dev_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "test"))
+
+ def _create_examples(self, data_dir):
+ """Creates examples."""
+ examples = []
+ for label in ["neg", "pos"]:
+ cur_dir = os.path.join(data_dir, label)
+ for filename in tf.io.gfile.listdir(cur_dir):
+ if not filename.endswith("txt"):
+ continue
+
+ if len(examples) % 1000 == 0:
+ logging.info("Loading dev example %d", len(examples))
+
+ path = os.path.join(cur_dir, filename)
+ with tf.io.gfile.GFile(path) as f:
+ text = f.read().strip().replace(" ", " ")
+ examples.append(
+ InputExample(
+ guid="unused_id", text_a=text, text_b=None, label=label))
+ return examples
+
+
+class MnliMatchedProcessor(GLUEProcessor):
+ """MnliMatchedProcessor."""
+
+ def __init__(self):
+ super(MnliMatchedProcessor, self).__init__()
+ self.dev_file = "dev_matched.tsv"
+ self.test_file = "test_matched.tsv"
+ self.label_column = -1
+ self.text_a_column = 8
+ self.text_b_column = 9
+
+ def get_labels(self):
+ return ["contradiction", "entailment", "neutral"]
+
+
+class MnliMismatchedProcessor(MnliMatchedProcessor):
+
+ def __init__(self):
+ super(MnliMismatchedProcessor, self).__init__()
+ self.dev_file = "dev_mismatched.tsv"
+ self.test_file = "test_mismatched.tsv"
+
+
+class StsbProcessor(GLUEProcessor):
+ """StsbProcessor."""
+
+ def __init__(self):
+ super(StsbProcessor, self).__init__()
+ self.label_column = 9
+ self.text_a_column = 7
+ self.text_b_column = 8
+
+ def get_labels(self):
+ return [0.0]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0 and self.contains_header and set_type != "test":
+ continue
+ if i == 0 and self.test_contains_header and set_type == "test":
+ continue
+ guid = "%s-%s" % (set_type, i)
+
+ a_column = (
+ self.text_a_column if set_type != "test" else self.test_text_a_column)
+ b_column = (
+ self.text_b_column if set_type != "test" else self.test_text_b_column)
+
+ # there are some incomplete lines in QNLI
+ if len(line) <= a_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ text_a = line[a_column]
+
+ if b_column is not None:
+ if len(line) <= b_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ text_b = line[b_column]
+ else:
+ text_b = None
+
+ if set_type == "test":
+ label = self.get_labels()[0]
+ else:
+ if len(line) <= self.label_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ label = float(line[self.label_column])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+
+ return examples
+
+
+def file_based_convert_examples_to_features(examples,
+ label_list,
+ max_seq_length,
+ tokenize_fn,
+ output_file,
+ num_passes=1):
+ """Convert a set of `InputExample`s to a TFRecord file."""
+
+ # do not create duplicated records
+ if tf.io.gfile.exists(output_file) and not FLAGS.overwrite_data:
+ logging.info("Do not overwrite tfrecord %s exists.", output_file)
+ return
+
+ logging.info("Create new tfrecord %s.", output_file)
+
+ writer = tf.io.TFRecordWriter(output_file)
+
+ examples *= num_passes
+
+ for (ex_index, example) in enumerate(examples):
+ if ex_index % 10000 == 0:
+ logging.info("Writing example %d of %d", ex_index, len(examples))
+
+ feature = classifier_utils.convert_single_example(ex_index, example,
+ label_list,
+ max_seq_length,
+ tokenize_fn,
+ FLAGS.use_bert_format)
+
+ def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ features = collections.OrderedDict()
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_float_feature(feature.input_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+ if label_list is not None:
+ features["label_ids"] = create_int_feature([feature.label_id])
+ else:
+ features["label_ids"] = create_float_feature([float(feature.label_id)])
+ features["is_real_example"] = create_int_feature(
+ [int(feature.is_real_example)])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+def main(_):
+ logging.set_verbosity(logging.INFO)
+ processors = {
+ "mnli_matched": MnliMatchedProcessor,
+ "mnli_mismatched": MnliMismatchedProcessor,
+ "sts-b": StsbProcessor,
+ "imdb": ImdbProcessor,
+ "yelp5": Yelp5Processor
+ }
+
+ task_name = FLAGS.task_name.lower()
+
+ if task_name not in processors:
+ raise ValueError("Task not found: %s" % (task_name))
+
+ processor = processors[task_name]()
+ label_list = processor.get_labels() if not FLAGS.is_regression else None
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(FLAGS.spiece_model_file)
+
+ def tokenize_fn(text):
+ text = preprocess_utils.preprocess_text(text, lower=FLAGS.uncased)
+ return preprocess_utils.encode_ids(sp, text)
+
+ spm_basename = os.path.basename(FLAGS.spiece_model_file)
+
+ train_file_base = "{}.len-{}.train.tf_record".format(spm_basename,
+ FLAGS.max_seq_length)
+ train_file = os.path.join(FLAGS.output_dir, train_file_base)
+ logging.info("Use tfrecord file %s", train_file)
+
+ train_examples = processor.get_train_examples(FLAGS.data_dir)
+ np.random.shuffle(train_examples)
+ logging.info("Num of train samples: %d", len(train_examples))
+
+ file_based_convert_examples_to_features(train_examples, label_list,
+ FLAGS.max_seq_length, tokenize_fn,
+ train_file, FLAGS.num_passes)
+ if FLAGS.eval_split == "dev":
+ eval_examples = processor.get_dev_examples(FLAGS.data_dir)
+ else:
+ eval_examples = processor.get_test_examples(FLAGS.data_dir)
+
+ logging.info("Num of eval samples: %d", len(eval_examples))
+
+ # TPU requires a fixed batch size for all batches, therefore the number
+ # of examples must be a multiple of the batch size, or else examples
+ # will get dropped. So we pad with fake examples which are ignored
+ # later on. These do NOT count towards the metric (all tf.metrics
+ # support a per-instance weight, and these get a weight of 0.0).
+ #
+ # Modified in XL: We also adopt the same mechanism for GPUs.
+ while len(eval_examples) % FLAGS.eval_batch_size != 0:
+ eval_examples.append(classifier_utils.PaddingInputExample())
+
+ eval_file_base = "{}.len-{}.{}.eval.tf_record".format(spm_basename,
+ FLAGS.max_seq_length,
+ FLAGS.eval_split)
+ eval_file = os.path.join(FLAGS.output_dir, eval_file_base)
+
+ file_based_convert_examples_to_features(eval_examples, label_list,
+ FLAGS.max_seq_length, tokenize_fn,
+ eval_file)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/xlnet/preprocess_pretrain_data.py b/models/official/nlp/xlnet/preprocess_pretrain_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bf5367611ca656e88c969e4711334911e9cedd0
--- /dev/null
+++ b/models/official/nlp/xlnet/preprocess_pretrain_data.py
@@ -0,0 +1,998 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to pre-process pre-training data into tfrecords."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import random
+
+from absl import app
+from absl import flags
+import absl.logging as _logging # pylint: disable=unused-import
+
+import numpy as np
+
+
+import tensorflow.google as tf
+from official.nlp.xlnet import preprocess_utils
+import sentencepiece as spm
+
+
+special_symbols = {
+ "" : 0,
+ "" : 1,
+ "" : 2,
+ "" : 3,
+ "" : 4,
+ "" : 5,
+ "" : 6,
+ "" : 7,
+ "" : 8,
+}
+
+VOCAB_SIZE = 32000
+UNK_ID = special_symbols[""]
+CLS_ID = special_symbols[""]
+SEP_ID = special_symbols[""]
+MASK_ID = special_symbols[""]
+EOD_ID = special_symbols[""]
+
+
+def _int64_feature(values):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
+
+
+def _float_feature(values):
+ return tf.train.Feature(float_list=tf.train.FloatList(value=values))
+
+
+def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
+ mask_alpha=5, mask_beta=1, reuse_len=None, uncased=False,
+ fixed_num_predict=None):
+ """docs."""
+ if reuse_len is None:
+ reuse_len_str = ""
+ else:
+ reuse_len_str = "reuse-{}.".format(reuse_len)
+ if not uncased:
+ uncased_str = ""
+ else:
+ uncased_str = "uncased."
+ if bi_data:
+ bi_data_str = "bi"
+ else:
+ bi_data_str = "uni"
+ if fixed_num_predict is not None:
+ fnp_str = "fnp-{}.".format(fixed_num_predict)
+ else:
+ fnp_str = ""
+
+ file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format(
+ prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str,
+ mask_alpha, mask_beta, fnp_str, suffix)
+
+ return file_name
+
+
+def _create_data(idx, input_paths):
+ # Load sentence-piece model
+ sp = spm.SentencePieceProcessor()
+ sp.Load(FLAGS.sp_path)
+
+ input_shards = []
+ total_line_cnt = 0
+ for input_path in input_paths:
+ input_data, sent_ids = [], []
+ sent_id, line_cnt = True, 0
+ tf.logging.info("Processing %s", input_path)
+ for line in tf.gfile.Open(input_path):
+ if line_cnt % 100000 == 0:
+ tf.logging.info("Loading line %d", line_cnt)
+ line_cnt += 1
+
+ if not line.strip():
+ if FLAGS.use_eod:
+ sent_id = not sent_id
+ cur_sent = [EOD_ID]
+ else:
+ continue
+ else:
+ if FLAGS.from_raw_text:
+ cur_sent = preprocess_utils.preprocess_text(
+ line.strip(), lower=FLAGS.uncased)
+ cur_sent = preprocess_utils.encode_ids(sp, cur_sent)
+ else:
+ cur_sent = list(map(int, line.strip().split()))
+
+ input_data.extend(cur_sent)
+ sent_ids.extend([sent_id] * len(cur_sent))
+ sent_id = not sent_id
+
+ tf.logging.info("Finish with line %d", line_cnt)
+ if line_cnt == 0:
+ continue
+
+ input_data = np.array(input_data, dtype=np.int64)
+ sent_ids = np.array(sent_ids, dtype=np.bool)
+
+ total_line_cnt += line_cnt
+ input_shards.append((input_data, sent_ids))
+
+ tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
+
+ tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
+
+ filenames, num_batch = [], 0
+
+ # Randomly shuffle input shards (with a fixed but distinct random seed)
+ np.random.seed(100 * FLAGS.task + FLAGS.pass_id)
+
+ perm_indices = np.random.permutation(len(input_shards))
+ tf.logging.info("Using perm indices %s for pass %d",
+ perm_indices.tolist(), FLAGS.pass_id)
+
+ input_data_list, sent_ids_list = [], []
+ prev_sent_id = None
+ for perm_idx in perm_indices:
+ input_data, sent_ids = input_shards[perm_idx]
+ # make sure the `send_ids[0] == not prev_sent_id`
+ if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
+ sent_ids = np.logical_not(sent_ids)
+
+ # append to temporary list
+ input_data_list.append(input_data)
+ sent_ids_list.append(sent_ids)
+
+ # update `prev_sent_id`
+ prev_sent_id = sent_ids[-1]
+
+ input_data = np.concatenate(input_data_list)
+ sent_ids = np.concatenate(sent_ids_list)
+
+ file_name, cur_num_batch = create_tfrecords(
+ save_dir=tfrecord_dir,
+ basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id),
+ data=[input_data, sent_ids],
+ bsz_per_host=FLAGS.bsz_per_host,
+ seq_len=FLAGS.seq_len,
+ bi_data=FLAGS.bi_data,
+ sp=sp,
+ )
+
+ filenames.append(file_name)
+ num_batch += cur_num_batch
+
+ record_info = {
+ "filenames": filenames,
+ "num_batch": num_batch
+ }
+
+ return record_info
+
+
+def create_data(_):
+ # Validate FLAGS
+ assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0
+ if not FLAGS.use_tpu:
+ FLAGS.num_core_per_host = 1 # forced to be one
+
+ # Make workdirs
+ if not tf.gfile.Exists(FLAGS.save_dir):
+ tf.gfile.MakeDirs(FLAGS.save_dir)
+
+ tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
+ if not tf.gfile.Exists(tfrecord_dir):
+ tf.gfile.MakeDirs(tfrecord_dir)
+
+ # Create and dump corpus_info from task 0
+ if FLAGS.task == 0 and FLAGS.pass_id == 0:
+ corpus_info = {
+ "vocab_size": VOCAB_SIZE,
+ "bsz_per_host": FLAGS.bsz_per_host,
+ "num_core_per_host": FLAGS.num_core_per_host,
+ "seq_len": FLAGS.seq_len,
+ "reuse_len": FLAGS.reuse_len,
+ "uncased": FLAGS.uncased,
+ "bi_data": FLAGS.bi_data,
+ "mask_alpha": FLAGS.mask_alpha,
+ "mask_beta": FLAGS.mask_beta,
+ "num_predict": FLAGS.num_predict,
+ "use_eod": FLAGS.use_eod,
+ "sp_path": FLAGS.sp_path,
+ "input_glob": FLAGS.input_glob,
+ }
+ corpus_info_path = os.path.join(FLAGS.save_dir, "corpus_info.json")
+ with tf.gfile.Open(corpus_info_path, "w") as fp:
+ json.dump(corpus_info, fp)
+
+ # Interleavely split the work into FLAGS.num_task splits
+ file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob))
+ tf.logging.info("Use glob: %s", FLAGS.input_glob)
+ tf.logging.info("Find %d files: %s", len(file_paths), file_paths)
+
+ task_file_paths = file_paths[FLAGS.task::FLAGS.num_task]
+ if not task_file_paths:
+ tf.logging.info("Exit: task %d has no file to process.", FLAGS.task)
+ return
+
+ tf.logging.info("Task %d process %d files: %s",
+ FLAGS.task, len(task_file_paths), task_file_paths)
+ record_info = _create_data(FLAGS.task, task_file_paths)
+
+ record_prefix = "record_info-{}-{}-{}".format(
+ FLAGS.split, FLAGS.task, FLAGS.pass_id)
+ record_name = format_filename(
+ prefix=record_prefix,
+ bsz_per_host=FLAGS.bsz_per_host,
+ seq_len=FLAGS.seq_len,
+ mask_alpha=FLAGS.mask_alpha,
+ mask_beta=FLAGS.mask_beta,
+ reuse_len=FLAGS.reuse_len,
+ bi_data=FLAGS.bi_data,
+ suffix="json",
+ uncased=FLAGS.uncased,
+ fixed_num_predict=FLAGS.num_predict)
+ record_info_path = os.path.join(tfrecord_dir, record_name)
+
+ with tf.gfile.Open(record_info_path, "w") as fp:
+ json.dump(record_info, fp)
+
+
+def batchify(data, bsz_per_host, sent_ids=None):
+ num_step = len(data) // bsz_per_host
+ data = data[:bsz_per_host * num_step]
+ data = data.reshape(bsz_per_host, num_step)
+ if sent_ids is not None:
+ sent_ids = sent_ids[:bsz_per_host * num_step]
+ sent_ids = sent_ids.reshape(bsz_per_host, num_step)
+
+ if sent_ids is not None:
+ return data, sent_ids
+ return data
+
+
+def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
+ """Split two segments from `data` starting from the index `begin_idx`."""
+
+ data_len = data.shape[0]
+ if begin_idx + tot_len >= data_len:
+ tf.logging.info("[_split_a_and_b] returns None: "
+ "begin_idx %d + tot_len %d >= data_len %d",
+ begin_idx, tot_len, data_len)
+ return None
+
+ end_idx = begin_idx + 1
+ cut_points = []
+ while end_idx < data_len:
+ if sent_ids[end_idx] != sent_ids[end_idx - 1]:
+ if end_idx - begin_idx >= tot_len: break
+ cut_points.append(end_idx)
+ end_idx += 1
+
+ a_begin = begin_idx
+ if len(cut_points) == 0 or random.random() < 0.5:
+ label = 0
+ if len(cut_points) == 0:
+ a_end = end_idx
+ else:
+ a_end = random.choice(cut_points)
+
+ b_len = max(1, tot_len - (a_end - a_begin))
+ # (zihangd): `data_len - 1` to account for extend_target
+ b_begin = random.randint(0, data_len - 1 - b_len)
+ b_end = b_begin + b_len
+ while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]:
+ b_begin -= 1
+ # (zihangd): `data_len - 1` to account for extend_target
+ while b_end < data_len - 1 and sent_ids[b_end - 1] == sent_ids[b_end]:
+ b_end += 1
+
+ new_begin = a_end
+ else:
+ label = 1
+ a_end = random.choice(cut_points)
+ b_begin = a_end
+ b_end = end_idx
+
+ new_begin = b_end
+
+ while a_end - a_begin + b_end - b_begin > tot_len:
+ if a_end - a_begin > b_end - b_begin:
+ # delete the right side only for the LM objective
+ a_end -= 1
+ else:
+ b_end -= 1
+
+ ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin]
+
+ if extend_target:
+ if a_end >= data_len or b_end >= data_len:
+ tf.logging.info("[_split_a_and_b] returns None: "
+ "a_end %d or b_end %d >= data_len %d",
+ a_end, b_end, data_len)
+ return None
+ a_target = data[a_begin + 1: a_end + 1]
+ b_target = data[b_begin: b_end + 1]
+ ret.extend([a_target, b_target])
+
+ return ret
+
+
+def _is_start_piece(piece):
+ special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
+ if (piece.startswith("▁") or piece.startswith("<")
+ or piece in special_pieces):
+ return True
+ else:
+ return False
+
+
+def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
+ """Sample `goal_num_predict` tokens for partial prediction.
+ About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
+
+ seg_len = len(seg)
+ mask = np.array([False] * seg_len, dtype=np.bool)
+
+ num_predict = 0
+
+ ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
+ pvals = 1. / np.arange(1, max_gram + 1)
+ pvals /= pvals.sum(keepdims=True)
+
+ if reverse:
+ seg = np.flip(seg, 0)
+
+ cur_len = 0
+ while cur_len < seg_len:
+ if goal_num_predict is not None and num_predict >= goal_num_predict: break
+
+ n = np.random.choice(ngrams, p=pvals)
+ if goal_num_predict is not None:
+ n = min(n, goal_num_predict - num_predict)
+ ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
+ l_ctx = np.random.choice(ctx_size)
+ r_ctx = ctx_size - l_ctx
+
+ # Find the start position of a complete token
+ beg = cur_len + l_ctx
+ while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
+ beg += 1
+ if beg >= seg_len:
+ break
+
+ # Find the end position of the n-gram (start pos of the n+1-th gram)
+ end = beg + 1
+ cnt_ngram = 1
+ while end < seg_len:
+ cnt_ngram += 1
+ if cnt_ngram > n:
+ break
+ end += 1
+ if end >= seg_len:
+ break
+
+ # Update
+ mask[beg:end] = True
+ num_predict += end - beg
+
+ cur_len = end + r_ctx
+
+ while goal_num_predict is not None and num_predict < goal_num_predict:
+ i = np.random.randint(seg_len)
+ if not mask[i]:
+ mask[i] = True
+ num_predict += 1
+
+ if reverse:
+ mask = np.flip(mask, 0)
+
+ return mask
+
+
+def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
+ goal_num_predict=None):
+ """Sample `goal_num_predict` tokens for partial prediction.
+ About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
+
+ seg_len = len(seg)
+ mask = np.array([False] * seg_len, dtype=np.bool)
+
+ num_predict = 0
+
+ ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
+ pvals = 1. / np.arange(1, max_gram + 1)
+ pvals /= pvals.sum(keepdims=True)
+
+ if reverse:
+ seg = np.flip(seg, 0)
+
+ cur_len = 0
+ while cur_len < seg_len:
+ if goal_num_predict is not None and num_predict >= goal_num_predict: break
+
+ n = np.random.choice(ngrams, p=pvals)
+ if goal_num_predict is not None:
+ n = min(n, goal_num_predict - num_predict)
+ ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
+ l_ctx = np.random.choice(ctx_size)
+ r_ctx = ctx_size - l_ctx
+
+ # Find the start position of a complete token
+ beg = cur_len + l_ctx
+ while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
+ beg += 1
+ if beg >= seg_len:
+ break
+
+ # Find the end position of the n-gram (start pos of the n+1-th gram)
+ end = beg
+ cnt_ngram = 0
+ while end < seg_len:
+ if _is_start_piece(sp.IdToPiece(seg[end].item())):
+ cnt_ngram += 1
+ if cnt_ngram > n:
+ break
+
+ # select current piece
+ mask[end] = True
+
+ # update the end pointer and increment num_predict
+ end += 1
+ num_predict += 1
+
+ if goal_num_predict is not None and num_predict >= goal_num_predict:
+ break
+
+ cur_len = end + r_ctx
+
+ while goal_num_predict is not None and num_predict < goal_num_predict:
+ i = np.random.randint(seg_len)
+ if not mask[i]:
+ mask[i] = True
+ num_predict += 1
+
+ if reverse:
+ mask = np.flip(mask, 0)
+
+ return mask
+
+
+def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
+ bi_data, sp):
+ data, sent_ids = data[0], data[1]
+
+ num_core = FLAGS.num_core_per_host
+ bsz_per_core = bsz_per_host // num_core
+
+ if bi_data:
+ assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0
+ fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids)
+
+ fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1)
+ fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1)
+
+ bwd_data = fwd_data[:, :, :, ::-1]
+ bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1]
+
+ data = np.concatenate(
+ [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1)
+ sent_ids = np.concatenate(
+ [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1)
+ else:
+ data, sent_ids = batchify(data, bsz_per_host, sent_ids)
+
+ tf.logging.info("Raw data shape %s.", data.shape)
+
+ file_name = format_filename(
+ prefix=basename,
+ bsz_per_host=bsz_per_host,
+ seq_len=seq_len,
+ bi_data=bi_data,
+ suffix="tfrecords",
+ mask_alpha=FLAGS.mask_alpha,
+ mask_beta=FLAGS.mask_beta,
+ reuse_len=FLAGS.reuse_len,
+ uncased=FLAGS.uncased,
+ fixed_num_predict=FLAGS.num_predict
+ )
+ save_path = os.path.join(save_dir, file_name)
+ record_writer = tf.python_io.TFRecordWriter(save_path)
+ tf.logging.info("Start writing %s.", save_path)
+
+ num_batch = 0
+ reuse_len = FLAGS.reuse_len
+
+ # [sep] x 2 + [cls]
+ assert reuse_len < seq_len - 3
+
+ data_len = data.shape[1]
+ sep_array = np.array([SEP_ID], dtype=np.int64)
+ cls_array = np.array([CLS_ID], dtype=np.int64)
+
+ i = 0
+ while i + seq_len <= data_len:
+ if num_batch % 500 == 0:
+ tf.logging.info("Processing batch %d", num_batch)
+
+ all_ok = True
+ features = []
+ for idx in range(bsz_per_host):
+ inp = data[idx, i: i + reuse_len]
+ tgt = data[idx, i + 1: i + reuse_len + 1]
+
+ results = _split_a_and_b(
+ data[idx],
+ sent_ids[idx],
+ begin_idx=i + reuse_len,
+ tot_len=seq_len - reuse_len - 3,
+ extend_target=True)
+ if results is None:
+ tf.logging.info("Break out with seq idx %d", i)
+ all_ok = False
+ break
+
+ # unpack the results
+ (a_data, b_data, label, _, a_target, b_target) = tuple(results)
+
+ # sample ngram spans to predict
+ reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1
+ if FLAGS.num_predict is None:
+ num_predict_0 = num_predict_1 = None
+ else:
+ num_predict_1 = FLAGS.num_predict // 2
+ num_predict_0 = FLAGS.num_predict - num_predict_1
+ mask_0 = _sample_mask(sp, inp, reverse=reverse,
+ goal_num_predict=num_predict_0)
+ mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
+ sep_array, cls_array]),
+ reverse=reverse, goal_num_predict=num_predict_1)
+
+ # concatenate data
+ cat_data = np.concatenate([inp, a_data, sep_array, b_data,
+ sep_array, cls_array])
+ seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
+ [1] * b_data.shape[0] + [1] + [2])
+ assert cat_data.shape[0] == seq_len
+ assert mask_0.shape[0] == seq_len // 2
+ assert mask_1.shape[0] == seq_len // 2
+
+ # the last two CLS's are not used, just for padding purposes
+ tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
+ assert tgt.shape[0] == seq_len
+
+ is_masked = np.concatenate([mask_0, mask_1], 0)
+ if FLAGS.num_predict is not None:
+ assert np.sum(is_masked) == FLAGS.num_predict
+
+ feature = {
+ "input": _int64_feature(cat_data),
+ "is_masked": _int64_feature(is_masked),
+ "target": _int64_feature(tgt),
+ "seg_id": _int64_feature(seg_id),
+ "label": _int64_feature([label]),
+ }
+ features.append(feature)
+
+ if all_ok:
+ assert len(features) == bsz_per_host
+ for feature in features:
+ example = tf.train.Example(features=tf.train.Features(feature=feature))
+ record_writer.write(example.SerializeToString())
+ num_batch += 1
+ else:
+ break
+
+ i += reuse_len
+
+ record_writer.close()
+ tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
+
+ return save_path, num_batch
+
+
+################
+# get_input_fn #
+################
+def _convert_example(example, use_bfloat16):
+ """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16."""
+ for key in list(example.keys()):
+ val = example[key]
+ if tf.keras.backend.is_sparse(val):
+ val = tf.sparse.to_dense(val)
+ if val.dtype == tf.int64:
+ val = tf.cast(val, tf.int32)
+ if use_bfloat16 and val.dtype == tf.float32:
+ val = tf.cast(val, tf.bfloat16)
+
+ example[key] = val
+
+
+def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
+ host_id, num_core_per_host, bsz_per_core):
+ # list of file pathes
+ num_files = len(file_names)
+ num_files_per_host = num_files // num_hosts
+ my_start_file_id = host_id * num_files_per_host
+ my_end_file_id = (host_id + 1) * num_files_per_host
+ if host_id == num_hosts - 1:
+ my_end_file_id = num_files
+ file_paths = file_names[my_start_file_id: my_end_file_id]
+ tf.logging.info("Host %d handles %d files", host_id, len(file_paths))
+
+ assert split == "train"
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ # file-level shuffle
+ if len(file_paths) > 1:
+ dataset = dataset.shuffle(len(file_paths))
+
+ # Note: we cannot perform sample-level shuffle here because this will violate
+ # the consecutive requirement of data stream.
+ dataset = tf.data.TFRecordDataset(dataset)
+
+ # Note: since we are doing online preprocessing, the parsed result of
+ # the same input at each time will be different. Thus, cache processed data
+ # is not helpful. It will use a lot of memory and lead to contrainer OOM.
+ # So, change to cache non-parsed raw data instead.
+ dataset = dataset.cache().map(parser).repeat()
+ dataset = dataset.batch(bsz_per_core, drop_remainder=True)
+ dataset = dataset.prefetch(num_core_per_host * bsz_per_core)
+
+ return dataset
+
+
+def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
+ """
+ Sample a permutation of the factorization order, and create an
+ attention mask accordingly.
+
+ Args:
+ inputs: int64 Tensor in shape [seq_len], input ids.
+ targets: int64 Tensor in shape [seq_len], target ids.
+ is_masked: bool Tensor in shape [seq_len]. True means being selected
+ for partial prediction.
+ perm_size: the length of longest permutation. Could be set to be reuse_len.
+ Should not be larger than reuse_len or there will be data leaks.
+ seq_len: int, sequence length.
+ """
+
+ # Generate permutation indices
+ index = tf.range(seq_len, dtype=tf.int64)
+ index = tf.transpose(tf.reshape(index, [-1, perm_size]))
+ index = tf.random_shuffle(index)
+ index = tf.reshape(tf.transpose(index), [-1])
+
+ # `perm_mask` and `target_mask`
+ # non-functional tokens
+ non_func_tokens = tf.logical_not(tf.logical_or(
+ tf.equal(inputs, SEP_ID),
+ tf.equal(inputs, CLS_ID)))
+
+ non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
+ masked_or_func_tokens = tf.logical_not(non_mask_tokens)
+
+ # Set the permutation indices of non-masked (& non-funcional) tokens to the
+ # smallest index (-1):
+ # (1) they can be seen by all other positions
+ # (2) they cannot see masked positions, so there won"t be information leak
+ smallest_index = -tf.ones([seq_len], dtype=tf.int64)
+ rev_index = tf.where(non_mask_tokens, smallest_index, index)
+
+ # Create `target_mask`: non-funcional and maksed tokens
+ # 1: use mask as input and have loss
+ # 0: use token (or [SEP], [CLS]) as input and do not have loss
+ target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
+ target_mask = tf.cast(target_tokens, tf.float32)
+
+ # Create `perm_mask`
+ # `target_tokens` cannot see themselves
+ self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
+
+ # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
+ # 0: can attend if i > j or j is non-masked
+ perm_mask = tf.logical_and(
+ self_rev_index[:, None] <= rev_index[None, :],
+ masked_or_func_tokens)
+ perm_mask = tf.cast(perm_mask, tf.float32)
+
+ # new target: [next token] for LM and [curr token] (self) for PLM
+ new_targets = tf.concat([inputs[0: 1], targets[: -1]],
+ axis=0)
+
+ # construct inputs_k
+ inputs_k = inputs
+
+ # construct inputs_q
+ inputs_q = target_mask
+
+ return perm_mask, new_targets, target_mask, inputs_k, inputs_q
+
+
+def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
+ num_batch, seq_len, reuse_len, perm_size, mask_alpha,
+ mask_beta, use_bfloat16=False, num_predict=None):
+
+ bsz_per_core = params["batch_size"]
+ if num_hosts > 1:
+ host_id = params["context"].current_host
+ else:
+ host_id = 0
+
+ #### Function used to parse tfrecord
+ def parser(record):
+ """function used to parse tfrecord."""
+
+ record_spec = {
+ "input": tf.FixedLenFeature([seq_len], tf.int64),
+ "target": tf.FixedLenFeature([seq_len], tf.int64),
+ "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
+ "label": tf.FixedLenFeature([1], tf.int64),
+ "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
+ }
+
+ # retrieve serialized example
+ example = tf.parse_single_example(
+ serialized=record,
+ features=record_spec)
+
+ inputs = example.pop("input")
+ target = example.pop("target")
+ is_masked = tf.cast(example.pop("is_masked"), tf.bool)
+
+ non_reuse_len = seq_len - reuse_len
+ assert perm_size <= reuse_len and perm_size <= non_reuse_len
+
+ perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
+ inputs[:reuse_len],
+ target[:reuse_len],
+ is_masked[:reuse_len],
+ perm_size,
+ reuse_len)
+
+ perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
+ inputs[reuse_len:],
+ target[reuse_len:],
+ is_masked[reuse_len:],
+ perm_size,
+ non_reuse_len)
+
+ perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
+ axis=1)
+ perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
+ axis=1)
+ perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
+ target = tf.concat([target_0, target_1], axis=0)
+ target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
+ input_k = tf.concat([input_k_0, input_k_1], axis=0)
+ input_q = tf.concat([input_q_0, input_q_1], axis=0)
+
+ if num_predict is not None:
+ indices = tf.range(seq_len, dtype=tf.int64)
+ bool_target_mask = tf.cast(target_mask, tf.bool)
+ indices = tf.boolean_mask(indices, bool_target_mask)
+
+ ##### extra padding due to CLS/SEP introduced after prepro
+ actual_num_predict = tf.shape(indices)[0]
+ pad_len = num_predict - actual_num_predict
+
+ ##### target_mapping
+ target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
+ paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
+ target_mapping = tf.concat([target_mapping, paddings], axis=0)
+ example["target_mapping"] = tf.reshape(target_mapping,
+ [num_predict, seq_len])
+
+ ##### target
+ target = tf.boolean_mask(target, bool_target_mask)
+ paddings = tf.zeros([pad_len], dtype=target.dtype)
+ target = tf.concat([target, paddings], axis=0)
+ example["target"] = tf.reshape(target, [num_predict])
+
+ ##### target mask
+ target_mask = tf.concat(
+ [tf.ones([actual_num_predict], dtype=tf.float32),
+ tf.zeros([pad_len], dtype=tf.float32)],
+ axis=0)
+ example["target_mask"] = tf.reshape(target_mask, [num_predict])
+ else:
+ example["target"] = tf.reshape(target, [seq_len])
+ example["target_mask"] = tf.reshape(target_mask, [seq_len])
+
+ # reshape back to fixed shape
+ example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
+ example["input_k"] = tf.reshape(input_k, [seq_len])
+ example["input_q"] = tf.reshape(input_q, [seq_len])
+
+ _convert_example(example, use_bfloat16)
+
+ for k, v in example.items():
+ tf.logging.info("%s: %s", k, v)
+
+ return example
+
+ # Get dataset
+ dataset = parse_files_to_dataset(
+ parser=parser,
+ file_names=file_names,
+ split=split,
+ num_batch=num_batch,
+ num_hosts=num_hosts,
+ host_id=host_id,
+ num_core_per_host=num_core_per_host,
+ bsz_per_core=bsz_per_core)
+
+ return dataset
+
+
+def get_input_fn(
+ tfrecord_dir,
+ split,
+ bsz_per_host,
+ seq_len,
+ reuse_len,
+ bi_data,
+ num_hosts=1,
+ num_core_per_host=1,
+ perm_size=None,
+ mask_alpha=None,
+ mask_beta=None,
+ uncased=False,
+ num_passes=None,
+ use_bfloat16=False,
+ num_predict=None):
+
+ # Merge all record infos into a single one
+ record_glob_base = format_filename(
+ prefix="record_info-{}-*".format(split),
+ bsz_per_host=bsz_per_host,
+ seq_len=seq_len,
+ bi_data=bi_data,
+ suffix="json",
+ mask_alpha=mask_alpha,
+ mask_beta=mask_beta,
+ reuse_len=reuse_len,
+ uncased=uncased,
+ fixed_num_predict=num_predict)
+
+ record_info = {"num_batch": 0, "filenames": []}
+
+ tfrecord_dirs = tfrecord_dir.split(",")
+ tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
+
+ for idx, record_dir in enumerate(tfrecord_dirs):
+ record_glob = os.path.join(record_dir, record_glob_base)
+ tf.logging.info("[%d] Record glob: %s", idx, record_glob)
+
+ record_paths = sorted(tf.gfile.Glob(record_glob))
+ tf.logging.info("[%d] Num of record info path: %d",
+ idx, len(record_paths))
+
+ cur_record_info = {"num_batch": 0, "filenames": []}
+
+ for record_info_path in record_paths:
+ if num_passes is not None:
+ record_info_name = os.path.basename(record_info_path)
+ fields = record_info_name.split(".")[0].split("-")
+ pass_id = int(fields[-1])
+ if len(fields) == 5 and pass_id >= num_passes:
+ tf.logging.info("Skip pass %d: %s", pass_id, record_info_name)
+ continue
+
+ with tf.gfile.Open(record_info_path, "r") as fp:
+ info = json.load(fp)
+ if num_passes is not None:
+ eff_num_passes = min(num_passes, len(info["filenames"]))
+ ratio = eff_num_passes / len(info["filenames"])
+ cur_record_info["num_batch"] += int(info["num_batch"] * ratio)
+ cur_record_info["filenames"] += info["filenames"][:eff_num_passes]
+ else:
+ cur_record_info["num_batch"] += info["num_batch"]
+ cur_record_info["filenames"] += info["filenames"]
+
+ # overwrite directory for `cur_record_info`
+ new_filenames = []
+ for filename in cur_record_info["filenames"]:
+ basename = os.path.basename(filename)
+ new_filename = os.path.join(record_dir, basename)
+ new_filenames.append(new_filename)
+ cur_record_info["filenames"] = new_filenames
+
+ tf.logging.info("[Dir %d] Number of chosen batches: %s",
+ idx, cur_record_info["num_batch"])
+ tf.logging.info("[Dir %d] Number of chosen files: %s",
+ idx, len(cur_record_info["filenames"]))
+ tf.logging.info(cur_record_info["filenames"])
+
+ # add `cur_record_info` to global `record_info`
+ record_info["num_batch"] += cur_record_info["num_batch"]
+ record_info["filenames"] += cur_record_info["filenames"]
+
+ tf.logging.info("Total number of batches: %d",
+ record_info["num_batch"])
+ tf.logging.info("Total number of files: %d",
+ len(record_info["filenames"]))
+ tf.logging.info(record_info["filenames"])
+
+ def input_fn(params):
+ """docs."""
+ assert params["batch_size"] * num_core_per_host == bsz_per_host
+
+ dataset = get_dataset(
+ params=params,
+ num_hosts=num_hosts,
+ num_core_per_host=num_core_per_host,
+ split=split,
+ file_names=record_info["filenames"],
+ num_batch=record_info["num_batch"],
+ seq_len=seq_len,
+ reuse_len=reuse_len,
+ perm_size=perm_size,
+ mask_alpha=mask_alpha,
+ mask_beta=mask_beta,
+ use_bfloat16=use_bfloat16,
+ num_predict=num_predict)
+
+ return dataset
+
+ return input_fn, record_info
+
+
+if __name__ == "__main__":
+ FLAGS = flags.FLAGS
+ flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs")
+ flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.")
+ flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.")
+
+ flags.DEFINE_integer("seq_len", 512,
+ help="Sequence length.")
+ flags.DEFINE_integer("reuse_len", 256,
+ help="Number of token that can be reused as memory. "
+ "Could be half of `seq_len`.")
+ flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
+ flags.DEFINE_bool("bi_data", True,
+ help="whether to create bidirectional data")
+ flags.DEFINE_integer("mask_alpha", default=6,
+ help="How many tokens to form a group.")
+ flags.DEFINE_integer("mask_beta", default=1,
+ help="How many tokens to mask within each group.")
+ flags.DEFINE_bool("use_eod", True,
+ help="whether to append EOD at the end of a doc.")
+ flags.DEFINE_bool("from_raw_text", True,
+ help="Whether the input is raw text or encoded ids.")
+ flags.DEFINE_integer("num_predict", default=85,
+ help="Num of tokens to predict.")
+
+ flags.DEFINE_string("input_glob", "data/example/*.txt",
+ help="Input file glob.")
+ flags.DEFINE_string("sp_path", "", help="Path to the sentence piece model.")
+ flags.DEFINE_string("save_dir", "proc_data/example",
+ help="Directory for saving the processed data.")
+ flags.DEFINE_enum("split", "train", ["train", "dev", "test"],
+ help="Save the data as which split.")
+
+ flags.DEFINE_integer("pass_id", 0, help="ID of the current pass."
+ "Different passes sample different negative segment.")
+ flags.DEFINE_integer("num_task", 1, help="Number of total tasks.")
+ flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
+ "using multiple workers to identify each worker.")
+
+ tf.logging.set_verbosity(tf.logging.INFO)
+ app.run(create_data)
diff --git a/models/official/nlp/xlnet/preprocess_squad_data.py b/models/official/nlp/xlnet/preprocess_squad_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..59c8944697348f12b185399463978c170b4ee46b
--- /dev/null
+++ b/models/official/nlp/xlnet/preprocess_squad_data.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to pre-process SQUAD data into tfrecords."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+
+import sentencepiece as spm
+from official.nlp.xlnet import squad_utils
+
+flags.DEFINE_integer(
+ "num_proc", default=1, help="Number of preprocessing processes.")
+flags.DEFINE_integer("proc_id", default=0, help="Process id for preprocessing.")
+
+# I/O paths
+flags.DEFINE_string("output_dir", default="", help="Output dir for TF records.")
+flags.DEFINE_string(
+ "spiece_model_file", default="", help="Sentence Piece model path.")
+flags.DEFINE_string("train_file", default="", help="Path of train file.")
+flags.DEFINE_string("predict_file", default="", help="Path of prediction file.")
+
+# Data preprocessing config
+flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length")
+flags.DEFINE_integer("max_query_length", default=64, help="Max query length")
+flags.DEFINE_integer("doc_stride", default=128, help="Doc stride")
+flags.DEFINE_bool("uncased", default=False, help="Use uncased data.")
+flags.DEFINE_bool(
+ "create_train_data", default=True, help="Whether to create training data.")
+flags.DEFINE_bool(
+ "create_eval_data", default=False, help="Whether to create eval data.")
+
+FLAGS = flags.FLAGS
+
+
+def preprocess():
+ """Preprocesses SQUAD data."""
+ sp_model = spm.SentencePieceProcessor()
+ sp_model.Load(FLAGS.spiece_model_file)
+ spm_basename = os.path.basename(FLAGS.spiece_model_file)
+ if FLAGS.create_train_data:
+ train_rec_file = os.path.join(
+ FLAGS.output_dir,
+ "{}.{}.slen-{}.qlen-{}.train.tf_record".format(spm_basename,
+ FLAGS.proc_id,
+ FLAGS.max_seq_length,
+ FLAGS.max_query_length))
+
+ logging.info("Read examples from %s", FLAGS.train_file)
+ train_examples = squad_utils.read_squad_examples(
+ FLAGS.train_file, is_training=True)
+ train_examples = train_examples[FLAGS.proc_id::FLAGS.num_proc]
+
+ # Pre-shuffle the input to avoid having to make a very large shuffle
+ # buffer in the `input_fn`.
+ random.shuffle(train_examples)
+ write_to_logging = "Write to " + train_rec_file
+ logging.info(write_to_logging)
+ train_writer = squad_utils.FeatureWriter(
+ filename=train_rec_file, is_training=True)
+ squad_utils.convert_examples_to_features(
+ examples=train_examples,
+ sp_model=sp_model,
+ max_seq_length=FLAGS.max_seq_length,
+ doc_stride=FLAGS.doc_stride,
+ max_query_length=FLAGS.max_query_length,
+ is_training=True,
+ output_fn=train_writer.process_feature,
+ uncased=FLAGS.uncased)
+ train_writer.close()
+ if FLAGS.create_eval_data:
+ eval_examples = squad_utils.read_squad_examples(
+ FLAGS.predict_file, is_training=False)
+ squad_utils.create_eval_data(spm_basename, sp_model, eval_examples,
+ FLAGS.max_seq_length, FLAGS.max_query_length,
+ FLAGS.doc_stride, FLAGS.uncased,
+ FLAGS.output_dir)
+
+
+def main(_):
+ logging.set_verbosity(logging.INFO)
+
+ if not tf.io.gfile.exists(FLAGS.output_dir):
+ tf.io.gfile.mkdir(FLAGS.output_dir)
+
+ preprocess()
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/xlnet/preprocess_utils.py b/models/official/nlp/xlnet/preprocess_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0e8ae8398111ae73185a4594f1ab9d7dac7dd38
--- /dev/null
+++ b/models/official/nlp/xlnet/preprocess_utils.py
@@ -0,0 +1,125 @@
+# coding=utf-8
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for pre-processing."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import unicodedata
+
+import six
+
+
+SPIECE_UNDERLINE = '▁'
+
+
+def printable_text(text):
+ """Returns text encoded in a way suitable for print or `tf.logging`."""
+
+ # These functions want `str` for both Python2 and Python3, but in one case
+ # it's a Unicode string and in the other it's a byte string.
+ if six.PY3:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, bytes):
+ return text.decode('utf-8', 'ignore')
+ else:
+ raise ValueError('Unsupported string type: %s' % (type(text)))
+ elif six.PY2:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, unicode):
+ return text.encode('utf-8')
+ else:
+ raise ValueError('Unsupported string type: %s' % (type(text)))
+ else:
+ raise ValueError('Not running on Python2 or Python 3?')
+
+
+def print_(*args):
+ new_args = []
+ for arg in args:
+ if isinstance(arg, list):
+ s = [printable_text(i) for i in arg]
+ s = ' '.join(s)
+ new_args.append(s)
+ else:
+ new_args.append(printable_text(arg))
+ print(*new_args)
+
+
+def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
+ """Preprocesses texts."""
+ if remove_space:
+ outputs = ' '.join(inputs.strip().split())
+ else:
+ outputs = inputs
+
+ outputs = outputs.replace('``', '"').replace("''", '"')
+
+ if six.PY2 and isinstance(outputs, str):
+ outputs = outputs.decode('utf-8')
+
+ if not keep_accents:
+ outputs = unicodedata.normalize('NFKD', outputs)
+ outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
+ if lower:
+ outputs = outputs.lower()
+
+ return outputs
+
+
+def encode_pieces(sp_model, text, return_unicode=True, sample=False):
+ """Encodes pieces."""
+ # return_unicode is used only for py2
+
+ if six.PY2 and isinstance(text, unicode):
+ text = text.encode('utf-8')
+
+ if not sample:
+ pieces = sp_model.EncodeAsPieces(text)
+ else:
+ pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
+ new_pieces = []
+ for piece in pieces:
+ if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
+ cur_pieces = sp_model.EncodeAsPieces(
+ piece[:-1].replace(SPIECE_UNDERLINE, ''))
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
+ if len(cur_pieces[0]) == 1:
+ cur_pieces = cur_pieces[1:]
+ else:
+ cur_pieces[0] = cur_pieces[0][1:]
+ cur_pieces.append(piece[-1])
+ new_pieces.extend(cur_pieces)
+ else:
+ new_pieces.append(piece)
+
+ # note(zhiliny): convert back to unicode for py2
+ if six.PY2 and return_unicode:
+ ret_pieces = []
+ for piece in new_pieces:
+ if isinstance(piece, str):
+ piece = piece.decode('utf-8')
+ ret_pieces.append(piece)
+ new_pieces = ret_pieces
+
+ return new_pieces
+
+
+def encode_ids(sp_model, text, sample=False):
+ pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
+ ids = [sp_model.PieceToId(piece) for piece in pieces]
+ return ids
diff --git a/models/official/nlp/xlnet/run_classifier.py b/models/official/nlp/xlnet/run_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..79a27f244d87617ea3cb34913154e7725cc94b1f
--- /dev/null
+++ b/models/official/nlp/xlnet/run_classifier.py
@@ -0,0 +1,196 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""XLNet classification finetuning runner in tf2.0."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import functools
+from absl import app
+from absl import flags
+from absl import logging
+
+import numpy as np
+import tensorflow as tf
+# pylint: disable=unused-import
+from official.nlp.xlnet import common_flags
+from official.nlp.xlnet import data_utils
+from official.nlp.xlnet import optimization
+from official.nlp.xlnet import training_utils
+from official.nlp.xlnet import xlnet_config
+from official.nlp.xlnet import xlnet_modeling as modeling
+from official.utils.misc import tpu_lib
+
+flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
+flags.DEFINE_string(
+ "summary_type",
+ default="last",
+ help="Method used to summarize a sequence into a vector.")
+
+FLAGS = flags.FLAGS
+
+
+def get_classificationxlnet_model(model_config,
+ run_config,
+ n_class,
+ summary_type="last"):
+ model = modeling.ClassificationXLNetModel(
+ model_config, run_config, n_class, summary_type, name="model")
+ return model
+
+
+def run_evaluation(strategy,
+ test_input_fn,
+ eval_steps,
+ model,
+ step,
+ eval_summary_writer=None):
+ """Run evaluation for classification task.
+
+ Args:
+ strategy: distribution strategy.
+ test_input_fn: input function for evaluation data.
+ eval_steps: total number of evaluation steps.
+ model: keras model object.
+ step: current train step.
+ eval_summary_writer: summary writer used to record evaluation metrics. As
+ there are fake data samples in validation set, we use mask to get rid of
+ them when calculating the accuracy. For the reason that there will be
+ dynamic-shape tensor, we first collect logits, labels and masks from TPU
+ and calculate the accuracy via numpy locally.
+
+ Returns:
+ A float metric, accuracy.
+ """
+
+ def _test_step_fn(inputs):
+ """Replicated validation step."""
+
+ inputs["mems"] = None
+ _, logits = model(inputs, training=False)
+ return logits, inputs["label_ids"], inputs["is_real_example"]
+
+ @tf.function
+ def _run_evaluation(test_iterator):
+ """Runs validation steps."""
+ logits, labels, masks = strategy.run(
+ _test_step_fn, args=(next(test_iterator),))
+ return logits, labels, masks
+
+ test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
+ correct = 0
+ total = 0
+ for _ in range(eval_steps):
+ logits, labels, masks = _run_evaluation(test_iterator)
+ logits = strategy.experimental_local_results(logits)
+ labels = strategy.experimental_local_results(labels)
+ masks = strategy.experimental_local_results(masks)
+ merged_logits = []
+ merged_labels = []
+ merged_masks = []
+
+ for i in range(strategy.num_replicas_in_sync):
+ merged_logits.append(logits[i].numpy())
+ merged_labels.append(labels[i].numpy())
+ merged_masks.append(masks[i].numpy())
+ merged_logits = np.vstack(np.array(merged_logits))
+ merged_labels = np.hstack(np.array(merged_labels))
+ merged_masks = np.hstack(np.array(merged_masks))
+ real_index = np.where(np.equal(merged_masks, 1))
+ correct += np.sum(
+ np.equal(
+ np.argmax(merged_logits[real_index], axis=-1),
+ merged_labels[real_index]))
+ total += np.shape(real_index)[-1]
+ accuracy = float(correct) / float(total)
+ logging.info("Train step: %d / acc = %d/%d = %f", step, correct, total,
+ accuracy)
+ if eval_summary_writer:
+ with eval_summary_writer.as_default():
+ tf.summary.scalar("eval_acc", float(correct) / float(total), step=step)
+ eval_summary_writer.flush()
+ return accuracy
+
+
+def get_metric_fn():
+ train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy(
+ "acc", dtype=tf.float32)
+ return train_acc_metric
+
+
+def main(unused_argv):
+ del unused_argv
+ if FLAGS.strategy_type == "mirror":
+ strategy = tf.distribute.MirroredStrategy()
+ elif FLAGS.strategy_type == "tpu":
+ cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
+ strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
+ else:
+ raise ValueError("The distribution strategy type is not supported: %s" %
+ FLAGS.strategy_type)
+ if strategy:
+ logging.info("***** Number of cores used : %d",
+ strategy.num_replicas_in_sync)
+ train_input_fn = functools.partial(data_utils.get_classification_input_data,
+ FLAGS.train_batch_size, FLAGS.seq_len,
+ strategy, True, FLAGS.train_tfrecord_path)
+ test_input_fn = functools.partial(data_utils.get_classification_input_data,
+ FLAGS.test_batch_size, FLAGS.seq_len,
+ strategy, False, FLAGS.test_tfrecord_path)
+
+ total_training_steps = FLAGS.train_steps
+ steps_per_loop = FLAGS.iterations
+ eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
+ eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
+ eval_steps)
+ optimizer, learning_rate_fn = optimization.create_optimizer(
+ FLAGS.learning_rate,
+ total_training_steps,
+ FLAGS.warmup_steps,
+ adam_epsilon=FLAGS.adam_epsilon)
+ model_config = xlnet_config.XLNetConfig(FLAGS)
+ run_config = xlnet_config.create_run_config(True, False, FLAGS)
+ model_fn = functools.partial(get_classificationxlnet_model, model_config,
+ run_config, FLAGS.n_class, FLAGS.summary_type)
+ input_meta_data = {}
+ input_meta_data["d_model"] = FLAGS.d_model
+ input_meta_data["mem_len"] = FLAGS.mem_len
+ input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
+ strategy.num_replicas_in_sync)
+ input_meta_data["n_layer"] = FLAGS.n_layer
+ input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
+ input_meta_data["n_class"] = FLAGS.n_class
+
+ training_utils.train(
+ strategy=strategy,
+ model_fn=model_fn,
+ input_meta_data=input_meta_data,
+ eval_fn=eval_fn,
+ metric_fn=get_metric_fn,
+ train_input_fn=train_input_fn,
+ init_checkpoint=FLAGS.init_checkpoint,
+ init_from_transformerxl=FLAGS.init_from_transformerxl,
+ total_training_steps=total_training_steps,
+ steps_per_loop=steps_per_loop,
+ optimizer=optimizer,
+ learning_rate_fn=learning_rate_fn,
+ model_dir=FLAGS.model_dir,
+ save_steps=FLAGS.save_steps)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/xlnet/run_pretrain.py b/models/official/nlp/xlnet/run_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..e136f4d12ab01d0b48c0d0765b8e3e8bbf8eedd7
--- /dev/null
+++ b/models/official/nlp/xlnet/run_pretrain.py
@@ -0,0 +1,156 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""XLNet classification finetuning runner in tf2.0."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import functools
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+# pylint: disable=unused-import
+from official.nlp.xlnet import common_flags
+from official.nlp.xlnet import data_utils
+from official.nlp.xlnet import optimization
+from official.nlp.xlnet import training_utils
+from official.nlp.xlnet import xlnet_config
+from official.nlp.xlnet import xlnet_modeling as modeling
+from official.utils.misc import tpu_lib
+
+flags.DEFINE_integer(
+ "num_predict",
+ default=None,
+ help="Number of tokens to predict in partial prediction.")
+
+# FLAGS for pretrain input preprocessing
+flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
+flags.DEFINE_float("leak_ratio", default=0.1,
+ help="Percent of masked tokens that are leaked.")
+
+flags.DEFINE_enum("sample_strategy", default="token_span",
+ enum_values=["single_token", "whole_word", "token_span",
+ "word_span"],
+ help="Stragey used to sample prediction targets.")
+flags.DEFINE_integer("max_num_tokens", default=5,
+ help="Maximum number of tokens to sample in a span."
+ "Effective when token_span strategy is used.")
+flags.DEFINE_integer("min_num_tokens", default=1,
+ help="Minimum number of tokens to sample in a span."
+ "Effective when token_span strategy is used.")
+
+flags.DEFINE_integer("max_num_words", default=5,
+ help="Maximum number of whole words to sample in a span."
+ "Effective when word_span strategy is used.")
+flags.DEFINE_integer("min_num_words", default=1,
+ help="Minimum number of whole words to sample in a span."
+ "Effective when word_span strategy is used.")
+FLAGS = flags.FLAGS
+
+
+def get_pretrainxlnet_model(model_config, run_config):
+ return modeling.PretrainingXLNetModel(
+ use_proj=True,
+ xlnet_config=model_config,
+ run_config=run_config,
+ name="model")
+
+
+def main(unused_argv):
+ del unused_argv
+ num_hosts = 1
+ if FLAGS.strategy_type == "mirror":
+ strategy = tf.distribute.MirroredStrategy()
+ elif FLAGS.strategy_type == "tpu":
+ cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
+ strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
+ topology = FLAGS.tpu_topology.split("x")
+ total_num_core = 2 * int(topology[0]) * int(topology[1])
+ num_hosts = total_num_core // FLAGS.num_core_per_host
+ else:
+ raise ValueError("The distribution strategy type is not supported: %s" %
+ FLAGS.strategy_type)
+ if strategy:
+ logging.info("***** Number of cores used : %d",
+ strategy.num_replicas_in_sync)
+ logging.info("***** Number of hosts used : %d", num_hosts)
+ online_masking_config = data_utils.OnlineMaskingConfig(
+ sample_strategy=FLAGS.sample_strategy,
+ max_num_tokens=FLAGS.max_num_tokens,
+ min_num_tokens=FLAGS.min_num_tokens,
+ max_num_words=FLAGS.max_num_words,
+ min_num_words=FLAGS.min_num_words)
+
+ train_input_fn = functools.partial(
+ data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
+ strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
+ FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config,
+ num_hosts)
+
+ total_training_steps = FLAGS.train_steps
+
+ steps_per_loop = FLAGS.iterations
+
+ optimizer, learning_rate_fn = optimization.create_optimizer(
+ init_lr=FLAGS.learning_rate,
+ num_train_steps=total_training_steps,
+ num_warmup_steps=FLAGS.warmup_steps,
+ min_lr_ratio=FLAGS.min_lr_ratio,
+ adam_epsilon=FLAGS.adam_epsilon,
+ weight_decay_rate=FLAGS.weight_decay_rate)
+
+ model_config = xlnet_config.XLNetConfig(FLAGS)
+ run_config = xlnet_config.create_run_config(True, False, FLAGS)
+ input_meta_data = {}
+ input_meta_data["d_model"] = FLAGS.d_model
+ input_meta_data["mem_len"] = FLAGS.mem_len
+ input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
+ strategy.num_replicas_in_sync)
+ input_meta_data["n_layer"] = FLAGS.n_layer
+ input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
+ model_fn = functools.partial(get_pretrainxlnet_model, model_config,
+ run_config)
+
+ model = training_utils.train(
+ strategy=strategy,
+ model_fn=model_fn,
+ input_meta_data=input_meta_data,
+ eval_fn=None,
+ metric_fn=None,
+ train_input_fn=train_input_fn,
+ init_checkpoint=FLAGS.init_checkpoint,
+ init_from_transformerxl=FLAGS.init_from_transformerxl,
+ total_training_steps=total_training_steps,
+ steps_per_loop=steps_per_loop,
+ optimizer=optimizer,
+ learning_rate_fn=learning_rate_fn,
+ model_dir=FLAGS.model_dir,
+ save_steps=FLAGS.save_steps)
+
+ # Export transformer-xl model checkpoint to be used in finetuning.
+ checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model)
+ saved_path = checkpoint.save(
+ os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt"))
+ logging.info("Exporting the transformer-xl model as a new TF checkpoint: %s",
+ saved_path)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/xlnet/run_squad.py b/models/official/nlp/xlnet/run_squad.py
new file mode 100644
index 0000000000000000000000000000000000000000..013893f1a289bb446dd67f33d9178903f706b2c8
--- /dev/null
+++ b/models/official/nlp/xlnet/run_squad.py
@@ -0,0 +1,304 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""XLNet SQUAD finetuning runner in tf2.0."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import functools
+import json
+import os
+import pickle
+
+from absl import app
+from absl import flags
+from absl import logging
+
+import tensorflow as tf
+# pylint: disable=unused-import
+import sentencepiece as spm
+from official.nlp.xlnet import common_flags
+from official.nlp.xlnet import data_utils
+from official.nlp.xlnet import optimization
+from official.nlp.xlnet import squad_utils
+from official.nlp.xlnet import training_utils
+from official.nlp.xlnet import xlnet_config
+from official.nlp.xlnet import xlnet_modeling as modeling
+from official.utils.misc import tpu_lib
+
+flags.DEFINE_string(
+ "test_feature_path", default=None, help="Path to feature of test set.")
+flags.DEFINE_integer("query_len", default=64, help="Max query length.")
+flags.DEFINE_integer("start_n_top", default=5, help="Beam size for span start.")
+flags.DEFINE_integer("end_n_top", default=5, help="Beam size for span end.")
+flags.DEFINE_string(
+ "predict_dir", default=None, help="Path to write predictions.")
+flags.DEFINE_string(
+ "predict_file", default=None, help="Path to json file of test set.")
+flags.DEFINE_integer(
+ "n_best_size", default=5, help="n best size for predictions.")
+flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.")
+# Data preprocessing config
+flags.DEFINE_string(
+ "spiece_model_file", default=None, help="Sentence Piece model path.")
+flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length.")
+flags.DEFINE_integer("max_query_length", default=64, help="Max query length.")
+flags.DEFINE_integer("doc_stride", default=128, help="Doc stride.")
+
+FLAGS = flags.FLAGS
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ example_index,
+ doc_span_index,
+ tok_start_to_orig_index,
+ tok_end_to_orig_index,
+ token_is_max_context,
+ input_ids,
+ input_mask,
+ p_mask,
+ segment_ids,
+ paragraph_len,
+ cls_index,
+ start_position=None,
+ end_position=None,
+ is_impossible=None):
+ self.unique_id = unique_id
+ self.example_index = example_index
+ self.doc_span_index = doc_span_index
+ self.tok_start_to_orig_index = tok_start_to_orig_index
+ self.tok_end_to_orig_index = tok_end_to_orig_index
+ self.token_is_max_context = token_is_max_context
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.p_mask = p_mask
+ self.segment_ids = segment_ids
+ self.paragraph_len = paragraph_len
+ self.cls_index = cls_index
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+
+# pylint: disable=unused-argument
+def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
+ original_data, eval_steps, input_meta_data, model,
+ current_step, eval_summary_writer):
+ """Run evaluation for SQUAD task.
+
+ Args:
+ strategy: distribution strategy.
+ test_input_fn: input function for evaluation data.
+ eval_examples: tf.Examples of the evaluation set.
+ eval_features: Feature objects of the evaluation set.
+ original_data: The original json data for the evaluation set.
+ eval_steps: total number of evaluation steps.
+ input_meta_data: input meta data.
+ model: keras model object.
+ current_step: current training step.
+ eval_summary_writer: summary writer used to record evaluation metrics.
+
+ Returns:
+ A float metric, F1 score.
+ """
+
+ def _test_step_fn(inputs):
+ """Replicated validation step."""
+
+ inputs["mems"] = None
+ res = model(inputs, training=False)
+ return res, inputs["unique_ids"]
+
+ @tf.function
+ def _run_evaluation(test_iterator):
+ """Runs validation steps."""
+ res, unique_ids = strategy.run(
+ _test_step_fn, args=(next(test_iterator),))
+ return res, unique_ids
+
+ test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
+ cur_results = []
+ for _ in range(eval_steps):
+ results, unique_ids = _run_evaluation(test_iterator)
+ unique_ids = strategy.experimental_local_results(unique_ids)
+
+ for result_key in results:
+ results[result_key] = (
+ strategy.experimental_local_results(results[result_key]))
+ for core_i in range(strategy.num_replicas_in_sync):
+ bsz = int(input_meta_data["test_batch_size"] /
+ strategy.num_replicas_in_sync)
+ for j in range(bsz):
+ result = {}
+ for result_key in results:
+ result[result_key] = results[result_key][core_i].numpy()[j]
+ result["unique_ids"] = unique_ids[core_i].numpy()[j]
+ # We appended a fake example into dev set to make data size can be
+ # divided by test_batch_size. Ignores this fake example during
+ # evaluation.
+ if result["unique_ids"] == 1000012047:
+ continue
+ unique_id = int(result["unique_ids"])
+
+ start_top_log_probs = ([
+ float(x) for x in result["start_top_log_probs"].flat
+ ])
+ start_top_index = [int(x) for x in result["start_top_index"].flat]
+ end_top_log_probs = ([
+ float(x) for x in result["end_top_log_probs"].flat
+ ])
+ end_top_index = [int(x) for x in result["end_top_index"].flat]
+
+ cls_logits = float(result["cls_logits"].flat[0])
+ cur_results.append(
+ squad_utils.RawResult(
+ unique_id=unique_id,
+ start_top_log_probs=start_top_log_probs,
+ start_top_index=start_top_index,
+ end_top_log_probs=end_top_log_probs,
+ end_top_index=end_top_index,
+ cls_logits=cls_logits))
+ if len(cur_results) % 1000 == 0:
+ logging.info("Processing example: %d", len(cur_results))
+
+ output_prediction_file = os.path.join(input_meta_data["predict_dir"],
+ "predictions.json")
+ output_nbest_file = os.path.join(input_meta_data["predict_dir"],
+ "nbest_predictions.json")
+ output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"],
+ "null_odds.json")
+
+ results = squad_utils.write_predictions(
+ eval_examples, eval_features, cur_results, input_meta_data["n_best_size"],
+ input_meta_data["max_answer_length"], output_prediction_file,
+ output_nbest_file, output_null_log_odds_file, original_data,
+ input_meta_data["start_n_top"], input_meta_data["end_n_top"])
+
+ # Log current results.
+ log_str = "Result | "
+ for key, val in results.items():
+ log_str += "{} {} | ".format(key, val)
+ logging.info(log_str)
+ with eval_summary_writer.as_default():
+ tf.summary.scalar("best_f1", results["best_f1"], step=current_step)
+ tf.summary.scalar("best_exact", results["best_exact"], step=current_step)
+ eval_summary_writer.flush()
+ return results["best_f1"]
+
+
+def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
+ model = modeling.QAXLNetModel(
+ model_config,
+ run_config,
+ start_n_top=start_n_top,
+ end_n_top=end_n_top,
+ name="model")
+ return model
+
+
+def main(unused_argv):
+ del unused_argv
+ if FLAGS.strategy_type == "mirror":
+ strategy = tf.distribute.MirroredStrategy()
+ elif FLAGS.strategy_type == "tpu":
+ cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
+ strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
+ else:
+ raise ValueError("The distribution strategy type is not supported: %s" %
+ FLAGS.strategy_type)
+ if strategy:
+ logging.info("***** Number of cores used : %d",
+ strategy.num_replicas_in_sync)
+ train_input_fn = functools.partial(data_utils.get_squad_input_data,
+ FLAGS.train_batch_size, FLAGS.seq_len,
+ FLAGS.query_len, strategy, True,
+ FLAGS.train_tfrecord_path)
+
+ test_input_fn = functools.partial(data_utils.get_squad_input_data,
+ FLAGS.test_batch_size, FLAGS.seq_len,
+ FLAGS.query_len, strategy, False,
+ FLAGS.test_tfrecord_path)
+
+ total_training_steps = FLAGS.train_steps
+ steps_per_loop = FLAGS.iterations
+ eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
+
+ optimizer, learning_rate_fn = optimization.create_optimizer(
+ FLAGS.learning_rate,
+ total_training_steps,
+ FLAGS.warmup_steps,
+ adam_epsilon=FLAGS.adam_epsilon)
+ model_config = xlnet_config.XLNetConfig(FLAGS)
+ run_config = xlnet_config.create_run_config(True, False, FLAGS)
+ input_meta_data = {}
+ input_meta_data["start_n_top"] = FLAGS.start_n_top
+ input_meta_data["end_n_top"] = FLAGS.end_n_top
+ input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
+ input_meta_data["predict_dir"] = FLAGS.predict_dir
+ input_meta_data["n_best_size"] = FLAGS.n_best_size
+ input_meta_data["max_answer_length"] = FLAGS.max_answer_length
+ input_meta_data["test_batch_size"] = FLAGS.test_batch_size
+ input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
+ strategy.num_replicas_in_sync)
+ input_meta_data["mem_len"] = FLAGS.mem_len
+ model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
+ FLAGS.start_n_top, FLAGS.end_n_top)
+ eval_examples = squad_utils.read_squad_examples(
+ FLAGS.predict_file, is_training=False)
+ if FLAGS.test_feature_path:
+ logging.info("start reading pickle file...")
+ with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
+ eval_features = pickle.load(f)
+ logging.info("finishing reading pickle file...")
+ else:
+ sp_model = spm.SentencePieceProcessor()
+ sp_model.LoadFromSerializedProto(
+ tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read())
+ spm_basename = os.path.basename(FLAGS.spiece_model_file)
+ eval_features = squad_utils.create_eval_data(
+ spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
+ FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)
+
+ with tf.io.gfile.GFile(FLAGS.predict_file) as f:
+ original_data = json.load(f)["data"]
+ eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
+ eval_examples, eval_features, original_data,
+ eval_steps, input_meta_data)
+
+ training_utils.train(
+ strategy=strategy,
+ model_fn=model_fn,
+ input_meta_data=input_meta_data,
+ eval_fn=eval_fn,
+ metric_fn=None,
+ train_input_fn=train_input_fn,
+ init_checkpoint=FLAGS.init_checkpoint,
+ init_from_transformerxl=FLAGS.init_from_transformerxl,
+ total_training_steps=total_training_steps,
+ steps_per_loop=steps_per_loop,
+ optimizer=optimizer,
+ learning_rate_fn=learning_rate_fn,
+ model_dir=FLAGS.model_dir,
+ save_steps=FLAGS.save_steps)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/nlp/xlnet/squad_utils.py b/models/official/nlp/xlnet/squad_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..efab6da6f80658213317e13dee86b09b2cb94c63
--- /dev/null
+++ b/models/official/nlp/xlnet/squad_utils.py
@@ -0,0 +1,973 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# coding=utf-8
+"""Utilities used in SQUAD task."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import collections
+import gc
+import json
+import math
+import os
+import pickle
+import re
+import string
+
+from absl import logging
+import numpy as np
+import six
+import tensorflow as tf
+
+from official.nlp.xlnet import data_utils
+from official.nlp.xlnet import preprocess_utils
+
+SPIECE_UNDERLINE = u"▁"
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ example_index,
+ doc_span_index,
+ tok_start_to_orig_index,
+ tok_end_to_orig_index,
+ token_is_max_context,
+ input_ids,
+ input_mask,
+ p_mask,
+ segment_ids,
+ paragraph_len,
+ cls_index,
+ start_position=None,
+ end_position=None,
+ is_impossible=None):
+ self.unique_id = unique_id
+ self.example_index = example_index
+ self.doc_span_index = doc_span_index
+ self.tok_start_to_orig_index = tok_start_to_orig_index
+ self.tok_end_to_orig_index = tok_end_to_orig_index
+ self.token_is_max_context = token_is_max_context
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.p_mask = p_mask
+ self.segment_ids = segment_ids
+ self.paragraph_len = paragraph_len
+ self.cls_index = cls_index
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+
+def make_qid_to_has_ans(dataset):
+ qid_to_has_ans = {}
+ for article in dataset:
+ for p in article["paragraphs"]:
+ for qa in p["qas"]:
+ qid_to_has_ans[qa["id"]] = bool(qa["answers"])
+ return qid_to_has_ans
+
+
+def get_raw_scores(dataset, preds):
+ """Gets exact scores and f1 scores."""
+ exact_scores = {}
+ f1_scores = {}
+ for article in dataset:
+ for p in article["paragraphs"]:
+ for qa in p["qas"]:
+ qid = qa["id"]
+ gold_answers = [
+ a["text"] for a in qa["answers"] if normalize_answer(a["text"])
+ ]
+ if not gold_answers:
+ # For unanswerable questions, only correct answer is empty string
+ gold_answers = [""]
+ if qid not in preds:
+ print("Missing prediction for %s" % qid)
+ continue
+ a_pred = preds[qid]
+ # Take max over all gold answers
+ exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
+ f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
+ return exact_scores, f1_scores
+
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
+ return re.sub(regex, " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def compute_exact(a_gold, a_pred):
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
+
+
+def get_tokens(s):
+ if not s:
+ return []
+ return normalize_answer(s).split()
+
+
+def compute_f1(a_gold, a_pred):
+ """Computes f1 score."""
+ gold_toks = get_tokens(a_gold)
+ pred_toks = get_tokens(a_pred)
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
+ num_same = sum(common.values())
+ # pylint: disable=g-explicit-length-test
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
+ return int(gold_toks == pred_toks)
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(pred_toks)
+ recall = 1.0 * num_same / len(gold_toks)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
+ """Finds best threshold."""
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
+ cur_score = num_no_ans
+ best_score = cur_score
+ best_thresh = 0.0
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
+ for qid in qid_list:
+ if qid not in scores:
+ continue
+ if qid_to_has_ans[qid]:
+ diff = scores[qid]
+ else:
+ if preds[qid]:
+ diff = -1
+ else:
+ diff = 0
+ cur_score += diff
+ if cur_score > best_score:
+ best_score = cur_score
+ best_thresh = na_probs[qid]
+
+ has_ans_score, has_ans_cnt = 0, 0
+ for qid in qid_list:
+ if not qid_to_has_ans[qid]:
+ continue
+ has_ans_cnt += 1
+
+ if qid not in scores:
+ continue
+ has_ans_score += scores[qid]
+
+ return 100.0 * best_score / len(
+ scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
+
+
+def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs,
+ qid_to_has_ans):
+ """Finds all best threshold."""
+ best_exact, exact_thresh, has_ans_exact = find_best_thresh(
+ preds, exact_raw, na_probs, qid_to_has_ans)
+ best_f1, f1_thresh, has_ans_f1 = find_best_thresh(preds, f1_raw, na_probs,
+ qid_to_has_ans)
+ main_eval["best_exact"] = best_exact
+ main_eval["best_exact_thresh"] = exact_thresh
+ main_eval["best_f1"] = best_f1
+ main_eval["best_f1_thresh"] = f1_thresh
+ main_eval["has_ans_exact"] = has_ans_exact
+ main_eval["has_ans_f1"] = has_ans_f1
+
+
+_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "PrelimPrediction", [
+ "feature_index", "start_index", "end_index", "start_log_prob",
+ "end_log_prob"
+ ])
+
+_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
+RawResult = collections.namedtuple("RawResult", [
+ "unique_id", "start_top_log_probs", "start_top_index", "end_top_log_probs",
+ "end_top_index", "cls_logits"
+])
+
+
+def _compute_softmax(scores):
+ """Computes softmax probability over raw logits."""
+ if not scores:
+ return []
+
+ max_score = None
+ for score in scores:
+ if max_score is None or score > max_score:
+ max_score = score
+
+ exp_scores = []
+ total_sum = 0.0
+ for score in scores:
+ x = math.exp(score - max_score)
+ exp_scores.append(x)
+ total_sum += x
+
+ probs = []
+ for score in exp_scores:
+ probs.append(score / total_sum)
+ return probs
+
+
+class SquadExample(object):
+ """A single training/test example for simple sequence classification.
+
+ For examples without an answer, the start and end position are -1.
+ """
+
+ def __init__(self,
+ qas_id,
+ question_text,
+ paragraph_text,
+ orig_answer_text=None,
+ start_position=None,
+ is_impossible=False):
+ self.qas_id = qas_id
+ self.question_text = question_text
+ self.paragraph_text = paragraph_text
+ self.orig_answer_text = orig_answer_text
+ self.start_position = start_position
+ self.is_impossible = is_impossible
+
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ s = ""
+ s += "qas_id: %s" % (preprocess_utils.printable_text(self.qas_id))
+ s += ", question_text: %s" % (
+ preprocess_utils.printable_text(self.question_text))
+ s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
+ if self.start_position:
+ s += ", start_position: %d" % (self.start_position)
+ if self.start_position:
+ s += ", is_impossible: %r" % (self.is_impossible)
+ return s
+
+
+def write_predictions(all_examples, all_features, all_results, n_best_size,
+ max_answer_length, output_prediction_file,
+ output_nbest_file, output_null_log_odds_file, orig_data,
+ start_n_top, end_n_top):
+ """Writes final predictions to the json file and log-odds of null if needed."""
+ logging.info("Writing predictions to: %s", (output_prediction_file))
+
+ example_index_to_features = collections.defaultdict(list)
+ for feature in all_features:
+ example_index_to_features[feature.example_index].append(feature)
+
+ unique_id_to_result = {}
+ for result in all_results:
+ unique_id_to_result[result.unique_id] = result
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict()
+
+ for (example_index, example) in enumerate(all_examples):
+ features = example_index_to_features[example_index]
+
+ prelim_predictions = []
+ # keep track of the minimum score of null start+end of position 0
+ score_null = 1000000 # large and positive
+
+ for (feature_index, feature) in enumerate(features):
+ result = unique_id_to_result[feature.unique_id]
+
+ cur_null_score = result.cls_logits
+
+ # if we could have irrelevant answers, get the min score of irrelevant
+ score_null = min(score_null, cur_null_score)
+
+ for i in range(start_n_top):
+ for j in range(end_n_top):
+ start_log_prob = result.start_top_log_probs[i]
+ start_index = result.start_top_index[i]
+
+ j_index = i * end_n_top + j
+
+ end_log_prob = result.end_top_log_probs[j_index]
+ end_index = result.end_top_index[j_index]
+
+ # We could hypothetically create invalid predictions, e.g., predict
+ # that the start of the span is in the question. We throw out all
+ # invalid predictions.
+ if start_index >= feature.paragraph_len - 1:
+ continue
+ if end_index >= feature.paragraph_len - 1:
+ continue
+
+ if not feature.token_is_max_context.get(start_index, False):
+ continue
+ if end_index < start_index:
+ continue
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ continue
+
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=feature_index,
+ start_index=start_index,
+ end_index=end_index,
+ start_log_prob=start_log_prob,
+ end_log_prob=end_log_prob))
+
+ prelim_predictions = sorted(
+ prelim_predictions,
+ key=lambda x: (x.start_log_prob + x.end_log_prob),
+ reverse=True)
+
+ seen_predictions = {}
+ nbest = []
+ for pred in prelim_predictions:
+ if len(nbest) >= n_best_size:
+ break
+ feature = features[pred.feature_index]
+
+ tok_start_to_orig_index = feature.tok_start_to_orig_index
+ tok_end_to_orig_index = feature.tok_end_to_orig_index
+ start_orig_pos = tok_start_to_orig_index[pred.start_index]
+ end_orig_pos = tok_end_to_orig_index[pred.end_index]
+
+ paragraph_text = example.paragraph_text
+ final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
+
+ if final_text in seen_predictions:
+ continue
+
+ seen_predictions[final_text] = True
+
+ nbest.append(
+ _NbestPrediction(
+ text=final_text,
+ start_log_prob=pred.start_log_prob,
+ end_log_prob=pred.end_log_prob))
+
+ # In very rare edge cases we could have no valid predictions. So we
+ # just create a nonce prediction in this case to avoid failure.
+ if not nbest:
+ nbest.append(
+ _NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
+
+ total_scores = []
+ best_non_null_entry = None
+ for entry in nbest:
+ total_scores.append(entry.start_log_prob + entry.end_log_prob)
+ if not best_non_null_entry:
+ best_non_null_entry = entry
+
+ probs = _compute_softmax(total_scores)
+
+ nbest_json = []
+ for (i, entry) in enumerate(nbest):
+ output = collections.OrderedDict()
+ output["text"] = entry.text
+ output["probability"] = probs[i]
+ output["start_log_prob"] = entry.start_log_prob
+ output["end_log_prob"] = entry.end_log_prob
+ nbest_json.append(output)
+
+ assert len(nbest_json) >= 1
+ assert best_non_null_entry is not None
+
+ score_diff = score_null
+ scores_diff_json[example.qas_id] = score_diff
+
+ all_predictions[example.qas_id] = best_non_null_entry.text
+
+ all_nbest_json[example.qas_id] = nbest_json
+
+ with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
+
+ with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
+
+ with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
+
+ qid_to_has_ans = make_qid_to_has_ans(orig_data)
+ exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
+ out_eval = {}
+
+ find_all_best_thresh(out_eval, all_predictions, exact_raw, f1_raw,
+ scores_diff_json, qid_to_has_ans)
+
+ return out_eval
+
+
+def read_squad_examples(input_file, is_training):
+ """Reads a SQuAD json file into a list of SquadExample."""
+ with tf.io.gfile.GFile(input_file, "r") as reader:
+ input_data = json.load(reader)["data"]
+
+ examples = []
+ for entry in input_data:
+ for paragraph in entry["paragraphs"]:
+ paragraph_text = paragraph["context"]
+
+ for qa in paragraph["qas"]:
+ qas_id = qa["id"]
+ question_text = qa["question"]
+ start_position = None
+ orig_answer_text = None
+ is_impossible = False
+
+ if is_training:
+ is_impossible = qa["is_impossible"]
+ if (len(qa["answers"]) != 1) and (not is_impossible):
+ raise ValueError(
+ "For training, each question should have exactly 1 answer.")
+ if not is_impossible:
+ answer = qa["answers"][0]
+ orig_answer_text = answer["text"]
+ start_position = answer["answer_start"]
+ else:
+ start_position = -1
+ orig_answer_text = ""
+
+ example = SquadExample(
+ qas_id=qas_id,
+ question_text=question_text,
+ paragraph_text=paragraph_text,
+ orig_answer_text=orig_answer_text,
+ start_position=start_position,
+ is_impossible=is_impossible)
+ examples.append(example)
+
+ return examples
+
+
+# pylint: disable=invalid-name
+def _convert_index(index, pos, M=None, is_start=True):
+ """Converts index."""
+ if index[pos] is not None:
+ return index[pos]
+ N = len(index)
+ rear = pos
+ while rear < N - 1 and index[rear] is None:
+ rear += 1
+ front = pos
+ while front > 0 and index[front] is None:
+ front -= 1
+ assert index[front] is not None or index[rear] is not None
+ if index[front] is None:
+ if index[rear] >= 1:
+ if is_start:
+ return 0
+ else:
+ return index[rear] - 1
+ return index[rear]
+ if index[rear] is None:
+ if M is not None and index[front] < M - 1:
+ if is_start:
+ return index[front] + 1
+ else:
+ return M - 1
+ return index[front]
+ if is_start:
+ if index[rear] > index[front] + 1:
+ return index[front] + 1
+ else:
+ return index[rear]
+ else:
+ if index[rear] > index[front] + 1:
+ return index[rear] - 1
+ else:
+ return index[front]
+
+
+def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
+ max_query_length, is_training, output_fn,
+ uncased):
+ """Loads a data file into a list of `InputBatch`s."""
+
+ cnt_pos, cnt_neg = 0, 0
+ unique_id = 1000000000
+ max_N, max_M = 1024, 1024
+ f = np.zeros((max_N, max_M), dtype=np.float32)
+
+ for (example_index, example) in enumerate(examples):
+ # pylint: disable=logging-format-interpolation
+ if example_index % 100 == 0:
+ logging.info("Converting {}/{} pos {} neg {}".format(
+ example_index, len(examples), cnt_pos, cnt_neg))
+
+ query_tokens = preprocess_utils.encode_ids(
+ sp_model,
+ preprocess_utils.preprocess_text(example.question_text, lower=uncased))
+
+ if len(query_tokens) > max_query_length:
+ query_tokens = query_tokens[0:max_query_length]
+
+ paragraph_text = example.paragraph_text
+ para_tokens = preprocess_utils.encode_pieces(
+ sp_model,
+ preprocess_utils.preprocess_text(example.paragraph_text, lower=uncased))
+
+ chartok_to_tok_index = []
+ tok_start_to_chartok_index = []
+ tok_end_to_chartok_index = []
+ char_cnt = 0
+ for i, token in enumerate(para_tokens):
+ chartok_to_tok_index.extend([i] * len(token))
+ tok_start_to_chartok_index.append(char_cnt)
+ char_cnt += len(token)
+ tok_end_to_chartok_index.append(char_cnt - 1)
+
+ tok_cat_text = "".join(para_tokens).replace(SPIECE_UNDERLINE, " ")
+ N, M = len(paragraph_text), len(tok_cat_text)
+
+ if N > max_N or M > max_M:
+ max_N = max(N, max_N)
+ max_M = max(M, max_M)
+ f = np.zeros((max_N, max_M), dtype=np.float32)
+ gc.collect()
+
+ g = {}
+
+ # pylint: disable=cell-var-from-loop
+ def _lcs_match(max_dist):
+ """LCS match."""
+ f.fill(0)
+ g.clear()
+
+ ### longest common sub sequence
+ # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
+ for i in range(N):
+
+ # note(zhiliny):
+ # unlike standard LCS, this is specifically optimized for the setting
+ # because the mismatch between sentence pieces and original text will
+ # be small
+ for j in range(i - max_dist, i + max_dist):
+ if j >= M or j < 0:
+ continue
+
+ if i > 0:
+ g[(i, j)] = 0
+ f[i, j] = f[i - 1, j]
+
+ if j > 0 and f[i, j - 1] > f[i, j]:
+ g[(i, j)] = 1
+ f[i, j] = f[i, j - 1]
+
+ f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
+ if (preprocess_utils.preprocess_text(
+ paragraph_text[i], lower=uncased,
+ remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
+ g[(i, j)] = 2
+ f[i, j] = f_prev + 1
+
+ max_dist = abs(N - M) + 5
+ for _ in range(2):
+ _lcs_match(max_dist)
+ if f[N - 1, M - 1] > 0.8 * N:
+ break
+ max_dist *= 2
+
+ orig_to_chartok_index = [None] * N
+ chartok_to_orig_index = [None] * M
+ i, j = N - 1, M - 1
+ while i >= 0 and j >= 0:
+ if (i, j) not in g:
+ break
+ if g[(i, j)] == 2:
+ orig_to_chartok_index[i] = j
+ chartok_to_orig_index[j] = i
+ i, j = i - 1, j - 1
+ elif g[(i, j)] == 1:
+ j = j - 1
+ else:
+ i = i - 1
+
+ if all(
+ v is None for v in orig_to_chartok_index) or f[N - 1, M - 1] < 0.8 * N:
+ print("MISMATCH DETECTED!")
+ continue
+
+ tok_start_to_orig_index = []
+ tok_end_to_orig_index = []
+ for i in range(len(para_tokens)):
+ start_chartok_pos = tok_start_to_chartok_index[i]
+ end_chartok_pos = tok_end_to_chartok_index[i]
+ start_orig_pos = _convert_index(
+ chartok_to_orig_index, start_chartok_pos, N, is_start=True)
+ end_orig_pos = _convert_index(
+ chartok_to_orig_index, end_chartok_pos, N, is_start=False)
+
+ tok_start_to_orig_index.append(start_orig_pos)
+ tok_end_to_orig_index.append(end_orig_pos)
+
+ if not is_training:
+ tok_start_position = tok_end_position = None
+
+ if is_training and example.is_impossible:
+ tok_start_position = -1
+ tok_end_position = -1
+
+ if is_training and not example.is_impossible:
+ start_position = example.start_position
+ end_position = start_position + len(example.orig_answer_text) - 1
+
+ start_chartok_pos = _convert_index(
+ orig_to_chartok_index, start_position, is_start=True)
+ tok_start_position = chartok_to_tok_index[start_chartok_pos]
+
+ end_chartok_pos = _convert_index(
+ orig_to_chartok_index, end_position, is_start=False)
+ tok_end_position = chartok_to_tok_index[end_chartok_pos]
+ assert tok_start_position <= tok_end_position
+
+ def _piece_to_id(x):
+ if six.PY2 and isinstance(x, unicode):
+ x = x.encode("utf-8")
+ return sp_model.PieceToId(x)
+
+ all_doc_tokens = list(map(_piece_to_id, para_tokens))
+
+ # The -3 accounts for [CLS], [SEP] and [SEP]
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
+
+ # We can have documents that are longer than the maximum sequence length.
+ # To deal with this we do a sliding window approach, where we take chunks
+ # of the up to our max length with a stride of `doc_stride`.
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
+ "DocSpan", ["start", "length"])
+ doc_spans = []
+ start_offset = 0
+ while start_offset < len(all_doc_tokens):
+ length = len(all_doc_tokens) - start_offset
+ if length > max_tokens_for_doc:
+ length = max_tokens_for_doc
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
+ if start_offset + length == len(all_doc_tokens):
+ break
+ start_offset += min(length, doc_stride)
+
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
+ tokens = []
+ token_is_max_context = {}
+ segment_ids = []
+ p_mask = []
+
+ cur_tok_start_to_orig_index = []
+ cur_tok_end_to_orig_index = []
+
+ for i in range(doc_span.length):
+ split_token_index = doc_span.start + i
+
+ cur_tok_start_to_orig_index.append(
+ tok_start_to_orig_index[split_token_index])
+ cur_tok_end_to_orig_index.append(
+ tok_end_to_orig_index[split_token_index])
+
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
+ split_token_index)
+ token_is_max_context[len(tokens)] = is_max_context
+ tokens.append(all_doc_tokens[split_token_index])
+ segment_ids.append(data_utils.SEG_ID_P)
+ p_mask.append(0)
+
+ paragraph_len = len(tokens)
+
+ tokens.append(data_utils.SEP_ID)
+ segment_ids.append(data_utils.SEG_ID_P)
+ p_mask.append(1)
+
+ # note(zhiliny): we put P before Q
+ # because during pretraining, B is always shorter than A
+ for token in query_tokens:
+ tokens.append(token)
+ segment_ids.append(data_utils.SEG_ID_Q)
+ p_mask.append(1)
+ tokens.append(data_utils.SEP_ID)
+ segment_ids.append(data_utils.SEG_ID_Q)
+ p_mask.append(1)
+
+ cls_index = len(segment_ids)
+ tokens.append(data_utils.CLS_ID)
+ segment_ids.append(data_utils.SEG_ID_CLS)
+ p_mask.append(0)
+
+ input_ids = tokens
+
+ # The mask has 0 for real tokens and 1 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [0] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(1)
+ segment_ids.append(data_utils.SEG_ID_PAD)
+ p_mask.append(1)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+ assert len(p_mask) == max_seq_length
+
+ span_is_impossible = example.is_impossible
+ start_position = None
+ end_position = None
+ if is_training and not span_is_impossible:
+ # For training, if our document chunk does not contain an annotation
+ # we throw it out, since there is nothing to predict.
+ doc_start = doc_span.start
+ doc_end = doc_span.start + doc_span.length - 1
+ out_of_span = False
+ if not (tok_start_position >= doc_start and
+ tok_end_position <= doc_end):
+ out_of_span = True
+ if out_of_span:
+ # continue
+ start_position = 0
+ end_position = 0
+ span_is_impossible = True
+ else:
+ # note: we put P before Q, so doc_offset should be zero.
+ # doc_offset = len(query_tokens) + 2
+ doc_offset = 0
+ start_position = tok_start_position - doc_start + doc_offset
+ end_position = tok_end_position - doc_start + doc_offset
+
+ if is_training and span_is_impossible:
+ start_position = cls_index
+ end_position = cls_index
+
+ if example_index < 20:
+ logging.info("*** Example ***")
+ logging.info("unique_id: %s", unique_id)
+ logging.info("example_index: %s", example_index)
+ logging.info("doc_span_index: %s", doc_span_index)
+ logging.info("tok_start_to_orig_index: %s",
+ " ".join([str(x) for x in cur_tok_start_to_orig_index]))
+ logging.info("tok_end_to_orig_index: %s",
+ " ".join([str(x) for x in cur_tok_end_to_orig_index]))
+ logging.info(
+ "token_is_max_context: %s", " ".join([
+ "%d:%s" % (x, y)
+ for (x, y) in six.iteritems(token_is_max_context)
+ ]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+
+ if is_training and span_is_impossible:
+ logging.info("impossible example span")
+
+ if is_training and not span_is_impossible:
+ pieces = [
+ sp_model.IdToPiece(token)
+ for token in tokens[start_position:(end_position + 1)]
+ ]
+ answer_text = sp_model.DecodePieces(pieces)
+ logging.info("start_position: %d", start_position)
+ logging.info("end_position: %d", end_position)
+ logging.info("answer: %s",
+ preprocess_utils.printable_text(answer_text))
+
+ # With multi processing, the example_index is actually the index
+ # within the current process therefore we use example_index=None to
+ # avoid being used in the future. # The current code does not use
+ # example_index of training data.
+ if is_training:
+ feat_example_index = None
+ else:
+ feat_example_index = example_index
+
+ feature = InputFeatures(
+ unique_id=unique_id,
+ example_index=feat_example_index,
+ doc_span_index=doc_span_index,
+ tok_start_to_orig_index=cur_tok_start_to_orig_index,
+ tok_end_to_orig_index=cur_tok_end_to_orig_index,
+ token_is_max_context=token_is_max_context,
+ input_ids=input_ids,
+ input_mask=input_mask,
+ p_mask=p_mask,
+ segment_ids=segment_ids,
+ paragraph_len=paragraph_len,
+ cls_index=cls_index,
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=span_is_impossible)
+
+ # Run callback
+ output_fn(feature)
+
+ unique_id += 1
+ if span_is_impossible:
+ cnt_neg += 1
+ else:
+ cnt_pos += 1
+
+ logging.info("Total number of instances: %d = pos %d + neg %d",
+ cnt_pos + cnt_neg, cnt_pos, cnt_neg)
+
+
+def _check_is_max_context(doc_spans, cur_span_index, position):
+ """Check if this is the "max context" doc span for the token."""
+
+ # Because of the sliding window approach taken to scoring documents, a single
+ # token can appear in multiple documents. E.g.
+ # Doc: the man went to the store and bought a gallon of milk
+ # Span A: the man went to the
+ # Span B: to the store and bought
+ # Span C: and bought a gallon of
+ # ...
+ #
+ # Now the word "bought" will have two scores from spans B and C. We only
+ # want to consider the score with "maximum context", which we define as
+ # the *minimum* of its left and right context (the *sum* of left and
+ # right context will always be the same, of course).
+ #
+ # In the example the maximum context for "bought" would be span C since
+ # it has 1 left context and 3 right context, while span B has 4 left context
+ # and 0 right context.
+ best_score = None
+ best_span_index = None
+ for (span_index, doc_span) in enumerate(doc_spans):
+ end = doc_span.start + doc_span.length - 1
+ if position < doc_span.start:
+ continue
+ if position > end:
+ continue
+ num_left_context = position - doc_span.start
+ num_right_context = end - position
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
+ if best_score is None or score > best_score:
+ best_score = score
+ best_span_index = span_index
+
+ return cur_span_index == best_span_index
+
+
+class FeatureWriter(object):
+ """Writes InputFeature to TF example file."""
+
+ def __init__(self, filename, is_training):
+ self.filename = filename
+ self.is_training = is_training
+ self.num_features = 0
+ self._writer = tf.io.TFRecordWriter(filename)
+
+ def process_feature(self, feature):
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
+ self.num_features += 1
+
+ def create_int_feature(values):
+ feature = tf.train.Feature(
+ int64_list=tf.train.Int64List(value=list(values)))
+ return feature
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ features = collections.OrderedDict()
+ features["unique_ids"] = create_int_feature([feature.unique_id])
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_float_feature(feature.input_mask)
+ features["p_mask"] = create_float_feature(feature.p_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+
+ features["cls_index"] = create_int_feature([feature.cls_index])
+
+ if self.is_training:
+ features["start_positions"] = create_int_feature([feature.start_position])
+ features["end_positions"] = create_int_feature([feature.end_position])
+ impossible = 0
+ if feature.is_impossible:
+ impossible = 1
+ features["is_impossible"] = create_float_feature([impossible])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ self._writer.write(tf_example.SerializeToString())
+
+ def close(self):
+ self._writer.close()
+
+
+def create_eval_data(spm_basename,
+ sp_model,
+ eval_examples,
+ max_seq_length,
+ max_query_length,
+ doc_stride,
+ uncased,
+ output_dir=None):
+ """Creates evaluation tfrecords."""
+ eval_features = []
+ eval_writer = None
+ if output_dir:
+ eval_rec_file = os.path.join(
+ output_dir,
+ "{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename, max_seq_length,
+ max_query_length))
+ eval_feature_file = os.path.join(
+ output_dir,
+ "{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename,
+ max_seq_length,
+ max_query_length))
+
+ eval_writer = FeatureWriter(filename=eval_rec_file, is_training=False)
+
+ def append_feature(feature):
+ eval_features.append(feature)
+ if eval_writer:
+ eval_writer.process_feature(feature)
+
+ convert_examples_to_features(
+ examples=eval_examples,
+ sp_model=sp_model,
+ max_seq_length=max_seq_length,
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ is_training=False,
+ output_fn=append_feature,
+ uncased=uncased)
+
+ if eval_writer:
+ eval_writer.close()
+ with tf.io.gfile.GFile(eval_feature_file, "wb") as fout:
+ pickle.dump(eval_features, fout)
+
+ return eval_features
diff --git a/models/official/nlp/xlnet/training_utils.py b/models/official/nlp/xlnet/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..293e4633d8f4ae0f00fc5fbabb3a3996827ced81
--- /dev/null
+++ b/models/official/nlp/xlnet/training_utils.py
@@ -0,0 +1,310 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""XLNet training utils."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+import re
+
+from absl import logging
+
+# pytype: disable=attribute-error
+# pylint: disable=g-bare-generic,unused-import
+import tensorflow as tf
+from typing import Any, Callable, Dict, Text, Optional
+
+from official.nlp.bert import model_training_utils
+from official.nlp.xlnet import data_utils
+from official.nlp.xlnet import xlnet_modeling as modeling
+
+_MIN_SUMMARY_STEPS = 10
+
+
+def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
+ """Saves model to with provided checkpoint prefix."""
+
+ checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
+ saved_path = checkpoint.save(checkpoint_path)
+ logging.info("Saving model as TF checkpoint: %s", saved_path)
+ return
+
+
+def _float_metric_value(metric):
+ """Gets the value of a float-value keras metric."""
+ return metric.result().numpy().astype(float)
+
+
+def train(
+ strategy: tf.distribute.Strategy,
+ model_fn: Callable,
+ input_meta_data: Dict,
+ train_input_fn: Callable,
+ total_training_steps: int,
+ steps_per_loop: int,
+ optimizer: tf.keras.optimizers.Optimizer,
+ learning_rate_fn: tf.keras.optimizers.schedules.LearningRateSchedule,
+ eval_fn: Optional[Callable[[tf.keras.Model, int, tf.summary.SummaryWriter],
+ Any]] = None,
+ metric_fn: Optional[Callable[[], tf.keras.metrics.Metric]] = None,
+ init_checkpoint: Optional[Text] = None,
+ init_from_transformerxl: Optional[bool] = False,
+ model_dir: Optional[Text] = None,
+ save_steps: Optional[int] = None,
+ run_eagerly: Optional[bool] = False):
+ """Runs customized training.
+
+ Args:
+ strategy: Distribution strategy on which to run low level training loop.
+ model_fn: The function returns a keras.Model.
+ input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`,
+ `n_layer`, `batch_size_per_core` and `d_model`.
+ train_input_fn: Function returns a tf.data.Dataset used for training.
+ total_training_steps: Number of steps to train in total.
+ steps_per_loop: Number of steps per graph-mode loop. In order to reduce
+ communication in eager context, training logs are printed every
+ steps_per_loop.
+ optimizer: The optimizer for model.
+ learning_rate_fn: the learning rate schedule.
+ eval_fn: A callback of evaluation function, that takes a keras.Model,
+ current step and evaluation summary writer.
+ metric_fn: A metrics function returns a Keras Metric object to record
+ evaluation result using evaluation dataset or with training dataset
+ after every epoch.
+ init_checkpoint: Optional checkpoint to load to `sub_model` returned by
+ `model_fn`.
+ init_from_transformerxl: Whether to load to `transformerxl_model` of
+ `model_fn`.
+ model_dir: The directory of model (checkpoints, summaries).
+ save_steps: The frequency to save checkpoints. Every save_steps, we save a
+ model checkpoint. Model checkpoint will be saved and evaluation will be
+ conducted if evaluation dataset is provided.
+ run_eagerly: Whether to run training eagerly.
+
+ Returns:
+ Last training step logits if training happens, otherwise returns None.
+ Raises:
+ TypeError: if model directory is not specified.
+ """
+ required_arguments = [
+ train_input_fn, total_training_steps, steps_per_loop, optimizer,
+ learning_rate_fn, save_steps
+ ]
+ if [arg for arg in required_arguments if arg is None]:
+ raise ValueError("`train_input_fn`, `total_training_steps`, "
+ "`steps_per_loop`, `optimizer`, `save_steps` and "
+ "`learning_rate_fn` are required parameters.")
+ if not model_dir:
+ raise TypeError("Model directory must be specified.")
+ train_iterator = data_utils.get_input_iterator(train_input_fn, strategy)
+ if not tf.io.gfile.exists(model_dir):
+ tf.io.gfile.mkdir(model_dir)
+ # Create summary writers
+ summary_dir = os.path.join(model_dir, "summaries")
+ if not tf.io.gfile.exists(summary_dir):
+ tf.io.gfile.mkdir(summary_dir)
+ train_summary_writer = None
+ eval_summary_writer = None
+ if eval_fn:
+ eval_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, "eval"))
+ if steps_per_loop >= _MIN_SUMMARY_STEPS:
+ # Only writes summary when the stats are collected sufficiently over
+ # enough steps.
+ train_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, "train"))
+
+ with strategy.scope():
+ model = model_fn()
+
+ if init_checkpoint:
+ logging.info("restore from %s", init_checkpoint)
+ if init_from_transformerxl:
+ checkpoint = tf.train.Checkpoint(
+ transformer_xl=model.transformerxl_model)
+ else:
+ checkpoint = tf.train.Checkpoint(model=model)
+ checkpoint.restore(init_checkpoint)
+
+ model.optimizer = optimizer
+
+ if not hasattr(model, "optimizer"):
+ raise ValueError("User should set optimizer attribute to model.")
+
+ train_loss_metric = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
+ train_metric = None
+ if metric_fn:
+ train_metric = metric_fn()
+
+ def _replicated_step(inputs, mem=None):
+ """Replicated training step."""
+
+ inputs["mems"] = mem
+ with tf.GradientTape() as tape:
+ mem, logits = model(inputs, training=True)
+ loss = model.losses
+ train_loss_metric.update_state(loss)
+ if train_metric:
+ train_metric.update_state(inputs["label_ids"], logits)
+ scaled_loss = loss[0] * 1.0 / float(strategy.num_replicas_in_sync)
+
+ # Collects training variables.
+ tvars = model.trainable_variables
+ grads = tape.gradient(scaled_loss, tvars)
+ clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
+
+ if input_meta_data["lr_layer_decay_rate"] != 1.0:
+ n_layer = 0
+ for i in range(len(clipped)):
+ m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name)
+ if not m:
+ continue
+ n_layer = max(n_layer, int(m.group(1)) + 1)
+
+ for i in range(len(clipped)):
+ for l in range(n_layer):
+ if "model/transformer/layer_{}/".format(l) in tvars[i].name:
+ abs_rate = input_meta_data["lr_layer_decay_rate"]**(
+ n_layer - 1 - l)
+ clipped[i] *= abs_rate
+ logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
+ abs_rate, l, tvars[i].name))
+ break
+
+ optimizer.apply_gradients(zip(clipped, tvars))
+ if input_meta_data["mem_len"] > 0:
+ return mem
+
+ def train_steps(iterator, steps):
+ """Performs distributed training steps in a loop.
+
+ Args:
+ iterator: the distributed iterator of training datasets.
+ steps: an tf.int32 integer tensor to specify number of steps to run
+ inside host training loop.
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+
+ Returns:
+ logits: logits computed.
+ """
+ if not isinstance(steps, tf.Tensor):
+ raise ValueError("steps should be an Tensor. Python object may cause "
+ "retracing.")
+
+ def cache_fn():
+ """Initializes memory tensor used in XLNet pretraining."""
+ mems = []
+ if input_meta_data["mem_len"] > 0:
+ for _ in range(input_meta_data["n_layer"]):
+ zeros = tf.zeros([
+ input_meta_data["mem_len"],
+ input_meta_data["batch_size_per_core"],
+ input_meta_data["d_model"]
+ ],
+ dtype=tf.float32)
+ mems.append(zeros)
+ return mems
+
+ if input_meta_data["mem_len"] > 0:
+ mem = strategy.run(cache_fn)
+ for _ in tf.range(steps):
+ mem = strategy.run(
+ _replicated_step, args=(
+ next(iterator),
+ mem,
+ ))
+ else:
+ for _ in tf.range(steps):
+ strategy.run(_replicated_step, args=(next(iterator),))
+
+ if not run_eagerly:
+ train_steps = tf.function(train_steps)
+
+ logging.info("Start training...")
+ checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
+ latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
+ if latest_checkpoint_file:
+ logging.info("Checkpoint file %s found and restoring from checkpoint",
+ latest_checkpoint_file)
+ checkpoint.restore(latest_checkpoint_file)
+ logging.info("Loading from checkpoint file completed")
+
+ current_step = optimizer.iterations.numpy()
+ checkpoint_name = "xlnet_step_{step}.ckpt"
+
+ while current_step < total_training_steps:
+ train_loss_metric.reset_states()
+ if train_metric:
+ train_metric.reset_states()
+
+ steps = model_training_utils.steps_to_run(current_step, save_steps,
+ steps_per_loop)
+ train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
+ current_step += steps
+ train_loss = _float_metric_value(train_loss_metric)
+ log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % (
+ current_step, total_training_steps, learning_rate_fn(current_step),
+ train_loss)
+ if train_metric:
+ log_stream += " / %s = %f" % (train_metric.name,
+ _float_metric_value(train_metric))
+ logging.info(log_stream)
+ if train_summary_writer:
+ with train_summary_writer.as_default():
+ tf.summary.scalar(
+ "learning_rate",
+ learning_rate_fn(current_step),
+ step=current_step)
+ tf.summary.scalar(
+ train_loss_metric.name, train_loss, step=current_step)
+ if train_metric:
+ tf.summary.scalar(
+ train_metric.name,
+ _float_metric_value(train_metric),
+ step=current_step)
+ train_summary_writer.flush()
+ if model_dir and current_step % save_steps == 0:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+
+ if eval_fn and current_step % save_steps == 0:
+
+ logging.info("Running evaluation after step: %s.", current_step)
+
+ eval_fn(model, current_step, eval_summary_writer)
+ if model_dir:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ if eval_fn:
+ logging.info("Running final evaluation after training is complete.")
+ eval_metric = eval_fn(model, current_step, eval_summary_writer)
+
+ training_summary = {
+ "total_training_steps": total_training_steps,
+ "train_loss": _float_metric_value(train_loss_metric),
+ }
+ if train_metric:
+ training_summary["last_train_metrics"] = _float_metric_value(train_metric)
+ if eval_fn:
+ # eval_metric is supposed to be a float.
+ training_summary["eval_metrics"] = eval_metric
+
+ model_training_utils.write_txt_summary(training_summary, summary_dir)
+
+ return model
diff --git a/models/official/nlp/xlnet/xlnet_config.py b/models/official/nlp/xlnet/xlnet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7852eadf469476b4772533dce563366cd3478317
--- /dev/null
+++ b/models/official/nlp/xlnet/xlnet_config.py
@@ -0,0 +1,181 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions used in XLNet model."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import json
+import os
+
+import tensorflow as tf
+
+
+def create_run_config(is_training, is_finetune, flags):
+ """Helper function for creating RunConfig."""
+ kwargs = dict(
+ is_training=is_training,
+ use_tpu=flags.use_tpu,
+ dropout=flags.dropout,
+ dropout_att=flags.dropout_att,
+ init_method=flags.init_method,
+ init_range=flags.init_range,
+ init_std=flags.init_std,
+ clamp_len=flags.clamp_len)
+
+ if not is_finetune:
+ kwargs.update(dict(
+ mem_len=flags.mem_len,
+ reuse_len=flags.reuse_len,
+ bi_data=flags.bi_data,
+ clamp_len=flags.clamp_len,
+ same_length=flags.same_length))
+
+ return RunConfig(**kwargs)
+
+
+# TODO(hongkuny): refactor XLNetConfig and RunConfig.
+class XLNetConfig(object):
+ """Configs for XLNet model.
+
+ XLNetConfig contains hyperparameters that are specific to a model checkpoint;
+ i.e., these hyperparameters should be the same between
+ pretraining and finetuning.
+
+ The following hyperparameters are defined:
+ n_layer: int, the number of layers.
+ d_model: int, the hidden size.
+ n_head: int, the number of attention heads.
+ d_head: int, the dimension size of each attention head.
+ d_inner: int, the hidden size in feed-forward layers.
+ ff_activation: str, "relu" or "gelu".
+ untie_r: bool, whether to untie the biases in attention.
+ n_token: int, the vocab size.
+ """
+
+ def __init__(self, FLAGS=None, json_path=None, args_dict=None):
+ """Constructing an XLNetConfig.
+
+ One of FLAGS or json_path should be provided.
+
+ Args:
+ FLAGS: An FLAGS instance.
+ json_path: A path to a json config file.
+ args_dict: A dict for args.
+ """
+
+ assert FLAGS is not None or json_path is not None or args_dict is not None
+
+ self.keys = ['n_layer', 'd_model', 'n_head', 'd_head', 'd_inner',
+ 'ff_activation', 'untie_r', 'n_token']
+
+ if FLAGS is not None:
+ self.init_from_flags(FLAGS)
+
+ if json_path is not None:
+ self.init_from_json(json_path)
+
+ if args_dict is not None:
+ self.init_from_dict(args_dict)
+
+ def init_from_dict(self, args_dict):
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
+ for key in self.keys:
+ setattr(self, key, args_dict[key])
+
+ def init_from_flags(self, flags):
+ for key in self.keys:
+ setattr(self, key, getattr(flags, key))
+
+ def init_from_json(self, json_path):
+ with tf.io.gfile.GFile(json_path) as f:
+ json_data = json.load(f)
+ self.init_from_dict(json_data)
+
+ def to_json(self, json_path):
+ """Save XLNetConfig to a json file."""
+ json_data = {}
+ for key in self.keys:
+ json_data[key] = getattr(self, key)
+
+ json_dir = os.path.dirname(json_path)
+ if not tf.io.gfile.exists(json_dir):
+ tf.io.gfile.makedirs(json_dir)
+ with tf.io.gfile.GFile(json_path, 'w') as f:
+ json.dump(json_data, f, indent=4, sort_keys=True)
+
+
+class RunConfig(object):
+ """Class of RunConfig.
+
+ RunConfig contains hyperparameters that could be different
+ between pretraining and finetuning.
+ These hyperparameters can also be changed from run to run.
+ We store them separately from XLNetConfig for flexibility.
+ """
+
+ def __init__(self,
+ is_training,
+ use_tpu,
+ dropout,
+ dropout_att,
+ init_method='normal',
+ init_range=0.1,
+ init_std=0.02,
+ mem_len=None,
+ reuse_len=None,
+ bi_data=False,
+ clamp_len=-1,
+ same_length=False,
+ use_cls_mask=True):
+ """Initializes RunConfig.
+
+ Args:
+ is_training: bool, whether in training mode.
+ use_tpu: bool, whether TPUs are used.
+ dropout: float, dropout rate.
+ dropout_att: float, dropout rate on attention probabilities.
+ init_method: str, the initialization scheme, either "normal" or "uniform".
+ init_range: float, initialize the parameters with a uniform distribution
+ in [-init_range, init_range]. Only effective when init="uniform".
+ init_std: float, initialize the parameters with a normal distribution
+ with mean 0 and stddev init_std. Only effective when init="normal".
+ mem_len: int, the number of tokens to cache.
+ reuse_len: int, the number of tokens in the currect batch to be cached
+ and reused in the future.
+ bi_data: bool, whether to use bidirectional input pipeline.
+ Usually set to True during pretraining and False during finetuning.
+ clamp_len: int, clamp all relative distances larger than clamp_len.
+ -1 means no clamping.
+ same_length: bool, whether to use the same attention length
+ for each token.
+ use_cls_mask: bool, whether to introduce cls mask.
+ """
+
+ self.init_method = init_method
+ self.init_range = init_range
+ self.init_std = init_std
+ self.is_training = is_training
+ self.dropout = dropout
+ self.dropout_att = dropout_att
+ self.use_tpu = use_tpu
+ self.mem_len = mem_len
+ self.reuse_len = reuse_len
+ self.bi_data = bi_data
+ self.clamp_len = clamp_len
+ self.same_length = same_length
+ self.use_cls_mask = use_cls_mask
diff --git a/models/official/nlp/xlnet/xlnet_modeling.py b/models/official/nlp/xlnet/xlnet_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e16af8e9930ba4dabb8e92743769cf1ebb48585
--- /dev/null
+++ b/models/official/nlp/xlnet/xlnet_modeling.py
@@ -0,0 +1,1290 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras layers of XLNet model in TF 2.0."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import copy
+import numpy as np
+
+import tensorflow as tf
+from official.nlp.xlnet import data_utils
+
+
+def gelu(x):
+ """Gaussian Error Linear Unit.
+
+ This is a smoother version of the RELU.
+ Original paper: https://arxiv.org/abs/1606.08415
+ Args:
+ x: float Tensor to perform activation.
+
+ Returns:
+ `x` with the GELU activation applied.
+ """
+ cdf = 0.5 * (1.0 + tf.tanh(
+ (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
+ return x * cdf
+
+
+def rel_shift(x, klen=-1):
+ """Performs relative shift to form the relative attention score."""
+ x_size = tf.shape(x)
+
+ x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
+ x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
+ x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
+ x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
+
+ return x
+
+
+def _get_initializer(flags):
+ """Get variable intializer."""
+ if flags.init_method == 'uniform':
+ initializer = tf.keras.initializers.RandomUniform(
+ minval=-flags.init_range, maxval=flags.init_range)
+ elif flags.init_method == 'normal':
+ initializer = tf.keras.initializers.RandomNormal(stddev=flags.init_std)
+ else:
+ raise ValueError('Initializer {} not supported'.format(flags.init_method))
+ return initializer
+
+
+def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
+ """Creates attention mask when single-side context allowed only."""
+ attn_mask = tf.ones([qlen, qlen], dtype=dtype)
+ mask_u = tf.linalg.band_part(attn_mask, 0, -1)
+ mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
+ attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
+ ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
+ if same_length:
+ mask_l = tf.linalg.band_part(attn_mask, -1, 0)
+ ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
+
+ return ret
+
+
+def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
+ """cache hidden states into memory."""
+
+ if mem_len is None or mem_len == 0:
+ return None
+ else:
+ if reuse_len is not None and reuse_len > 0:
+ curr_out = curr_out[:reuse_len]
+
+ if prev_mem is None:
+ new_mem = curr_out[-mem_len:]
+ else:
+ new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]
+
+ return tf.keras.backend.stop_gradient(new_mem)
+
+
+def is_special_none_tensor(tensor):
+ """Checks if a tensor is a special None Tensor."""
+ return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
+
+
+class PositionalEmbedding(tf.keras.layers.Layer):
+ """Generates relative positional embeddings used in Transformer-XL and XLNet."""
+
+ def __init__(self, dim, **kwargs):
+ super(PositionalEmbedding, self).__init__(**kwargs)
+ self.dim = dim
+
+ def build(self, unused_input_shapes):
+ """Constructs inversed frequency vector for positional embedding layer."""
+ self.inv_freq = 1.0 / (10000.0**(tf.range(0, self.dim, 2.0) / self.dim))
+ super(PositionalEmbedding, self).build(unused_input_shapes)
+
+ def call(self, pos_seq, batch_size):
+ """Implements call() for the layer."""
+ sinusoid_inp = tf.einsum('i,d->id', pos_seq, self.inv_freq)
+ pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
+ pos_emb = pos_emb[:, None, :]
+
+ if batch_size is not None:
+ pos_emb = tf.tile(pos_emb, [1, batch_size, 1])
+
+ return pos_emb
+
+
+class RelativeAttention(tf.keras.layers.Layer):
+ """Core calculations for relative attention."""
+
+ def __init__(self, dropout_att, scale):
+ super(RelativeAttention, self).__init__()
+ self.scale = scale
+ self.dropout_att = dropout_att
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+
+ self.attention_probs_dropout = tf.keras.layers.Dropout(
+ rate=self.dropout_att)
+
+ super(RelativeAttention, self).build(unused_input_shapes)
+
+ def call(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
+ r_w_bias, r_r_bias, r_s_bias, attn_mask):
+ """Implements call() for the layer."""
+
+ # content based attention score
+ ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)
+
+ # position based attention score
+ bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
+ bd = rel_shift(bd, klen=tf.shape(ac)[1])
+
+ # segment-based attention score
+ if seg_mat is None:
+ ef = 0
+ else:
+ ef = tf.einsum('ibnd,snd->isbn', q_head + r_s_bias, seg_embed)
+ tgt_shape = tf.shape(bd)
+ ef = tf.where(
+ tf.broadcast_to(tf.expand_dims(seg_mat, 3), tgt_shape),
+ tf.broadcast_to(ef[:, 1:, :, :], tgt_shape),
+ tf.broadcast_to(ef[:, :1, :, :], tgt_shape))
+
+ # merges attention scores and performs masking
+ attn_score = (ac + bd + ef) * self.scale
+ if attn_mask is not None:
+ attn_score = attn_score - 1e30 * attn_mask
+
+ # attention probability
+ attn_prob = tf.nn.softmax(attn_score, 1)
+ attn_prob = self.attention_probs_dropout(attn_prob)
+
+ # attention output
+ attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)
+
+ return attn_vec
+
+
+class PositionwiseFF(tf.keras.layers.Layer):
+ """Positionwise feed-forward layer."""
+
+ def __init__(self, d_model, d_inner, dropout, kernel_initializer,
+ activation_type, **kwargs):
+ super(PositionwiseFF, self).__init__(**kwargs)
+ self.d_model = d_model
+ self.d_inner = d_inner
+ self.dropout = dropout
+ self.activation_type = activation_type
+ self.kernel_initializer = kernel_initializer
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ if self.activation_type == 'relu':
+ activation = tf.nn.relu
+ elif self.activation_type == 'gelu':
+ activation = gelu
+ else:
+ raise (ValueError('Unsupported activation type {}'.format(
+ self.activation_type)))
+ self.inner_projection_layer = (
+ tf.keras.layers.Dense(
+ units=self.d_inner,
+ activation=activation,
+ kernel_initializer=self.kernel_initializer,
+ name='layer_1'))
+ self.output_projection_layer = (
+ tf.keras.layers.Dense(
+ units=self.d_model,
+ kernel_initializer=self.kernel_initializer,
+ name='layer_2'))
+ self.output_dropout = tf.keras.layers.Dropout(
+ rate=self.dropout, name='drop_2')
+ self.output_layer_norm = (
+ tf.keras.layers.LayerNormalization(
+ name='LayerNorm', axis=-1, epsilon=1e-12))
+ super(PositionwiseFF, self).build(unused_input_shapes)
+
+ def call(self, inp):
+ """Implements call() for the layer."""
+
+ output = self.inner_projection_layer(inp)
+ output = self.output_projection_layer(output)
+ output = self.output_dropout(output)
+ output = self.output_layer_norm(output + inp)
+ return output
+
+
+class EmbeddingLookup(tf.keras.layers.Layer):
+ """Looks up words embeddings for id tensor."""
+
+ def __init__(self, n_token, d_embed, initializer, **kwargs):
+ super(EmbeddingLookup, self).__init__(**kwargs)
+ self.n_token = n_token
+ self.d_embed = d_embed
+ self.initializer = initializer
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.lookup_table = self.add_weight(
+ 'lookup_table',
+ shape=[self.n_token, self.d_embed],
+ initializer=self.initializer,
+ dtype=self.dtype)
+
+ super(EmbeddingLookup, self).build(unused_input_shapes)
+
+ def call(self, inputs):
+ return tf.nn.embedding_lookup(self.lookup_table, inputs)
+
+
+class RelativeMultiheadAttention(tf.keras.layers.Layer):
+ """Multi-head attention with relative embedding."""
+
+ def __init__(self, d_model, n_head, d_head, dropout, dropout_att,
+ kernel_initializer, **kwargs):
+ super(RelativeMultiheadAttention, self).__init__(**kwargs)
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_head = d_head
+ self.dropout = dropout
+ self.dropout_att = dropout_att
+ self.initializer = kernel_initializer
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.scale = 1.0 / (self.d_head**0.5)
+
+ self.output_layer_norm = tf.keras.layers.LayerNormalization(
+ name='LayerNorm', axis=-1, epsilon=1e-12)
+
+ self.kh_projection_layer = self.add_weight(
+ 'k/kernel',
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+ self.vh_projection_layer = self.add_weight(
+ 'v/kernel',
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+ self.kr_projection_layer = self.add_weight(
+ 'r/kernel',
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+ self.qh_projection_layer = self.add_weight(
+ 'q/kernel',
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+
+ self.relative_attention_layer = RelativeAttention(
+ dropout_att=self.dropout_att, scale=self.scale)
+
+ self.proj_o = self.add_weight(
+ 'o/kernel',
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+
+ self.attention_dropout = tf.keras.layers.Dropout(rate=self.dropout)
+
+ super(RelativeMultiheadAttention, self).build(unused_input_shapes)
+
+ def call(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
+ attn_mask_h, attn_mask_g, mems, target_mapping):
+ """Implements call() for the layer."""
+
+ if mems is not None and mems.shape.ndims > 1:
+ cat = tf.concat([mems, h], 0)
+ else:
+ cat = h
+
+ # content heads
+ q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer)
+ k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer)
+ v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer)
+
+ # positional heads
+ k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
+
+ # core attention ops
+ attn_vec_h = self.relative_attention_layer(q_head_h, k_head_h, v_head_h,
+ k_head_r, seg_embed, seg_mat,
+ r_w_bias, r_r_bias, r_s_bias,
+ attn_mask_h)
+
+ # post processing
+ output_h = tf.einsum('ibnd,hnd->ibh', attn_vec_h, self.proj_o)
+ output_h = self.attention_dropout(output_h)
+ output_h = self.output_layer_norm(output_h + h)
+
+ output_g = None
+ if g is not None: # enable two-stream attention
+ # g-stream
+ q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.qh_projection_layer)
+ if target_mapping is not None:
+ q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
+ attn_vec_g = self.relative_attention_layer(q_head_g, k_head_h, v_head_h,
+ k_head_r, seg_embed, seg_mat,
+ r_w_bias, r_r_bias, r_s_bias,
+ attn_mask_g)
+ attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
+
+ else:
+ attn_vec_g = self.relative_attention_layer(q_head_g, k_head_h, v_head_h,
+ k_head_r, seg_embed, seg_mat,
+ r_w_bias, r_r_bias, r_s_bias,
+ attn_mask_g)
+
+ # post processing
+ output_g = tf.einsum('ibnd,hnd->ibh', attn_vec_g, self.proj_o)
+ output_g = self.attention_dropout(output_g)
+ output_g = self.output_layer_norm(output_g + g)
+
+ return (output_h, output_g)
+
+
+class TransformerXLModel(tf.keras.layers.Layer):
+ """Defines a Transformer-XL computation graph with additional support for XLNet."""
+
+ def __init__(self,
+ n_token,
+ n_layer,
+ d_model,
+ n_head,
+ d_head,
+ d_inner,
+ dropout,
+ dropout_att,
+ attn_type,
+ bi_data,
+ is_training,
+ initializer,
+ mem_len=None,
+ same_length=False,
+ clamp_len=-1,
+ untie_r=False,
+ use_tpu=True,
+ reuse_len=None,
+ ff_activation='relu',
+ use_cls_mask=False,
+ **kwargs):
+ """Initializes TransformerXLModel.
+
+ Args:
+ n_token: int, the number of tokens in vocabulary.
+ n_layer: int, the number of layers.
+ d_model: int, the hidden size.
+ n_head: int, the number of attention heads.
+ d_head: int, the dimension size of each attention head.
+ d_inner: int, the hidden size in feed-forward layers.
+ dropout: float, dropout rate.
+ dropout_att: float, dropout rate on attention probabilities.
+ attn_type: str, "uni" or "bi".
+ bi_data: bool, whether to use bidirectional input pipeline. Usually set to
+ True during pretraining and False during finetuning.
+ is_training: bool, whether in training mode.
+ initializer: A tf initializer.
+ mem_len: int, the number of tokens to cache.
+ same_length: bool, whether to use the same attention length for each
+ token.
+ clamp_len: int, clamp all relative distances larger than clamp_len. -1
+ means no clamping.
+ untie_r: bool, whether to untie the biases in attention.
+ use_tpu: bool, whether TPUs are used.
+ reuse_len: int, the number of tokens in the currect batch to be cached and
+ reused in the future.
+ ff_activation: str, "relu" or "gelu".
+ use_cls_mask: bool, whether to introduce cls mask.
+ **kwargs: Other parameters.
+ """
+
+ super(TransformerXLModel, self).__init__(**kwargs)
+
+ self.n_token = n_token
+ self.initializer = initializer
+ self.attn_type = attn_type
+ self.n_layer = n_layer
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_head = d_head
+ self.d_inner = d_inner
+ self.ff_activation = ff_activation
+ self.untie_r = untie_r
+ self.use_tpu = use_tpu
+ self.dropout = dropout
+ self.dropout_att = dropout_att
+
+ self.mem_len = mem_len
+ self.reuse_len = reuse_len
+ self.bi_data = bi_data
+ self.clamp_len = clamp_len
+ self.same_length = same_length
+ self.use_cls_mask = use_cls_mask
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.tf_float = tf.float32
+
+ self.embedding_lookup = EmbeddingLookup(
+ n_token=self.n_token,
+ d_embed=self.d_model,
+ initializer=self.initializer,
+ dtype=self.tf_float,
+ name='word_embedding')
+
+ self.h_dropout = tf.keras.layers.Dropout(rate=self.dropout)
+ self.g_dropout = tf.keras.layers.Dropout(rate=self.dropout)
+
+ if self.untie_r:
+ self.r_w_bias = (
+ self.add_weight(
+ 'r_w_bias',
+ shape=[self.n_layer, self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ self.r_r_bias = (
+ self.add_weight(
+ 'r_r_bias',
+ shape=[self.n_layer, self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ self.r_s_bias = (
+ self.add_weight(
+ 'r_s_bias',
+ shape=[self.n_layer, self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ else:
+ self.r_w_bias = (
+ self.add_weight(
+ 'r_w_bias',
+ shape=[self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ self.r_r_bias = (
+ self.add_weight(
+ 'r_r_bias',
+ shape=[self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ self.r_s_bias = (
+ self.add_weight(
+ 'r_s_bias', [self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+
+ self.seg_embed = self.add_weight(
+ 'seg_embed', [self.n_layer, 2, self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer)
+
+ self.mask_emb = self.add_weight(
+ 'mask_emb/mask_emb', shape=[1, 1, self.d_model], dtype=self.tf_float)
+
+ self.emb_dropout = tf.keras.layers.Dropout(rate=self.dropout)
+ self.fwd_position_embedding = PositionalEmbedding(self.d_model)
+ self.bwd_position_embedding = PositionalEmbedding(self.d_model)
+
+ self.rel_multihead_layers = []
+ self.h_positionwise_ffn_layers = []
+ for i in range(self.n_layer):
+ self.rel_multihead_layers.append(
+ RelativeMultiheadAttention(
+ d_model=self.d_model,
+ dropout=self.dropout,
+ n_head=self.n_head,
+ d_head=self.d_head,
+ dropout_att=self.dropout_att,
+ kernel_initializer=self.initializer,
+ name='layer_%d/rel_attn' % (i)))
+ self.h_positionwise_ffn_layers.append(
+ PositionwiseFF(
+ d_model=self.d_model,
+ d_inner=self.d_inner,
+ dropout=self.dropout,
+ kernel_initializer=self.initializer,
+ activation_type=self.ff_activation,
+ name='layer_%d/ff' % (i)))
+
+ self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout)
+
+ super(TransformerXLModel, self).build(unused_input_shapes)
+
+ def __call__(self,
+ inp_k,
+ seg_id=None,
+ input_mask=None,
+ mems=None,
+ perm_mask=None,
+ target_mapping=None,
+ inp_q=None,
+ **kwargs):
+ # Uses dict to feed inputs into call() in order to keep mems as a python
+ # list.
+ inputs = {
+ 'inp_k': inp_k,
+ 'seg_id': seg_id,
+ 'input_mask': input_mask,
+ 'mems': mems,
+ 'perm_mask': perm_mask,
+ 'target_mapping': target_mapping,
+ 'inp_q': inp_q
+ }
+ return super(TransformerXLModel, self).__call__(inputs, **kwargs)
+
+ def call(self, inputs):
+ """Implements call() for the layer."""
+ inp_k = inputs['inp_k']
+ seg_id = inputs['seg_id']
+ input_mask = inputs['input_mask']
+ mems = inputs['mems']
+ perm_mask = inputs['perm_mask']
+ target_mapping = inputs['target_mapping']
+ inp_q = inputs['inp_q']
+
+ new_mems = []
+
+ bsz = tf.shape(inp_k)[1]
+
+ qlen = inp_k.shape.as_list()[0]
+
+ mlen = mems[0].shape.as_list()[0] if mems is not None else 0
+ klen = mlen + qlen
+
+ ##### Attention mask
+ # causal attention mask
+ if self.attn_type == 'uni':
+ attn_mask = _create_mask(qlen, mlen, self.tf_float, self.same_length)
+ # pylint: enable=protected-access
+ attn_mask = attn_mask[:, :, None, None]
+ elif self.attn_type == 'bi':
+ attn_mask = None
+ else:
+ raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
+
+ # data mask: input mask & perm mask
+ if input_mask is not None and perm_mask is not None:
+ data_mask = input_mask[None] + perm_mask
+
+ elif input_mask is not None and perm_mask is None:
+ data_mask = input_mask[None]
+ elif input_mask is None and perm_mask is not None:
+ data_mask = perm_mask
+ else:
+ data_mask = None
+
+ if data_mask is not None:
+ # all mems can be attended to
+ mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
+ dtype=self.tf_float)
+ data_mask = tf.concat([mems_mask, data_mask], 1)
+ if attn_mask is None:
+ attn_mask = data_mask[:, :, :, None]
+ else:
+ attn_mask += data_mask[:, :, :, None]
+
+ if attn_mask is not None:
+ attn_mask = tf.cast(attn_mask > 0, dtype=self.tf_float)
+
+ if attn_mask is not None:
+ non_tgt_mask = -tf.eye(qlen, dtype=self.tf_float)
+ non_tgt_mask = tf.concat(
+ [tf.zeros([qlen, mlen], dtype=self.tf_float), non_tgt_mask], axis=-1)
+ non_tgt_mask = tf.cast(
+ (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=self.tf_float)
+ else:
+ non_tgt_mask = None
+
+ word_emb_k = self.embedding_lookup(inp_k)
+
+ if inp_q is not None:
+ if target_mapping is not None:
+ word_emb_q = tf.tile(self.mask_emb,
+ [tf.shape(target_mapping)[0], bsz, 1])
+ else:
+ inp_q_ext = inp_q[:, :, None]
+ word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
+
+ output_h = self.h_dropout(word_emb_k)
+ output_g = None
+ if inp_q is not None:
+ output_g = self.g_dropout(word_emb_q)
+
+ ##### Segment embedding
+ if seg_id is not None:
+
+ # Convert `seg_id` to one-hot `seg_mat`
+
+ mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
+
+ cat_id = tf.concat([mem_pad, seg_id], 0)
+
+ if self.use_cls_mask:
+ # `1` indicates not in the same segment [qlen x klen x bsz]
+ # seg_id: [qlen x bsz] & cat_id: [klen x bsz]
+ cls_mat = tf.logical_or(
+ tf.equal(seg_id, tf.constant([data_utils.SEG_ID_CLS]))[:, None],
+ tf.equal(cat_id, tf.constant([data_utils.SEG_ID_CLS]))[None, :])
+ seg_mat = tf.equal(seg_id[:, None], cat_id[None, :])
+ seg_mat = tf.logical_or(cls_mat, seg_mat)
+ else:
+ seg_mat = tf.logical_not(tf.equal(seg_id[:, None], cat_id[None, :]))
+ else:
+ seg_mat = None
+
+ dtype = self.tf_float
+ freq_seq = tf.range(0, self.d_model, 2.0)
+ if dtype is not None and dtype != tf.float32:
+ freq_seq = tf.cast(freq_seq, dtype=self.dtype)
+
+ if self.attn_type == 'bi':
+ beg, end = klen, -qlen
+ elif self.attn_type == 'uni':
+ beg, end = klen, -1
+ else:
+ raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
+
+ if self.bi_data:
+ fwd_pos_seq = tf.range(beg, end, -1.0)
+ bwd_pos_seq = tf.range(-beg, -end, 1.0)
+
+ if dtype is not None and dtype != tf.float32:
+ fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
+ bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
+
+ if self.clamp_len > 0:
+ fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len,
+ self.clamp_len)
+ bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len,
+ self.clamp_len)
+
+ if bsz is not None:
+ fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz // 2)
+ bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, bsz // 2)
+ else:
+ fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, None)
+ bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, None)
+
+ pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
+ else:
+ fwd_pos_seq = tf.range(beg, end, -1.0)
+ if dtype is not None and dtype != tf.float32:
+ fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
+ if self.clamp_len > 0:
+ fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len,
+ self.lamp_len)
+
+ pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz)
+
+ pos_emb = self.emb_dropout(pos_emb)
+
+ if mems is None:
+ mems = [None] * self.n_layer
+ for i in range(self.n_layer):
+ # cache new mems
+ new_mems.append(
+ _cache_mem(output_h, mems[i], self.mem_len, self.reuse_len))
+ # pylint: enable=protected-access
+
+ # segment bias
+ if seg_id is None:
+ r_s_bias_i = None
+ seg_embed_i = None
+ else:
+ r_s_bias_i = self.r_s_bias if not self.untie_r else self.r_s_bias[i]
+ seg_embed_i = self.seg_embed[i]
+
+ ffn_layer = self.h_positionwise_ffn_layers[i]
+ attention_layer = self.rel_multihead_layers[i]
+ output_h, output_g = attention_layer(
+ h=output_h,
+ g=output_g,
+ r=pos_emb,
+ r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i],
+ r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i],
+ seg_mat=seg_mat,
+ r_s_bias=r_s_bias_i,
+ seg_embed=seg_embed_i,
+ attn_mask_h=non_tgt_mask,
+ attn_mask_g=attn_mask,
+ mems=mems[i],
+ target_mapping=target_mapping)
+ output_h = ffn_layer(output_h)
+ if output_g is not None:
+ output_g = ffn_layer(output_g)
+
+ if inp_q is not None:
+ output = output_g
+ else:
+ output = output_h
+
+ return output, new_mems, None
+
+
+class PretrainingXLNetModel(tf.keras.Model):
+ """XLNet keras model combined with pretraining LM loss layer.
+
+ See the original paper: https://arxiv.org/pdf/1906.08237.pdf
+
+ """
+
+ def __init__(self, use_proj, xlnet_config, run_config, **kwargs):
+ super(PretrainingXLNetModel, self).__init__(**kwargs)
+ self.run_config = run_config
+ self.initializer = _get_initializer(run_config)
+ self.xlnet_config = copy.deepcopy(xlnet_config)
+
+ self.transformerxl_model = TransformerXLModel(
+ n_token=self.xlnet_config.n_token,
+ initializer=self.initializer,
+ attn_type='bi',
+ n_layer=self.xlnet_config.n_layer,
+ d_model=self.xlnet_config.d_model,
+ n_head=self.xlnet_config.n_head,
+ d_head=self.xlnet_config.d_head,
+ d_inner=self.xlnet_config.d_inner,
+ ff_activation=self.xlnet_config.ff_activation,
+ untie_r=self.xlnet_config.untie_r,
+ is_training=self.run_config.is_training,
+ use_tpu=self.run_config.use_tpu,
+ dropout=self.run_config.dropout,
+ dropout_att=self.run_config.dropout_att,
+ mem_len=self.run_config.mem_len,
+ reuse_len=self.run_config.reuse_len,
+ bi_data=self.run_config.bi_data,
+ clamp_len=self.run_config.clamp_len,
+ same_length=self.run_config.same_length,
+ use_cls_mask=self.run_config.use_cls_mask,
+ name='transformer')
+ self.lmloss_layer = LMLossLayer(
+ n_token=self.xlnet_config.n_token,
+ d_model=self.xlnet_config.d_model,
+ initializer=self.initializer,
+ tie_weight=True,
+ bi_data=self.run_config.bi_data,
+ use_tpu=self.run_config.use_tpu,
+ use_proj=use_proj,
+ name='lm_loss')
+
+ def call(self, features):
+ """Implements call() for the layer."""
+
+ input_ids = tf.transpose(features['input_k'], [1, 0])
+ inp_q = tf.transpose(features['input_q'], [1, 0])
+
+ seg_ids = tf.transpose(features['seg_id'], [1, 0])
+
+ perm_mask = tf.transpose(features['perm_mask'], [1, 2, 0])
+
+ target_mapping = tf.transpose(features['target_mapping'], [1, 2, 0])
+
+ # target for LM loss
+ target = tf.transpose(features['target'], [1, 0])
+
+ # target mask for LM loss
+ tgt_mask = tf.transpose(features['target_mask'], [1, 0])
+
+ mems = features.get('mems', None)
+
+ transformerxl_output, self.new_mems, self.lookup_table = self.transformerxl_model(
+ input_ids,
+ seg_id=seg_ids,
+ input_mask=None,
+ mems=mems,
+ perm_mask=perm_mask,
+ target_mapping=target_mapping,
+ inp_q=inp_q)
+ lm_loss, _ = self.lmloss_layer(
+ hidden=transformerxl_output,
+ target=target,
+ lookup_table=self.transformerxl_model.embedding_lookup.lookup_table,
+ target_mask=tgt_mask)
+ self.add_loss(lm_loss)
+ return self.new_mems, transformerxl_output
+
+
+class ClassificationXLNetModel(tf.keras.Model):
+ """XLNet keras model combined with classification loss layer.
+
+ See the original paper: https://arxiv.org/pdf/1906.08237.pdf
+
+ """
+
+ def __init__(self, xlnet_config, run_config, n_class, summary_type, **kwargs):
+ super(ClassificationXLNetModel, self).__init__(**kwargs)
+ self.run_config = run_config
+ self.initializer = _get_initializer(run_config)
+ self.xlnet_config = copy.deepcopy(xlnet_config)
+
+ self.transformerxl_model = TransformerXLModel(
+ n_token=self.xlnet_config.n_token,
+ initializer=self.initializer,
+ attn_type='bi',
+ n_layer=self.xlnet_config.n_layer,
+ d_model=self.xlnet_config.d_model,
+ n_head=self.xlnet_config.n_head,
+ d_head=self.xlnet_config.d_head,
+ d_inner=self.xlnet_config.d_inner,
+ ff_activation=self.xlnet_config.ff_activation,
+ untie_r=self.xlnet_config.untie_r,
+ is_training=self.run_config.is_training,
+ use_tpu=self.run_config.use_tpu,
+ dropout=self.run_config.dropout,
+ dropout_att=self.run_config.dropout_att,
+ mem_len=self.run_config.mem_len,
+ reuse_len=self.run_config.reuse_len,
+ bi_data=self.run_config.bi_data,
+ clamp_len=self.run_config.clamp_len,
+ same_length=self.run_config.same_length,
+ name='transformer')
+
+ self.summarization_layer = Summarization(
+ d_model=self.xlnet_config.d_model,
+ n_head=self.xlnet_config.n_head,
+ d_head=self.xlnet_config.d_head,
+ dropout=self.run_config.dropout,
+ dropout_att=self.run_config.dropout_att,
+ initializer=self.initializer,
+ use_proj=True,
+ summary_type=summary_type,
+ name='sequence_summary')
+
+ self.cl_loss_layer = ClassificationLossLayer(
+ n_class=n_class, initializer=self.initializer, name='classification')
+
+ def call(self, features):
+ """Implements call() for the layer."""
+ bsz_per_core = tf.shape(features['input_ids'])[0]
+
+ input_ids = tf.transpose(features['input_ids'], [1, 0])
+ seg_ids = tf.transpose(features['segment_ids'], [1, 0])
+ input_mask = tf.transpose(features['input_mask'], [1, 0])
+
+ label = tf.reshape(features['label_ids'], [bsz_per_core])
+
+ mems = features.get('mems', None)
+
+ transformerxl_output, new_mems, self.lookup_table = (
+ self.transformerxl_model(input_ids, seg_ids, input_mask, mems))
+
+ summary = self.summarization_layer(transformerxl_output)
+ per_example_loss, logits = self.cl_loss_layer(hidden=summary, labels=label)
+ self.add_loss(tf.keras.backend.mean(per_example_loss))
+ return new_mems, logits
+
+
+class LMLossLayer(tf.keras.layers.Layer):
+ """Layer computing cross entropy loss for language modeling."""
+
+ def __init__(self,
+ n_token,
+ d_model,
+ initializer,
+ tie_weight=False,
+ bi_data=True,
+ use_tpu=False,
+ use_proj=False,
+ **kwargs):
+ """Constructs LMLoss layer.
+
+ Args:
+ n_token: Number of tokens in vocabulary.
+ d_model: The dimension of model hidden state.
+ initializer: Initializer used for parameters.
+ tie_weight: Whether to share weights between embedding lookup layer and
+ next-token prediction layer.
+ bi_data: Whether to use bidirectional input pipeline. Usually set to True
+ during pretraining and False during finetuning.
+ use_tpu: bool, whether to use TPU.
+ use_proj: bool, whether to add a projection layer before LM prediction.
+ **kwargs: Other parameters.
+ """
+ super(LMLossLayer, self).__init__(**kwargs)
+ self.n_token = n_token
+ self.d_model = d_model
+ self.initializer = initializer
+
+ self.tie_weight = tie_weight
+ self.bi_data = bi_data
+ self.use_tpu = use_tpu
+ self.use_proj = use_proj
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ if self.use_proj:
+ self.proj_layer = tf.keras.layers.Dense(
+ units=self.d_model,
+ kernel_initializer=self.initializer,
+ activation=gelu,
+ name='lm_projection/dense')
+ self.proj_layer_norm = tf.keras.layers.LayerNormalization(
+ axis=-1, epsilon=1e-12, name='lm_projection/LayerNorm')
+ if not self.tie_weight:
+ self.softmax_w = self.add_weight(
+ 'weight',
+ shape=[self.n_token, self.d_model],
+ initializer=self.initializer)
+
+ self.softmax_b = self.add_weight(
+ 'bias', shape=[self.n_token], initializer=tf.zeros_initializer())
+
+ super(LMLossLayer, self).build(unused_input_shapes)
+
+ def call(self, hidden, target, lookup_table, target_mask):
+ """Implements call() for the layer."""
+ if self.use_proj:
+ hidden = self.proj_layer_norm(self.proj_layer(hidden))
+ if self.tie_weight:
+ logits = tf.einsum('ibd,nd->ibn', hidden, lookup_table) + self.softmax_b
+ else:
+ logits = tf.einsum('ibd,nd->ibn', hidden, self.softmax_w) + self.softmax_b
+
+ if self.use_tpu:
+ one_hot_target = tf.one_hot(target, self.n_token, dtype=logits.dtype)
+ loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
+ else:
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=target, logits=logits)
+
+ total_loss = tf.reduce_sum(loss * target_mask) / tf.reduce_sum(target_mask)
+
+ return total_loss, logits
+
+
+class Summarization(tf.keras.layers.Layer):
+ """The layer to pool the output from XLNet model into a vector."""
+
+ def __init__(self,
+ d_model,
+ n_head,
+ d_head,
+ dropout,
+ dropout_att,
+ initializer,
+ use_proj=True,
+ summary_type='last',
+ **kwargs):
+ """Constructs Summarization layer.
+
+ Args:
+ d_model: int, the dimension of model hidden state.
+ n_head: int, the number of attention heads.
+ d_head: int, the dimension size of each attention head.
+ dropout: float, dropout rate.
+ dropout_att: float, dropout rate on attention probabilities.
+ initializer: Initializer used for parameters.
+ use_proj: bool, whether to use projection layer for summarization.
+ summary_type: Method used to summarize a sequence into a compact vector.
+ **kwargs: Other parameters.
+ """
+ super(Summarization, self).__init__(**kwargs)
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_head = d_head
+ self.initializer = initializer
+
+ self.dropout = dropout
+ self.dropout_att = dropout_att
+ self.use_proj = use_proj
+ self.summary_type = summary_type
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ if self.use_proj:
+ self.proj_layer = tf.keras.layers.Dense(
+ units=self.d_model,
+ kernel_initializer=self.initializer,
+ activation=tf.nn.tanh,
+ name='summary')
+ self.dropout_layer = tf.keras.layers.Dropout(rate=self.dropout)
+
+ super(Summarization, self).build(unused_input_shapes)
+
+ def call(self, inputs):
+ """Implements call() for the layer."""
+ if self.summary_type == 'last':
+ summary = inputs[-1]
+ elif self.summary_type == 'first':
+ summary = inputs[0]
+ else:
+ raise ValueError('Invalid summary type provided: %s' % self.summary_type)
+ if self.use_proj:
+ summary = self.proj_layer(summary)
+ summary = self.dropout_layer(summary)
+ return summary
+
+
+class ClassificationLossLayer(tf.keras.layers.Layer):
+ """Layer computing cross entropy loss for classification task."""
+
+ def __init__(self, n_class, initializer, **kwargs):
+ """Constructs Summarization layer.
+
+ Args:
+ n_class: Number of tokens in vocabulary.
+ initializer: Initializer used for parameters.
+ **kwargs: Other parameters.
+ """
+ super(ClassificationLossLayer, self).__init__(**kwargs)
+
+ self.n_class = n_class
+ self.initializer = initializer
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.proj_layer = tf.keras.layers.Dense(
+ units=self.n_class, kernel_initializer=self.initializer, name='logit')
+
+ super(ClassificationLossLayer, self).build(unused_input_shapes)
+
+ def call(self, hidden, labels):
+ """Implements call() for the layer."""
+
+ logits = self.proj_layer(hidden)
+ one_hot_target = tf.one_hot(labels, self.n_class, dtype=hidden.dtype) # pytype: disable=attribute-error
+ loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
+
+ return loss, logits
+
+
+class QAXLNetModel(tf.keras.Model):
+ """XLNet keras model combined with question answering loss layer.
+
+ See the original paper: https://arxiv.org/pdf/1906.08237.pdf
+
+ """
+
+ def __init__(self, xlnet_config, run_config, start_n_top, end_n_top,
+ **kwargs):
+ super(QAXLNetModel, self).__init__(**kwargs)
+ self.run_config = run_config
+ self.initializer = _get_initializer(run_config)
+ self.xlnet_config = copy.deepcopy(xlnet_config)
+
+ self.transformerxl_model = TransformerXLModel(
+ n_token=self.xlnet_config.n_token,
+ initializer=self.initializer,
+ attn_type='bi',
+ n_layer=self.xlnet_config.n_layer,
+ d_model=self.xlnet_config.d_model,
+ n_head=self.xlnet_config.n_head,
+ d_head=self.xlnet_config.d_head,
+ d_inner=self.xlnet_config.d_inner,
+ ff_activation=self.xlnet_config.ff_activation,
+ untie_r=self.xlnet_config.untie_r,
+ is_training=self.run_config.is_training,
+ use_tpu=self.run_config.use_tpu,
+ dropout=self.run_config.dropout,
+ dropout_att=self.run_config.dropout_att,
+ mem_len=self.run_config.mem_len,
+ reuse_len=self.run_config.reuse_len,
+ bi_data=self.run_config.bi_data,
+ clamp_len=self.run_config.clamp_len,
+ same_length=self.run_config.same_length,
+ name='transformer')
+
+ self.qa_loss_layer = QALossLayer(
+ d_model=self.xlnet_config.d_model,
+ start_n_top=start_n_top,
+ end_n_top=end_n_top,
+ initializer=self.initializer,
+ dropout=self.run_config.dropout)
+
+ def call(self, features, training=False):
+ """Implements call() for the layer."""
+
+ input_ids = tf.transpose(features['input_ids'], [1, 0])
+ seg_ids = tf.transpose(features['segment_ids'], [1, 0])
+ input_mask = tf.transpose(features['input_mask'], [1, 0])
+
+ cls_index = tf.reshape(features['cls_index'], [-1])
+ p_mask = features['p_mask']
+
+ transformerxl_output, new_mems, self.lookup_table = (
+ self.transformerxl_model(input_ids, seg_ids, input_mask))
+
+ if training:
+ loss, logits = self.qa_loss_layer(
+ hidden=transformerxl_output,
+ p_mask=p_mask,
+ cls_index=cls_index,
+ start_positions=features['start_positions'],
+ end_positions=features['end_positions'],
+ is_impossible=features['is_impossible'])
+ self.add_loss(loss)
+ return new_mems, logits
+ else:
+ results = self.qa_loss_layer(
+ hidden=transformerxl_output, p_mask=p_mask, cls_index=cls_index)
+ return results
+
+
+class QALossLayer(tf.keras.layers.Layer):
+ """Layer computing position and regression loss for question answering task."""
+
+ def __init__(self, d_model, start_n_top, end_n_top, initializer, dropout,
+ **kwargs):
+ """Constructs Summarization layer.
+
+ Args:
+ d_model: Int, the hidden size.
+ start_n_top: Beam size for span start.
+ end_n_top: Beam size for span end.
+ initializer: Initializer used for parameters.
+ dropout: float, dropout rate.
+ **kwargs: Other parameters.
+ """
+ super(QALossLayer, self).__init__(**kwargs)
+ self.d_model = d_model
+ self.start_n_top = start_n_top
+ self.end_n_top = end_n_top
+ self.initializer = initializer
+ self.dropout = dropout
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.start_logits_proj_layer = tf.keras.layers.Dense(
+ units=1, kernel_initializer=self.initializer, name='start_logits/dense')
+ self.end_logits_proj_layer0 = tf.keras.layers.Dense(
+ units=self.d_model,
+ kernel_initializer=self.initializer,
+ activation=tf.nn.tanh,
+ name='end_logits/dense_0')
+ self.end_logits_proj_layer1 = tf.keras.layers.Dense(
+ units=1, kernel_initializer=self.initializer, name='end_logits/dense_1')
+ self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
+ axis=-1, epsilon=1e-12, name='end_logits/LayerNorm')
+ self.answer_class_proj_layer0 = tf.keras.layers.Dense(
+ units=self.d_model,
+ kernel_initializer=self.initializer,
+ activation=tf.nn.tanh,
+ name='answer_class/dense_0')
+ self.answer_class_proj_layer1 = tf.keras.layers.Dense(
+ units=1,
+ kernel_initializer=self.initializer,
+ use_bias=False,
+ name='answer_class/dense_1')
+ self.ans_feature_dropout = tf.keras.layers.Dropout(rate=self.dropout)
+ super(QALossLayer, self).build(unused_input_shapes)
+
+ def __call__(self, hidden, p_mask, cls_index, **kwargs):
+ return super(QALossLayer, self).__call__(
+ (hidden, p_mask, cls_index, kwargs))
+
+ def call(self, inputs, training=False):
+ """Implements call() for the layer."""
+ hidden, p_mask, cls_index, kwargs = inputs
+ return_dict = {}
+ seq_len = tf.shape(hidden)[0]
+
+ start_logits = self.start_logits_proj_layer(hidden)
+ start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])
+ start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
+ start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)
+ if training:
+ start_positions = kwargs['start_positions']
+ end_positions = kwargs['end_positions']
+ is_impossible = kwargs['is_impossible']
+ start_positions = tf.reshape(start_positions, [-1])
+ start_index = tf.one_hot(
+ start_positions, depth=seq_len, axis=-1, dtype=tf.float32)
+ start_features = tf.einsum('lbh,bl->bh', hidden, start_index)
+ start_features = tf.tile(start_features[None], [seq_len, 1, 1])
+ end_logits = self.end_logits_proj_layer0(
+ tf.concat([hidden, start_features], axis=-1))
+
+ end_logits = self.end_logits_layer_norm(end_logits)
+
+ end_logits = self.end_logits_proj_layer1(end_logits)
+ end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])
+ end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
+ end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
+ else:
+ # during inference, compute the end logits based on beam search
+
+ start_top_log_probs, start_top_index = tf.nn.top_k(
+ start_log_probs, k=self.start_n_top)
+ start_index = tf.one_hot(
+ start_top_index, depth=seq_len, axis=-1, dtype=tf.float32)
+ start_features = tf.einsum('lbh,bkl->bkh', hidden, start_index)
+ end_input = tf.tile(hidden[:, :, None], [1, 1, self.start_n_top, 1])
+ start_features = tf.tile(start_features[None], [seq_len, 1, 1, 1])
+ end_input = tf.concat([end_input, start_features], axis=-1)
+ end_logits = self.end_logits_proj_layer0(end_input)
+ end_logits = tf.reshape(end_logits, [seq_len, -1, self.d_model])
+ end_logits = self.end_logits_layer_norm(end_logits)
+
+ end_logits = tf.reshape(end_logits,
+ [seq_len, -1, self.start_n_top, self.d_model])
+
+ end_logits = self.end_logits_proj_layer1(end_logits)
+ end_logits = tf.reshape(end_logits, [seq_len, -1, self.start_n_top])
+ end_logits = tf.transpose(end_logits, [1, 2, 0])
+ end_logits_masked = end_logits * (
+ 1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
+ end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
+ end_top_log_probs, end_top_index = tf.nn.top_k(
+ end_log_probs, k=self.end_n_top)
+ end_top_log_probs = tf.reshape(end_top_log_probs,
+ [-1, self.start_n_top * self.end_n_top])
+ end_top_index = tf.reshape(end_top_index,
+ [-1, self.start_n_top * self.end_n_top])
+
+ if training:
+ return_dict['start_log_probs'] = start_log_probs
+ return_dict['end_log_probs'] = end_log_probs
+ else:
+ return_dict['start_top_log_probs'] = start_top_log_probs
+ return_dict['start_top_index'] = start_top_index
+ return_dict['end_top_log_probs'] = end_top_log_probs
+ return_dict['end_top_index'] = end_top_index
+ # an additional layer to predict answerability
+
+ # get the representation of CLS
+ cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32)
+ cls_feature = tf.einsum('lbh,bl->bh', hidden, cls_index)
+
+ # get the representation of START
+ start_p = tf.nn.softmax(start_logits_masked, axis=-1, name='softmax_start')
+ start_feature = tf.einsum('lbh,bl->bh', hidden, start_p)
+
+ ans_feature = tf.concat([start_feature, cls_feature], -1)
+ ans_feature = self.answer_class_proj_layer0(ans_feature)
+ ans_feature = self.ans_feature_dropout(ans_feature)
+ cls_logits = self.answer_class_proj_layer1(ans_feature)
+ cls_logits = tf.squeeze(cls_logits, -1)
+ return_dict['cls_logits'] = cls_logits
+
+ if not training:
+ return return_dict
+
+ def compute_loss(log_probs, positions):
+ one_hot_positions = tf.one_hot(positions, depth=seq_len, dtype=tf.float32)
+
+ loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
+ loss = tf.reduce_mean(loss)
+ return loss
+
+ start_loss = compute_loss(start_log_probs, start_positions)
+ end_loss = compute_loss(end_log_probs, end_positions)
+
+ total_loss = (start_loss + end_loss) * 0.5
+
+ is_impossible = tf.reshape(is_impossible, [-1])
+ regression_loss = tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=is_impossible, logits=cls_logits)
+ regression_loss = tf.reduce_mean(regression_loss)
+
+ total_loss += regression_loss * 0.5
+ return total_loss, cls_logits
diff --git a/models/official/nlp/xlnet/xlnet_modeling_test.py b/models/official/nlp/xlnet/xlnet_modeling_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dce887aebd77c75999091af9ec112f8d0d336eee
--- /dev/null
+++ b/models/official/nlp/xlnet/xlnet_modeling_test.py
@@ -0,0 +1,52 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+import numpy as np
+import tensorflow as tf
+
+from official.nlp.xlnet import xlnet_modeling
+
+
+class PositionalEmbeddingLayerTest(tf.test.TestCase):
+
+ def test_positional_embedding(self):
+ """A low-dimensional example is tested.
+
+ With len(pos_seq)=2 and d_model=4:
+
+ pos_seq = [[1.], [0.]]
+ inv_freq = [1., 0.01]
+ pos_seq x inv_freq = [[1, 0.01], [0., 0.]]
+ pos_emb = [[sin(1.), sin(0.01), cos(1.), cos(0.01)],
+ [sin(0.), sin(0.), cos(0.), cos(0.)]]
+ = [[0.84147096, 0.00999983, 0.54030228, 0.99994999],
+ [0., 0., 1., 1.]]
+ """
+ target = np.array([[[0.84147096, 0.00999983, 0.54030228, 0.99994999]],
+ [[0., 0., 1., 1.]]])
+ d_model = 4
+ pos_seq = tf.range(1, -1, -1.0) # [1., 0.]
+ pos_emb_layer = xlnet_modeling.PositionalEmbedding(d_model)
+ pos_emb = pos_emb_layer(pos_seq, batch_size=None).numpy().astype(float)
+
+ logging.info(pos_emb)
+ self.assertAllClose(pos_emb, target)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/pip_package/setup.py b/models/official/pip_package/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..8433f25321c6f28c1c0c51797154633aa1a3ec71
--- /dev/null
+++ b/models/official/pip_package/setup.py
@@ -0,0 +1,89 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sets up TensorFlow Official Models."""
+import datetime
+import os
+import sys
+
+from setuptools import find_packages
+from setuptools import setup
+
+version = '2.3.0'
+
+project_name = 'tf-models-official'
+
+long_description = """The TensorFlow official models are a collection of
+models that use TensorFlow's high-level APIs.
+They are intended to be well-maintained, tested, and kept up to date with the
+latest TensorFlow API. They should also be reasonably optimized for fast
+performance while still being easy to read."""
+
+if '--project_name' in sys.argv:
+ project_name_idx = sys.argv.index('--project_name')
+ project_name = sys.argv[project_name_idx + 1]
+ sys.argv.remove('--project_name')
+ sys.argv.pop(project_name_idx)
+
+
+def _get_requirements():
+ """Parses requirements.txt file."""
+ install_requires_tmp = []
+ dependency_links_tmp = []
+ with open(
+ os.path.join(os.path.dirname(__file__), '../requirements.txt'), 'r') as f:
+ for line in f:
+ package_name = line.strip()
+ if package_name.startswith('-e '):
+ dependency_links_tmp.append(package_name[3:].strip())
+ else:
+ install_requires_tmp.append(package_name)
+ return install_requires_tmp, dependency_links_tmp
+
+install_requires, dependency_links = _get_requirements()
+
+if project_name == 'tf-models-nightly':
+ version += '.dev' + datetime.datetime.now().strftime('%Y%m%d')
+ install_requires.append('tf-nightly')
+else:
+ install_requires.append('tensorflow>=2.3.0')
+
+print('install_requires: ', install_requires)
+print('dependency_links: ', dependency_links)
+
+setup(
+ name=project_name,
+ version=version,
+ description='TensorFlow Official Models',
+ long_description=long_description,
+ author='Google Inc.',
+ author_email='no-reply@google.com',
+ url='https://github.com/tensorflow/models',
+ license='Apache 2.0',
+ packages=find_packages(exclude=[
+ 'research*',
+ 'tutorials*',
+ 'samples*',
+ 'official.r1*',
+ 'official.pip_package*',
+ 'official.benchmark*',
+ 'official.colab*',
+ ]),
+ exclude_package_data={
+ '': ['*_test.py',],
+ },
+ install_requires=install_requires,
+ dependency_links=dependency_links,
+ python_requires='>=3.6',
+)
diff --git a/models/official/recommendation/README.md b/models/official/recommendation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..441bc128681c3189b53f7909b22c70fccf564414
--- /dev/null
+++ b/models/official/recommendation/README.md
@@ -0,0 +1,72 @@
+# Recommendation Model
+## Overview
+This is an implementation of the Neural Collaborative Filtering (NCF) framework with Neural Matrix Factorization (NeuMF) model as described in the [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) paper. Current implementation is based on the code from the authors' [NCF code](https://github.com/hexiangnan/neural_collaborative_filtering) and the Stanford implementation in the [MLPerf Repo](https://github.com/mlperf/reference/tree/master/recommendation/pytorch).
+
+NCF is a general framework for collaborative filtering of recommendations in which a neural network architecture is used to model user-item interactions. Unlike traditional models, NCF does not resort to Matrix Factorization (MF) with an inner product on latent features of users and items. It replaces the inner product with a multi-layer perceptron that can learn an arbitrary function from data.
+
+Two instantiations of NCF are Generalized Matrix Factorization (GMF) and Multi-Layer Perceptron (MLP). GMF applies a linear kernel to model the latent feature interactions, and and MLP uses a nonlinear kernel to learn the interaction function from data. NeuMF is a fused model of GMF and MLP to better model the complex user-item interactions, and unifies the strengths of linearity of MF and non-linearity of MLP for modeling the user-item latent structures. NeuMF allows GMF and MLP to learn separate embeddings, and combines the two models by concatenating their last hidden layer. [neumf_model.py](neumf_model.py) defines the architecture details.
+
+Some abbreviations used the code base include:
+ - NCF: Neural Collaborative Filtering
+ - NeuMF: Neural Matrix Factorization
+ - GMF: Generalized Matrix Factorization
+ - MLP: Multi-Layer Perceptron
+ - HR: Hit Ratio (HR)
+ - NDCG: Normalized Discounted Cumulative Gain
+ - ml-1m: MovieLens 1 million dataset
+ - ml-20m: MovieLens 20 million dataset
+
+## Dataset
+The [MovieLens datasets](http://files.grouplens.org/datasets/movielens/) are used for model training and evaluation. Specifically, we use two datasets: **ml-1m** (short for MovieLens 1 million) and **ml-20m** (short for MovieLens 20 million).
+
+### ml-1m
+ml-1m dataset contains 1,000,209 anonymous ratings of approximately 3,706 movies made by 6,040 users who joined MovieLens in 2000. All ratings are contained in the file "ratings.dat" without header row, and are in the following format:
+```
+ UserID::MovieID::Rating::Timestamp
+```
+ - UserIDs range between 1 and 6040.
+ - MovieIDs range between 1 and 3952.
+ - Ratings are made on a 5-star scale (whole-star ratings only).
+
+### ml-20m
+ml-20m dataset contains 20,000,263 ratings of 26,744 movies by 138493 users. All ratings are contained in the file "ratings.csv". Each line of this file after the header row represents one rating of one movie by one user, and has the following format:
+```
+userId,movieId,rating,timestamp
+```
+ - The lines within this file are ordered first by userId, then, within user, by movieId.
+ - Ratings are made on a 5-star scale, with half-star increments (0.5 stars - 5.0 stars).
+
+In both datasets, the timestamp is represented in seconds since midnight Coordinated Universal Time (UTC) of January 1, 1970. Each user has at least 20 ratings.
+
+## Running Code
+
+### Download and preprocess dataset
+To download the dataset, please install Pandas package first. Then issue the following command:
+```
+python movielens.py
+```
+Arguments:
+ * `--data_dir`: Directory where to download and save the preprocessed data. By default, it is `/tmp/movielens-data/`.
+ * `--dataset`: The dataset name to be downloaded and preprocessed. By default, it is `ml-1m`.
+
+Use the `--help` or `-h` flag to get a full list of possible arguments.
+
+Note the ml-20m dataset is large (the rating file is ~500 MB), and it may take several minutes (~2 mins) for data preprocessing.
+Both the ml-1m and ml-20m datasets will be coerced into a common format when downloaded.
+
+### Train and evaluate model
+
+[ncf_keras_main.py](ncf_keras_main.py) is the Keras trainer that supports
+features in TF 2.x. Users can train the model on both GPU and TPU.
+
+To train and evaluate the model, issue the following command:
+```
+python ncf_keras_main.py
+```
+Arguments:
+ * `--model_dir`: Directory to save model training checkpoints. By default, it is `/tmp/ncf/`.
+ * `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
+ * `--dataset`: The dataset name to be downloaded and preprocessed. By default, it is `ml-1m`.
+ * `--num_gpus`: The number of GPUs used for training/evaluation of the model. Use CPU if this flag is 0. By default, it is 1.
+
+There are other arguments about models and training process. Refer to the [Flags package](https://abseil.io/docs/python/guides/flags) documentation or use the `--helpfull` flag to get a full list of possible arguments with detailed descriptions.
diff --git a/models/official/recommendation/__init__.py b/models/official/recommendation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/recommendation/constants.py b/models/official/recommendation/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e313bfa66a2133862e79dbad89f03421fee39c5
--- /dev/null
+++ b/models/official/recommendation/constants.py
@@ -0,0 +1,79 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Central location for NCF specific values."""
+
+import sys
+
+import numpy as np
+
+from official.recommendation import movielens
+
+# ==============================================================================
+# == Main Thread Data Processing ===============================================
+# ==============================================================================
+
+# Keys for data shards
+TRAIN_USER_KEY = "train_{}".format(movielens.USER_COLUMN)
+TRAIN_ITEM_KEY = "train_{}".format(movielens.ITEM_COLUMN)
+TRAIN_LABEL_KEY = "train_labels"
+MASK_START_INDEX = "mask_start_index"
+VALID_POINT_MASK = "valid_point_mask"
+EVAL_USER_KEY = "eval_{}".format(movielens.USER_COLUMN)
+EVAL_ITEM_KEY = "eval_{}".format(movielens.ITEM_COLUMN)
+
+USER_MAP = "user_map"
+ITEM_MAP = "item_map"
+
+USER_DTYPE = np.int32
+ITEM_DTYPE = np.int32
+
+# In both datasets, each user has at least 20 ratings.
+MIN_NUM_RATINGS = 20
+
+# The number of negative examples attached with a positive example
+# when performing evaluation.
+NUM_EVAL_NEGATIVES = 999
+
+# keys for evaluation metrics
+TOP_K = 10 # Top-k list for evaluation
+HR_KEY = "HR"
+NDCG_KEY = "NDCG"
+DUPLICATE_MASK = "duplicate_mask"
+
+# Metric names
+HR_METRIC_NAME = "HR_METRIC"
+NDCG_METRIC_NAME = "NDCG_METRIC"
+
+# Trying to load a cache created in py2 when running in py3 will cause an
+# error due to differences in unicode handling.
+RAW_CACHE_FILE = "raw_data_cache_py{}.pickle".format(sys.version_info[0])
+CACHE_INVALIDATION_SEC = 3600 * 24
+
+# ==============================================================================
+# == Data Generation ===========================================================
+# ==============================================================================
+CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead"
+ # of the main training loop.
+
+# Number of batches to run per epoch when using synthetic data. At high batch
+# sizes, we run for more batches than with real data, which is good since
+# running more batches reduces noise when measuring the average batches/second.
+SYNTHETIC_BATCHES_PER_EPOCH = 2000
+
+# Only used when StreamingFilesDataset is used.
+NUM_FILE_SHARDS = 16
+TRAIN_FOLDER_TEMPLATE = "training_cycle_{}"
+EVAL_FOLDER = "eval_data"
+SHARD_TEMPLATE = "shard_{}.tfrecords"
diff --git a/models/official/recommendation/create_ncf_data.py b/models/official/recommendation/create_ncf_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..60267bcd5f77ec7cb2036cb2037efe9360d692ba
--- /dev/null
+++ b/models/official/recommendation/create_ncf_data.py
@@ -0,0 +1,117 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Binary to generate training/evaluation dataset for NCF model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+# pylint: disable=g-bad-import-order
+from absl import app
+from absl import flags
+import tensorflow.compat.v2 as tf
+# pylint: enable=g-bad-import-order
+
+from official.recommendation import movielens
+from official.recommendation import data_preprocessing
+
+flags.DEFINE_string(
+ "data_dir", None,
+ "The input data dir at which training and evaluation tf record files "
+ "will be saved.")
+flags.DEFINE_string("meta_data_file_path", None,
+ "The path in which input meta data will be written.")
+flags.DEFINE_enum("dataset", "ml-20m", ["ml-1m", "ml-20m"],
+ "Dataset to be trained/evaluated.")
+flags.DEFINE_enum(
+ "constructor_type", "bisection", ["bisection", "materialized"],
+ "Strategy to use for generating false negatives. materialized has a "
+ "precompute that scales badly, but a faster per-epoch construction "
+ "time and can be faster on very large systems.")
+flags.DEFINE_integer("num_train_epochs", 14,
+ "Total number of training epochs to generate.")
+flags.DEFINE_integer(
+ "num_negative_samples", 4,
+ "Number of negative instances to pair with positive instance.")
+flags.DEFINE_integer(
+ "train_prebatch_size", 99000,
+ "Batch size to be used for prebatching the dataset "
+ "for training.")
+flags.DEFINE_integer(
+ "eval_prebatch_size", 99000,
+ "Batch size to be used for prebatching the dataset "
+ "for training.")
+
+FLAGS = flags.FLAGS
+
+
+def prepare_raw_data(flag_obj):
+ """Downloads and prepares raw data for data generation."""
+ movielens.download(flag_obj.dataset, flag_obj.data_dir)
+
+ data_processing_params = {
+ "train_epochs": flag_obj.num_train_epochs,
+ "batch_size": flag_obj.train_prebatch_size,
+ "eval_batch_size": flag_obj.eval_prebatch_size,
+ "batches_per_step": 1,
+ "stream_files": True,
+ "num_neg": flag_obj.num_negative_samples,
+ }
+
+ num_users, num_items, producer = data_preprocessing.instantiate_pipeline(
+ dataset=flag_obj.dataset,
+ data_dir=flag_obj.data_dir,
+ params=data_processing_params,
+ constructor_type=flag_obj.constructor_type,
+ epoch_dir=flag_obj.data_dir,
+ generate_data_offline=True)
+
+ # pylint: disable=protected-access
+ input_metadata = {
+ "num_users": num_users,
+ "num_items": num_items,
+ "constructor_type": flag_obj.constructor_type,
+ "num_train_elements": producer._elements_in_epoch,
+ "num_eval_elements": producer._eval_elements_in_epoch,
+ "num_train_epochs": flag_obj.num_train_epochs,
+ "train_prebatch_size": flag_obj.train_prebatch_size,
+ "eval_prebatch_size": flag_obj.eval_prebatch_size,
+ "num_train_steps": producer.train_batches_per_epoch,
+ "num_eval_steps": producer.eval_batches_per_epoch,
+ }
+ # pylint: enable=protected-access
+
+ return producer, input_metadata
+
+
+def generate_data():
+ """Creates NCF train/eval dataset and writes input metadata as a file."""
+ producer, input_metadata = prepare_raw_data(FLAGS)
+ producer.run()
+
+ with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
+ writer.write(json.dumps(input_metadata, indent=4) + "\n")
+
+
+def main(_):
+ generate_data()
+
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("data_dir")
+ flags.mark_flag_as_required("meta_data_file_path")
+ app.run(main)
diff --git a/models/official/recommendation/data_pipeline.py b/models/official/recommendation/data_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b4dd33afe25df2468cdfcbb2c146392d7bec76e
--- /dev/null
+++ b/models/official/recommendation/data_pipeline.py
@@ -0,0 +1,959 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Asynchronous data producer for the NCF pipeline."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import atexit
+import functools
+import os
+import sys
+import tempfile
+import threading
+import time
+import timeit
+import traceback
+import typing
+
+import numpy as np
+import six
+from six.moves import queue
+import tensorflow as tf
+from absl import logging
+
+from official.recommendation import constants as rconst
+from official.recommendation import movielens
+from official.recommendation import popen_helper
+from official.recommendation import stat_utils
+from tensorflow.python.tpu.datasets import StreamingFilesDataset
+
+
+SUMMARY_TEMPLATE = """General:
+{spacer}Num users: {num_users}
+{spacer}Num items: {num_items}
+
+Training:
+{spacer}Positive count: {train_pos_ct}
+{spacer}Batch size: {train_batch_size} {multiplier}
+{spacer}Batch count per epoch: {train_batch_ct}
+
+Eval:
+{spacer}Positive count: {eval_pos_ct}
+{spacer}Batch size: {eval_batch_size} {multiplier}
+{spacer}Batch count per epoch: {eval_batch_ct}"""
+
+
+class DatasetManager(object):
+ """Helper class for handling TensorFlow specific data tasks.
+
+ This class takes the (relatively) framework agnostic work done by the data
+ constructor classes and handles the TensorFlow specific portions (TFRecord
+ management, tf.Dataset creation, etc.).
+ """
+
+ def __init__(self,
+ is_training,
+ stream_files,
+ batches_per_epoch,
+ shard_root=None,
+ deterministic=False,
+ num_train_epochs=None):
+ # type: (bool, bool, int, typing.Optional[str], bool, int) -> None
+ """Constructs a `DatasetManager` instance.
+ Args:
+ is_training: Boolean of whether the data provided is training or
+ evaluation data. This determines whether to reuse the data
+ (if is_training=False) and the exact structure to use when storing and
+ yielding data.
+ stream_files: Boolean indicating whether data should be serialized and
+ written to file shards.
+ batches_per_epoch: The number of batches in a single epoch.
+ shard_root: The base directory to be used when stream_files=True.
+ deterministic: Forgo non-deterministic speedups. (i.e. sloppy=True)
+ num_train_epochs: Number of epochs to generate. If None, then each
+ call to `get_dataset()` increments the number of epochs requested.
+ """
+ self._is_training = is_training
+ self._deterministic = deterministic
+ self._stream_files = stream_files
+ self._writers = []
+ self._write_locks = [threading.RLock() for _ in
+ range(rconst.NUM_FILE_SHARDS)] if stream_files else []
+ self._batches_per_epoch = batches_per_epoch
+ self._epochs_completed = 0
+ self._epochs_requested = num_train_epochs if num_train_epochs else 0
+ self._shard_root = shard_root
+
+ self._result_queue = queue.Queue()
+ self._result_reuse = []
+
+ @property
+ def current_data_root(self):
+ subdir = (rconst.TRAIN_FOLDER_TEMPLATE.format(self._epochs_completed)
+ if self._is_training else rconst.EVAL_FOLDER)
+ return os.path.join(self._shard_root, subdir)
+
+ def buffer_reached(self):
+ # Only applicable for training.
+ return (self._epochs_completed - self._epochs_requested >=
+ rconst.CYCLES_TO_BUFFER and self._is_training)
+
+ @staticmethod
+ def serialize(data):
+ """Convert NumPy arrays into a TFRecords entry."""
+
+ def create_int_feature(values):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+
+ feature_dict = {
+ k: create_int_feature(v.astype(np.int64)) for k, v in data.items()
+ }
+
+ return tf.train.Example(
+ features=tf.train.Features(feature=feature_dict)).SerializeToString()
+
+ @staticmethod
+ def deserialize(serialized_data, batch_size=None, is_training=True):
+ """Convert serialized TFRecords into tensors.
+
+ Args:
+ serialized_data: A tensor containing serialized records.
+ batch_size: The data arrives pre-batched, so batch size is needed to
+ deserialize the data.
+ is_training: Boolean, whether data to deserialize to training data
+ or evaluation data.
+ """
+
+ def _get_feature_map(batch_size, is_training=True):
+ """Returns data format of the serialized tf record file."""
+
+ if is_training:
+ return {
+ movielens.USER_COLUMN:
+ tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
+ movielens.ITEM_COLUMN:
+ tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
+ rconst.VALID_POINT_MASK:
+ tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
+ "labels":
+ tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64)
+ }
+ else:
+ return {
+ movielens.USER_COLUMN:
+ tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
+ movielens.ITEM_COLUMN:
+ tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
+ rconst.DUPLICATE_MASK:
+ tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64)
+ }
+
+ features = tf.io.parse_single_example(
+ serialized_data, _get_feature_map(batch_size, is_training=is_training))
+ users = tf.cast(features[movielens.USER_COLUMN], rconst.USER_DTYPE)
+ items = tf.cast(features[movielens.ITEM_COLUMN], rconst.ITEM_DTYPE)
+
+ if is_training:
+ valid_point_mask = tf.cast(features[rconst.VALID_POINT_MASK], tf.bool)
+ fake_dup_mask = tf.zeros_like(users)
+ return {
+ movielens.USER_COLUMN: users,
+ movielens.ITEM_COLUMN: items,
+ rconst.VALID_POINT_MASK: valid_point_mask,
+ rconst.TRAIN_LABEL_KEY:
+ tf.reshape(tf.cast(features["labels"], tf.bool),
+ (batch_size, 1)),
+ rconst.DUPLICATE_MASK: fake_dup_mask
+ }
+ else:
+ labels = tf.cast(tf.zeros_like(users), tf.bool)
+ fake_valid_pt_mask = tf.cast(tf.zeros_like(users), tf.bool)
+ return {
+ movielens.USER_COLUMN:
+ users,
+ movielens.ITEM_COLUMN:
+ items,
+ rconst.DUPLICATE_MASK:
+ tf.cast(features[rconst.DUPLICATE_MASK], tf.bool),
+ rconst.VALID_POINT_MASK:
+ fake_valid_pt_mask,
+ rconst.TRAIN_LABEL_KEY:
+ labels
+ }
+
+ def put(self, index, data):
+ # type: (int, dict) -> None
+ """Store data for later consumption.
+
+ Because there are several paths for storing and yielding data (queues,
+ lists, files) the data producer simply provides the data in a standard
+ format at which point the dataset manager handles storing it in the correct
+ form.
+
+ Args:
+ index: Used to select shards when writing to files.
+ data: A dict of the data to be stored. This method mutates data, and
+ therefore expects to be the only consumer.
+ """
+ if self._is_training:
+ mask_start_index = data.pop(rconst.MASK_START_INDEX)
+ batch_size = data[movielens.ITEM_COLUMN].shape[0]
+ data[rconst.VALID_POINT_MASK] = np.expand_dims(
+ np.less(np.arange(batch_size), mask_start_index), -1)
+
+ if self._stream_files:
+ example_bytes = self.serialize(data)
+ with self._write_locks[index % rconst.NUM_FILE_SHARDS]:
+ self._writers[index % rconst.NUM_FILE_SHARDS].write(example_bytes)
+
+ else:
+ self._result_queue.put((
+ data, data.pop("labels")) if self._is_training else data)
+
+ def start_construction(self):
+ if self._stream_files:
+ tf.io.gfile.makedirs(self.current_data_root)
+ template = os.path.join(self.current_data_root, rconst.SHARD_TEMPLATE)
+ self._writers = [tf.io.TFRecordWriter(template.format(i))
+ for i in range(rconst.NUM_FILE_SHARDS)]
+
+ def end_construction(self):
+ if self._stream_files:
+ [writer.close() for writer in self._writers]
+ self._writers = []
+ self._result_queue.put(self.current_data_root)
+
+ self._epochs_completed += 1
+
+ def data_generator(self, epochs_between_evals):
+ """Yields examples during local training."""
+ assert not self._stream_files
+ assert self._is_training or epochs_between_evals == 1
+
+ if self._is_training:
+ for _ in range(self._batches_per_epoch * epochs_between_evals):
+ yield self._result_queue.get(timeout=300)
+
+ else:
+ if self._result_reuse:
+ assert len(self._result_reuse) == self._batches_per_epoch
+
+ for i in self._result_reuse:
+ yield i
+ else:
+ # First epoch.
+ for _ in range(self._batches_per_epoch * epochs_between_evals):
+ result = self._result_queue.get(timeout=300)
+ self._result_reuse.append(result)
+ yield result
+
+ def increment_request_epoch(self):
+ self._epochs_requested += 1
+
+ def get_dataset(self, batch_size, epochs_between_evals):
+ """Construct the dataset to be used for training and eval.
+
+ For local training, data is provided through Dataset.from_generator. For
+ remote training (TPUs) the data is first serialized to files and then sent
+ to the TPU through a StreamingFilesDataset.
+
+ Args:
+ batch_size: The per-replica batch size of the dataset.
+ epochs_between_evals: How many epochs worth of data to yield.
+ (Generator mode only.)
+ """
+ self.increment_request_epoch()
+ if self._stream_files:
+ if epochs_between_evals > 1:
+ raise ValueError("epochs_between_evals > 1 not supported for file "
+ "based dataset.")
+ epoch_data_dir = self._result_queue.get(timeout=300)
+ if not self._is_training:
+ self._result_queue.put(epoch_data_dir) # Eval data is reused.
+
+ file_pattern = os.path.join(
+ epoch_data_dir, rconst.SHARD_TEMPLATE.format("*"))
+ dataset = StreamingFilesDataset(
+ files=file_pattern, worker_job=popen_helper.worker_job(),
+ num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
+ sloppy=not self._deterministic)
+ map_fn = functools.partial(
+ self.deserialize,
+ batch_size=batch_size,
+ is_training=self._is_training)
+ dataset = dataset.map(map_fn, num_parallel_calls=16)
+
+ else:
+ types = {movielens.USER_COLUMN: rconst.USER_DTYPE,
+ movielens.ITEM_COLUMN: rconst.ITEM_DTYPE}
+ shapes = {
+ movielens.USER_COLUMN: tf.TensorShape([batch_size, 1]),
+ movielens.ITEM_COLUMN: tf.TensorShape([batch_size, 1])
+ }
+
+ if self._is_training:
+ types[rconst.VALID_POINT_MASK] = np.bool
+ shapes[rconst.VALID_POINT_MASK] = tf.TensorShape([batch_size, 1])
+
+ types = (types, np.bool)
+ shapes = (shapes, tf.TensorShape([batch_size, 1]))
+
+ else:
+ types[rconst.DUPLICATE_MASK] = np.bool
+ shapes[rconst.DUPLICATE_MASK] = tf.TensorShape([batch_size, 1])
+
+ data_generator = functools.partial(
+ self.data_generator, epochs_between_evals=epochs_between_evals)
+ dataset = tf.data.Dataset.from_generator(
+ generator=data_generator, output_types=types,
+ output_shapes=shapes)
+
+ return dataset.prefetch(16)
+
+ def make_input_fn(self, batch_size):
+ """Create an input_fn which checks for batch size consistency."""
+
+ def input_fn(params):
+ """Returns batches for training."""
+
+ # Estimator passes batch_size during training and eval_batch_size during
+ # eval.
+ param_batch_size = (params["batch_size"] if self._is_training else
+ params.get("eval_batch_size") or params["batch_size"])
+ if batch_size != param_batch_size:
+ raise ValueError("producer batch size ({}) differs from params batch "
+ "size ({})".format(batch_size, param_batch_size))
+
+ epochs_between_evals = (params.get("epochs_between_evals", 1)
+ if self._is_training else 1)
+ return self.get_dataset(batch_size=batch_size,
+ epochs_between_evals=epochs_between_evals)
+
+ return input_fn
+
+
+class BaseDataConstructor(threading.Thread):
+ """Data constructor base class.
+
+ This class manages the control flow for constructing data. It is not meant
+ to be used directly, but instead subclasses should implement the following
+ two methods:
+
+ self.construct_lookup_variables
+ self.lookup_negative_items
+
+ """
+
+ def __init__(
+ self,
+ maximum_number_epochs, # type: int
+ num_users, # type: int
+ num_items, # type: int
+ user_map, # type: dict
+ item_map, # type: dict
+ train_pos_users, # type: np.ndarray
+ train_pos_items, # type: np.ndarray
+ train_batch_size, # type: int
+ batches_per_train_step, # type: int
+ num_train_negatives, # type: int
+ eval_pos_users, # type: np.ndarray
+ eval_pos_items, # type: np.ndarray
+ eval_batch_size, # type: int
+ batches_per_eval_step, # type: int
+ stream_files, # type: bool
+ deterministic=False, # type: bool
+ epoch_dir=None, # type: str
+ num_train_epochs=None, # type: int
+ create_data_offline=False # type: bool
+ ):
+ # General constants
+ self._maximum_number_epochs = maximum_number_epochs
+ self._num_users = num_users
+ self._num_items = num_items
+ self.user_map = user_map
+ self.item_map = item_map
+ self._train_pos_users = train_pos_users
+ self._train_pos_items = train_pos_items
+ self.train_batch_size = train_batch_size
+ self._num_train_negatives = num_train_negatives
+ self._batches_per_train_step = batches_per_train_step
+ self._eval_pos_users = eval_pos_users
+ self._eval_pos_items = eval_pos_items
+ self.eval_batch_size = eval_batch_size
+ self.num_train_epochs = num_train_epochs
+ self.create_data_offline = create_data_offline
+
+ # Training
+ if self._train_pos_users.shape != self._train_pos_items.shape:
+ raise ValueError(
+ "User positives ({}) is different from item positives ({})".format(
+ self._train_pos_users.shape, self._train_pos_items.shape))
+
+ (self._train_pos_count,) = self._train_pos_users.shape
+ self._elements_in_epoch = (1 + num_train_negatives) * self._train_pos_count
+ self.train_batches_per_epoch = self._count_batches(
+ self._elements_in_epoch, train_batch_size, batches_per_train_step)
+
+ # Evaluation
+ if eval_batch_size % (1 + rconst.NUM_EVAL_NEGATIVES):
+ raise ValueError("Eval batch size {} is not divisible by {}".format(
+ eval_batch_size, 1 + rconst.NUM_EVAL_NEGATIVES))
+ self._eval_users_per_batch = int(
+ eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
+ self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
+ self.eval_batches_per_epoch = self._count_batches(
+ self._eval_elements_in_epoch, eval_batch_size, batches_per_eval_step)
+
+ # Intermediate artifacts
+ self._current_epoch_order = np.empty(shape=(0,))
+ self._shuffle_iterator = None
+
+ self._shuffle_with_forkpool = not stream_files
+ if stream_files:
+ self._shard_root = epoch_dir or tempfile.mkdtemp(prefix="ncf_")
+ if not create_data_offline:
+ atexit.register(tf.io.gfile.rmtree, self._shard_root)
+ else:
+ self._shard_root = None
+
+ self._train_dataset = DatasetManager(True, stream_files,
+ self.train_batches_per_epoch,
+ self._shard_root, deterministic,
+ num_train_epochs)
+ self._eval_dataset = DatasetManager(False, stream_files,
+ self.eval_batches_per_epoch,
+ self._shard_root, deterministic,
+ num_train_epochs)
+
+ # Threading details
+ super(BaseDataConstructor, self).__init__()
+ self.daemon = True
+ self._stop_loop = False
+ self._fatal_exception = None
+ self.deterministic = deterministic
+
+ def __str__(self):
+ multiplier = ("(x{} devices)".format(self._batches_per_train_step)
+ if self._batches_per_train_step > 1 else "")
+ summary = SUMMARY_TEMPLATE.format(
+ spacer=" ", num_users=self._num_users, num_items=self._num_items,
+ train_pos_ct=self._train_pos_count,
+ train_batch_size=self.train_batch_size,
+ train_batch_ct=self.train_batches_per_epoch,
+ eval_pos_ct=self._num_users, eval_batch_size=self.eval_batch_size,
+ eval_batch_ct=self.eval_batches_per_epoch, multiplier=multiplier)
+ return super(BaseDataConstructor, self).__str__() + "\n" + summary
+
+ @staticmethod
+ def _count_batches(example_count, batch_size, batches_per_step):
+ """Determine the number of batches, rounding up to fill all devices."""
+ x = (example_count + batch_size - 1) // batch_size
+ return (x + batches_per_step - 1) // batches_per_step * batches_per_step
+
+ def stop_loop(self):
+ self._stop_loop = True
+
+ def construct_lookup_variables(self):
+ """Perform any one time pre-compute work."""
+ raise NotImplementedError
+
+ def lookup_negative_items(self, **kwargs):
+ """Randomly sample negative items for given users."""
+ raise NotImplementedError
+
+ def _run(self):
+ atexit.register(self.stop_loop)
+ self._start_shuffle_iterator()
+ self.construct_lookup_variables()
+ self._construct_training_epoch()
+ self._construct_eval_epoch()
+ for _ in range(self._maximum_number_epochs - 1):
+ self._construct_training_epoch()
+ self.stop_loop()
+
+ def run(self):
+ try:
+ self._run()
+ except Exception as e:
+ # The Thread base class swallows stack traces, so unfortunately it is
+ # necessary to catch and re-raise to get debug output
+ traceback.print_exc()
+ self._fatal_exception = e
+ sys.stderr.flush()
+ raise
+
+ def _start_shuffle_iterator(self):
+ if self._shuffle_with_forkpool:
+ pool = popen_helper.get_forkpool(3, closing=False)
+ else:
+ pool = popen_helper.get_threadpool(1, closing=False)
+ atexit.register(pool.close)
+ args = [(self._elements_in_epoch, stat_utils.random_int32())
+ for _ in range(self._maximum_number_epochs)]
+ imap = pool.imap if self.deterministic else pool.imap_unordered
+ self._shuffle_iterator = imap(stat_utils.permutation, args)
+
+ def _get_training_batch(self, i):
+ """Construct a single batch of training data.
+
+ Args:
+ i: The index of the batch. This is used when stream_files=True to assign
+ data to file shards.
+ """
+ batch_indices = self._current_epoch_order[i * self.train_batch_size:
+ (i + 1) * self.train_batch_size]
+ (mask_start_index,) = batch_indices.shape
+
+ batch_ind_mod = np.mod(batch_indices, self._train_pos_count)
+ users = self._train_pos_users[batch_ind_mod]
+
+ negative_indices = np.greater_equal(batch_indices, self._train_pos_count)
+ negative_users = users[negative_indices]
+
+ negative_items = self.lookup_negative_items(negative_users=negative_users)
+
+ items = self._train_pos_items[batch_ind_mod]
+ items[negative_indices] = negative_items
+
+ labels = np.logical_not(negative_indices)
+
+ # Pad last partial batch
+ pad_length = self.train_batch_size - mask_start_index
+ if pad_length:
+ # We pad with arange rather than zeros because the network will still
+ # compute logits for padded examples, and padding with zeros would create
+ # a very "hot" embedding key which can have performance implications.
+ user_pad = np.arange(pad_length, dtype=users.dtype) % self._num_users
+ item_pad = np.arange(pad_length, dtype=items.dtype) % self._num_items
+ label_pad = np.zeros(shape=(pad_length,), dtype=labels.dtype)
+ users = np.concatenate([users, user_pad])
+ items = np.concatenate([items, item_pad])
+ labels = np.concatenate([labels, label_pad])
+
+ self._train_dataset.put(
+ i, {
+ movielens.USER_COLUMN:
+ np.reshape(users, (self.train_batch_size, 1)),
+ movielens.ITEM_COLUMN:
+ np.reshape(items, (self.train_batch_size, 1)),
+ rconst.MASK_START_INDEX:
+ np.array(mask_start_index, dtype=np.int32),
+ "labels":
+ np.reshape(labels, (self.train_batch_size, 1)),
+ })
+
+ def _wait_to_construct_train_epoch(self):
+ count = 0
+ while self._train_dataset.buffer_reached() and not self._stop_loop:
+ time.sleep(0.01)
+ count += 1
+ if count >= 100 and np.log10(count) == np.round(np.log10(count)):
+ logging.info(
+ "Waited {} times for training data to be consumed".format(count))
+
+ def _construct_training_epoch(self):
+ """Loop to construct a batch of training data."""
+ if not self.create_data_offline:
+ self._wait_to_construct_train_epoch()
+
+ start_time = timeit.default_timer()
+ if self._stop_loop:
+ return
+
+ self._train_dataset.start_construction()
+ map_args = list(range(self.train_batches_per_epoch))
+ self._current_epoch_order = next(self._shuffle_iterator)
+
+ get_pool = (popen_helper.get_fauxpool if self.deterministic else
+ popen_helper.get_threadpool)
+ with get_pool(6) as pool:
+ pool.map(self._get_training_batch, map_args)
+ self._train_dataset.end_construction()
+
+ logging.info("Epoch construction complete. Time: {:.1f} seconds".format(
+ timeit.default_timer() - start_time))
+
+ @staticmethod
+ def _assemble_eval_batch(users, positive_items, negative_items,
+ users_per_batch):
+ """Construct duplicate_mask and structure data accordingly.
+
+ The positive items should be last so that they lose ties. However, they
+ should not be masked out if the true eval positive happens to be
+ selected as a negative. So instead, the positive is placed in the first
+ position, and then switched with the last element after the duplicate
+ mask has been computed.
+
+ Args:
+ users: An array of users in a batch. (should be identical along axis 1)
+ positive_items: An array (batch_size x 1) of positive item indices.
+ negative_items: An array of negative item indices.
+ users_per_batch: How many users should be in the batch. This is passed
+ as an argument so that ncf_test.py can use this method.
+
+ Returns:
+ User, item, and duplicate_mask arrays.
+ """
+ items = np.concatenate([positive_items, negative_items], axis=1)
+
+ # We pad the users and items here so that the duplicate mask calculation
+ # will include padding. The metric function relies on all padded elements
+ # except the positive being marked as duplicate to mask out padded points.
+ if users.shape[0] < users_per_batch:
+ pad_rows = users_per_batch - users.shape[0]
+ padding = np.zeros(shape=(pad_rows, users.shape[1]), dtype=np.int32)
+ users = np.concatenate([users, padding.astype(users.dtype)], axis=0)
+ items = np.concatenate([items, padding.astype(items.dtype)], axis=0)
+
+ duplicate_mask = stat_utils.mask_duplicates(items, axis=1).astype(np.bool)
+
+ items[:, (0, -1)] = items[:, (-1, 0)]
+ duplicate_mask[:, (0, -1)] = duplicate_mask[:, (-1, 0)]
+
+ assert users.shape == items.shape == duplicate_mask.shape
+ return users, items, duplicate_mask
+
+ def _get_eval_batch(self, i):
+ """Construct a single batch of evaluation data.
+
+ Args:
+ i: The index of the batch.
+ """
+ low_index = i * self._eval_users_per_batch
+ high_index = (i + 1) * self._eval_users_per_batch
+ users = np.repeat(self._eval_pos_users[low_index:high_index, np.newaxis],
+ 1 + rconst.NUM_EVAL_NEGATIVES, axis=1)
+ positive_items = self._eval_pos_items[low_index:high_index, np.newaxis]
+ negative_items = (self.lookup_negative_items(negative_users=users[:, :-1])
+ .reshape(-1, rconst.NUM_EVAL_NEGATIVES))
+
+ users, items, duplicate_mask = self._assemble_eval_batch(
+ users, positive_items, negative_items, self._eval_users_per_batch)
+
+ self._eval_dataset.put(
+ i, {
+ movielens.USER_COLUMN:
+ np.reshape(users.flatten(), (self.eval_batch_size, 1)),
+ movielens.ITEM_COLUMN:
+ np.reshape(items.flatten(), (self.eval_batch_size, 1)),
+ rconst.DUPLICATE_MASK:
+ np.reshape(duplicate_mask.flatten(), (self.eval_batch_size, 1)),
+ })
+
+ def _construct_eval_epoch(self):
+ """Loop to construct data for evaluation."""
+ if self._stop_loop:
+ return
+
+ start_time = timeit.default_timer()
+
+ self._eval_dataset.start_construction()
+ map_args = [i for i in range(self.eval_batches_per_epoch)]
+
+ get_pool = (popen_helper.get_fauxpool if self.deterministic else
+ popen_helper.get_threadpool)
+ with get_pool(6) as pool:
+ pool.map(self._get_eval_batch, map_args)
+ self._eval_dataset.end_construction()
+
+ logging.info("Eval construction complete. Time: {:.1f} seconds".format(
+ timeit.default_timer() - start_time))
+
+ def make_input_fn(self, is_training):
+ # It isn't feasible to provide a foolproof check, so this is designed to
+ # catch most failures rather than provide an exhaustive guard.
+ if self._fatal_exception is not None:
+ raise ValueError("Fatal exception in the data production loop: {}"
+ .format(self._fatal_exception))
+
+ return (
+ self._train_dataset.make_input_fn(self.train_batch_size) if is_training
+ else self._eval_dataset.make_input_fn(self.eval_batch_size))
+
+ def increment_request_epoch(self):
+ self._train_dataset.increment_request_epoch()
+
+
+class DummyConstructor(threading.Thread):
+ """Class for running with synthetic data."""
+
+ def __init__(self, *args, **kwargs):
+ super(DummyConstructor, self).__init__(*args, **kwargs)
+ self.train_batches_per_epoch = rconst.SYNTHETIC_BATCHES_PER_EPOCH
+ self.eval_batches_per_epoch = rconst.SYNTHETIC_BATCHES_PER_EPOCH
+
+ def run(self):
+ pass
+
+ def stop_loop(self):
+ pass
+
+ def increment_request_epoch(self):
+ pass
+
+ @staticmethod
+ def make_input_fn(is_training):
+ """Construct training input_fn that uses synthetic data."""
+
+ def input_fn(params):
+ """Returns dummy input batches for training."""
+
+ # Estimator passes batch_size during training and eval_batch_size during
+ # eval.
+ batch_size = (params["batch_size"] if is_training else
+ params.get("eval_batch_size") or params["batch_size"])
+ num_users = params["num_users"]
+ num_items = params["num_items"]
+
+ users = tf.random.uniform([batch_size, 1],
+ dtype=tf.int32,
+ minval=0,
+ maxval=num_users)
+ items = tf.random.uniform([batch_size, 1],
+ dtype=tf.int32,
+ minval=0,
+ maxval=num_items)
+
+ if is_training:
+ valid_point_mask = tf.cast(
+ tf.random.uniform([batch_size, 1],
+ dtype=tf.int32,
+ minval=0,
+ maxval=2), tf.bool)
+ labels = tf.cast(
+ tf.random.uniform([batch_size, 1],
+ dtype=tf.int32,
+ minval=0,
+ maxval=2), tf.bool)
+ data = {
+ movielens.USER_COLUMN: users,
+ movielens.ITEM_COLUMN: items,
+ rconst.VALID_POINT_MASK: valid_point_mask,
+ }, labels
+ else:
+ dupe_mask = tf.cast(
+ tf.random.uniform([batch_size, 1],
+ dtype=tf.int32,
+ minval=0,
+ maxval=2), tf.bool)
+ data = {
+ movielens.USER_COLUMN: users,
+ movielens.ITEM_COLUMN: items,
+ rconst.DUPLICATE_MASK: dupe_mask,
+ }
+
+ dataset = tf.data.Dataset.from_tensors(data).repeat(
+ rconst.SYNTHETIC_BATCHES_PER_EPOCH * params["batches_per_step"])
+ dataset = dataset.prefetch(32)
+ return dataset
+
+ return input_fn
+
+
+class MaterializedDataConstructor(BaseDataConstructor):
+ """Materialize a table of negative examples for fast negative generation.
+
+ This class creates a table (num_users x num_items) containing all of the
+ negative examples for each user. This table is conceptually ragged; that is to
+ say the items dimension will have a number of unused elements at the end equal
+ to the number of positive elements for a given user. For instance:
+
+ num_users = 3
+ num_items = 5
+ positives = [[1, 3], [0], [1, 2, 3, 4]]
+
+ will generate a negative table:
+ [
+ [0 2 4 int32max int32max],
+ [1 2 3 4 int32max],
+ [0 int32max int32max int32max int32max],
+ ]
+
+ and a vector of per-user negative counts, which in this case would be:
+ [3, 4, 1]
+
+ When sampling negatives, integers are (nearly) uniformly selected from the
+ range [0, per_user_neg_count[user]) which gives a column_index, at which
+ point the negative can be selected as:
+ negative_table[user, column_index]
+
+ This technique will not scale; however MovieLens is small enough that even
+ a pre-compute which is quadratic in problem size will still fit in memory. A
+ more scalable lookup method is in the works.
+ """
+ def __init__(self, *args, **kwargs):
+ super(MaterializedDataConstructor, self).__init__(*args, **kwargs)
+ self._negative_table = None
+ self._per_user_neg_count = None
+
+ def construct_lookup_variables(self):
+ # Materialize negatives for fast lookup sampling.
+ start_time = timeit.default_timer()
+ inner_bounds = np.argwhere(self._train_pos_users[1:] -
+ self._train_pos_users[:-1])[:, 0] + 1
+ (upper_bound,) = self._train_pos_users.shape
+ index_bounds = [0] + inner_bounds.tolist() + [upper_bound]
+ self._negative_table = np.zeros(shape=(self._num_users, self._num_items),
+ dtype=rconst.ITEM_DTYPE)
+
+ # Set the table to the max value to make sure the embedding lookup will fail
+ # if we go out of bounds, rather than just overloading item zero.
+ self._negative_table += np.iinfo(rconst.ITEM_DTYPE).max
+ assert self._num_items < np.iinfo(rconst.ITEM_DTYPE).max
+
+ # Reuse arange during generation. np.delete will make a copy.
+ full_set = np.arange(self._num_items, dtype=rconst.ITEM_DTYPE)
+
+ self._per_user_neg_count = np.zeros(
+ shape=(self._num_users,), dtype=np.int32)
+
+ # Threading does not improve this loop. For some reason, the np.delete
+ # call does not parallelize well. Multiprocessing incurs too much
+ # serialization overhead to be worthwhile.
+ for i in range(self._num_users):
+ positives = self._train_pos_items[index_bounds[i]:index_bounds[i+1]]
+ negatives = np.delete(full_set, positives)
+ self._per_user_neg_count[i] = self._num_items - positives.shape[0]
+ self._negative_table[i, :self._per_user_neg_count[i]] = negatives
+
+ logging.info("Negative sample table built. Time: {:.1f} seconds".format(
+ timeit.default_timer() - start_time))
+
+ def lookup_negative_items(self, negative_users, **kwargs):
+ negative_item_choice = stat_utils.very_slightly_biased_randint(
+ self._per_user_neg_count[negative_users])
+ return self._negative_table[negative_users, negative_item_choice]
+
+
+class BisectionDataConstructor(BaseDataConstructor):
+ """Use bisection to index within positive examples.
+
+ This class tallies the number of negative items which appear before each
+ positive item for a user. This means that in order to select the ith negative
+ item for a user, it only needs to determine which two positive items bound
+ it at which point the item id for the ith negative is a simply algebraic
+ expression.
+ """
+ def __init__(self, *args, **kwargs):
+ super(BisectionDataConstructor, self).__init__(*args, **kwargs)
+ self.index_bounds = None
+ self._sorted_train_pos_items = None
+ self._total_negatives = None
+
+ def _index_segment(self, user):
+ lower, upper = self.index_bounds[user:user+2]
+ items = self._sorted_train_pos_items[lower:upper]
+
+ negatives_since_last_positive = np.concatenate(
+ [items[0][np.newaxis], items[1:] - items[:-1] - 1])
+
+ return np.cumsum(negatives_since_last_positive)
+
+ def construct_lookup_variables(self):
+ start_time = timeit.default_timer()
+ inner_bounds = np.argwhere(self._train_pos_users[1:] -
+ self._train_pos_users[:-1])[:, 0] + 1
+ (upper_bound,) = self._train_pos_users.shape
+ self.index_bounds = np.array([0] + inner_bounds.tolist() + [upper_bound])
+
+ # Later logic will assume that the users are in sequential ascending order.
+ assert np.array_equal(self._train_pos_users[self.index_bounds[:-1]],
+ np.arange(self._num_users))
+
+ self._sorted_train_pos_items = self._train_pos_items.copy()
+
+ for i in range(self._num_users):
+ lower, upper = self.index_bounds[i:i+2]
+ self._sorted_train_pos_items[lower:upper].sort()
+
+ self._total_negatives = np.concatenate([
+ self._index_segment(i) for i in range(self._num_users)])
+
+ logging.info("Negative total vector built. Time: {:.1f} seconds".format(
+ timeit.default_timer() - start_time))
+
+ def lookup_negative_items(self, negative_users, **kwargs):
+ output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1
+
+ left_index = self.index_bounds[negative_users]
+ right_index = self.index_bounds[negative_users + 1] - 1
+
+ num_positives = right_index - left_index + 1
+ num_negatives = self._num_items - num_positives
+ neg_item_choice = stat_utils.very_slightly_biased_randint(num_negatives)
+
+ # Shortcuts:
+ # For points where the negative is greater than or equal to the tally before
+ # the last positive point there is no need to bisect. Instead the item id
+ # corresponding to the negative item choice is simply:
+ # last_postive_index + 1 + (neg_choice - last_negative_tally)
+ # Similarly, if the selection is less than the tally at the first positive
+ # then the item_id is simply the selection.
+ #
+ # Because MovieLens organizes popular movies into low integers (which is
+ # preserved through the preprocessing), the first shortcut is very
+ # efficient, allowing ~60% of samples to bypass the bisection. For the same
+ # reason, the second shortcut is rarely triggered (<0.02%) and is therefore
+ # not worth implementing.
+ use_shortcut = neg_item_choice >= self._total_negatives[right_index]
+ output[use_shortcut] = (
+ self._sorted_train_pos_items[right_index] + 1 +
+ (neg_item_choice - self._total_negatives[right_index])
+ )[use_shortcut]
+
+ if np.all(use_shortcut):
+ # The bisection code is ill-posed when there are no elements.
+ return output
+
+ not_use_shortcut = np.logical_not(use_shortcut)
+ left_index = left_index[not_use_shortcut]
+ right_index = right_index[not_use_shortcut]
+ neg_item_choice = neg_item_choice[not_use_shortcut]
+
+ num_loops = np.max(
+ np.ceil(np.log2(num_positives[not_use_shortcut])).astype(np.int32))
+
+ for i in range(num_loops):
+ mid_index = (left_index + right_index) // 2
+ right_criteria = self._total_negatives[mid_index] > neg_item_choice
+ left_criteria = np.logical_not(right_criteria)
+
+ right_index[right_criteria] = mid_index[right_criteria]
+ left_index[left_criteria] = mid_index[left_criteria]
+
+ # Expected state after bisection pass:
+ # The right index is the smallest index whose tally is greater than the
+ # negative item choice index.
+
+ assert np.all((right_index - left_index) <= 1)
+
+ output[not_use_shortcut] = (
+ self._sorted_train_pos_items[right_index] -
+ (self._total_negatives[right_index] - neg_item_choice)
+ )
+
+ assert np.all(output >= 0)
+
+ return output
+
+
+def get_constructor(name):
+ if name == "bisection":
+ return BisectionDataConstructor
+ if name == "materialized":
+ return MaterializedDataConstructor
+ raise ValueError("Unrecognized constructor: {}".format(name))
diff --git a/models/official/recommendation/data_preprocessing.py b/models/official/recommendation/data_preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7a3f856a7d8de45ff00ff3a0e1a6e6eacadd3a
--- /dev/null
+++ b/models/official/recommendation/data_preprocessing.py
@@ -0,0 +1,265 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Preprocess dataset and construct any necessary artifacts."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+import pickle
+import time
+import timeit
+
+# pylint: disable=wrong-import-order
+from absl import logging
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+import typing
+from typing import Dict, Text, Tuple
+# pylint: enable=wrong-import-order
+
+from official.recommendation import constants as rconst
+from official.recommendation import data_pipeline
+from official.recommendation import movielens
+
+
+_EXPECTED_CACHE_KEYS = (
+ rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY,
+ rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
+
+
+def read_dataframe(
+ raw_rating_path: Text
+) -> Tuple[Dict[int, int], Dict[int, int], pd.DataFrame]:
+ """Read in data CSV, and output DataFrame for downstream processing.
+
+ This function reads in the raw CSV of positive items, and performs three
+ preprocessing transformations:
+
+ 1) Filter out all users who have not rated at least a certain number
+ of items. (Typically 20 items)
+
+ 2) Zero index the users and items such that the largest user_id is
+ `num_users - 1` and the largest item_id is `num_items - 1`
+
+ 3) Sort the dataframe by user_id, with timestamp as a secondary sort key.
+ This allows the dataframe to be sliced by user in-place, and for the last
+ item to be selected simply by calling the `-1` index of a user's slice.
+
+ Args:
+ raw_rating_path: The path to the CSV which contains the raw dataset.
+
+ Returns:
+ A dict mapping raw user IDs to regularized user IDs, a dict mapping raw
+ item IDs to regularized item IDs, and a filtered, zero-index remapped,
+ sorted dataframe.
+ """
+ with tf.io.gfile.GFile(raw_rating_path) as f:
+ df = pd.read_csv(f)
+
+ # Get the info of users who have more than 20 ratings on items
+ grouped = df.groupby(movielens.USER_COLUMN)
+ df = grouped.filter(
+ lambda x: len(x) >= rconst.MIN_NUM_RATINGS) # type: pd.DataFrame
+
+ original_users = df[movielens.USER_COLUMN].unique()
+ original_items = df[movielens.ITEM_COLUMN].unique()
+
+ # Map the ids of user and item to 0 based index for following processing
+ logging.info("Generating user_map and item_map...")
+ user_map = {user: index for index, user in enumerate(original_users)}
+ item_map = {item: index for index, item in enumerate(original_items)}
+
+ df[movielens.USER_COLUMN] = df[movielens.USER_COLUMN].apply(
+ lambda user: user_map[user])
+ df[movielens.ITEM_COLUMN] = df[movielens.ITEM_COLUMN].apply(
+ lambda item: item_map[item])
+
+ num_users = len(original_users)
+ num_items = len(original_items)
+
+ assert num_users <= np.iinfo(rconst.USER_DTYPE).max
+ assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
+ assert df[movielens.USER_COLUMN].max() == num_users - 1
+ assert df[movielens.ITEM_COLUMN].max() == num_items - 1
+
+ # This sort is used to shard the dataframe by user, and later to select
+ # the last item for a user to be used in validation.
+ logging.info("Sorting by user, timestamp...")
+
+ # This sort is equivalent to
+ # df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
+ # inplace=True)
+ # except that the order of items with the same user and timestamp are
+ # sometimes different. For some reason, this sort results in a better
+ # hit-rate during evaluation, matching the performance of the MLPerf
+ # reference implementation.
+ df.sort_values(by=movielens.TIMESTAMP_COLUMN, inplace=True)
+ df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
+ inplace=True,
+ kind="mergesort")
+
+ # The dataframe does not reconstruct indices in the sort or filter steps.
+ return user_map, item_map, df.reset_index()
+
+
+def _filter_index_sort(raw_rating_path: Text,
+ cache_path: Text) -> Tuple[pd.DataFrame, bool]:
+ """Read in data CSV, and output structured data.
+
+ This function reads in the raw CSV of positive items, and performs three
+ preprocessing transformations:
+
+ 1) Filter out all users who have not rated at least a certain number
+ of items. (Typically 20 items)
+
+ 2) Zero index the users and items such that the largest user_id is
+ `num_users - 1` and the largest item_id is `num_items - 1`
+
+ 3) Sort the dataframe by user_id, with timestamp as a secondary sort key.
+ This allows the dataframe to be sliced by user in-place, and for the last
+ item to be selected simply by calling the `-1` index of a user's slice.
+
+ While all of these transformations are performed by Pandas (and are therefore
+ single-threaded), they only take ~2 minutes, and the overhead to apply a
+ MapReduce pattern to parallel process the dataset adds significant complexity
+ for no computational gain. For a larger dataset parallelizing this
+ preprocessing could yield speedups. (Also, this preprocessing step is only
+ performed once for an entire run.
+
+ Args:
+ raw_rating_path: The path to the CSV which contains the raw dataset.
+ cache_path: The path to the file where results of this function are saved.
+
+ Returns:
+ A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user
+ IDs to regularized user IDs, and a dict mapping raw item IDs to regularized
+ item IDs.
+ """
+ valid_cache = tf.io.gfile.exists(cache_path)
+ if valid_cache:
+ with tf.io.gfile.GFile(cache_path, "rb") as f:
+ cached_data = pickle.load(f)
+
+ # (nnigania)disabled this check as the dataset is not expected to change
+ # cache_age = time.time() - cached_data.get("create_time", 0)
+ # if cache_age > rconst.CACHE_INVALIDATION_SEC:
+ # valid_cache = False
+
+ for key in _EXPECTED_CACHE_KEYS:
+ if key not in cached_data:
+ valid_cache = False
+
+ if not valid_cache:
+ logging.info("Removing stale raw data cache file.")
+ tf.io.gfile.remove(cache_path)
+
+ if valid_cache:
+ data = cached_data
+ else:
+ user_map, item_map, df = read_dataframe(raw_rating_path)
+
+ grouped = df.groupby(movielens.USER_COLUMN, group_keys=False)
+ eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])
+
+ data = {
+ rconst.TRAIN_USER_KEY: train_df[movielens.USER_COLUMN]
+ .values.astype(rconst.USER_DTYPE),
+ rconst.TRAIN_ITEM_KEY: train_df[movielens.ITEM_COLUMN]
+ .values.astype(rconst.ITEM_DTYPE),
+ rconst.EVAL_USER_KEY: eval_df[movielens.USER_COLUMN]
+ .values.astype(rconst.USER_DTYPE),
+ rconst.EVAL_ITEM_KEY: eval_df[movielens.ITEM_COLUMN]
+ .values.astype(rconst.ITEM_DTYPE),
+ rconst.USER_MAP: user_map,
+ rconst.ITEM_MAP: item_map,
+ "create_time": time.time(),
+ }
+
+ logging.info("Writing raw data cache.")
+ with tf.io.gfile.GFile(cache_path, "wb") as f:
+ pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
+
+ # TODO(robieta): MLPerf cache clear.
+ return data, valid_cache
+
+
+def instantiate_pipeline(dataset,
+ data_dir,
+ params,
+ constructor_type=None,
+ deterministic=False,
+ epoch_dir=None,
+ generate_data_offline=False):
+ # type: (str, str, dict, typing.Optional[str], bool, typing.Optional[str], bool) -> (int, int, data_pipeline.BaseDataConstructor)
+ """Load and digest data CSV into a usable form.
+
+ Args:
+ dataset: The name of the dataset to be used.
+ data_dir: The root directory of the dataset.
+ params: dict of parameters for the run.
+ constructor_type: The name of the constructor subclass that should be used
+ for the input pipeline.
+ deterministic: Tell the data constructor to produce deterministically.
+ epoch_dir: Directory in which to store the training epochs.
+ generate_data_offline: Boolean, whether current pipeline is done offline
+ or while training.
+ """
+ logging.info("Beginning data preprocessing.")
+
+ st = timeit.default_timer()
+ raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
+ cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE)
+
+ raw_data, _ = _filter_index_sort(raw_rating_path, cache_path)
+ user_map, item_map = raw_data["user_map"], raw_data["item_map"]
+ num_users, num_items = movielens.DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
+
+ if num_users != len(user_map):
+ raise ValueError("Expected to find {} users, but found {}".format(
+ num_users, len(user_map)))
+ if num_items != len(item_map):
+ raise ValueError("Expected to find {} items, but found {}".format(
+ num_items, len(item_map)))
+
+ producer = data_pipeline.get_constructor(constructor_type or "materialized")(
+ maximum_number_epochs=params["train_epochs"],
+ num_users=num_users,
+ num_items=num_items,
+ user_map=user_map,
+ item_map=item_map,
+ train_pos_users=raw_data[rconst.TRAIN_USER_KEY],
+ train_pos_items=raw_data[rconst.TRAIN_ITEM_KEY],
+ train_batch_size=params["batch_size"],
+ batches_per_train_step=params["batches_per_step"],
+ num_train_negatives=params["num_neg"],
+ eval_pos_users=raw_data[rconst.EVAL_USER_KEY],
+ eval_pos_items=raw_data[rconst.EVAL_ITEM_KEY],
+ eval_batch_size=params["eval_batch_size"],
+ batches_per_eval_step=params["batches_per_step"],
+ stream_files=params["stream_files"],
+ deterministic=deterministic,
+ epoch_dir=epoch_dir,
+ create_data_offline=generate_data_offline)
+
+ run_time = timeit.default_timer() - st
+ logging.info("Data preprocessing complete. Time: {:.1f} sec."
+ .format(run_time))
+
+ print(producer)
+ return num_users, num_items, producer
diff --git a/models/official/recommendation/data_test.py b/models/official/recommendation/data_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..9541ee3f8bb4c65fb1f69070fa3876ee51b6c191
--- /dev/null
+++ b/models/official/recommendation/data_test.py
@@ -0,0 +1,355 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test NCF data pipeline."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+import hashlib
+import os
+
+import mock
+import numpy as np
+import scipy.stats
+import tensorflow as tf
+
+from official.recommendation import constants as rconst
+from official.recommendation import data_preprocessing
+from official.recommendation import movielens
+from official.recommendation import popen_helper
+
+
+DATASET = "ml-test"
+NUM_USERS = 1000
+NUM_ITEMS = 2000
+NUM_PTS = 50000
+BATCH_SIZE = 2048
+EVAL_BATCH_SIZE = 4000
+NUM_NEG = 4
+
+
+END_TO_END_TRAIN_MD5 = "b218738e915e825d03939c5e305a2698"
+END_TO_END_EVAL_MD5 = "d753d0f3186831466d6e218163a9501e"
+FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22"
+
+
+def mock_download(*args, **kwargs):
+ return
+
+
+# The forkpool used by data producers interacts badly with the threading
+# used by TestCase. Without this patch tests will hang, and no amount
+# of diligent closing and joining within the producer will prevent it.
+@mock.patch.object(popen_helper, "get_forkpool", popen_helper.get_fauxpool)
+class BaseTest(tf.test.TestCase):
+
+ def setUp(self):
+ tf.compat.v1.disable_eager_execution()
+ self.temp_data_dir = self.get_temp_dir()
+ ratings_folder = os.path.join(self.temp_data_dir, DATASET)
+ tf.io.gfile.makedirs(ratings_folder)
+ np.random.seed(0)
+ raw_user_ids = np.arange(NUM_USERS * 3)
+ np.random.shuffle(raw_user_ids)
+ raw_user_ids = raw_user_ids[:NUM_USERS]
+
+ raw_item_ids = np.arange(NUM_ITEMS * 3)
+ np.random.shuffle(raw_item_ids)
+ raw_item_ids = raw_item_ids[:NUM_ITEMS]
+
+ users = np.random.choice(raw_user_ids, NUM_PTS)
+ items = np.random.choice(raw_item_ids, NUM_PTS)
+ scores = np.random.randint(low=0, high=5, size=NUM_PTS)
+ times = np.random.randint(low=1000000000, high=1200000000, size=NUM_PTS)
+
+ self.rating_file = os.path.join(ratings_folder, movielens.RATINGS_FILE)
+ self.seen_pairs = set()
+ self.holdout = {}
+ with tf.io.gfile.GFile(self.rating_file, "w") as f:
+ f.write("user_id,item_id,rating,timestamp\n")
+ for usr, itm, scr, ts in zip(users, items, scores, times):
+ pair = (usr, itm)
+ if pair in self.seen_pairs:
+ continue
+ self.seen_pairs.add(pair)
+ if usr not in self.holdout or (ts, itm) > self.holdout[usr]:
+ self.holdout[usr] = (ts, itm)
+
+ f.write("{},{},{},{}\n".format(usr, itm, scr, ts))
+
+ movielens.download = mock_download
+ movielens.NUM_RATINGS[DATASET] = NUM_PTS
+ movielens.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS, NUM_ITEMS)
+
+ def make_params(self, train_epochs=1):
+ return {
+ "train_epochs": train_epochs,
+ "batches_per_step": 1,
+ "use_seed": False,
+ "batch_size": BATCH_SIZE,
+ "eval_batch_size": EVAL_BATCH_SIZE,
+ "num_neg": NUM_NEG,
+ "match_mlperf": True,
+ "use_tpu": False,
+ "use_xla_for_gpu": False,
+ "stream_files": False,
+ }
+
+ def test_preprocessing(self):
+ # For the most part the necessary checks are performed within
+ # _filter_index_sort()
+
+ cache_path = os.path.join(self.temp_data_dir, "test_cache.pickle")
+ data, valid_cache = data_preprocessing._filter_index_sort(
+ self.rating_file, cache_path=cache_path)
+
+ assert len(data[rconst.USER_MAP]) == NUM_USERS
+ assert len(data[rconst.ITEM_MAP]) == NUM_ITEMS
+
+ def drain_dataset(self, dataset, g):
+ # type: (tf.data.Dataset, tf.Graph) -> list
+ with self.session(graph=g) as sess:
+ with g.as_default():
+ batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
+ output = []
+ while True:
+ try:
+ output.append(sess.run(batch))
+ except tf.errors.OutOfRangeError:
+ break
+ return output
+
+ def _test_end_to_end(self, constructor_type):
+ params = self.make_params(train_epochs=1)
+ _, _, producer = data_preprocessing.instantiate_pipeline(
+ dataset=DATASET, data_dir=self.temp_data_dir, params=params,
+ constructor_type=constructor_type, deterministic=True)
+
+ producer.start()
+ producer.join()
+ assert producer._fatal_exception is None
+
+ user_inv_map = {v: k for k, v in producer.user_map.items()}
+ item_inv_map = {v: k for k, v in producer.item_map.items()}
+
+ # ==========================================================================
+ # == Training Data =========================================================
+ # ==========================================================================
+ g = tf.Graph()
+ with g.as_default():
+ input_fn = producer.make_input_fn(is_training=True)
+ dataset = input_fn(params)
+
+ first_epoch = self.drain_dataset(dataset=dataset, g=g)
+
+ counts = defaultdict(int)
+ train_examples = {
+ True: set(),
+ False: set(),
+ }
+
+ md5 = hashlib.md5()
+ for features, labels in first_epoch:
+ data_list = [
+ features[movielens.USER_COLUMN].flatten(),
+ features[movielens.ITEM_COLUMN].flatten(),
+ features[rconst.VALID_POINT_MASK].flatten(),
+ labels.flatten()
+ ]
+ for i in data_list:
+ md5.update(i.tobytes())
+
+ for u, i, v, l in zip(*data_list):
+ if not v:
+ continue # ignore padding
+
+ u_raw = user_inv_map[u]
+ i_raw = item_inv_map[i]
+ if ((u_raw, i_raw) in self.seen_pairs) != l:
+ # The evaluation item is not considered during false negative
+ # generation, so it will occasionally appear as a negative example
+ # during training.
+ assert not l
+ self.assertEqual(i_raw, self.holdout[u_raw][1])
+ train_examples[l].add((u_raw, i_raw))
+ counts[(u_raw, i_raw)] += 1
+
+ self.assertRegexpMatches(md5.hexdigest(), END_TO_END_TRAIN_MD5)
+
+ num_positives_seen = len(train_examples[True])
+ self.assertEqual(producer._train_pos_users.shape[0], num_positives_seen)
+
+ # This check is more heuristic because negatives are sampled with
+ # replacement. It only checks that negative generation is reasonably random.
+ self.assertGreater(
+ len(train_examples[False]) / NUM_NEG / num_positives_seen, 0.9)
+
+ # This checks that the samples produced are independent by checking the
+ # number of duplicate entries. If workers are not properly independent there
+ # will be lots of repeated pairs.
+ self.assertLess(np.mean(list(counts.values())), 1.1)
+
+ # ==========================================================================
+ # == Eval Data =============================================================
+ # ==========================================================================
+ with g.as_default():
+ input_fn = producer.make_input_fn(is_training=False)
+ dataset = input_fn(params)
+
+ eval_data = self.drain_dataset(dataset=dataset, g=g)
+
+ current_user = None
+ md5 = hashlib.md5()
+ for features in eval_data:
+ data_list = [
+ features[movielens.USER_COLUMN].flatten(),
+ features[movielens.ITEM_COLUMN].flatten(),
+ features[rconst.DUPLICATE_MASK].flatten()
+ ]
+ for i in data_list:
+ md5.update(i.tobytes())
+
+ for idx, (u, i, d) in enumerate(zip(*data_list)):
+ u_raw = user_inv_map[u]
+ i_raw = item_inv_map[i]
+ if current_user is None:
+ current_user = u
+
+ # Ensure that users appear in blocks, as the evaluation logic expects
+ # this structure.
+ self.assertEqual(u, current_user)
+
+ # The structure of evaluation data is 999 negative examples followed
+ # by the holdout positive.
+ if not (idx + 1) % (rconst.NUM_EVAL_NEGATIVES + 1):
+ # Check that the last element in each chunk is the holdout item.
+ self.assertEqual(i_raw, self.holdout[u_raw][1])
+ current_user = None
+
+ elif i_raw == self.holdout[u_raw][1]:
+ # Because the holdout item is not given to the negative generation
+ # process, it can appear as a negative. In that case, it should be
+ # masked out as a duplicate. (Since the true positive is placed at
+ # the end and would therefore lose the tie.)
+ assert d
+
+ else:
+ # Otherwise check that the other 999 points for a user are selected
+ # from the negatives.
+ assert (u_raw, i_raw) not in self.seen_pairs
+
+ self.assertRegexpMatches(md5.hexdigest(), END_TO_END_EVAL_MD5)
+
+ def _test_fresh_randomness(self, constructor_type):
+ train_epochs = 5
+ params = self.make_params(train_epochs=train_epochs)
+ _, _, producer = data_preprocessing.instantiate_pipeline(
+ dataset=DATASET, data_dir=self.temp_data_dir, params=params,
+ constructor_type=constructor_type, deterministic=True)
+
+ producer.start()
+
+ results = []
+ g = tf.Graph()
+ with g.as_default():
+ for _ in range(train_epochs):
+ input_fn = producer.make_input_fn(is_training=True)
+ dataset = input_fn(params)
+ results.extend(self.drain_dataset(dataset=dataset, g=g))
+
+ producer.join()
+ assert producer._fatal_exception is None
+
+ positive_counts, negative_counts = defaultdict(int), defaultdict(int)
+ md5 = hashlib.md5()
+ for features, labels in results:
+ data_list = [
+ features[movielens.USER_COLUMN].flatten(),
+ features[movielens.ITEM_COLUMN].flatten(),
+ features[rconst.VALID_POINT_MASK].flatten(),
+ labels.flatten()
+ ]
+ for i in data_list:
+ md5.update(i.tobytes())
+
+ for u, i, v, l in zip(*data_list):
+ if not v:
+ continue # ignore padding
+
+ if l:
+ positive_counts[(u, i)] += 1
+ else:
+ negative_counts[(u, i)] += 1
+
+ self.assertRegexpMatches(md5.hexdigest(), FRESH_RANDOMNESS_MD5)
+
+ # The positive examples should appear exactly once each epoch
+ self.assertAllEqual(list(positive_counts.values()),
+ [train_epochs for _ in positive_counts])
+
+ # The threshold for the negatives is heuristic, but in general repeats are
+ # expected, but should not appear too frequently.
+
+ pair_cardinality = NUM_USERS * NUM_ITEMS
+ neg_pair_cardinality = pair_cardinality - len(self.seen_pairs)
+
+ # Approximation for the expectation number of times that a particular
+ # negative will appear in a given epoch. Implicit in this calculation is the
+ # treatment of all negative pairs as equally likely. Normally is not
+ # necessarily reasonable; however the generation in self.setUp() will
+ # approximate this behavior sufficiently for heuristic testing.
+ e_sample = len(self.seen_pairs) * NUM_NEG / neg_pair_cardinality
+
+ # The frequency of occurance of a given negative pair should follow an
+ # approximately binomial distribution in the limit that the cardinality of
+ # the negative pair set >> number of samples per epoch.
+ approx_pdf = scipy.stats.binom.pmf(k=np.arange(train_epochs+1),
+ n=train_epochs, p=e_sample)
+
+ # Tally the actual observed counts.
+ count_distribution = [0 for _ in range(train_epochs + 1)]
+ for i in negative_counts.values():
+ i = min([i, train_epochs]) # round down tail for simplicity.
+ count_distribution[i] += 1
+ count_distribution[0] = neg_pair_cardinality - sum(count_distribution[1:])
+
+ # Check that the frequency of negative pairs is approximately binomial.
+ for i in range(train_epochs + 1):
+ if approx_pdf[i] < 0.05:
+ continue # Variance will be high at the tails.
+
+ observed_fraction = count_distribution[i] / neg_pair_cardinality
+ deviation = (2 * abs(observed_fraction - approx_pdf[i]) /
+ (observed_fraction + approx_pdf[i]))
+
+ self.assertLess(deviation, 0.2)
+
+ def test_end_to_end_materialized(self):
+ self._test_end_to_end("materialized")
+
+ def test_end_to_end_bisection(self):
+ self._test_end_to_end("bisection")
+
+ def test_fresh_randomness_materialized(self):
+ self._test_fresh_randomness("materialized")
+
+ def test_fresh_randomness_bisection(self):
+ self._test_fresh_randomness("bisection")
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/recommendation/movielens.py b/models/official/recommendation/movielens.py
new file mode 100644
index 0000000000000000000000000000000000000000..576519a316bb3e05d786ac737da19cb44d2b61c4
--- /dev/null
+++ b/models/official/recommendation/movielens.py
@@ -0,0 +1,317 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Download and extract the MovieLens dataset from GroupLens website.
+
+Download the dataset, and perform basic preprocessing.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import tempfile
+import zipfile
+
+# pylint: disable=g-bad-import-order
+import numpy as np
+import pandas as pd
+import six
+from six.moves import urllib # pylint: disable=redefined-builtin
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+# pylint: enable=g-bad-import-order
+
+from official.utils.flags import core as flags_core
+
+
+ML_1M = "ml-1m"
+ML_20M = "ml-20m"
+DATASETS = [ML_1M, ML_20M]
+
+RATINGS_FILE = "ratings.csv"
+MOVIES_FILE = "movies.csv"
+
+# URL to download dataset
+_DATA_URL = "http://files.grouplens.org/datasets/movielens/"
+
+GENRE_COLUMN = "genres"
+ITEM_COLUMN = "item_id" # movies
+RATING_COLUMN = "rating"
+TIMESTAMP_COLUMN = "timestamp"
+TITLE_COLUMN = "titles"
+USER_COLUMN = "user_id"
+
+GENRES = [
+ 'Action', 'Adventure', 'Animation', "Children", 'Comedy', 'Crime',
+ 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', "IMAX", 'Musical',
+ 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western'
+]
+N_GENRE = len(GENRES)
+
+RATING_COLUMNS = [USER_COLUMN, ITEM_COLUMN, RATING_COLUMN, TIMESTAMP_COLUMN]
+MOVIE_COLUMNS = [ITEM_COLUMN, TITLE_COLUMN, GENRE_COLUMN]
+
+# Note: Users are indexed [1, k], not [0, k-1]
+NUM_USER_IDS = {
+ ML_1M: 6040,
+ ML_20M: 138493,
+}
+
+# Note: Movies are indexed [1, k], not [0, k-1]
+# Both the 1m and 20m datasets use the same movie set.
+NUM_ITEM_IDS = 3952
+
+MAX_RATING = 5
+
+NUM_RATINGS = {
+ ML_1M: 1000209,
+ ML_20M: 20000263
+}
+
+DATASET_TO_NUM_USERS_AND_ITEMS = {ML_1M: (6040, 3706), ML_20M: (138493, 26744)}
+
+
+def _download_and_clean(dataset, data_dir):
+ """Download MovieLens dataset in a standard format.
+
+ This function downloads the specified MovieLens format and coerces it into a
+ standard format. The only difference between the ml-1m and ml-20m datasets
+ after this point (other than size, of course) is that the 1m dataset uses
+ whole number ratings while the 20m dataset allows half integer ratings.
+ """
+ if dataset not in DATASETS:
+ raise ValueError("dataset {} is not in {{{}}}".format(
+ dataset, ",".join(DATASETS)))
+
+ data_subdir = os.path.join(data_dir, dataset)
+
+ expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE]
+
+ tf.io.gfile.makedirs(data_subdir)
+ if set(expected_files).intersection(
+ tf.io.gfile.listdir(data_subdir)) == set(expected_files):
+ logging.info("Dataset {} has already been downloaded".format(dataset))
+ return
+
+ url = "{}{}.zip".format(_DATA_URL, dataset)
+
+ temp_dir = tempfile.mkdtemp()
+ try:
+ zip_path = os.path.join(temp_dir, "{}.zip".format(dataset))
+ zip_path, _ = urllib.request.urlretrieve(url, zip_path)
+ statinfo = os.stat(zip_path)
+ # A new line to clear the carriage return from download progress
+ # logging.info is not applicable here
+ print()
+ logging.info(
+ "Successfully downloaded {} {} bytes".format(
+ zip_path, statinfo.st_size))
+
+ zipfile.ZipFile(zip_path, "r").extractall(temp_dir)
+
+ if dataset == ML_1M:
+ _regularize_1m_dataset(temp_dir)
+ else:
+ _regularize_20m_dataset(temp_dir)
+
+ for fname in tf.io.gfile.listdir(temp_dir):
+ if not tf.io.gfile.exists(os.path.join(data_subdir, fname)):
+ tf.io.gfile.copy(os.path.join(temp_dir, fname),
+ os.path.join(data_subdir, fname))
+ else:
+ logging.info("Skipping copy of {}, as it already exists in the "
+ "destination folder.".format(fname))
+
+ finally:
+ tf.io.gfile.rmtree(temp_dir)
+
+
+def _transform_csv(input_path, output_path, names, skip_first, separator=","):
+ """Transform csv to a regularized format.
+
+ Args:
+ input_path: The path of the raw csv.
+ output_path: The path of the cleaned csv.
+ names: The csv column names.
+ skip_first: Boolean of whether to skip the first line of the raw csv.
+ separator: Character used to separate fields in the raw csv.
+ """
+ if six.PY2:
+ names = [six.ensure_text(n, "utf-8") for n in names]
+
+ with tf.io.gfile.GFile(output_path, "wb") as f_out, \
+ tf.io.gfile.GFile(input_path, "rb") as f_in:
+
+ # Write column names to the csv.
+ f_out.write(",".join(names).encode("utf-8"))
+ f_out.write(b"\n")
+ for i, line in enumerate(f_in):
+ if i == 0 and skip_first:
+ continue # ignore existing labels in the csv
+
+ line = six.ensure_text(line, "utf-8", errors="ignore")
+ fields = line.split(separator)
+ if separator != ",":
+ fields = ['"{}"'.format(field) if "," in field else field
+ for field in fields]
+ f_out.write(",".join(fields).encode("utf-8"))
+
+
+def _regularize_1m_dataset(temp_dir):
+ """
+ ratings.dat
+ The file has no header row, and each line is in the following format:
+ UserID::MovieID::Rating::Timestamp
+ - UserIDs range from 1 and 6040
+ - MovieIDs range from 1 and 3952
+ - Ratings are made on a 5-star scale (whole-star ratings only)
+ - Timestamp is represented in seconds since midnight Coordinated Universal
+ Time (UTC) of January 1, 1970.
+ - Each user has at least 20 ratings
+
+ movies.dat
+ Each line has the following format:
+ MovieID::Title::Genres
+ - MovieIDs range from 1 and 3952
+ """
+ working_dir = os.path.join(temp_dir, ML_1M)
+
+ _transform_csv(
+ input_path=os.path.join(working_dir, "ratings.dat"),
+ output_path=os.path.join(temp_dir, RATINGS_FILE),
+ names=RATING_COLUMNS, skip_first=False, separator="::")
+
+ _transform_csv(
+ input_path=os.path.join(working_dir, "movies.dat"),
+ output_path=os.path.join(temp_dir, MOVIES_FILE),
+ names=MOVIE_COLUMNS, skip_first=False, separator="::")
+
+ tf.io.gfile.rmtree(working_dir)
+
+
+def _regularize_20m_dataset(temp_dir):
+ """
+ ratings.csv
+ Each line of this file after the header row represents one rating of one
+ movie by one user, and has the following format:
+ userId,movieId,rating,timestamp
+ - The lines within this file are ordered first by userId, then, within user,
+ by movieId.
+ - Ratings are made on a 5-star scale, with half-star increments
+ (0.5 stars - 5.0 stars).
+ - Timestamps represent seconds since midnight Coordinated Universal Time
+ (UTC) of January 1, 1970.
+ - All the users had rated at least 20 movies.
+
+ movies.csv
+ Each line has the following format:
+ MovieID,Title,Genres
+ - MovieIDs range from 1 and 3952
+ """
+ working_dir = os.path.join(temp_dir, ML_20M)
+
+ _transform_csv(
+ input_path=os.path.join(working_dir, "ratings.csv"),
+ output_path=os.path.join(temp_dir, RATINGS_FILE),
+ names=RATING_COLUMNS, skip_first=True, separator=",")
+
+ _transform_csv(
+ input_path=os.path.join(working_dir, "movies.csv"),
+ output_path=os.path.join(temp_dir, MOVIES_FILE),
+ names=MOVIE_COLUMNS, skip_first=True, separator=",")
+
+ tf.io.gfile.rmtree(working_dir)
+
+
+def download(dataset, data_dir):
+ if dataset:
+ _download_and_clean(dataset, data_dir)
+ else:
+ _ = [_download_and_clean(d, data_dir) for d in DATASETS]
+
+
+def ratings_csv_to_dataframe(data_dir, dataset):
+ with tf.io.gfile.GFile(os.path.join(data_dir, dataset, RATINGS_FILE)) as f:
+ return pd.read_csv(f, encoding="utf-8")
+
+
+def csv_to_joint_dataframe(data_dir, dataset):
+ ratings = ratings_csv_to_dataframe(data_dir, dataset)
+
+ with tf.io.gfile.GFile(os.path.join(data_dir, dataset, MOVIES_FILE)) as f:
+ movies = pd.read_csv(f, encoding="utf-8")
+
+ df = ratings.merge(movies, on=ITEM_COLUMN)
+ df[RATING_COLUMN] = df[RATING_COLUMN].astype(np.float32)
+
+ return df
+
+
+def integerize_genres(dataframe):
+ """Replace genre string with a binary vector.
+
+ Args:
+ dataframe: a pandas dataframe of movie data.
+
+ Returns:
+ The transformed dataframe.
+ """
+ def _map_fn(entry):
+ entry.replace("Children's", "Children") # naming difference.
+ movie_genres = entry.split("|")
+ output = np.zeros((len(GENRES),), dtype=np.int64)
+ for i, genre in enumerate(GENRES):
+ if genre in movie_genres:
+ output[i] = 1
+ return output
+
+ dataframe[GENRE_COLUMN] = dataframe[GENRE_COLUMN].apply(_map_fn)
+
+ return dataframe
+
+
+def define_flags():
+ """Add flags specifying data usage arguments."""
+ flags.DEFINE_enum(
+ name="dataset",
+ default=None,
+ enum_values=DATASETS,
+ case_sensitive=False,
+ help=flags_core.help_wrap("Dataset to be trained and evaluated."))
+
+
+def define_data_download_flags():
+ """Add flags specifying data download and usage arguments."""
+ flags.DEFINE_string(
+ name="data_dir", default="/tmp/movielens-data/",
+ help=flags_core.help_wrap(
+ "Directory to download and extract data."))
+
+ define_flags()
+
+
+def main(_):
+ """Download and extract the data from GroupLens website."""
+ download(flags.FLAGS.dataset, flags.FLAGS.data_dir)
+
+
+if __name__ == "__main__":
+ define_data_download_flags()
+ FLAGS = flags.FLAGS
+ app.run(main)
diff --git a/models/official/recommendation/ncf_common.py b/models/official/recommendation/ncf_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..8abc927bfa29c52d6c151023d281d7e4f6f52100
--- /dev/null
+++ b/models/official/recommendation/ncf_common.py
@@ -0,0 +1,327 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common functionalities used by both Keras and Estimator implementations.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+
+# pylint: disable=g-bad-import-order
+import numpy as np
+from absl import flags
+from absl import logging
+import tensorflow as tf
+# pylint: enable=g-bad-import-order
+
+from official.recommendation import constants as rconst
+from official.recommendation import data_pipeline
+from official.recommendation import data_preprocessing
+from official.recommendation import movielens
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+
+
+def get_inputs(params):
+ """Returns some parameters used by the model."""
+ if FLAGS.download_if_missing and not FLAGS.use_synthetic_data:
+ movielens.download(FLAGS.dataset, FLAGS.data_dir)
+
+ if FLAGS.seed is not None:
+ np.random.seed(FLAGS.seed)
+
+ if FLAGS.use_synthetic_data:
+ producer = data_pipeline.DummyConstructor()
+ num_users, num_items = movielens.DATASET_TO_NUM_USERS_AND_ITEMS[
+ FLAGS.dataset]
+ num_train_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
+ num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
+ else:
+ num_users, num_items, producer = data_preprocessing.instantiate_pipeline(
+ dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, params=params,
+ constructor_type=FLAGS.constructor_type,
+ deterministic=FLAGS.seed is not None)
+ num_train_steps = producer.train_batches_per_epoch
+ num_eval_steps = producer.eval_batches_per_epoch
+
+ return num_users, num_items, num_train_steps, num_eval_steps, producer
+
+
+def parse_flags(flags_obj):
+ """Convenience function to turn flags into params."""
+ num_gpus = flags_core.get_num_gpus(flags_obj)
+
+ batch_size = flags_obj.batch_size
+ eval_batch_size = flags_obj.eval_batch_size or flags_obj.batch_size
+
+ return {
+ "train_epochs": flags_obj.train_epochs,
+ "batches_per_step": 1,
+ "use_seed": flags_obj.seed is not None,
+ "batch_size": batch_size,
+ "eval_batch_size": eval_batch_size,
+ "learning_rate": flags_obj.learning_rate,
+ "mf_dim": flags_obj.num_factors,
+ "model_layers": [int(layer) for layer in flags_obj.layers],
+ "mf_regularization": flags_obj.mf_regularization,
+ "mlp_reg_layers": [float(reg) for reg in flags_obj.mlp_regularization],
+ "num_neg": flags_obj.num_neg,
+ "distribution_strategy": flags_obj.distribution_strategy,
+ "num_gpus": num_gpus,
+ "use_tpu": flags_obj.tpu is not None,
+ "tpu": flags_obj.tpu,
+ "tpu_zone": flags_obj.tpu_zone,
+ "tpu_gcp_project": flags_obj.tpu_gcp_project,
+ "beta1": flags_obj.beta1,
+ "beta2": flags_obj.beta2,
+ "epsilon": flags_obj.epsilon,
+ "match_mlperf": flags_obj.ml_perf,
+ "epochs_between_evals": FLAGS.epochs_between_evals,
+ "keras_use_ctl": flags_obj.keras_use_ctl,
+ "hr_threshold": flags_obj.hr_threshold,
+ "stream_files": flags_obj.tpu is not None,
+ "train_dataset_path": flags_obj.train_dataset_path,
+ "eval_dataset_path": flags_obj.eval_dataset_path,
+ "input_meta_data_path": flags_obj.input_meta_data_path,
+ }
+
+
+def get_v1_distribution_strategy(params):
+ """Returns the distribution strategy to use."""
+ if params["use_tpu"]:
+ # Some of the networking libraries are quite chatty.
+ for name in ["googleapiclient.discovery", "googleapiclient.discovery_cache",
+ "oauth2client.transport"]:
+ logging.getLogger(name).setLevel(logging.ERROR)
+
+ tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
+ tpu=params["tpu"],
+ zone=params["tpu_zone"],
+ project=params["tpu_gcp_project"],
+ coordinator_name="coordinator"
+ )
+
+ logging.info("Issuing reset command to TPU to ensure a clean state.")
+ tf.Session.reset(tpu_cluster_resolver.get_master())
+
+ # Estimator looks at the master it connects to for MonitoredTrainingSession
+ # by reading the `TF_CONFIG` environment variable, and the coordinator
+ # is used by StreamingFilesDataset.
+ tf_config_env = {
+ "session_master": tpu_cluster_resolver.get_master(),
+ "eval_session_master": tpu_cluster_resolver.get_master(),
+ "coordinator": tpu_cluster_resolver.cluster_spec()
+ .as_dict()["coordinator"]
+ }
+ os.environ["TF_CONFIG"] = json.dumps(tf_config_env)
+
+ distribution = tf.distribute.experimental.TPUStrategy(
+ tpu_cluster_resolver, steps_per_run=100)
+
+ else:
+ distribution = distribution_utils.get_distribution_strategy(
+ num_gpus=params["num_gpus"])
+
+ return distribution
+
+
+def define_ncf_flags():
+ """Add flags for running ncf_main."""
+ # Add common flags
+ flags_core.define_base(model_dir=True, clean=True, train_epochs=True,
+ epochs_between_evals=True, export_dir=False,
+ run_eagerly=True, stop_threshold=True, num_gpu=True,
+ distribution_strategy=True)
+ flags_core.define_performance(
+ synthetic_data=True,
+ dtype=True,
+ fp16_implementation=True,
+ loss_scale=True,
+ dynamic_loss_scale=True,
+ enable_xla=True,
+ )
+ flags_core.define_device(tpu=True)
+ flags_core.define_benchmark()
+
+ flags.adopt_module_key_flags(flags_core)
+
+ movielens.define_flags()
+
+ flags_core.set_defaults(
+ model_dir="/tmp/ncf/",
+ data_dir="/tmp/movielens-data/",
+ dataset=movielens.ML_1M,
+ train_epochs=2,
+ batch_size=99000,
+ tpu=None
+ )
+
+ # Add ncf-specific flags
+ flags.DEFINE_boolean(
+ name="download_if_missing", default=True, help=flags_core.help_wrap(
+ "Download data to data_dir if it is not already present."))
+
+ flags.DEFINE_integer(
+ name="eval_batch_size", default=None, help=flags_core.help_wrap(
+ "The batch size used for evaluation. This should generally be larger"
+ "than the training batch size as the lack of back propagation during"
+ "evaluation can allow for larger batch sizes to fit in memory. If not"
+ "specified, the training batch size (--batch_size) will be used."))
+
+ flags.DEFINE_integer(
+ name="num_factors", default=8,
+ help=flags_core.help_wrap("The Embedding size of MF model."))
+
+ # Set the default as a list of strings to be consistent with input arguments
+ flags.DEFINE_list(
+ name="layers", default=["64", "32", "16", "8"],
+ help=flags_core.help_wrap(
+ "The sizes of hidden layers for MLP. Example "
+ "to specify different sizes of MLP layers: --layers=32,16,8,4"))
+
+ flags.DEFINE_float(
+ name="mf_regularization", default=0.,
+ help=flags_core.help_wrap(
+ "The regularization factor for MF embeddings. The factor is used by "
+ "regularizer which allows to apply penalties on layer parameters or "
+ "layer activity during optimization."))
+
+ flags.DEFINE_list(
+ name="mlp_regularization", default=["0.", "0.", "0.", "0."],
+ help=flags_core.help_wrap(
+ "The regularization factor for each MLP layer. See mf_regularization "
+ "help for more info about regularization factor."))
+
+ flags.DEFINE_integer(
+ name="num_neg", default=4,
+ help=flags_core.help_wrap(
+ "The Number of negative instances to pair with a positive instance."))
+
+ flags.DEFINE_float(
+ name="learning_rate", default=0.001,
+ help=flags_core.help_wrap("The learning rate."))
+
+ flags.DEFINE_float(
+ name="beta1", default=0.9,
+ help=flags_core.help_wrap("beta1 hyperparameter for the Adam optimizer."))
+
+ flags.DEFINE_float(
+ name="beta2", default=0.999,
+ help=flags_core.help_wrap("beta2 hyperparameter for the Adam optimizer."))
+
+ flags.DEFINE_float(
+ name="epsilon", default=1e-8,
+ help=flags_core.help_wrap("epsilon hyperparameter for the Adam "
+ "optimizer."))
+
+ flags.DEFINE_float(
+ name="hr_threshold", default=1.0,
+ help=flags_core.help_wrap(
+ "If passed, training will stop when the evaluation metric HR is "
+ "greater than or equal to hr_threshold. For dataset ml-1m, the "
+ "desired hr_threshold is 0.68 which is the result from the paper; "
+ "For dataset ml-20m, the threshold can be set as 0.95 which is "
+ "achieved by MLPerf implementation."))
+
+ flags.DEFINE_enum(
+ name="constructor_type", default="bisection",
+ enum_values=["bisection", "materialized"], case_sensitive=False,
+ help=flags_core.help_wrap(
+ "Strategy to use for generating false negatives. materialized has a"
+ "precompute that scales badly, but a faster per-epoch construction"
+ "time and can be faster on very large systems."))
+
+ flags.DEFINE_string(
+ name="train_dataset_path",
+ default=None,
+ help=flags_core.help_wrap("Path to training data."))
+
+ flags.DEFINE_string(
+ name="eval_dataset_path",
+ default=None,
+ help=flags_core.help_wrap("Path to evaluation data."))
+
+ flags.DEFINE_string(
+ name="input_meta_data_path",
+ default=None,
+ help=flags_core.help_wrap("Path to input meta data file."))
+
+ flags.DEFINE_bool(
+ name="ml_perf", default=False,
+ help=flags_core.help_wrap(
+ "If set, changes the behavior of the model slightly to match the "
+ "MLPerf reference implementations here: \n"
+ "https://github.com/mlperf/reference/tree/master/recommendation/"
+ "pytorch\n"
+ "The two changes are:\n"
+ "1. When computing the HR and NDCG during evaluation, remove "
+ "duplicate user-item pairs before the computation. This results in "
+ "better HRs and NDCGs.\n"
+ "2. Use a different soring algorithm when sorting the input data, "
+ "which performs better due to the fact the sorting algorithms are "
+ "not stable."))
+
+ flags.DEFINE_bool(
+ name="output_ml_perf_compliance_logging", default=False,
+ help=flags_core.help_wrap(
+ "If set, output the MLPerf compliance logging. This is only useful "
+ "if one is running the model for MLPerf. See "
+ "https://github.com/mlperf/policies/blob/master/training_rules.adoc"
+ "#submission-compliance-logs for details. This uses sudo and so may "
+ "ask for your password, as root access is needed to clear the system "
+ "caches, which is required for MLPerf compliance."
+ )
+ )
+
+ flags.DEFINE_integer(
+ name="seed", default=None, help=flags_core.help_wrap(
+ "This value will be used to seed both NumPy and TensorFlow."))
+
+ @flags.validator("eval_batch_size", "eval_batch_size must be at least {}"
+ .format(rconst.NUM_EVAL_NEGATIVES + 1))
+ def eval_size_check(eval_batch_size):
+ return (eval_batch_size is None or
+ int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES)
+
+ flags.DEFINE_bool(
+ name="early_stopping",
+ default=False,
+ help=flags_core.help_wrap(
+ "If True, we stop the training when it reaches hr_threshold"))
+
+ flags.DEFINE_bool(
+ name="keras_use_ctl",
+ default=False,
+ help=flags_core.help_wrap(
+ "If True, we use a custom training loop for keras."))
+
+
+def convert_to_softmax_logits(logits):
+ """Convert the logits returned by the base model to softmax logits.
+
+ Args:
+ logits: used to create softmax.
+
+ Returns:
+ Softmax with the first column of zeros is equivalent to sigmoid.
+ """
+ softmax_logits = tf.concat([logits * 0, logits], axis=1)
+ return softmax_logits
diff --git a/models/official/recommendation/ncf_input_pipeline.py b/models/official/recommendation/ncf_input_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dab86c43bfde14eb5adfc82e52b30b315060217
--- /dev/null
+++ b/models/official/recommendation/ncf_input_pipeline.py
@@ -0,0 +1,200 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""NCF model input pipeline."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+# pylint: disable=g-bad-import-order
+import tensorflow.compat.v2 as tf
+# pylint: enable=g-bad-import-order
+
+from official.recommendation import constants as rconst
+from official.recommendation import movielens
+from official.recommendation import data_pipeline
+
+NUM_SHARDS = 16
+
+
+def create_dataset_from_tf_record_files(input_file_pattern,
+ pre_batch_size,
+ batch_size,
+ is_training=True,
+ rebatch=False):
+ """Creates dataset from (tf)records files for training/evaluation."""
+
+ files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training)
+
+ def make_dataset(files_dataset, shard_index):
+ """Returns dataset for sharded tf record files."""
+ if pre_batch_size != batch_size:
+ raise ValueError("Pre-batch ({}) size is not equal to batch "
+ "size ({})".format(pre_batch_size, batch_size))
+ files_dataset = files_dataset.shard(NUM_SHARDS, shard_index)
+ dataset = files_dataset.interleave(
+ tf.data.TFRecordDataset,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ decode_fn = functools.partial(
+ data_pipeline.DatasetManager.deserialize,
+ batch_size=pre_batch_size,
+ is_training=is_training)
+ dataset = dataset.map(
+ decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ dataset = tf.data.Dataset.range(NUM_SHARDS)
+ map_fn = functools.partial(make_dataset, files)
+ dataset = dataset.interleave(
+ map_fn,
+ cycle_length=NUM_SHARDS,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if rebatch:
+ # A workaround for TPU Pod evaluation dataset.
+ # TODO (b/162341937) remove once it's fixed.
+ dataset = dataset.unbatch()
+ dataset = dataset.batch(pre_batch_size)
+
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def create_dataset_from_data_producer(producer, params):
+ """Return dataset online-generating data."""
+
+ def preprocess_train_input(features, labels):
+ """Pre-process the training data.
+
+ This is needed because
+ - The label needs to be extended to be used in the loss fn
+ - We need the same inputs for training and eval so adding fake inputs
+ for DUPLICATE_MASK in training data.
+
+ Args:
+ features: Dictionary of features for training.
+ labels: Training labels.
+
+ Returns:
+ Processed training features.
+ """
+ fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN])
+ features[rconst.DUPLICATE_MASK] = fake_dup_mask
+ features[rconst.TRAIN_LABEL_KEY] = labels
+ return features
+
+ train_input_fn = producer.make_input_fn(is_training=True)
+ train_input_dataset = train_input_fn(params).map(preprocess_train_input)
+
+ def preprocess_eval_input(features):
+ """Pre-process the eval data.
+
+ This is needed because:
+ - The label needs to be extended to be used in the loss fn
+ - We need the same inputs for training and eval so adding fake inputs
+ for VALID_PT_MASK in eval data.
+
+ Args:
+ features: Dictionary of features for evaluation.
+
+ Returns:
+ Processed evaluation features.
+ """
+ labels = tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
+ fake_valid_pt_mask = tf.cast(
+ tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
+ features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask
+ features[rconst.TRAIN_LABEL_KEY] = labels
+ return features
+
+ eval_input_fn = producer.make_input_fn(is_training=False)
+ eval_input_dataset = eval_input_fn(params).map(preprocess_eval_input)
+
+ return train_input_dataset, eval_input_dataset
+
+
+def create_ncf_input_data(params,
+ producer=None,
+ input_meta_data=None,
+ strategy=None):
+ """Creates NCF training/evaluation dataset.
+
+ Args:
+ params: Dictionary containing parameters for train/evaluation data.
+ producer: Instance of BaseDataConstructor that generates data online. Must
+ not be None when params['train_dataset_path'] or
+ params['eval_dataset_path'] is not specified.
+ input_meta_data: A dictionary of input metadata to be used when reading data
+ from tf record files. Must be specified when params["train_input_dataset"]
+ is specified.
+ strategy: Distribution strategy used for distributed training. If specified,
+ used to assert that evaluation batch size is correctly a multiple of
+ total number of devices used.
+
+ Returns:
+ (training dataset, evaluation dataset, train steps per epoch,
+ eval steps per epoch)
+
+ Raises:
+ ValueError: If data is being generated online for when using TPU's.
+ """
+ # NCF evaluation metric calculation logic assumes that evaluation data
+ # sample size are in multiples of (1 + number of negative samples in
+ # evaluation) for each device. As so, evaluation batch size must be a
+ # multiple of (number of replicas * (1 + number of negative samples)).
+ num_devices = strategy.num_replicas_in_sync if strategy else 1
+ if (params["eval_batch_size"] % (num_devices *
+ (1 + rconst.NUM_EVAL_NEGATIVES))):
+ raise ValueError("Evaluation batch size must be divisible by {} "
+ "times {}".format(num_devices,
+ (1 + rconst.NUM_EVAL_NEGATIVES)))
+
+ if params["train_dataset_path"]:
+ assert params["eval_dataset_path"]
+
+ train_dataset = create_dataset_from_tf_record_files(
+ params["train_dataset_path"],
+ input_meta_data["train_prebatch_size"],
+ params["batch_size"],
+ is_training=True,
+ rebatch=False)
+
+ # Re-batch evaluation dataset for TPU Pods.
+ # TODO (b/162341937) remove once it's fixed.
+ eval_rebatch = (params["use_tpu"] and strategy.num_replicas_in_sync > 8)
+ eval_dataset = create_dataset_from_tf_record_files(
+ params["eval_dataset_path"],
+ input_meta_data["eval_prebatch_size"],
+ params["eval_batch_size"],
+ is_training=False,
+ rebatch=eval_rebatch)
+
+ num_train_steps = int(input_meta_data["num_train_steps"])
+ num_eval_steps = int(input_meta_data["num_eval_steps"])
+ else:
+ if params["use_tpu"]:
+ raise ValueError("TPU training does not support data producer yet. "
+ "Use pre-processed data.")
+
+ assert producer
+ # Start retrieving data from producer.
+ train_dataset, eval_dataset = create_dataset_from_data_producer(
+ producer, params)
+ num_train_steps = producer.train_batches_per_epoch
+ num_eval_steps = producer.eval_batches_per_epoch
+
+ return train_dataset, eval_dataset, num_train_steps, num_eval_steps
diff --git a/models/official/recommendation/ncf_keras_main.py b/models/official/recommendation/ncf_keras_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..c850539d4bf24e159cbf04a2c029c1e2bf4d5c26
--- /dev/null
+++ b/models/official/recommendation/ncf_keras_main.py
@@ -0,0 +1,567 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""NCF framework to train and evaluate the NeuMF model.
+
+The NeuMF model assembles both MF and MLP models under the NCF framework. Check
+`neumf_model.py` for more details about the models.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+
+# pylint: disable=g-bad-import-order
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow.compat.v2 as tf
+# pylint: enable=g-bad-import-order
+
+from official.recommendation import constants as rconst
+from official.recommendation import movielens
+from official.recommendation import ncf_common
+from official.recommendation import ncf_input_pipeline
+from official.recommendation import neumf_model
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+from official.utils.misc import model_helpers
+
+
+FLAGS = flags.FLAGS
+
+
+def metric_fn(logits, dup_mask, match_mlperf):
+ dup_mask = tf.cast(dup_mask, tf.float32)
+ logits = tf.slice(logits, [0, 1], [-1, -1])
+ in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
+ logits,
+ dup_mask,
+ match_mlperf)
+ metric_weights = tf.cast(metric_weights, tf.float32)
+ return in_top_k, metric_weights
+
+
+class MetricLayer(tf.keras.layers.Layer):
+ """Custom layer of metrics for NCF model."""
+
+ def __init__(self, match_mlperf):
+ super(MetricLayer, self).__init__()
+ self.match_mlperf = match_mlperf
+
+ def get_config(self):
+ return {"match_mlperf": self.match_mlperf}
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
+
+ def call(self, inputs, training=False):
+ logits, dup_mask = inputs
+
+ if training:
+ hr_sum = 0.0
+ hr_count = 0.0
+ else:
+ metric, metric_weights = metric_fn(logits, dup_mask, self.match_mlperf)
+ hr_sum = tf.reduce_sum(metric * metric_weights)
+ hr_count = tf.reduce_sum(metric_weights)
+
+ self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
+ self.add_metric(hr_count, name="hr_count", aggregation="mean")
+ return logits
+
+
+class LossLayer(tf.keras.layers.Layer):
+ """Pass-through loss layer for NCF model."""
+
+ def __init__(self, loss_normalization_factor):
+ # The loss may overflow in float16, so we use float32 instead.
+ super(LossLayer, self).__init__(dtype="float32")
+ self.loss_normalization_factor = loss_normalization_factor
+ self.loss = tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True, reduction="sum")
+
+ def get_config(self):
+ return {"loss_normalization_factor": self.loss_normalization_factor}
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
+
+ def call(self, inputs):
+ logits, labels, valid_pt_mask_input = inputs
+ loss = self.loss(
+ y_true=labels, y_pred=logits, sample_weight=valid_pt_mask_input)
+ loss = loss * (1.0 / self.loss_normalization_factor)
+ self.add_loss(loss)
+ return logits
+
+
+class IncrementEpochCallback(tf.keras.callbacks.Callback):
+ """A callback to increase the requested epoch for the data producer.
+
+ The reason why we need this is because we can only buffer a limited amount of
+ data. So we keep a moving window to represent the buffer. This is to move the
+ one of the window's boundaries for each epoch.
+ """
+
+ def __init__(self, producer):
+ self._producer = producer
+
+ def on_epoch_begin(self, epoch, logs=None):
+ self._producer.increment_request_epoch()
+
+
+class CustomEarlyStopping(tf.keras.callbacks.Callback):
+ """Stop training has reached a desired hit rate."""
+
+ def __init__(self, monitor, desired_value):
+ super(CustomEarlyStopping, self).__init__()
+
+ self.monitor = monitor
+ self.desired = desired_value
+ self.stopped_epoch = 0
+
+ def on_epoch_end(self, epoch, logs=None):
+ current = self.get_monitor_value(logs)
+ if current and current >= self.desired:
+ self.stopped_epoch = epoch
+ self.model.stop_training = True
+
+ def on_train_end(self, logs=None):
+ if self.stopped_epoch > 0:
+ print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
+
+ def get_monitor_value(self, logs):
+ logs = logs or {}
+ monitor_value = logs.get(self.monitor)
+ if monitor_value is None:
+ logging.warning("Early stopping conditioned on metric `%s` "
+ "which is not available. Available metrics are: %s",
+ self.monitor, ",".join(list(logs.keys())))
+ return monitor_value
+
+
+def _get_keras_model(params):
+ """Constructs and returns the model."""
+ batch_size = params["batch_size"]
+
+ user_input = tf.keras.layers.Input(
+ shape=(1,), name=movielens.USER_COLUMN, dtype=tf.int32)
+
+ item_input = tf.keras.layers.Input(
+ shape=(1,), name=movielens.ITEM_COLUMN, dtype=tf.int32)
+
+ valid_pt_mask_input = tf.keras.layers.Input(
+ shape=(1,), name=rconst.VALID_POINT_MASK, dtype=tf.bool)
+
+ dup_mask_input = tf.keras.layers.Input(
+ shape=(1,), name=rconst.DUPLICATE_MASK, dtype=tf.int32)
+
+ label_input = tf.keras.layers.Input(
+ shape=(1,), name=rconst.TRAIN_LABEL_KEY, dtype=tf.bool)
+
+ base_model = neumf_model.construct_model(user_input, item_input, params)
+
+ logits = base_model.output
+
+ zeros = tf.keras.layers.Lambda(
+ lambda x: x * 0)(logits)
+
+ softmax_logits = tf.keras.layers.concatenate(
+ [zeros, logits],
+ axis=-1)
+
+ # Custom training loop calculates loss and metric as a part of
+ # training/evaluation step function.
+ if not params["keras_use_ctl"]:
+ softmax_logits = MetricLayer(
+ params["match_mlperf"])([softmax_logits, dup_mask_input])
+ # TODO(b/134744680): Use model.add_loss() instead once the API is well
+ # supported.
+ softmax_logits = LossLayer(batch_size)(
+ [softmax_logits, label_input, valid_pt_mask_input])
+
+ keras_model = tf.keras.Model(
+ inputs={
+ movielens.USER_COLUMN: user_input,
+ movielens.ITEM_COLUMN: item_input,
+ rconst.VALID_POINT_MASK: valid_pt_mask_input,
+ rconst.DUPLICATE_MASK: dup_mask_input,
+ rconst.TRAIN_LABEL_KEY: label_input},
+ outputs=softmax_logits)
+
+ keras_model.summary()
+ return keras_model
+
+
+def run_ncf(_):
+ """Run NCF training and eval with Keras."""
+
+ keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)
+
+ if FLAGS.seed is not None:
+ print("Setting tf seed")
+ tf.random.set_seed(FLAGS.seed)
+
+ model_helpers.apply_clean(FLAGS)
+
+ if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras":
+ policy = tf.keras.mixed_precision.experimental.Policy(
+ "mixed_float16",
+ loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic"))
+ tf.keras.mixed_precision.experimental.set_policy(policy)
+
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ tpu_address=FLAGS.tpu)
+
+ params = ncf_common.parse_flags(FLAGS)
+ params["distribute_strategy"] = strategy
+ params["use_tpu"] = (FLAGS.distribution_strategy == "tpu")
+
+ if params["use_tpu"] and not params["keras_use_ctl"]:
+ logging.error("Custom training loop must be used when using TPUStrategy.")
+ return
+
+ batch_size = params["batch_size"]
+ time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
+ callbacks = [time_callback]
+
+ producer, input_meta_data = None, None
+ generate_input_online = params["train_dataset_path"] is None
+
+ if generate_input_online:
+ # Start data producing thread.
+ num_users, num_items, _, _, producer = ncf_common.get_inputs(params)
+ producer.start()
+ per_epoch_callback = IncrementEpochCallback(producer)
+ callbacks.append(per_epoch_callback)
+ else:
+ assert params["eval_dataset_path"] and params["input_meta_data_path"]
+ with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader:
+ input_meta_data = json.loads(reader.read().decode("utf-8"))
+ num_users = input_meta_data["num_users"]
+ num_items = input_meta_data["num_items"]
+
+ params["num_users"], params["num_items"] = num_users, num_items
+
+ if FLAGS.early_stopping:
+ early_stopping_callback = CustomEarlyStopping(
+ "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
+ callbacks.append(early_stopping_callback)
+
+ (train_input_dataset, eval_input_dataset,
+ num_train_steps, num_eval_steps) = \
+ (ncf_input_pipeline.create_ncf_input_data(
+ params, producer, input_meta_data, strategy))
+ steps_per_epoch = None if generate_input_online else num_train_steps
+
+ with distribution_utils.get_strategy_scope(strategy):
+ keras_model = _get_keras_model(params)
+ optimizer = tf.keras.optimizers.Adam(
+ learning_rate=params["learning_rate"],
+ beta_1=params["beta1"],
+ beta_2=params["beta2"],
+ epsilon=params["epsilon"])
+ if FLAGS.fp16_implementation == "graph_rewrite":
+ optimizer = \
+ tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
+ optimizer,
+ loss_scale=flags_core.get_loss_scale(FLAGS,
+ default_for_fp16="dynamic"))
+ elif FLAGS.dtype == "fp16" and params["keras_use_ctl"]:
+ # When keras_use_ctl is False, instead Model.fit() automatically applies
+ # loss scaling so we don't need to create a LossScaleOptimizer.
+ optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
+ optimizer,
+ tf.keras.mixed_precision.experimental.global_policy().loss_scale)
+
+ if params["keras_use_ctl"]:
+ train_loss, eval_results = run_ncf_custom_training(
+ params,
+ strategy,
+ keras_model,
+ optimizer,
+ callbacks,
+ train_input_dataset,
+ eval_input_dataset,
+ num_train_steps,
+ num_eval_steps,
+ generate_input_online=generate_input_online)
+ else:
+ keras_model.compile(optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)
+
+ if not FLAGS.ml_perf:
+ # Create Tensorboard summary and checkpoint callbacks.
+ summary_dir = os.path.join(FLAGS.model_dir, "summaries")
+ summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
+ checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint")
+ checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
+ checkpoint_path, save_weights_only=True)
+
+ callbacks += [summary_callback, checkpoint_callback]
+
+ history = keras_model.fit(
+ train_input_dataset,
+ epochs=FLAGS.train_epochs,
+ steps_per_epoch=steps_per_epoch,
+ callbacks=callbacks,
+ validation_data=eval_input_dataset,
+ validation_steps=num_eval_steps,
+ verbose=2)
+
+ logging.info("Training done. Start evaluating")
+
+ eval_loss_and_metrics = keras_model.evaluate(
+ eval_input_dataset, steps=num_eval_steps, verbose=2)
+
+ logging.info("Keras evaluation is done.")
+
+ # Keras evaluate() API returns scalar loss and metric values from
+ # evaluation as a list. Here, the returned list would contain
+ # [evaluation loss, hr sum, hr count].
+ eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]
+
+ # Format evaluation result into [eval loss, eval hit accuracy].
+ eval_results = [eval_loss_and_metrics[0], eval_hit_rate]
+
+ if history and history.history:
+ train_history = history.history
+ train_loss = train_history["loss"][-1]
+
+ stats = build_stats(train_loss, eval_results, time_callback)
+ return stats
+
+
+def run_ncf_custom_training(params,
+ strategy,
+ keras_model,
+ optimizer,
+ callbacks,
+ train_input_dataset,
+ eval_input_dataset,
+ num_train_steps,
+ num_eval_steps,
+ generate_input_online=True):
+ """Runs custom training loop.
+
+ Args:
+ params: Dictionary containing training parameters.
+ strategy: Distribution strategy to be used for distributed training.
+ keras_model: Model used for training.
+ optimizer: Optimizer used for training.
+ callbacks: Callbacks to be invoked between batches/epochs.
+ train_input_dataset: tf.data.Dataset used for training.
+ eval_input_dataset: tf.data.Dataset used for evaluation.
+ num_train_steps: Total number of steps to run for training.
+ num_eval_steps: Total number of steps to run for evaluation.
+ generate_input_online: Whether input data was generated by data producer.
+ When data is generated by data producer, then train dataset must be
+ re-initialized after every epoch.
+
+ Returns:
+ A tuple of train loss and a list of training and evaluation results.
+ """
+ loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
+ reduction="sum", from_logits=True)
+ train_input_iterator = iter(
+ strategy.experimental_distribute_dataset(train_input_dataset))
+
+ def train_step(train_iterator):
+ """Called once per step to train the model."""
+
+ def step_fn(features):
+ """Computes loss and applied gradient per replica."""
+ with tf.GradientTape() as tape:
+ softmax_logits = keras_model(features)
+ # The loss can overflow in float16, so we cast to float32.
+ softmax_logits = tf.cast(softmax_logits, "float32")
+ labels = features[rconst.TRAIN_LABEL_KEY]
+ loss = loss_object(
+ labels,
+ softmax_logits,
+ sample_weight=features[rconst.VALID_POINT_MASK])
+ loss *= (1.0 / params["batch_size"])
+ if FLAGS.dtype == "fp16":
+ loss = optimizer.get_scaled_loss(loss)
+
+ grads = tape.gradient(loss, keras_model.trainable_variables)
+ if FLAGS.dtype == "fp16":
+ grads = optimizer.get_unscaled_gradients(grads)
+ # Converting gradients to dense form helps in perf on GPU for NCF
+ grads = neumf_model.sparse_to_dense_grads(
+ list(zip(grads, keras_model.trainable_variables)))
+ optimizer.apply_gradients(grads)
+ return loss
+
+ per_replica_losses = strategy.run(
+ step_fn, args=(next(train_iterator),))
+ mean_loss = strategy.reduce(
+ tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
+ return mean_loss
+
+ def eval_step(eval_iterator):
+ """Called once per eval step to compute eval metrics."""
+
+ def step_fn(features):
+ """Computes eval metrics per replica."""
+ softmax_logits = keras_model(features)
+ in_top_k, metric_weights = metric_fn(softmax_logits,
+ features[rconst.DUPLICATE_MASK],
+ params["match_mlperf"])
+ hr_sum = tf.reduce_sum(in_top_k * metric_weights)
+ hr_count = tf.reduce_sum(metric_weights)
+ return hr_sum, hr_count
+
+ per_replica_hr_sum, per_replica_hr_count = (
+ strategy.run(
+ step_fn, args=(next(eval_iterator),)))
+ hr_sum = strategy.reduce(
+ tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
+ hr_count = strategy.reduce(
+ tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None)
+ return hr_sum, hr_count
+
+ if not FLAGS.run_eagerly:
+ train_step = tf.function(train_step)
+ eval_step = tf.function(eval_step)
+
+ for callback in callbacks:
+ callback.on_train_begin()
+
+ # Not writing tensorboard summaries if running in MLPerf.
+ if FLAGS.ml_perf:
+ eval_summary_writer, train_summary_writer = None, None
+ else:
+ summary_dir = os.path.join(FLAGS.model_dir, "summaries")
+ eval_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, "eval"))
+ train_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, "train"))
+
+ train_loss = 0
+ for epoch in range(FLAGS.train_epochs):
+ for cb in callbacks:
+ cb.on_epoch_begin(epoch)
+
+ # As NCF dataset is sampled with randomness, not repeating
+ # data elements in each epoch has significant impact on
+ # convergence. As so, offline-generated TF record files
+ # contains all epoch worth of data. Thus we do not need
+ # to initialize dataset when reading from tf record files.
+ if generate_input_online:
+ train_input_iterator = iter(
+ strategy.experimental_distribute_dataset(train_input_dataset))
+
+ train_loss = 0
+ for step in range(num_train_steps):
+ current_step = step + epoch * num_train_steps
+ for c in callbacks:
+ c.on_batch_begin(current_step)
+
+ train_loss += train_step(train_input_iterator)
+
+ # Write train loss once in every 1000 steps.
+ if train_summary_writer and step % 1000 == 0:
+ with train_summary_writer.as_default():
+ tf.summary.scalar("training_loss", train_loss/(step + 1),
+ step=current_step)
+
+ for c in callbacks:
+ c.on_batch_end(current_step)
+
+ train_loss /= num_train_steps
+ logging.info("Done training epoch %s, epoch loss=%.3f", epoch + 1,
+ train_loss)
+
+ eval_input_iterator = iter(
+ strategy.experimental_distribute_dataset(eval_input_dataset))
+
+ hr_sum = 0.0
+ hr_count = 0.0
+ for _ in range(num_eval_steps):
+ step_hr_sum, step_hr_count = eval_step(eval_input_iterator)
+ hr_sum += step_hr_sum
+ hr_count += step_hr_count
+
+ logging.info("Done eval epoch %s, hit_rate=%.3f", epoch + 1,
+ hr_sum / hr_count)
+ if eval_summary_writer:
+ with eval_summary_writer.as_default():
+ tf.summary.scalar("hit_rate", hr_sum / hr_count, step=current_step)
+
+ if (FLAGS.early_stopping and
+ float(hr_sum / hr_count) > params["hr_threshold"]):
+ break
+
+ for c in callbacks:
+ c.on_train_end()
+
+ # Saving the model at the end of training.
+ if not FLAGS.ml_perf:
+ checkpoint = tf.train.Checkpoint(model=keras_model, optimizer=optimizer)
+ checkpoint_path = os.path.join(FLAGS.model_dir, "ctl_checkpoint")
+ checkpoint.save(checkpoint_path)
+ logging.info("Saving model as TF checkpoint: %s", checkpoint_path)
+
+ return train_loss, [None, hr_sum / hr_count]
+
+
+def build_stats(loss, eval_result, time_callback):
+ """Normalizes and returns dictionary of stats.
+
+ Args:
+ loss: The final loss at training time.
+ eval_result: Output of the eval step. Assumes first value is eval_loss and
+ second value is accuracy_top_1.
+ time_callback: Time tracking callback likely used during keras.fit.
+
+ Returns:
+ Dictionary of normalized results.
+ """
+ stats = {}
+ if loss:
+ stats["loss"] = loss
+
+ if eval_result:
+ stats["eval_loss"] = eval_result[0]
+ stats["eval_hit_rate"] = eval_result[1]
+
+ if time_callback:
+ timestamp_log = time_callback.timestamp_log
+ stats["step_timestamp_log"] = timestamp_log
+ stats["train_finish_time"] = time_callback.train_finish_time
+ if len(timestamp_log) > 1:
+ stats["avg_exp_per_second"] = (
+ time_callback.batch_size * time_callback.log_steps *
+ (len(time_callback.timestamp_log)-1) /
+ (timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
+
+ return stats
+
+
+def main(_):
+ logging.info("Result is %s", run_ncf(FLAGS))
+
+
+if __name__ == "__main__":
+ ncf_common.define_ncf_flags()
+ app.run(main)
diff --git a/models/official/recommendation/ncf_test.py b/models/official/recommendation/ncf_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..5103283e0aa617b0042ca75f5d2e9572cecb1b68
--- /dev/null
+++ b/models/official/recommendation/ncf_test.py
@@ -0,0 +1,111 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests NCF."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+import tensorflow as tf
+from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
+from official.recommendation import constants as rconst
+from official.recommendation import ncf_common
+from official.recommendation import ncf_keras_main
+from official.utils.testing import integration
+
+NUM_TRAIN_NEG = 4
+
+
+class NcfTest(tf.test.TestCase):
+
+ @classmethod
+ def setUpClass(cls): # pylint: disable=invalid-name
+ super(NcfTest, cls).setUpClass()
+ ncf_common.define_ncf_flags()
+
+ def setUp(self):
+ self.top_k_old = rconst.TOP_K
+ self.num_eval_negatives_old = rconst.NUM_EVAL_NEGATIVES
+ rconst.NUM_EVAL_NEGATIVES = 2
+
+ def tearDown(self):
+ rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
+ rconst.TOP_K = self.top_k_old
+
+ _BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
+
+ @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
+ def test_end_to_end_keras_no_dist_strat(self):
+ integration.run_synthetic(
+ ncf_keras_main.main, tmp_root=self.get_temp_dir(),
+ extra_flags=self._BASE_END_TO_END_FLAGS +
+ ['-distribution_strategy', 'off'])
+
+ @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
+ def test_end_to_end_keras_dist_strat(self):
+ integration.run_synthetic(
+ ncf_keras_main.main, tmp_root=self.get_temp_dir(),
+ extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
+
+ @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
+ def test_end_to_end_keras_dist_strat_ctl(self):
+ flags = (self._BASE_END_TO_END_FLAGS +
+ ['-num_gpus', '0'] +
+ ['-keras_use_ctl', 'True'])
+ integration.run_synthetic(
+ ncf_keras_main.main, tmp_root=self.get_temp_dir(),
+ extra_flags=flags)
+
+ @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
+ def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
+ if context.num_gpus() < 1:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(1, context.num_gpus()))
+
+ integration.run_synthetic(
+ ncf_keras_main.main, tmp_root=self.get_temp_dir(),
+ extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
+ '--dtype', 'fp16'])
+
+ @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
+ def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
+ if context.num_gpus() < 1:
+ self.skipTest(
+ '{} GPUs are not available for this test. {} GPUs are available'.
+ format(1, context.num_gpus()))
+
+ integration.run_synthetic(
+ ncf_keras_main.main, tmp_root=self.get_temp_dir(),
+ extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
+ '--dtype', 'fp16',
+ '--keras_use_ctl'])
+
+ @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
+ def test_end_to_end_keras_2_gpu_fp16(self):
+ if context.num_gpus() < 2:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(2, context.num_gpus()))
+
+ integration.run_synthetic(
+ ncf_keras_main.main, tmp_root=self.get_temp_dir(),
+ extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2',
+ '--dtype', 'fp16'])
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/recommendation/neumf_model.py b/models/official/recommendation/neumf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..48b09293af065a19db2dbfb1d44023439c2b9765
--- /dev/null
+++ b/models/official/recommendation/neumf_model.py
@@ -0,0 +1,431 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines NeuMF model for NCF framework.
+
+Some abbreviations used in the code base:
+NeuMF: Neural Matrix Factorization
+NCF: Neural Collaborative Filtering
+GMF: Generalized Matrix Factorization
+MLP: Multi-Layer Perceptron
+
+GMF applies a linear kernel to model the latent feature interactions, and MLP
+uses a nonlinear kernel to learn the interaction function from data. NeuMF model
+is a fused model of GMF and MLP to better model the complex user-item
+interactions, and unifies the strengths of linearity of MF and non-linearity of
+MLP for modeling the user-item latent structures.
+
+In NeuMF model, it allows GMF and MLP to learn separate embeddings, and combine
+the two models by concatenating their last hidden layer.
+"""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import sys
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
+from typing import Any, Dict, Text
+
+from official.recommendation import constants as rconst
+from official.recommendation import movielens
+from official.recommendation import ncf_common
+from official.recommendation import stat_utils
+
+
+def sparse_to_dense_grads(grads_and_vars):
+ """Convert sparse gradients to dense gradients.
+
+ All sparse gradients, which are represented as instances of tf.IndexedSlices,
+ are converted to dense Tensors. Dense gradients, which are represents as
+ Tensors, are unchanged.
+
+ The purpose of this conversion is that for small embeddings, which are used by
+ this model, applying dense gradients with the AdamOptimizer is faster than
+ applying sparse gradients.
+
+ Args
+ grads_and_vars: A list of (gradient, variable) tuples. Each gradient can
+ be a Tensor or an IndexedSlices. Tensors are unchanged, and IndexedSlices
+ are converted to dense Tensors.
+ Returns:
+ The same list of (gradient, variable) as `grads_and_vars`, except each
+ IndexedSlices gradient is converted to a Tensor.
+ """
+
+ # Calling convert_to_tensor changes IndexedSlices into Tensors, and leaves
+ # Tensors unchanged.
+ return [(tf.convert_to_tensor(g), v) for g, v in grads_and_vars]
+
+
+def neumf_model_fn(features, labels, mode, params):
+ """Model Function for NeuMF estimator."""
+ if params.get("use_seed"):
+ tf.set_random_seed(stat_utils.random_int32())
+
+ users = features[movielens.USER_COLUMN]
+ items = features[movielens.ITEM_COLUMN]
+
+ user_input = tf.keras.layers.Input(tensor=users)
+ item_input = tf.keras.layers.Input(tensor=items)
+ logits = construct_model(user_input, item_input, params).output
+
+ # Softmax with the first column of zeros is equivalent to sigmoid.
+ softmax_logits = ncf_common.convert_to_softmax_logits(logits)
+
+ if mode == tf.estimator.ModeKeys.EVAL:
+ duplicate_mask = tf.cast(features[rconst.DUPLICATE_MASK], tf.float32)
+ return _get_estimator_spec_with_metrics(
+ logits,
+ softmax_logits,
+ duplicate_mask,
+ params["num_neg"],
+ params["match_mlperf"],
+ use_tpu_spec=params["use_tpu"])
+
+ elif mode == tf.estimator.ModeKeys.TRAIN:
+ labels = tf.cast(labels, tf.int32)
+ valid_pt_mask = features[rconst.VALID_POINT_MASK]
+
+ optimizer = tf.compat.v1.train.AdamOptimizer(
+ learning_rate=params["learning_rate"],
+ beta1=params["beta1"],
+ beta2=params["beta2"],
+ epsilon=params["epsilon"])
+ if params["use_tpu"]:
+ optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)
+
+ loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(
+ labels=labels,
+ logits=softmax_logits,
+ weights=tf.cast(valid_pt_mask, tf.float32)
+ )
+
+ tf.identity(loss, name="cross_entropy")
+
+ global_step = tf.compat.v1.train.get_global_step()
+ tvars = tf.compat.v1.trainable_variables()
+ gradients = optimizer.compute_gradients(
+ loss, tvars, colocate_gradients_with_ops=True)
+ gradients = sparse_to_dense_grads(gradients)
+ minimize_op = optimizer.apply_gradients(
+ gradients, global_step=global_step, name="train")
+ update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
+ train_op = tf.group(minimize_op, update_ops)
+
+ return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
+
+ else:
+ raise NotImplementedError
+
+
+def _strip_first_and_last_dimension(x, batch_size):
+ return tf.reshape(x[0, :], (batch_size,))
+
+
+def construct_model(user_input: tf.Tensor, item_input: tf.Tensor,
+ params: Dict[Text, Any]) -> tf.keras.Model:
+ """Initialize NeuMF model.
+
+ Args:
+ user_input: keras input layer for users
+ item_input: keras input layer for items
+ params: Dict of hyperparameters.
+
+ Raises:
+ ValueError: if the first model layer is not even.
+ Returns:
+ model: a keras Model for computing the logits
+ """
+ num_users = params["num_users"]
+ num_items = params["num_items"]
+
+ model_layers = params["model_layers"]
+
+ mf_regularization = params["mf_regularization"]
+ mlp_reg_layers = params["mlp_reg_layers"]
+
+ mf_dim = params["mf_dim"]
+
+ if model_layers[0] % 2 != 0:
+ raise ValueError("The first layer size should be multiple of 2!")
+
+ # Initializer for embedding layers
+ embedding_initializer = "glorot_uniform"
+
+ def mf_slice_fn(x):
+ x = tf.squeeze(x, [1])
+ return x[:, :mf_dim]
+
+ def mlp_slice_fn(x):
+ x = tf.squeeze(x, [1])
+ return x[:, mf_dim:]
+
+ # It turns out to be significantly more effecient to store the MF and MLP
+ # embedding portions in the same table, and then slice as needed.
+ embedding_user = tf.keras.layers.Embedding(
+ num_users,
+ mf_dim + model_layers[0] // 2,
+ embeddings_initializer=embedding_initializer,
+ embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization),
+ input_length=1,
+ name="embedding_user")(
+ user_input)
+
+ embedding_item = tf.keras.layers.Embedding(
+ num_items,
+ mf_dim + model_layers[0] // 2,
+ embeddings_initializer=embedding_initializer,
+ embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization),
+ input_length=1,
+ name="embedding_item")(
+ item_input)
+
+ # GMF part
+ mf_user_latent = tf.keras.layers.Lambda(
+ mf_slice_fn, name="embedding_user_mf")(embedding_user)
+ mf_item_latent = tf.keras.layers.Lambda(
+ mf_slice_fn, name="embedding_item_mf")(embedding_item)
+
+ # MLP part
+ mlp_user_latent = tf.keras.layers.Lambda(
+ mlp_slice_fn, name="embedding_user_mlp")(embedding_user)
+ mlp_item_latent = tf.keras.layers.Lambda(
+ mlp_slice_fn, name="embedding_item_mlp")(embedding_item)
+
+ # Element-wise multiply
+ mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent])
+
+ # Concatenation of two latent features
+ mlp_vector = tf.keras.layers.concatenate([mlp_user_latent, mlp_item_latent])
+
+ num_layer = len(model_layers) # Number of layers in the MLP
+ for layer in xrange(1, num_layer):
+ model_layer = tf.keras.layers.Dense(
+ model_layers[layer],
+ kernel_regularizer=tf.keras.regularizers.l2(mlp_reg_layers[layer]),
+ activation="relu")
+ mlp_vector = model_layer(mlp_vector)
+
+ # Concatenate GMF and MLP parts
+ predict_vector = tf.keras.layers.concatenate([mf_vector, mlp_vector])
+
+ # Final prediction layer
+ logits = tf.keras.layers.Dense(
+ 1, activation=None, kernel_initializer="lecun_uniform",
+ name=movielens.RATING_COLUMN)(predict_vector)
+
+ # Print model topology.
+ model = tf.keras.models.Model([user_input, item_input], logits)
+ model.summary()
+ sys.stdout.flush()
+
+ return model
+
+
+def _get_estimator_spec_with_metrics(logits: tf.Tensor,
+ softmax_logits: tf.Tensor,
+ duplicate_mask: tf.Tensor,
+ num_training_neg: int,
+ match_mlperf: bool = False,
+ use_tpu_spec: bool = False):
+ """Returns a EstimatorSpec that includes the metrics."""
+ cross_entropy, \
+ metric_fn, \
+ in_top_k, \
+ ndcg, \
+ metric_weights = compute_eval_loss_and_metrics_helper(
+ logits,
+ softmax_logits,
+ duplicate_mask,
+ num_training_neg,
+ match_mlperf)
+
+ if use_tpu_spec:
+ return tf.estimator.tpu.TPUEstimatorSpec(
+ mode=tf.estimator.ModeKeys.EVAL,
+ loss=cross_entropy,
+ eval_metrics=(metric_fn, [in_top_k, ndcg, metric_weights]))
+
+ return tf.estimator.EstimatorSpec(
+ mode=tf.estimator.ModeKeys.EVAL,
+ loss=cross_entropy,
+ eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights)
+ )
+
+
+def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
+ softmax_logits: tf.Tensor,
+ duplicate_mask: tf.Tensor,
+ num_training_neg: int,
+ match_mlperf: bool = False):
+ """Model evaluation with HR and NDCG metrics.
+
+ The evaluation protocol is to rank the test interacted item (truth items)
+ among the randomly chosen 999 items that are not interacted by the user.
+ The performance of the ranked list is judged by Hit Ratio (HR) and Normalized
+ Discounted Cumulative Gain (NDCG).
+
+ For evaluation, the ranked list is truncated at 10 for both metrics. As such,
+ the HR intuitively measures whether the test item is present on the top-10
+ list, and the NDCG accounts for the position of the hit by assigning higher
+ scores to hits at top ranks. Both metrics are calculated for each test user,
+ and the average scores are reported.
+
+ If `match_mlperf` is True, then the HR and NDCG computations are done in a
+ slightly unusual way to match the MLPerf reference implementation.
+ Specifically, if the evaluation negatives contain duplicate items, it will be
+ treated as if the item only appeared once. Effectively, for duplicate items in
+ a row, the predicted score for all but one of the items will be set to
+ -infinity
+
+ For example, suppose we have that following inputs:
+ logits_by_user: [[ 2, 3, 3],
+ [ 5, 4, 4]]
+
+ items_by_user: [[10, 20, 20],
+ [30, 40, 40]]
+
+ # Note: items_by_user is not explicitly present. Instead the relevant \
+ information is contained within `duplicate_mask`
+
+ top_k: 2
+
+ Then with match_mlperf=True, the HR would be 2/2 = 1.0. With
+ match_mlperf=False, the HR would be 1/2 = 0.5. This is because each user has
+ predicted scores for only 2 unique items: 10 and 20 for the first user, and 30
+ and 40 for the second. Therefore, with match_mlperf=True, it's guaranteed the
+ first item's score is in the top 2. With match_mlperf=False, this function
+ would compute the first user's first item is not in the top 2, because item 20
+ has a higher score, and item 20 occurs twice.
+
+ Args:
+ logits: A tensor containing the predicted logits for each user. The shape of
+ logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits for a
+ user are grouped, and the last element of the group is the true element.
+ softmax_logits: The same tensor, but with zeros left-appended.
+ duplicate_mask: A vector with the same shape as logits, with a value of 1 if
+ the item corresponding to the logit at that position has already appeared
+ for that user.
+ num_training_neg: The number of negatives per positive during training.
+ match_mlperf: Use the MLPerf reference convention for computing rank.
+
+ Returns:
+ cross_entropy: the loss
+ metric_fn: the metrics function
+ in_top_k: hit rate metric
+ ndcg: ndcg metric
+ metric_weights: metric weights
+ """
+ in_top_k, ndcg, metric_weights, logits_by_user = compute_top_k_and_ndcg(
+ logits, duplicate_mask, match_mlperf)
+
+ # Examples are provided by the eval Dataset in a structured format, so eval
+ # labels can be reconstructed on the fly.
+ eval_labels = tf.reshape(shape=(-1,), tensor=tf.one_hot(
+ tf.zeros(shape=(logits_by_user.shape[0],), dtype=tf.int32) +
+ rconst.NUM_EVAL_NEGATIVES, logits_by_user.shape[1], dtype=tf.int32))
+
+ eval_labels_float = tf.cast(eval_labels, tf.float32)
+
+ # During evaluation, the ratio of negatives to positives is much higher
+ # than during training. (Typically 999 to 1 vs. 4 to 1) By adjusting the
+ # weights for the negative examples we compute a loss which is consistent with
+ # the training data. (And provides apples-to-apples comparison)
+ negative_scale_factor = num_training_neg / rconst.NUM_EVAL_NEGATIVES
+ example_weights = (
+ (eval_labels_float + (1 - eval_labels_float) * negative_scale_factor) *
+ (1 + rconst.NUM_EVAL_NEGATIVES) / (1 + num_training_neg))
+
+ # Tile metric weights back to logit dimensions
+ expanded_metric_weights = tf.reshape(tf.tile(
+ metric_weights[:, tf.newaxis], (1, rconst.NUM_EVAL_NEGATIVES + 1)), (-1,))
+
+ # ignore padded examples
+ example_weights *= tf.cast(expanded_metric_weights, tf.float32)
+
+ cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy(
+ logits=softmax_logits, labels=eval_labels, weights=example_weights)
+
+ def metric_fn(top_k_tensor, ndcg_tensor, weight_tensor):
+ return {
+ rconst.HR_KEY: tf.compat.v1.metrics.mean(top_k_tensor,
+ weights=weight_tensor,
+ name=rconst.HR_METRIC_NAME),
+ rconst.NDCG_KEY: tf.compat.v1.metrics.mean(ndcg_tensor,
+ weights=weight_tensor,
+ name=rconst.NDCG_METRIC_NAME)
+ }
+
+ return cross_entropy, metric_fn, in_top_k, ndcg, metric_weights
+
+
+def compute_top_k_and_ndcg(logits: tf.Tensor,
+ duplicate_mask: tf.Tensor,
+ match_mlperf: bool = False):
+ """Compute inputs of metric calculation.
+
+ Args:
+ logits: A tensor containing the predicted logits for each user. The shape of
+ logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits for a
+ user are grouped, and the first element of the group is the true element.
+ duplicate_mask: A vector with the same shape as logits, with a value of 1 if
+ the item corresponding to the logit at that position has already appeared
+ for that user.
+ match_mlperf: Use the MLPerf reference convention for computing rank.
+
+ Returns:
+ is_top_k, ndcg and weights, all of which has size (num_users_in_batch,), and
+ logits_by_user which has size
+ (num_users_in_batch, (rconst.NUM_EVAL_NEGATIVES + 1)).
+ """
+ logits_by_user = tf.reshape(logits, (-1, rconst.NUM_EVAL_NEGATIVES + 1))
+ duplicate_mask_by_user = tf.cast(
+ tf.reshape(duplicate_mask, (-1, rconst.NUM_EVAL_NEGATIVES + 1)),
+ logits_by_user.dtype)
+
+ if match_mlperf:
+ # Set duplicate logits to the min value for that dtype. The MLPerf
+ # reference dedupes during evaluation.
+ logits_by_user *= (1 - duplicate_mask_by_user)
+ logits_by_user += duplicate_mask_by_user * logits_by_user.dtype.min
+
+ # Determine the location of the first element in each row after the elements
+ # are sorted.
+ sort_indices = tf.argsort(
+ logits_by_user, axis=1, direction="DESCENDING")
+
+ # Use matrix multiplication to extract the position of the true item from the
+ # tensor of sorted indices. This approach is chosen because both GPUs and TPUs
+ # perform matrix multiplications very quickly. This is similar to np.argwhere.
+ # However this is a special case because the target will only appear in
+ # sort_indices once.
+ one_hot_position = tf.cast(tf.equal(sort_indices, rconst.NUM_EVAL_NEGATIVES),
+ tf.int32)
+ sparse_positions = tf.multiply(
+ one_hot_position, tf.range(logits_by_user.shape[1])[tf.newaxis, :])
+ position_vector = tf.reduce_sum(sparse_positions, axis=1)
+
+ in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32)
+ ndcg = tf.math.log(2.) / tf.math.log(
+ tf.cast(position_vector, tf.float32) + 2)
+ ndcg *= in_top_k
+
+ # If a row is a padded row, all but the first element will be a duplicate.
+ metric_weights = tf.not_equal(tf.reduce_sum(duplicate_mask_by_user, axis=1),
+ rconst.NUM_EVAL_NEGATIVES)
+
+ return in_top_k, ndcg, metric_weights, logits_by_user
diff --git a/models/official/recommendation/popen_helper.py b/models/official/recommendation/popen_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcdca4ced8e0b45294023c4675d16efd875694b7
--- /dev/null
+++ b/models/official/recommendation/popen_helper.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper file for running the async data generation process in OSS."""
+
+import contextlib
+import multiprocessing
+import multiprocessing.pool
+
+
+def get_forkpool(num_workers, init_worker=None, closing=True):
+ pool = multiprocessing.Pool(processes=num_workers, initializer=init_worker)
+ return contextlib.closing(pool) if closing else pool
+
+
+def get_threadpool(num_workers, init_worker=None, closing=True):
+ pool = multiprocessing.pool.ThreadPool(processes=num_workers,
+ initializer=init_worker)
+ return contextlib.closing(pool) if closing else pool
+
+
+class FauxPool(object):
+ """Mimic a pool using for loops.
+
+ This class is used in place of proper pools when true determinism is desired
+ for testing or debugging.
+ """
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def map(self, func, iterable, chunksize=None):
+ return [func(i) for i in iterable]
+
+ def imap(self, func, iterable, chunksize=1):
+ for i in iterable:
+ yield func(i)
+
+ def close(self):
+ pass
+
+ def terminate(self):
+ pass
+
+ def join(self):
+ pass
+
+def get_fauxpool(num_workers, init_worker=None, closing=True):
+ pool = FauxPool(processes=num_workers, initializer=init_worker)
+ return contextlib.closing(pool) if closing else pool
+
+
+def worker_job():
+ return "worker"
diff --git a/models/official/recommendation/run.sh b/models/official/recommendation/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b8e1143a38ba0cc26e97be6bad20a5ae6c13be65
--- /dev/null
+++ b/models/official/recommendation/run.sh
@@ -0,0 +1,101 @@
+#!/bin/bash
+set -e
+
+if [ `id -u` != 0 ]; then
+ echo "Calling sudo to gain root for this shell. (Needed to clear caches.)"
+ sudo echo "Success"
+fi
+
+SCRIPT_DIR=`dirname "$BASH_SOURCE"`
+export PYTHONPATH="${SCRIPT_DIR}/../../"
+MAIN_SCRIPT="ncf_estimator_main.py"
+
+DATASET="ml-20m"
+
+BUCKET=${BUCKET:-""}
+ROOT_DIR="${BUCKET:-/tmp}/MLPerf_NCF"
+echo "Root directory: ${ROOT_DIR}"
+
+if [[ -z ${BUCKET} ]]; then
+ LOCAL_ROOT=${ROOT_DIR}
+else
+ LOCAL_ROOT="/tmp/MLPerf_NCF"
+ mkdir -p ${LOCAL_ROOT}
+ echo "Local root (for files which cannot use GCS): ${LOCAL_ROOT}"
+fi
+
+DATE=$(date '+%Y-%m-%d_%H:%M:%S')
+TEST_DIR="${ROOT_DIR}/${DATE}"
+LOCAL_TEST_DIR="${LOCAL_ROOT}/${DATE}"
+mkdir -p ${LOCAL_TEST_DIR}
+
+TPU=${TPU:-""}
+if [[ -z ${TPU} ]]; then
+ DEVICE_FLAG="--num_gpus -1" # --use_xla_for_gpu"
+else
+ DEVICE_FLAG="--tpu ${TPU} --num_gpus 0"
+fi
+
+DATA_DIR="${ROOT_DIR}/movielens_data"
+python "${SCRIPT_DIR}/movielens.py" --data_dir ${DATA_DIR} --dataset ${DATASET}
+
+if [ "$1" == "keras" ]
+then
+ MAIN_SCRIPT="ncf_keras_main.py"
+ BATCH_SIZE=99000
+ DEVICE_FLAG="--num_gpus 1"
+else
+ BATCH_SIZE=98340
+fi
+
+{
+
+for i in `seq 0 4`;
+do
+ START_TIME=$(date +%s)
+ MODEL_DIR="${TEST_DIR}/model_dir_${i}"
+
+ RUN_LOG="${LOCAL_TEST_DIR}/run_${i}.log"
+ export COMPLIANCE_FILE="${LOCAL_TEST_DIR}/run_${i}_compliance_raw.log"
+ export STITCHED_COMPLIANCE_FILE="${LOCAL_TEST_DIR}/run_${i}_compliance_submission.log"
+ echo ""
+ echo "Beginning run ${i}"
+ echo " Complete output logs are in ${RUN_LOG}"
+ echo " Compliance logs: (submission log is created after run.)"
+ echo " ${COMPLIANCE_FILE}"
+ echo " ${STITCHED_COMPLIANCE_FILE}"
+
+ # To reduce variation set the seed flag:
+ # --seed ${i}
+
+ python -u "${SCRIPT_DIR}/${MAIN_SCRIPT}" \
+ --model_dir ${MODEL_DIR} \
+ --data_dir ${DATA_DIR} \
+ --dataset ${DATASET} --hooks "" \
+ ${DEVICE_FLAG} \
+ --clean \
+ --train_epochs 14 \
+ --batch_size ${BATCH_SIZE} \
+ --eval_batch_size 160000 \
+ --learning_rate 0.00382059 \
+ --beta1 0.783529 \
+ --beta2 0.909003 \
+ --epsilon 1.45439e-07 \
+ --layers 256,256,128,64 --num_factors 64 \
+ --hr_threshold 0.635 \
+ --ml_perf \
+ |& tee ${RUN_LOG} \
+ | grep --line-buffered -E --regexp="(Iteration [0-9]+: HR = [0-9\.]+, NDCG = [0-9\.]+, Loss = [0-9\.]+)|(pipeline_hash)|(MLPerf time:)"
+
+ END_TIME=$(date +%s)
+ echo "Run ${i} complete: $(( $END_TIME - $START_TIME )) seconds."
+
+ # Don't fill up the local hard drive.
+ if [[ -z ${BUCKET} ]]; then
+ echo "Removing model directory to save space."
+ rm -r ${MODEL_DIR}
+ fi
+
+done
+
+} |& tee "${LOCAL_TEST_DIR}/summary.log"
diff --git a/models/official/recommendation/stat_utils.py b/models/official/recommendation/stat_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..658a2721e98a88d71dc2ac4562366283ffd2fc47
--- /dev/null
+++ b/models/official/recommendation/stat_utils.py
@@ -0,0 +1,92 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Statistics utility functions of NCF."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+
+def random_int32():
+ return np.random.randint(low=0, high=np.iinfo(np.int32).max, dtype=np.int32)
+
+
+def permutation(args):
+ """Fork safe permutation function.
+
+ This function can be called within a multiprocessing worker and give
+ appropriately random results.
+
+ Args:
+ args: A size two tuple that will unpacked into the size of the permutation
+ and the random seed. This form is used because starmap is not universally
+ available.
+
+ returns:
+ A NumPy array containing a random permutation.
+ """
+ x, seed = args
+
+ # If seed is None NumPy will seed randomly.
+ state = np.random.RandomState(seed=seed) # pylint: disable=no-member
+ output = np.arange(x, dtype=np.int32)
+ state.shuffle(output)
+ return output
+
+
+def very_slightly_biased_randint(max_val_vector):
+ sample_dtype = np.uint64
+ out_dtype = max_val_vector.dtype
+ samples = np.random.randint(low=0, high=np.iinfo(sample_dtype).max,
+ size=max_val_vector.shape, dtype=sample_dtype)
+ return np.mod(samples, max_val_vector.astype(sample_dtype)).astype(out_dtype)
+
+
+def mask_duplicates(x, axis=1): # type: (np.ndarray, int) -> np.ndarray
+ """Identify duplicates from sampling with replacement.
+
+ Args:
+ x: A 2D NumPy array of samples
+ axis: The axis along which to de-dupe.
+
+ Returns:
+ A NumPy array with the same shape as x with one if an element appeared
+ previously along axis 1, else zero.
+ """
+ if axis != 1:
+ raise NotImplementedError
+
+ x_sort_ind = np.argsort(x, axis=1, kind="mergesort")
+ sorted_x = x[np.arange(x.shape[0])[:, np.newaxis], x_sort_ind]
+
+ # compute the indices needed to map values back to their original position.
+ inv_x_sort_ind = np.argsort(x_sort_ind, axis=1, kind="mergesort")
+
+ # Compute the difference of adjacent sorted elements.
+ diffs = sorted_x[:, :-1] - sorted_x[:, 1:]
+
+ # We are only interested in whether an element is zero. Therefore left padding
+ # with ones to restore the original shape is sufficient.
+ diffs = np.concatenate(
+ [np.ones((diffs.shape[0], 1), dtype=diffs.dtype), diffs], axis=1)
+
+ # Duplicate values will have a difference of zero. By definition the first
+ # element is never a duplicate.
+ return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis],
+ inv_x_sort_ind], 0, 1)
diff --git a/models/official/requirements.txt b/models/official/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5d492baaac5e1f2b25ea4db08f4d28075626f4c3
--- /dev/null
+++ b/models/official/requirements.txt
@@ -0,0 +1,23 @@
+six
+google-api-python-client>=1.6.7
+google-cloud-bigquery>=0.31.0
+kaggle>=1.3.9
+numpy>=1.15.4
+pandas>=0.22.0
+psutil>=5.4.3
+py-cpuinfo>=3.3.0
+scipy>=0.19.1
+tensorflow-hub>=0.6.0
+tensorflow-model-optimization>=0.2.1
+tensorflow-datasets
+tensorflow-addons
+dataclasses
+gin-config
+tf_slim>=1.1.0
+sentencepiece
+Cython
+matplotlib
+opencv-python-headless
+pyyaml
+Pillow
+-e git+https://github.com/cocodataset/cocoapi#egg=pycocotools&subdirectory=PythonAPI
diff --git a/models/official/staging/__init__.py b/models/official/staging/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/staging/training/__init__.py b/models/official/staging/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/staging/training/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/staging/training/controller.py b/models/official/staging/training/controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07be66329ad49ba07dff300d66f153552e1c78f
--- /dev/null
+++ b/models/official/staging/training/controller.py
@@ -0,0 +1,337 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A light weight utilities to train TF2 models."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import time
+
+from absl import logging
+
+import tensorflow.compat.v2 as tf
+from typing import Callable, Dict, Optional, Text
+
+from official.staging.training import utils
+
+
+class Controller(object):
+ """Class that facilitates training and evaluation of models."""
+
+ def __init__(
+ self,
+ strategy: Optional[tf.distribute.Strategy] = None,
+ train_fn: Optional[Callable[[tf.Tensor],
+ Optional[Dict[Text, tf.Tensor]]]] = None,
+ eval_fn: Optional[Callable[[tf.Tensor],
+ Optional[Dict[Text, tf.Tensor]]]] = None,
+ global_step: Optional[tf.Variable] = None,
+ # Train related
+ train_steps: Optional[int] = None,
+ steps_per_loop: Optional[int] = None,
+ summary_dir: Optional[Text] = None,
+ checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
+ # summary related
+ summary_interval: Optional[int] = None,
+ # Evaluation related
+ eval_summary_dir: Optional[Text] = None,
+ eval_steps: Optional[int] = None,
+ eval_interval: Optional[int] = None):
+ """Constructs a `Controller` instance.
+
+ Args:
+ strategy: An instance of `tf.distribute.Strategy`.
+ train_fn: A callable defined as `def train_fn(num_steps)`, which
+ `num_steps` indicates the number of steps to run for each loop.
+ eval_fn: A callable defined as `def eval_fn(num_steps)`, which `num_steps`
+ indicates the number of steps for one evaluation.
+ global_step: An integer `tf.Variable` indicating the global training step
+ number. Usually this can be obtained from `iterations` property of the
+ model's optimizer (e.g. `self.optimizer.iterations`), or users can
+ create their own global step variable as well. If the users create their
+ own global step variable, it is recommended to create the `tf.Variable`
+ inside strategy scope, and with
+ `aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`.
+ train_steps: The total (maximum) number of training steps to perform.
+ steps_per_loop: The number of steps to run in each "inner loop" of
+ training (passed to the `num_steps` parameter of `train_fn`).
+ summary_dir: The directory to restore and write checkpoints and summaries.
+ If None, it will be set to `checkpoint_manager.directory`.
+ checkpoint_manager: An instance of `tf.train.CheckpointManager`.
+ summary_interval: Step interval for training summaries. Note that this
+ argument only applies to the summaries outside the training loop. If the
+ value is None, then training summaries are not enabled.
+ eval_summary_dir: The directory to write eval summaries. If None, it will
+ be set to `summary_dir`.
+ eval_steps: Number of steps to run evaluation.
+ eval_interval: Step interval for evaluation. If None, will skip evaluation
+ in the middle of training. Note that evaluation only happens outside the
+ training loop, which the loop iteration is specify by `steps_per_loop`
+ parameter.
+
+ Raises:
+ ValueError: If both `train_fn` and `eval_fn` are None.
+ ValueError: If `train_fn` is not None and `train_steps` is None.
+ ValueError: If `steps_per_loop` is None when `train_fn` is provided.
+ ValueError: If `steps_per_loop` is not a positive integer.
+ """
+ if train_fn is None and eval_fn is None:
+ raise ValueError("`train_fn` and `eval_fn` should not both be None")
+
+ # TODO(rxsang): Support training until exhaustion by passing
+ # `train_steps=-1`. Currently it cannot be supported with a host training
+ # loop because break statements are not supported with distributed dataset.
+ if train_fn is not None:
+ if train_steps is None:
+ raise ValueError("`train_steps` is required when `train_fn` is "
+ "provided.")
+ if steps_per_loop is None:
+ raise ValueError("`steps_per_loop` is required when `train_fn is "
+ "provided.")
+ if not isinstance(steps_per_loop, int) or steps_per_loop < 1:
+ raise ValueError("`steps_per_loop` should be a positive integer")
+ if summary_interval is not None and summary_interval <= 0:
+ raise ValueError("`summary_interval` should be larger than 0")
+
+ self.strategy = strategy or tf.distribute.get_strategy()
+
+ self.train_fn = train_fn
+ self.eval_fn = eval_fn
+ self.global_step = global_step
+ self.checkpoint_manager = checkpoint_manager
+
+ if self.train_fn is not None:
+ self.train_steps = train_steps
+ self.steps_per_loop = steps_per_loop
+ if summary_dir:
+ self.summary_dir = summary_dir
+ elif checkpoint_manager:
+ self.summary_dir = checkpoint_manager.directory
+ else:
+ self.summary_dir = None
+
+ self.summary_interval = summary_interval
+ if self.summary_dir and self.summary_interval:
+ summary_writer = tf.summary.create_file_writer(self.summary_dir)
+ else:
+ summary_writer = None
+ # TODO(rxsang): Consider pass SummaryManager directly into Controller for
+ # maximum customizability.
+ self.summary_manager = utils.SummaryManager(
+ summary_writer,
+ tf.summary.scalar,
+ global_step=self.global_step,
+ summary_interval=self.summary_interval)
+
+ if self.eval_fn is not None:
+ eval_summary_dir = eval_summary_dir or self.summary_dir
+ eval_summary_writer = tf.summary.create_file_writer(
+ eval_summary_dir) if eval_summary_dir else None
+ self.eval_summary_manager = utils.SummaryManager(
+ eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
+
+ self.eval_steps = eval_steps
+ self.eval_interval = eval_interval
+
+ # Creates and initializes the interval triggers.
+ self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
+ self.global_step.numpy()) # pytype: disable=attribute-error
+
+ if self.global_step:
+ tf.summary.experimental.set_step(self.global_step)
+
+ # Restores the model if needed.
+ if self.checkpoint_manager is not None:
+ model_restored = self._restore_model()
+ if not model_restored and self.checkpoint_manager.checkpoint_interval:
+ # If the model is not restored from a checkpoint, save an initial
+ # checkpoint.
+ ckpt_path = self.checkpoint_manager.save(
+ checkpoint_number=self.global_step)
+ logging.info("Saved checkpoins in %s", ckpt_path)
+
+ def _restore_model(self, checkpoint_path=None):
+ """Restore or initialize the model.
+
+ Args:
+ checkpoint_path: An optional string indicates the checkpoint path to
+ restore. If None, will restore from `self.checkpoint_manager`.
+
+ Returns:
+ True if the latest checkpoint is found or restored. Otherwise False.
+ """
+ with self.strategy.scope():
+ # Checkpoint restoring should be inside scope. b/139450638
+ if checkpoint_path is not None:
+ self.checkpoint_manager.checkpoint.restore(checkpoint_path)
+ return True
+ return self.checkpoint_manager.restore_or_initialize()
+
+ def _evaluate_once(self, current_step):
+ """Runs the evaluation once."""
+ logging.info("Start evaluation at step: %s", current_step)
+
+ with self.eval_summary_manager.summary_writer.as_default():
+ eval_outputs = self.eval_fn(self.eval_steps)
+
+ if eval_outputs:
+ eval_outputs = tf.nest.map_structure(lambda x: x.numpy(), eval_outputs)
+
+ info = "step: {} evaluation metric: {}".format(
+ current_step, eval_outputs)
+ self._log_info(info)
+
+ self.eval_summary_manager.write_summaries(eval_outputs)
+ self.eval_summary_manager.flush()
+
+ def _maybe_save_checkpoints(self, current_step, force_trigger=False):
+ if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
+ ckpt_path = self.checkpoint_manager.save(
+ checkpoint_number=current_step, check_interval=not force_trigger)
+ if ckpt_path is not None:
+ logging.info("Saved checkpoins in %s", ckpt_path)
+
+ def _maybe_evaluate(self, current_step, force_trigger=False):
+ if self.eval_trigger(current_step, force_trigger):
+ self._evaluate_once(current_step)
+
+ def _log_info(self, message):
+ """Logs `message` to the `info` log, and also prints to stdout."""
+ logging.info(message)
+ print(message)
+
+ def train(self, evaluate=True):
+ """Runs the training, with optional evaluation.
+
+ This handles evaluation, gathering summaries, and saving checkpoints.
+
+ Args:
+ evaluate: A boolean indicates whether to perform evaluation during
+ training.
+
+ Raises:
+ RuntimeError: If `global_step` is not updated correctly in `train_fn`.
+ """
+ if self.train_fn is None:
+ raise ValueError("`self.train_fn` is required when calling `train` "
+ "method.")
+ if self.global_step is None:
+ raise ValueError("`self.global_step` is required when calling `train` "
+ "method.")
+ if evaluate and self.eval_fn is None:
+ raise ValueError("`self.eval_fn` is required when calling `train` method "
+ "with `evaluate=True`")
+
+ step_timer = _StepTimer(self.global_step)
+ current_step = self.global_step.numpy()
+ logging.info("Train at step %s of %s", current_step, self.train_steps)
+ while current_step < self.train_steps:
+ # Calculates steps to run for the next train loop.
+ steps_per_loop = min(self.train_steps - current_step, self.steps_per_loop)
+ logging.info("Entering training loop with %s steps, at step %s of %s",
+ steps_per_loop, current_step, self.train_steps)
+ current_step += steps_per_loop
+ steps_per_loop = tf.convert_to_tensor(steps_per_loop, dtype=tf.int32)
+
+ with self.summary_manager.summary_writer.as_default():
+ train_outputs = self.train_fn(steps_per_loop)
+
+ # Updates and verifies the current step after a training loop finishes.
+ if current_step != self.global_step.numpy():
+ raise RuntimeError("`self.train_fn` is not updating `global_step` "
+ "correctly, expected: %s, actual: %s" %
+ (current_step, self.global_step.numpy()))
+
+ # Print information like metrics and steps_per_second after a training
+ # loop.
+ if train_outputs:
+ train_outputs = tf.nest.map_structure(
+ lambda x: x.numpy(), train_outputs)
+ steps_per_second = step_timer.steps_per_second()
+ info = "step: {} steps_per_second: {:.2f} {}".format(
+ current_step, steps_per_second, train_outputs)
+ self._log_info(info)
+
+ train_outputs = train_outputs or {}
+ train_outputs["steps_per_second"] = steps_per_second
+ self.summary_manager.write_summaries(train_outputs)
+
+ self._maybe_save_checkpoints(current_step)
+
+ if evaluate:
+ self._maybe_evaluate(current_step)
+
+ self.summary_manager.write_summaries(train_outputs, always_write=True)
+ self.summary_manager.flush()
+ self._maybe_save_checkpoints(current_step, force_trigger=True)
+ if evaluate:
+ self._maybe_evaluate(current_step, force_trigger=True)
+
+ def evaluate(self, continuous=False, timeout_fn=None):
+ """Runs the evaluation.
+
+ Args:
+ continuous: If `True`, will continously monitor the checkpoint directory
+ to evaluate on the latest checkpoint. If `False`, will do the evaluation
+ once.
+ timeout_fn: Optional callable to call after a timeout. If the function
+ returns True, then it means that no new checkpoints will be generated
+ and the iterator will exit.
+
+ Raises:
+ ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
+ """
+ if self.eval_fn is None:
+ raise ValueError("`self.eval_fn` should not be None to call "
+ "`evaluate()` method.")
+
+ if not continuous and timeout_fn is not None:
+ raise ValueError("`timeout_fn` can be only passed when `continuous` is "
+ "True")
+
+ if continuous:
+ for checkpoint_path in tf.train.checkpoints_iterator(
+ self.checkpoint_manager.directory, timeout_fn=timeout_fn):
+ self._restore_model(checkpoint_path)
+ self._evaluate_once(self.global_step.numpy())
+ return
+
+ latest_checkpoint = self.checkpoint_manager.latest_checkpoint
+ if not latest_checkpoint:
+ raise ValueError("no checkpoint found in dir %s" %
+ self.checkpoint_manager.directory)
+ self._restore_model()
+ self._evaluate_once(self.global_step.numpy())
+
+
+class _StepTimer(object):
+ """Utility class for measuring steps/second."""
+
+ def __init__(self, step):
+ self.step = step
+ self.start()
+
+ def start(self):
+ self.last_iteration = self.step.numpy()
+ self.last_time = time.time()
+
+ def steps_per_second(self, restart=True):
+ value = ((self.step.numpy() - self.last_iteration) /
+ (time.time() - self.last_time))
+ if restart:
+ self.start()
+ return value
diff --git a/models/official/staging/training/controller_test.py b/models/official/staging/training/controller_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeaa191c04d40fcc108ed7b00dec86d30d5a2a0b
--- /dev/null
+++ b/models/official/staging/training/controller_test.py
@@ -0,0 +1,308 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for official.staging.training.controller."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.staging.training import controller
+from official.staging.training import standard_runnable
+
+
+def all_strategy_combinations():
+ """Gets combinations of distribution strategies."""
+ return combinations.combine(
+ strategy=[
+ strategy_combinations.one_device_strategy,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ ],
+ mode="eager",
+ )
+
+
+def create_model():
+ x = tf.keras.layers.Input(shape=(3,), name="input")
+ y = tf.keras.layers.Dense(4, name="dense")(x)
+ model = tf.keras.Model(x, y)
+ return model
+
+
+def summaries_with_matching_keyword(keyword, summary_dir):
+ """Yields summary protos matching given keyword from event file."""
+ event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
+ for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
+ if event.summary is not None:
+ for value in event.summary.value:
+ if keyword in value.tag:
+ tf.compat.v1.logging.error(event)
+ yield event.summary
+
+
+def check_eventfile_for_keyword(keyword, summary_dir):
+ """Checks event files for the keyword."""
+ return any(summaries_with_matching_keyword(keyword, summary_dir))
+
+
+def dataset_fn(ctx):
+ del ctx
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10, drop_remainder=True)
+ return dataset
+
+
+class TestRunnable(standard_runnable.StandardTrainable,
+ standard_runnable.StandardEvaluable):
+ """Implements the training and evaluation APIs for the test model."""
+
+ def __init__(self):
+ standard_runnable.StandardTrainable.__init__(self)
+ standard_runnable.StandardEvaluable.__init__(self)
+ self.strategy = tf.distribute.get_strategy()
+ self.model = create_model()
+ self.optimizer = tf.keras.optimizers.RMSprop()
+ self.global_step = self.optimizer.iterations
+ self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
+ self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
+
+ def build_train_dataset(self):
+ return self.strategy.experimental_distribute_datasets_from_function(
+ dataset_fn)
+
+ def train_step(self, iterator):
+
+ def _replicated_step(inputs):
+ """Replicated training step."""
+ inputs, targets = inputs
+ with tf.GradientTape() as tape:
+ outputs = self.model(inputs)
+ loss = tf.math.reduce_sum(outputs - targets)
+ grads = tape.gradient(loss, self.model.variables)
+ self.optimizer.apply_gradients(zip(grads, self.model.variables))
+ self.train_loss.update_state(loss)
+
+ self.strategy.run(_replicated_step, args=(next(iterator),))
+
+ def train_loop_end(self):
+ return {
+ "loss": self.train_loss.result(),
+ }
+
+ def build_eval_dataset(self):
+ return self.strategy.experimental_distribute_datasets_from_function(
+ dataset_fn)
+
+ def eval_begin(self):
+ self.eval_loss.reset_states()
+
+ def eval_step(self, iterator):
+
+ def _replicated_step(inputs):
+ """Replicated evaluation step."""
+ inputs, targets = inputs
+ outputs = self.model(inputs)
+ loss = tf.math.reduce_sum(outputs - targets)
+ self.eval_loss.update_state(loss)
+
+ self.strategy.run(_replicated_step, args=(next(iterator),))
+
+ def eval_end(self):
+ return {
+ "eval_loss": self.eval_loss.result(),
+ }
+
+
+class ControllerTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(ControllerTest, self).setUp()
+ self.model_dir = self.get_temp_dir()
+
+ def test_no_checkpoint(self):
+ test_runnable = TestRunnable()
+ # No checkpoint manager and no strategy.
+ test_controller = controller.Controller(
+ train_fn=test_runnable.train,
+ eval_fn=test_runnable.evaluate,
+ global_step=test_runnable.global_step,
+ train_steps=10,
+ steps_per_loop=2,
+ summary_dir=os.path.join(self.model_dir, "summaries/train"),
+ summary_interval=2,
+ eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
+ eval_steps=2,
+ eval_interval=5)
+ test_controller.train(evaluate=True)
+ self.assertEqual(test_runnable.global_step.numpy(), 10)
+ # Loss and accuracy values should be written into summaries.
+ self.assertNotEmpty(
+ tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
+ self.assertTrue(
+ check_eventfile_for_keyword(
+ "loss", os.path.join(self.model_dir, "summaries/train")))
+ self.assertNotEmpty(
+ tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
+ self.assertTrue(
+ check_eventfile_for_keyword(
+ "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
+ # No checkpoint, so global step starts from 0.
+ test_runnable.global_step.assign(0)
+ test_controller.train(evaluate=True)
+ self.assertEqual(test_runnable.global_step.numpy(), 10)
+
+ def test_no_checkpoint_and_summaries(self):
+ test_runnable = TestRunnable()
+ # No checkpoint + summary directories.
+ test_controller = controller.Controller(
+ train_fn=test_runnable.train,
+ eval_fn=test_runnable.evaluate,
+ global_step=test_runnable.global_step,
+ train_steps=10,
+ steps_per_loop=2,
+ eval_steps=2,
+ eval_interval=5)
+ test_controller.train(evaluate=True)
+ self.assertEqual(test_runnable.global_step.numpy(), 10)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_train_and_evaluate(self, strategy):
+ with strategy.scope():
+ test_runnable = TestRunnable()
+
+ checkpoint = tf.train.Checkpoint(
+ model=test_runnable.model, optimizer=test_runnable.optimizer)
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ self.model_dir,
+ max_to_keep=None,
+ step_counter=test_runnable.global_step,
+ checkpoint_interval=10)
+ test_controller = controller.Controller(
+ strategy=strategy,
+ train_fn=test_runnable.train,
+ eval_fn=test_runnable.evaluate,
+ global_step=test_runnable.global_step,
+ train_steps=10,
+ steps_per_loop=2,
+ summary_dir=os.path.join(self.model_dir, "summaries/train"),
+ summary_interval=2,
+ checkpoint_manager=checkpoint_manager,
+ eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
+ eval_steps=2,
+ eval_interval=5)
+ test_controller.train(evaluate=True)
+
+ # Checkpoints are saved.
+ self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
+
+ # Loss and accuracy values should be written into summaries.
+ self.assertNotEmpty(
+ tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
+ self.assertTrue(
+ check_eventfile_for_keyword(
+ "loss", os.path.join(self.model_dir, "summaries/train")))
+ self.assertNotEmpty(
+ tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
+ self.assertTrue(
+ check_eventfile_for_keyword(
+ "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
+
+ @combinations.generate(all_strategy_combinations())
+ def test_train_only(self, strategy):
+ with strategy.scope():
+ test_runnable = TestRunnable()
+
+ checkpoint = tf.train.Checkpoint(
+ model=test_runnable.model, optimizer=test_runnable.optimizer)
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ self.model_dir,
+ max_to_keep=None,
+ step_counter=test_runnable.global_step,
+ checkpoint_interval=10)
+ test_controller = controller.Controller(
+ strategy=strategy,
+ train_fn=test_runnable.train,
+ global_step=test_runnable.global_step,
+ train_steps=10,
+ steps_per_loop=2,
+ summary_dir=os.path.join(self.model_dir, "summaries/train"),
+ summary_interval=2,
+ checkpoint_manager=checkpoint_manager,
+ eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
+ )
+ test_controller.train(evaluate=False)
+
+ # Checkpoints are saved.
+ self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
+
+ # Only train summaries are written.
+ self.assertNotEmpty(
+ tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
+ self.assertTrue(
+ check_eventfile_for_keyword(
+ "loss", os.path.join(self.model_dir, "summaries/train")))
+ self.assertFalse(
+ tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
+
+ @combinations.generate(all_strategy_combinations())
+ def test_evaluate_only(self, strategy):
+ with strategy.scope():
+ test_runnable = TestRunnable()
+
+ checkpoint = tf.train.Checkpoint(model=test_runnable.model)
+ checkpoint.save(os.path.join(self.model_dir, "ckpt"))
+
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ self.model_dir,
+ max_to_keep=None,
+ step_counter=test_runnable.global_step)
+ test_controller = controller.Controller(
+ strategy=strategy,
+ eval_fn=test_runnable.evaluate,
+ global_step=test_runnable.global_step,
+ checkpoint_manager=checkpoint_manager,
+ summary_dir=os.path.join(self.model_dir, "summaries/train"),
+ eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
+ eval_steps=2,
+ eval_interval=5)
+ test_controller.evaluate()
+
+ # Only eval summaries are written
+ self.assertFalse(
+ tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
+ self.assertNotEmpty(
+ tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
+ self.assertTrue(
+ check_eventfile_for_keyword(
+ "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/staging/training/grad_utils.py b/models/official/staging/training/grad_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..efda2e7616e5ca841dae0877f951982371a44bba
--- /dev/null
+++ b/models/official/staging/training/grad_utils.py
@@ -0,0 +1,143 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Some gradient util functions to help users writing custom training loop."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import logging
+
+import tensorflow.compat.v2 as tf
+
+
+def _filter_grads(grads_and_vars):
+ """Filter out iterable with grad equal to None."""
+ grads_and_vars = tuple(grads_and_vars)
+ if not grads_and_vars:
+ return grads_and_vars
+ filtered = []
+ vars_with_empty_grads = []
+ for grad, var in grads_and_vars:
+ if grad is None:
+ vars_with_empty_grads.append(var)
+ else:
+ filtered.append((grad, var))
+ filtered = tuple(filtered)
+ if not filtered:
+ raise ValueError("No gradients provided for any variable: %s." %
+ ([v.name for _, v in grads_and_vars],))
+ if vars_with_empty_grads:
+ logging.warning(
+ ("Gradients do not exist for variables %s when minimizing the loss."),
+ ([v.name for v in vars_with_empty_grads]))
+ return filtered
+
+
+def _filter_and_allreduce_gradients(grads_and_vars,
+ allreduce_precision="float32"):
+ """Filter None grads and then allreduce gradients in specified precision.
+
+ This utils function is used when users intent to explicitly allreduce
+ gradients and customize gradients operations before and after allreduce.
+ The allreduced gradients are then passed to optimizer.apply_gradients(
+ experimental_aggregate_gradients=False).
+
+ Arguments:
+ grads_and_vars: gradients and variables pairs.
+ allreduce_precision: Whether to allreduce gradients in float32 or float16.
+
+ Returns:
+ pairs of allreduced non-None gradients and variables.
+ """
+ filtered_grads_and_vars = _filter_grads(grads_and_vars)
+ (grads, variables) = zip(*filtered_grads_and_vars)
+ if allreduce_precision == "float16":
+ grads = [tf.cast(grad, "float16") for grad in grads]
+ allreduced_grads = tf.distribute.get_replica_context().all_reduce(
+ tf.distribute.ReduceOp.SUM, grads)
+ if allreduce_precision == "float16":
+ allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
+ return allreduced_grads, variables
+
+
+def _run_callbacks(callbacks, grads_and_vars):
+ for callback in callbacks:
+ grads_and_vars = callback(grads_and_vars)
+ return grads_and_vars
+
+
+def minimize_using_explicit_allreduce(tape,
+ optimizer,
+ loss,
+ trainable_variables,
+ pre_allreduce_callbacks=None,
+ post_allreduce_callbacks=None):
+ """Minimizes loss for one step by updating `trainable_variables`.
+
+ Minimizes loss for one step by updating `trainable_variables`.
+ This explicitly performs gradient allreduce, instead of relying on implicit
+ allreduce in optimizer.apply_gradients(). If training using FP16 mixed
+ precision, explicit allreduce will aggregate gradients in FP16 format.
+ For TPU and GPU training using FP32, explicit allreduce will aggregate
+ gradients in FP32 format.
+
+ Arguments:
+ tape: An instance of `tf.GradientTape`.
+ optimizer: An instance of `tf.keras.optimizers.Optimizer`.
+ loss: the loss tensor.
+ trainable_variables: A list of model Variables.
+ pre_allreduce_callbacks: A list of callback functions that takes gradients
+ and model variables pairs as input, manipulate them, and returns a new
+ gradients and model variables pairs. The callback functions will be
+ invoked in the list order and before gradients are allreduced.
+ With mixed precision training, the pre_allreduce_allbacks will be
+ applied on scaled_gradients. Default is no callbacks.
+ post_allreduce_callbacks: A list of callback functions that takes
+ gradients and model variables pairs as input, manipulate them, and
+ returns a new gradients and model variables paris. The callback
+ functions will be invoked in the list order and right before gradients
+ are applied to variables for updates. Default is no callbacks.
+ """
+ if isinstance(optimizer,
+ tf.keras.mixed_precision.experimental.LossScaleOptimizer):
+ # FP16 GPU code path
+ with tape:
+ scaled_loss = optimizer.get_scaled_loss(loss)
+ scaled_grads = tape.gradient(scaled_loss, trainable_variables)
+ grads_and_vars = zip(scaled_grads, trainable_variables)
+ if pre_allreduce_callbacks:
+ grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
+ (allreduced_scaled_grads,
+ filtered_training_vars) = _filter_and_allreduce_gradients(
+ grads_and_vars, allreduce_precision="float16")
+ allreduced_unscaled_grads = optimizer.get_unscaled_gradients(
+ allreduced_scaled_grads)
+ grads_and_vars = zip(allreduced_unscaled_grads, filtered_training_vars)
+ else:
+ # TPU or FP32 GPU code path
+ grads = tape.gradient(loss, trainable_variables)
+ grads_and_vars = zip(grads, trainable_variables)
+ if pre_allreduce_callbacks:
+ grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
+ (allreduced_grads,
+ filtered_training_vars) = _filter_and_allreduce_gradients(
+ grads_and_vars, allreduce_precision="float32")
+ grads_and_vars = zip(allreduced_grads, filtered_training_vars)
+ if post_allreduce_callbacks:
+ grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars)
+ optimizer.apply_gradients(
+ grads_and_vars, experimental_aggregate_gradients=False)
diff --git a/models/official/staging/training/runnable.py b/models/official/staging/training/runnable.py
new file mode 100644
index 0000000000000000000000000000000000000000..1af6eca06a337506a68d6329e0da16c9ca095e0a
--- /dev/null
+++ b/models/official/staging/training/runnable.py
@@ -0,0 +1,79 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An abstraction that users can easily handle their custom training loops."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import abc
+import six
+import tensorflow.compat.v2 as tf
+from typing import Dict, Optional, Text
+
+
+@six.add_metaclass(abc.ABCMeta)
+class AbstractTrainable(tf.Module):
+ """An abstract class defining the APIs required for training."""
+
+ @abc.abstractmethod
+ def train(self,
+ num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
+ """Implements model training with multiple steps.
+
+ In training, it is common to break the total training steps into several
+ training loops, so users can do checkpointing, write summaries and run some
+ python callbacks. This is necessary for getting good performance in TPU
+ training, as the overhead for launching a multi worker tf.function may be
+ large in Eager mode. It is usually encouraged to create a host training loop
+ (e.g. using a `tf.range` wrapping `strategy.run` inside a
+ `tf.function`) in the TPU case. For the cases that don't require host
+ training loop to acheive peak performance, users can just implement a simple
+ python loop to drive each step.
+
+ Args:
+ num_steps: A guideline for how many training steps to run. Note that it is
+ up to the model what constitutes a "step" (this may involve more than
+ one update to model parameters, e.g. if training a GAN).
+
+ Returns:
+ The function may return a dictionary of `Tensors`, which will be
+ written to logs and as TensorBoard summaries.
+ """
+ pass
+
+
+@six.add_metaclass(abc.ABCMeta)
+class AbstractEvaluable(tf.Module):
+ """An abstract class defining the APIs required for evaluation."""
+
+ @abc.abstractmethod
+ def evaluate(
+ self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
+ """Implements model evaluation.
+
+ Args:
+ num_steps: A guideline for how many evaluation steps to run. Note that it
+ is up to the model what constitutes a "step". Generally, it may be
+ desirable to support both a limited number of eval steps and iterating
+ over a full dataset (however many steps are required) when `num_steps`
+ is `None`.
+
+ Returns:
+ The function may return a dictionary of `Tensors`, which will be
+ written to logs and as TensorBoard summaries.
+ """
+ pass
diff --git a/models/official/staging/training/standard_runnable.py b/models/official/staging/training/standard_runnable.py
new file mode 100644
index 0000000000000000000000000000000000000000..20dd66f28e44f7b799dff4af826dcb22bb13595a
--- /dev/null
+++ b/models/official/staging/training/standard_runnable.py
@@ -0,0 +1,181 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An abstraction that users can easily handle their custom training loops."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import abc
+import six
+import tensorflow.compat.v2 as tf
+from typing import Dict, Optional, Text
+
+from official.staging.training import runnable
+from official.staging.training import utils
+
+
+@six.add_metaclass(abc.ABCMeta)
+class StandardTrainable(runnable.AbstractTrainable):
+ """Implements the standard functionality of AbstractTrainable APIs."""
+
+ def __init__(self, use_tf_while_loop=True, use_tf_function=True):
+ if use_tf_while_loop and not use_tf_function:
+ raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
+ "is not supported")
+ self.use_tf_while_loop = use_tf_while_loop
+ self.use_tf_function = use_tf_function
+ self.train_dataset = None
+ self.train_iter = None
+ self.train_loop_fn = None
+
+ @abc.abstractmethod
+ def build_train_dataset(self):
+ """Builds the training datasets.
+
+ Returns:
+ A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
+ """
+ pass
+
+ def train(self,
+ num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
+ """See base class."""
+ if self.train_dataset is None:
+ # Build train input dataset
+ self.train_dataset = self.build_train_dataset()
+ self.train_iter = tf.nest.map_structure(iter, self.train_dataset)
+
+ if self.train_loop_fn is None:
+ train_fn = self.train_step
+ if self.use_tf_while_loop:
+ self.train_loop_fn = utils.create_tf_while_loop_fn(train_fn)
+ else:
+ if self.use_tf_function:
+ train_fn = tf.function(train_fn)
+ self.train_loop_fn = utils.create_loop_fn(train_fn)
+
+ self.train_loop_begin()
+ self.train_loop_fn(self.train_iter, num_steps)
+ return self.train_loop_end()
+
+ def train_loop_begin(self):
+ """Called once at the beginning of the training loop.
+
+ This is a good place to reset metrics that accumulate values over multiple
+ steps of training.
+ """
+ pass
+
+ @abc.abstractmethod
+ def train_step(self, iterator):
+ """Implements one step of training.
+
+ What a "step" consists of is up to the implementer. If using distribution
+ strategies, the call to this method should take place in the "cross-replica
+ context" for generality, to allow e.g. multiple iterator dequeues and calls
+ to `strategy.run`.
+
+ Args:
+ iterator: A tf.nest-compatible structure of tf.data Iterator or
+ DistributedIterator.
+ """
+ pass
+
+ def train_loop_end(self) -> Optional[Dict[Text, tf.Tensor]]:
+ """Called at the end of the training loop.
+
+ This is a good place to get metric results. The value returned from this
+ function will be returned as-is from the train() method.
+
+ Returns:
+ The function may return a dictionary of `Tensors`, which will be
+ written to logs and as TensorBoard summaries.
+ """
+ pass
+
+
+@six.add_metaclass(abc.ABCMeta)
+class StandardEvaluable(runnable.AbstractEvaluable):
+ """Implements the standard functionality of AbstractEvaluable APIs."""
+
+ def __init__(self, use_tf_function=True):
+ self.eval_use_tf_function = use_tf_function
+ self.eval_dataset = None
+ self.eval_loop_fn = None
+
+ @abc.abstractmethod
+ def build_eval_dataset(self):
+ """Builds the evaluation datasets.
+
+ Returns:
+ A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
+ """
+ pass
+
+ def evaluate(
+ self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
+ """See base class."""
+ if self.eval_dataset is None:
+ # Build train input dataset
+ self.eval_dataset = self.build_eval_dataset()
+
+ if self.eval_loop_fn is None:
+ eval_fn = self.eval_step
+ if self.eval_use_tf_function:
+ eval_fn = tf.function(eval_fn)
+ self.eval_loop_fn = utils.create_loop_fn(eval_fn)
+
+ eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
+
+ self.eval_begin()
+ self.eval_loop_fn(eval_iter, num_steps)
+ return self.eval_end()
+
+ def eval_begin(self):
+ """Called once at the beginning of the evaluation.
+
+ This is a good place to reset metrics that accumulate values over the entire
+ evaluation.
+ """
+ pass
+
+ @abc.abstractmethod
+ def eval_step(self, iterator):
+ """Implements one step of evaluation.
+
+ What a "step" consists of is up to the implementer. If using distribution
+ strategies, the call to this method should take place in the "cross-replica
+ context" for generality, to allow e.g. multiple iterator dequeues and calls
+ to `strategy.run`.
+
+ Args:
+ iterator: A tf.nest-compatible structure of tf.data Iterator or
+ DistributedIterator.
+ """
+ pass
+
+ def eval_end(self) -> Optional[Dict[Text, tf.Tensor]]:
+ """Called at the end of the evaluation.
+
+ This is a good place to get metric results. The value returned from this
+ function will be returned as-is from the evaluate() method.
+
+ Returns:
+ The function may return a dictionary of `Tensors`, which will be
+ written to logs and as TensorBoard summaries.
+ """
+ pass
diff --git a/models/official/staging/training/utils.py b/models/official/staging/training/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..33fa368b7b966e449c8309e523cd31db73efb978
--- /dev/null
+++ b/models/official/staging/training/utils.py
@@ -0,0 +1,342 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Some layered modules/functions to help users writing custom training loop."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import abc
+import inspect
+import six
+
+import tensorflow.compat.v2 as tf
+
+
+def create_loop_fn(step_fn):
+ """Creates a multiple steps function driven by the python while loop.
+
+ Args:
+ step_fn: A function which takes `iterator` as input.
+
+ Returns:
+ A callable defined as the `loop_fn` defination below.
+ """
+
+ def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
+ """A loop function with multiple steps.
+
+ Args:
+ iterator: A nested structure of tf.data `Iterator` or
+ `DistributedIterator`.
+ num_steps: The number of steps in the loop. If `num_steps==-1`, will
+ iterate until exausting the iterator.
+ state: An optional initial state before running the loop.
+ reduce_fn: a callable defined as `def reduce_fn(state, value)`, where
+ `value` is the outputs from `step_fn`.
+
+ Returns:
+ The updated state.
+ """
+ try:
+ step = 0
+ # To make sure the OutOfRangeError exception can be handled well with
+ # async remote eager, we need to wrap the loop body in a `async_scope`.
+ with tf.experimental.async_scope():
+ while (num_steps == -1 or step < num_steps):
+ outputs = step_fn(iterator)
+ if reduce_fn is not None:
+ state = reduce_fn(state, outputs)
+ step += 1
+ return state
+ except (StopIteration, tf.errors.OutOfRangeError):
+ tf.experimental.async_clear_error()
+ return state
+
+ return loop_fn
+
+
+def create_tf_while_loop_fn(step_fn):
+ """Create a multiple steps function driven by tf.while_loop on the host.
+
+ Args:
+ step_fn: A function which takes `iterator` as input.
+
+ Returns:
+ A callable defined as the `loop_fn` defination below.
+ """
+
+ @tf.function
+ def loop_fn(iterator, num_steps):
+ """A loop function with multiple steps.
+
+ Args:
+ iterator: A nested structure of tf.data `Iterator` or
+ `DistributedIterator`.
+ num_steps: The number of steps in the loop. Must be a tf.Tensor.
+ """
+ if not isinstance(num_steps, tf.Tensor):
+ raise ValueError("`num_steps` should be an `tf.Tensor`. Python object "
+ "may cause retracing.")
+
+ for _ in tf.range(num_steps):
+ step_fn(iterator)
+
+ return loop_fn
+
+
+def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
+ """A helper function to create distributed dataset.
+
+ Args:
+ strategy: An instance of `tf.distribute.Strategy`.
+ dataset_or_fn: A instance of `tf.data.Dataset` or a function which takes an
+ `tf.distribute.InputContext` as input and returns a `tf.data.Dataset`. If
+ it is a function, it could optionally have an argument named
+ `input_context` which is `tf.distribute.InputContext` argument type.
+ *args: The list of arguments to be passed to dataset_or_fn.
+ **kwargs: Any keyword arguments to be passed.
+
+ Returns:
+ A distributed Dataset.
+ """
+ if strategy is None:
+ strategy = tf.distribute.get_strategy()
+
+ if isinstance(dataset_or_fn, tf.data.Dataset):
+ return strategy.experimental_distribute_dataset(dataset_or_fn)
+
+ if not callable(dataset_or_fn):
+ raise ValueError("`dataset_or_fn` should be either callable or an instance "
+ "of `tf.data.Dataset`")
+
+ def dataset_fn(ctx):
+ """Wrapped dataset function for creating distributed dataset.."""
+
+ # If `dataset_or_fn` is a function and has `input_context` as argument
+ # names, pass `ctx` as the value of `input_context` when calling
+ # `dataset_or_fn`. Otherwise `ctx` will not be used when calling
+ # `dataset_or_fn`.
+ if six.PY3:
+ argspec = inspect.getfullargspec(dataset_or_fn)
+ else:
+ argspec = inspect.getargspec(dataset_or_fn)
+ args_names = argspec.args
+
+ if "input_context" in args_names:
+ kwargs["input_context"] = ctx
+ ds = dataset_or_fn(*args, **kwargs)
+ return ds
+
+ return strategy.experimental_distribute_datasets_from_function(dataset_fn)
+
+
+class SummaryManager(object):
+ """A class manages writing summaries."""
+
+ def __init__(self,
+ summary_writer,
+ summary_fn,
+ global_step=None,
+ summary_interval=None):
+ """Construct a summary manager object.
+
+ Args:
+ summary_writer: A `tf.summary.SummaryWriter` instance for writing
+ summaries.
+ summary_fn: A callable defined as `def summary_fn(name, tensor,
+ step=None)`, which describes the summary operation.
+ global_step: A `tf.Variable` instance for checking the current global step
+ value, in case users want to save summaries every N steps.
+ summary_interval: An integer, indicates the minimum step interval between
+ two summaries.
+ """
+ if summary_writer is not None:
+ self._summary_writer = summary_writer
+ self._enabled = True
+ else:
+ self._summary_writer = tf.summary.create_noop_writer()
+ self._enabled = False
+ self._summary_fn = summary_fn
+
+ if global_step is None:
+ self._global_step = tf.summary.experimental.get_step()
+ else:
+ self._global_step = global_step
+
+ if summary_interval is not None:
+ if self._global_step is None:
+ raise ValueError("`summary_interval` is not None, but no `global_step` "
+ "can be obtained ")
+ self._last_summary_step = self._global_step.numpy()
+ self._summary_interval = summary_interval
+
+ @property
+ def summary_interval(self):
+ return self._summary_interval
+
+ @property
+ def summary_writer(self):
+ """Returns the underlying summary writer."""
+ return self._summary_writer
+
+ def flush(self):
+ """Flush the underlying summary writer."""
+ if self._enabled:
+ tf.summary.flush(self._summary_writer)
+
+ def write_summaries(self, items, always_write=True):
+ """Write a bulk of summaries.
+
+ Args:
+ items: a dictionary of `Tensors` for writing summaries.
+ always_write: An optional boolean. If `True`, the manager will always
+ write summaries unless the summaries have been written for the same
+ step. Otherwise the manager will only write the summaries if the
+ interval between summaries are larger than `summary_interval`.
+
+ Returns:
+ A boolean indicates whether the summaries are written or not.
+ """
+ # TODO(rxsang): Support writing summaries with nested structure, so users
+ # can split the summaries into different directories for nicer visualization
+ # in Tensorboard, like train and eval metrics.
+ if not self._enabled:
+ return False
+
+ if self._summary_interval is not None:
+ current_step = self._global_step.numpy()
+ if current_step == self._last_summary_step:
+ return False
+ if not always_write and current_step < (self._last_summary_step +
+ self._summary_interval):
+ return False
+ self._last_summary_step = current_step
+
+ with self._summary_writer.as_default():
+ for name, tensor in items.items():
+ self._summary_fn(name, tensor, step=self._global_step)
+ return True
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Trigger(object):
+ """An abstract class representing a "trigger" for some event."""
+
+ @abc.abstractmethod
+ def __call__(self, value: float, force_trigger=False):
+ """Maybe trigger the event based on the given value.
+
+ Args:
+ value: the value for triggering.
+ force_trigger: Whether the trigger is forced triggered.
+
+ Returns:
+ `True` if the trigger is triggered on the given `value`, and
+ `False` otherwise.
+ """
+
+ @abc.abstractmethod
+ def reset(self):
+ """Reset states in the trigger."""
+
+
+class IntervalTrigger(Trigger):
+ """Triggers on every fixed interval."""
+
+ def __init__(self, interval, start=0):
+ """Constructs the IntervalTrigger.
+
+ Args:
+ interval: The triggering interval.
+ start: An initial value for the trigger.
+ """
+ self._interval = interval
+ self._last_trigger_value = start
+
+ def __call__(self, value, force_trigger=False):
+ """Maybe trigger the event based on the given value.
+
+ Args:
+ value: the value for triggering.
+ force_trigger: If True, the trigger will be forced triggered unless the
+ last trigger value is equal to `value`.
+
+ Returns:
+ `True` if the trigger is triggered on the given `value`, and
+ `False` otherwise.
+ """
+ if force_trigger and value != self._last_trigger_value:
+ self._last_trigger_value = value
+ return True
+
+ if self._interval and self._interval > 0:
+ if value >= self._last_trigger_value + self._interval:
+ self._last_trigger_value = value
+ return True
+ return False
+
+ def reset(self):
+ """See base class."""
+ self._last_trigger_value = 0
+
+
+class EpochHelper(object):
+ """A Helper class to handle epochs in Customized Training Loop."""
+
+ def __init__(self, epoch_steps, global_step):
+ """Constructs the EpochHelper.
+
+ Args:
+ epoch_steps: An integer indicates how many steps in an epoch.
+ global_step: A `tf.Variable` instance indicates the current global step.
+ """
+ self._epoch_steps = epoch_steps
+ self._global_step = global_step
+ self._current_epoch = None
+ self._epoch_start_step = None
+ self._in_epoch = False
+
+ def epoch_begin(self):
+ """Returns whether a new epoch should begin."""
+ if self._in_epoch:
+ return False
+ current_step = self._global_step.numpy()
+ self._epoch_start_step = current_step
+ self._current_epoch = current_step // self._epoch_steps
+ self._in_epoch = True
+ return True
+
+ def epoch_end(self):
+ """Returns whether the current epoch should end."""
+ if not self._in_epoch:
+ raise ValueError("`epoch_end` can only be called inside an epoch")
+ current_step = self._global_step.numpy()
+ epoch = current_step // self._epoch_steps
+
+ if epoch > self._current_epoch:
+ self._in_epoch = False
+ return True
+ return False
+
+ @property
+ def batch_index(self):
+ """Index of the next batch within the current epoch."""
+ return self._global_step.numpy() - self._epoch_start_step
+
+ @property
+ def current_epoch(self):
+ return self._current_epoch
diff --git a/models/official/utils/__init__.py b/models/official/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/utils/flags/README.md b/models/official/utils/flags/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..18160f780a0928a2f28ab9a8e66433938179d581
--- /dev/null
+++ b/models/official/utils/flags/README.md
@@ -0,0 +1,97 @@
+# Adding Abseil (absl) flags quickstart
+## Defining a flag
+absl flag definitions are similar to argparse, although they are defined on a global namespace.
+
+For instance defining a string flag looks like:
+```$xslt
+from absl import flags
+flags.DEFINE_string(
+ name="my_flag",
+ default="a_sensible_default",
+ help="Here is what this flag does."
+)
+```
+
+All three arguments are required, but default may be `None`. A common optional argument is
+short_name for defining abreviations. Certain `DEFINE_*` methods will have other required arguments.
+For instance `DEFINE_enum` requires the `enum_values` argument to be specified.
+
+## Key Flags
+absl has the concept of a key flag. Any flag defined in `__main__` is considered a key flag by
+default. Key flags are displayed in `--help`, others only appear in `--helpfull`. In order to
+handle key flags that are defined outside the module in question, absl provides the
+`flags.adopt_module_key_flags()` method. This adds the key flags of a different module to one's own
+key flags. For example:
+```$xslt
+File: flag_source.py
+---------------------------------------
+
+from absl import flags
+flags.DEFINE_string(name="my_flag", default="abc", help="a flag.")
+```
+
+```$xslt
+File: my_module.py
+---------------------------------------
+
+from absl import app as absl_app
+from absl import flags
+
+import flag_source
+
+flags.adopt_module_key_flags(flag_source)
+
+def main(_):
+ pass
+
+absl_app.run(main, [__file__, "-h"]
+```
+
+when `my_module.py` is run it will show the help text for `my_flag`. Because not all flags defined
+in a file are equally important, `official/utils/flags/core.py` (generally imported as flags_core)
+provides an abstraction for handling key flag declaration in an easy way through the
+`register_key_flags_in_core()` function, which allows a module to make a single
+`adopt_key_flags(flags_core)` call when using the util flag declaration functions.
+
+## Validators
+Often the constraints on a flag are complicated. absl provides the validator decorator to allow
+one to mark a function as a flag validation function. Suppose we want users to provide a flag
+which is a palindrome.
+
+```$xslt
+from absl import flags
+
+flags.DEFINE_string(name="pal_flag", short_name="pf", default="", help="Give me a palindrome")
+
+@flags.validator("pal_flag")
+def _check_pal(provided_pal_flag):
+ return provided_pal_flag == provided_pal_flag[::-1]
+
+```
+
+Validators take the form that returning True (truthy) passes, and all others
+(False, None, exception) fail.
+
+## Testing
+To test using absl, simply declare flags in the setupClass method of TensorFlow's TestCase.
+
+```$xslt
+from absl import flags
+import tensorflow as tf
+
+def define_flags():
+ flags.DEFINE_string(name="test_flag", default="abc", help="an example flag")
+
+
+class BaseTester(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(BaseTester, cls).setUpClass()
+ define_flags()
+
+ def test_trivial(self):
+ flags_core.parse_flags([__file__, "test_flag", "def"])
+ self.AssertEqual(flags.FLAGS.test_flag, "def")
+
+```
diff --git a/models/official/utils/flags/__init__.py b/models/official/utils/flags/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/utils/flags/_base.py b/models/official/utils/flags/_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a143e078200eac12b0836eb32fa7c7d0416a8e66
--- /dev/null
+++ b/models/official/utils/flags/_base.py
@@ -0,0 +1,157 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Flags which will be nearly universal across models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+import tensorflow as tf
+from official.utils.flags._conventions import help_wrap
+
+
+def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
+ epochs_between_evals=False, stop_threshold=False,
+ batch_size=True, num_gpu=False, hooks=False, export_dir=False,
+ distribution_strategy=False, run_eagerly=False):
+ """Register base flags.
+
+ Args:
+ data_dir: Create a flag for specifying the input data directory.
+ model_dir: Create a flag for specifying the model file directory.
+ clean: Create a flag for removing the model_dir.
+ train_epochs: Create a flag to specify the number of training epochs.
+ epochs_between_evals: Create a flag to specify the frequency of testing.
+ stop_threshold: Create a flag to specify a threshold accuracy or other
+ eval metric which should trigger the end of training.
+ batch_size: Create a flag to specify the batch size.
+ num_gpu: Create a flag to specify the number of GPUs used.
+ hooks: Create a flag to specify hooks for logging.
+ export_dir: Create a flag to specify where a SavedModel should be exported.
+ distribution_strategy: Create a flag to specify which Distribution Strategy
+ to use.
+ run_eagerly: Create a flag to specify to run eagerly op by op.
+ Returns:
+ A list of flags for core.py to marks as key flags.
+ """
+ key_flags = []
+
+ if data_dir:
+ flags.DEFINE_string(
+ name="data_dir", short_name="dd", default="/tmp",
+ help=help_wrap("The location of the input data."))
+ key_flags.append("data_dir")
+
+ if model_dir:
+ flags.DEFINE_string(
+ name="model_dir", short_name="md", default="/tmp",
+ help=help_wrap("The location of the model checkpoint files."))
+ key_flags.append("model_dir")
+
+ if clean:
+ flags.DEFINE_boolean(
+ name="clean", default=False,
+ help=help_wrap("If set, model_dir will be removed if it exists."))
+ key_flags.append("clean")
+
+ if train_epochs:
+ flags.DEFINE_integer(
+ name="train_epochs", short_name="te", default=1,
+ help=help_wrap("The number of epochs used to train."))
+ key_flags.append("train_epochs")
+
+ if epochs_between_evals:
+ flags.DEFINE_integer(
+ name="epochs_between_evals", short_name="ebe", default=1,
+ help=help_wrap("The number of training epochs to run between "
+ "evaluations."))
+ key_flags.append("epochs_between_evals")
+
+ if stop_threshold:
+ flags.DEFINE_float(
+ name="stop_threshold", short_name="st",
+ default=None,
+ help=help_wrap("If passed, training will stop at the earlier of "
+ "train_epochs and when the evaluation metric is "
+ "greater than or equal to stop_threshold."))
+
+ if batch_size:
+ flags.DEFINE_integer(
+ name="batch_size", short_name="bs", default=32,
+ help=help_wrap("Batch size for training and evaluation. When using "
+ "multiple gpus, this is the global batch size for "
+ "all devices. For example, if the batch size is 32 "
+ "and there are 4 GPUs, each GPU will get 8 examples on "
+ "each step."))
+ key_flags.append("batch_size")
+
+ if num_gpu:
+ flags.DEFINE_integer(
+ name="num_gpus", short_name="ng",
+ default=1,
+ help=help_wrap(
+ "How many GPUs to use at each worker with the "
+ "DistributionStrategies API. The default is 1."))
+
+ if run_eagerly:
+ flags.DEFINE_boolean(
+ name="run_eagerly", default=False,
+ help="Run the model op by op without building a model function.")
+
+ if hooks:
+ flags.DEFINE_list(
+ name="hooks", short_name="hk", default="LoggingTensorHook",
+ help=help_wrap(
+ u"A list of (case insensitive) strings to specify the names of "
+ u"training hooks. Example: `--hooks ProfilerHook,"
+ u"ExamplesPerSecondHook`\n See hooks_helper "
+ u"for details.")
+ )
+ key_flags.append("hooks")
+
+ if export_dir:
+ flags.DEFINE_string(
+ name="export_dir", short_name="ed", default=None,
+ help=help_wrap("If set, a SavedModel serialization of the model will "
+ "be exported to this directory at the end of training. "
+ "See the README for more details and relevant links.")
+ )
+ key_flags.append("export_dir")
+
+ if distribution_strategy:
+ flags.DEFINE_string(
+ name="distribution_strategy", short_name="ds", default="mirrored",
+ help=help_wrap("The Distribution Strategy to use for training. "
+ "Accepted values are 'off', 'one_device', "
+ "'mirrored', 'parameter_server', 'collective', "
+ "case insensitive. 'off' means not to use "
+ "Distribution Strategy; 'default' means to choose "
+ "from `MirroredStrategy` or `OneDeviceStrategy` "
+ "according to the number of GPUs.")
+ )
+
+
+ return key_flags
+
+
+def get_num_gpus(flags_obj):
+ """Treat num_gpus=-1 as 'use all'."""
+ if flags_obj.num_gpus != -1:
+ return flags_obj.num_gpus
+
+ from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
+ local_device_protos = device_lib.list_local_devices()
+ return sum([1 for d in local_device_protos if d.device_type == "GPU"])
diff --git a/models/official/utils/flags/_benchmark.py b/models/official/utils/flags/_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aa01421c5f5c7fede94b971d6674267f232b6da
--- /dev/null
+++ b/models/official/utils/flags/_benchmark.py
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Flags for benchmarking models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+
+from official.utils.flags._conventions import help_wrap
+
+
+def define_log_steps():
+ flags.DEFINE_integer(
+ name="log_steps", default=100,
+ help="Frequency with which to log timing information with TimeHistory.")
+
+ return []
+
+
+def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
+ """Register benchmarking flags.
+
+ Args:
+ benchmark_log_dir: Create a flag to specify location for benchmark logging.
+ bigquery_uploader: Create flags for uploading results to BigQuery.
+
+ Returns:
+ A list of flags for core.py to marks as key flags.
+ """
+
+ key_flags = []
+
+ flags.DEFINE_enum(
+ name="benchmark_logger_type", default="BaseBenchmarkLogger",
+ enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger"],
+ help=help_wrap("The type of benchmark logger to use. Defaults to using "
+ "BaseBenchmarkLogger which logs to STDOUT. Different "
+ "loggers will require other flags to be able to work."))
+ flags.DEFINE_string(
+ name="benchmark_test_id", short_name="bti", default=None,
+ help=help_wrap("The unique test ID of the benchmark run. It could be the "
+ "combination of key parameters. It is hardware "
+ "independent and could be used compare the performance "
+ "between different test runs. This flag is designed for "
+ "human consumption, and does not have any impact within "
+ "the system."))
+
+ define_log_steps()
+
+ if benchmark_log_dir:
+ flags.DEFINE_string(
+ name="benchmark_log_dir", short_name="bld", default=None,
+ help=help_wrap("The location of the benchmark logging.")
+ )
+
+ if bigquery_uploader:
+ flags.DEFINE_string(
+ name="gcp_project", short_name="gp", default=None,
+ help=help_wrap(
+ "The GCP project name where the benchmark will be uploaded."))
+
+ flags.DEFINE_string(
+ name="bigquery_data_set", short_name="bds", default="test_benchmark",
+ help=help_wrap(
+ "The Bigquery dataset name where the benchmark will be uploaded."))
+
+ flags.DEFINE_string(
+ name="bigquery_run_table", short_name="brt", default="benchmark_run",
+ help=help_wrap("The Bigquery table name where the benchmark run "
+ "information will be uploaded."))
+
+ flags.DEFINE_string(
+ name="bigquery_run_status_table", short_name="brst",
+ default="benchmark_run_status",
+ help=help_wrap("The Bigquery table name where the benchmark run "
+ "status information will be uploaded."))
+
+ flags.DEFINE_string(
+ name="bigquery_metric_table", short_name="bmt",
+ default="benchmark_metric",
+ help=help_wrap("The Bigquery table name where the benchmark metric "
+ "information will be uploaded."))
+
+ @flags.multi_flags_validator(
+ ["benchmark_logger_type", "benchmark_log_dir"],
+ message="--benchmark_logger_type=BenchmarkFileLogger will require "
+ "--benchmark_log_dir being set")
+ def _check_benchmark_log_dir(flags_dict):
+ benchmark_logger_type = flags_dict["benchmark_logger_type"]
+ if benchmark_logger_type == "BenchmarkFileLogger":
+ return flags_dict["benchmark_log_dir"]
+ return True
+
+ return key_flags
diff --git a/models/official/utils/flags/_conventions.py b/models/official/utils/flags/_conventions.py
new file mode 100644
index 0000000000000000000000000000000000000000..e04448ab81fc6db7fd8ba1650b427320ff00c05e
--- /dev/null
+++ b/models/official/utils/flags/_conventions.py
@@ -0,0 +1,54 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Central location for shared argparse convention definitions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+import codecs
+import functools
+
+from absl import app as absl_app
+from absl import flags
+
+
+# This codifies help string conventions and makes it easy to update them if
+# necessary. Currently the only major effect is that help bodies start on the
+# line after flags are listed. All flag definitions should wrap the text bodies
+# with help wrap when calling DEFINE_*.
+_help_wrap = functools.partial(flags.text_wrap, length=80, indent="",
+ firstline_indent="\n")
+
+
+# Pretty formatting causes issues when utf-8 is not installed on a system.
+def _stdout_utf8():
+ try:
+ codecs.lookup("utf-8")
+ except LookupError:
+ return False
+ return getattr(sys.stdout, "encoding", "") == "UTF-8"
+
+
+if _stdout_utf8():
+ help_wrap = _help_wrap
+else:
+ def help_wrap(text, *args, **kwargs):
+ return _help_wrap(text, *args, **kwargs).replace(u"\ufeff", u"")
+
+
+# Replace None with h to also allow -h
+absl_app.HelpshortFlag.SHORT_NAME = "h"
diff --git a/models/official/utils/flags/_device.py b/models/official/utils/flags/_device.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8974fc48d1fc77d227745191579df16b2e46bcc
--- /dev/null
+++ b/models/official/utils/flags/_device.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Flags for managing compute devices. Currently only contains TPU flags."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+from absl import logging
+
+from official.utils.flags._conventions import help_wrap
+
+
+def require_cloud_storage(flag_names):
+ """Register a validator to check directory flags.
+ Args:
+ flag_names: An iterable of strings containing the names of flags to be
+ checked.
+ """
+ msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
+ @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
+ def _path_check(flag_values): # pylint: disable=missing-docstring
+ if flag_values["tpu"] is None:
+ return True
+
+ valid_flags = True
+ for key in flag_names:
+ if not flag_values[key].startswith("gs://"):
+ logging.error("%s must be a GCS path.", key)
+ valid_flags = False
+
+ return valid_flags
+
+
+def define_device(tpu=True):
+ """Register device specific flags.
+ Args:
+ tpu: Create flags to specify TPU operation.
+ Returns:
+ A list of flags for core.py to marks as key flags.
+ """
+
+ key_flags = []
+
+ if tpu:
+ flags.DEFINE_string(
+ name="tpu", default=None,
+ help=help_wrap(
+ "The Cloud TPU to use for training. This should be either the name "
+ "used when creating the Cloud TPU, or a "
+ "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
+ "CPU of the local instance instead. (Good for debugging.)"))
+ key_flags.append("tpu")
+
+ flags.DEFINE_string(
+ name="tpu_zone", default=None,
+ help=help_wrap(
+ "[Optional] GCE zone where the Cloud TPU is located in. If not "
+ "specified, we will attempt to automatically detect the GCE "
+ "project from metadata."))
+
+ flags.DEFINE_string(
+ name="tpu_gcp_project", default=None,
+ help=help_wrap(
+ "[Optional] Project name for the Cloud TPU-enabled project. If not "
+ "specified, we will attempt to automatically detect the GCE "
+ "project from metadata."))
+
+ flags.DEFINE_integer(name="num_tpu_shards", default=8,
+ help=help_wrap("Number of shards (TPU chips)."))
+
+ return key_flags
diff --git a/models/official/utils/flags/_distribution.py b/models/official/utils/flags/_distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca331bf24affed5185273a19752d28a491ea3711
--- /dev/null
+++ b/models/official/utils/flags/_distribution.py
@@ -0,0 +1,54 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Flags related to distributed execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+import tensorflow as tf
+
+from official.utils.flags._conventions import help_wrap
+
+
+def define_distribution(worker_hosts=True, task_index=True):
+ """Register distributed execution flags.
+
+ Args:
+ worker_hosts: Create a flag for specifying comma-separated list of workers.
+ task_index: Create a flag for specifying index of task.
+
+ Returns:
+ A list of flags for core.py to marks as key flags.
+ """
+ key_flags = []
+
+ if worker_hosts:
+ flags.DEFINE_string(
+ name='worker_hosts', default=None,
+ help=help_wrap(
+ 'Comma-separated list of worker ip:port pairs for running '
+ 'multi-worker models with DistributionStrategy. The user would '
+ 'start the program on each host with identical value for this '
+ 'flag.'))
+
+ if task_index:
+ flags.DEFINE_integer(
+ name='task_index', default=-1,
+ help=help_wrap('If multi-worker training, the task_index of this '
+ 'worker.'))
+
+ return key_flags
diff --git a/models/official/utils/flags/_misc.py b/models/official/utils/flags/_misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6fa24b5ae7e29827967c5c6a1b78dc3613d40fe
--- /dev/null
+++ b/models/official/utils/flags/_misc.py
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Misc flags."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+
+from official.utils.flags._conventions import help_wrap
+
+
+def define_image(data_format=True):
+ """Register image specific flags.
+
+ Args:
+ data_format: Create a flag to specify image axis convention.
+
+ Returns:
+ A list of flags for core.py to marks as key flags.
+ """
+
+ key_flags = []
+
+ if data_format:
+ flags.DEFINE_enum(
+ name="data_format", short_name="df", default=None,
+ enum_values=["channels_first", "channels_last"],
+ help=help_wrap(
+ "A flag to override the data format used in the model. "
+ "channels_first provides a performance boost on GPU but is not "
+ "always compatible with CPU. If left unspecified, the data format "
+ "will be chosen automatically based on whether TensorFlow was "
+ "built for CPU or GPU."))
+ key_flags.append("data_format")
+
+ return key_flags
diff --git a/models/official/utils/flags/_performance.py b/models/official/utils/flags/_performance.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc5840f95e1ea26697951d1b78fe847526d5859b
--- /dev/null
+++ b/models/official/utils/flags/_performance.py
@@ -0,0 +1,289 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Register flags for optimizing performance."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import multiprocessing
+
+from absl import flags # pylint: disable=g-bad-import-order
+import tensorflow as tf # pylint: disable=g-bad-import-order
+
+from official.utils.flags._conventions import help_wrap
+
+
+# Map string to TensorFlow dtype
+DTYPE_MAP = {
+ "fp16": tf.float16,
+ "bf16": tf.bfloat16,
+ "fp32": tf.float32,
+}
+
+
+def get_tf_dtype(flags_obj):
+ if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
+ # If the graph_rewrite is used, we build the graph with fp32, and let the
+ # graph rewrite change ops to fp16.
+ return tf.float32
+ return DTYPE_MAP[flags_obj.dtype]
+
+
+def get_loss_scale(flags_obj, default_for_fp16):
+ dtype = get_tf_dtype(flags_obj)
+ if flags_obj.loss_scale == "dynamic":
+ return flags_obj.loss_scale
+ elif flags_obj.loss_scale is not None:
+ return float(flags_obj.loss_scale)
+ elif dtype == tf.float32 or dtype == tf.bfloat16:
+ return 1 # No loss scaling is needed for fp32
+ else:
+ assert dtype == tf.float16
+ return default_for_fp16
+
+
+def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
+ synthetic_data=False, max_train_steps=False, dtype=False,
+ all_reduce_alg=False, num_packs=False,
+ tf_gpu_thread_mode=False,
+ datasets_num_private_threads=False,
+ datasets_num_parallel_batches=False,
+ dynamic_loss_scale=False, fp16_implementation=False,
+ loss_scale=False,
+ tf_data_experimental_slack=False, enable_xla=False,
+ training_dataset_cache=False):
+ """Register flags for specifying performance tuning arguments.
+
+ Args:
+ num_parallel_calls: Create a flag to specify parallelism of data loading.
+ inter_op: Create a flag to allow specification of inter op threads.
+ intra_op: Create a flag to allow specification of intra op threads.
+ synthetic_data: Create a flag to allow the use of synthetic data.
+ max_train_steps: Create a flags to allow specification of maximum number
+ of training steps
+ dtype: Create flags for specifying dtype.
+ all_reduce_alg: If set forces a specific algorithm for multi-gpu.
+ num_packs: If set provides number of packs for MirroredStrategy's cross
+ device ops.
+ tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
+ datasets_num_private_threads: Number of private threads for datasets.
+ datasets_num_parallel_batches: Determines how many batches to process in
+ parallel when using map and batch from tf.data.
+ dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
+ "dynamic". Only valid if `dtype` is True.
+ fp16_implementation: Create fp16_implementation flag.
+ loss_scale: Controls the loss scaling, normally for mixed-precision
+ training. Can only be turned on if dtype is also True.
+ tf_data_experimental_slack: Determines whether to enable tf.data's
+ `experimental_slack` option.
+ enable_xla: Determines if XLA (auto clustering) is turned on.
+ training_dataset_cache: Whether to cache the training dataset on workers.
+ Typically used to improve training performance when training data is in
+ remote storage and can fit into worker memory.
+
+ Returns:
+ A list of flags for core.py to marks as key flags.
+ """
+
+ key_flags = []
+ if num_parallel_calls:
+ flags.DEFINE_integer(
+ name="num_parallel_calls", short_name="npc",
+ default=multiprocessing.cpu_count(),
+ help=help_wrap("The number of records that are processed in parallel "
+ "during input processing. This can be optimized per "
+ "data set but for generally homogeneous data sets, "
+ "should be approximately the number of available CPU "
+ "cores. (default behavior)"))
+
+ if inter_op:
+ flags.DEFINE_integer(
+ name="inter_op_parallelism_threads", short_name="inter", default=0,
+ help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. "
+ "See TensorFlow config.proto for details.")
+ )
+
+ if intra_op:
+ flags.DEFINE_integer(
+ name="intra_op_parallelism_threads", short_name="intra", default=0,
+ help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. "
+ "See TensorFlow config.proto for details."))
+
+ if synthetic_data:
+ flags.DEFINE_bool(
+ name="use_synthetic_data", short_name="synth", default=False,
+ help=help_wrap(
+ "If set, use fake data (zeroes) instead of a real dataset. "
+ "This mode is useful for performance debugging, as it removes "
+ "input processing steps, but will not learn anything."))
+
+ if max_train_steps:
+ flags.DEFINE_integer(
+ name="max_train_steps", short_name="mts", default=None, help=help_wrap(
+ "The model will stop training if the global_step reaches this "
+ "value. If not set, training will run until the specified number "
+ "of epochs have run as usual. It is generally recommended to set "
+ "--train_epochs=1 when using this flag."
+ ))
+
+ if dtype:
+ flags.DEFINE_enum(
+ name="dtype", short_name="dt", default="fp32",
+ enum_values=DTYPE_MAP.keys(),
+ help=help_wrap("The TensorFlow datatype used for calculations. "
+ "Variables may be cast to a higher precision on a "
+ "case-by-case basis for numerical stability."))
+
+ loss_scale_help_text = (
+ "The amount to scale the loss by when the model is run. {}. Before "
+ "gradients are computed, the loss is multiplied by the loss scale, "
+ "making all gradients loss_scale times larger. To adjust for this, "
+ "gradients are divided by the loss scale before being applied to "
+ "variables. This is mathematically equivalent to training without "
+ "a loss scale, but the loss scale helps avoid some intermediate "
+ "gradients from underflowing to zero. If not provided the default "
+ "for fp16 is 128 and 1 for all other dtypes.{}"
+ )
+ if dynamic_loss_scale:
+ loss_scale_help_text = loss_scale_help_text.format(
+ "This can be an int/float or the string 'dynamic'",
+ " The string 'dynamic' can be used to dynamically determine the "
+ "optimal loss scale during training, but currently this "
+ "significantly slows down performance")
+ loss_scale_validation_msg = ("loss_scale should be a positive int/float "
+ "or the string 'dynamic'.")
+ else:
+ loss_scale_help_text = loss_scale_help_text.format(
+ "This must be an int/float", "")
+ loss_scale_validation_msg = "loss_scale should be a positive int/float."
+ if loss_scale:
+ flags.DEFINE_string(
+ name="loss_scale", short_name="ls", default=None,
+ help=help_wrap(loss_scale_help_text))
+
+ @flags.validator(flag_name="loss_scale",
+ message=loss_scale_validation_msg)
+ def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
+ """Validator to check the loss scale flag is valid."""
+ if loss_scale is None:
+ return True # null case is handled in get_loss_scale()
+
+ if loss_scale == "dynamic" and dynamic_loss_scale:
+ return True
+
+ try:
+ loss_scale = float(loss_scale)
+ except ValueError:
+ return False
+
+ return loss_scale > 0
+
+ if fp16_implementation:
+ flags.DEFINE_enum(
+ name="fp16_implementation", default="keras",
+ enum_values=("keras', 'graph_rewrite"),
+ help=help_wrap(
+ "When --dtype=fp16, how fp16 should be implemented. This has no "
+ "impact on correctness. 'keras' uses the "
+ "tf.keras.mixed_precision API. 'graph_rewrite' uses the "
+ "tf.train.experimental.enable_mixed_precision_graph_rewrite "
+ "API."))
+
+ @flags.multi_flags_validator(["fp16_implementation", "dtype",
+ "loss_scale"])
+ def _check_fp16_implementation(flags_dict):
+ """Validator to check fp16_implementation flag is valid."""
+ if (flags_dict["fp16_implementation"] == "graph_rewrite" and
+ flags_dict["dtype"] != "fp16"):
+ raise flags.ValidationError("--fp16_implementation should not be "
+ "specified unless --dtype=fp16")
+ return True
+
+ if all_reduce_alg:
+ flags.DEFINE_string(
+ name="all_reduce_alg", short_name="ara", default=None,
+ help=help_wrap("Defines the algorithm to use for performing all-reduce."
+ "When specified with MirroredStrategy for single "
+ "worker, this controls "
+ "tf.contrib.distribute.AllReduceCrossTowerOps. When "
+ "specified with MultiWorkerMirroredStrategy, this "
+ "controls "
+ "tf.distribute.experimental.CollectiveCommunication; "
+ "valid options are `ring` and `nccl`."))
+
+ if num_packs:
+ flags.DEFINE_integer(
+ name="num_packs", default=1,
+ help=help_wrap("Sets `num_packs` in the cross device ops used in "
+ "MirroredStrategy. For details, see "
+ "tf.distribute.NcclAllReduce."))
+
+ if tf_gpu_thread_mode:
+ flags.DEFINE_string(
+ name="tf_gpu_thread_mode", short_name="gt_mode", default=None,
+ help=help_wrap(
+ "Whether and how the GPU device uses its own threadpool.")
+ )
+
+ flags.DEFINE_integer(
+ name="per_gpu_thread_count", short_name="pgtc", default=0,
+ help=help_wrap(
+ "The number of threads to use for GPU. Only valid when "
+ "tf_gpu_thread_mode is not global.")
+ )
+
+ if datasets_num_private_threads:
+ flags.DEFINE_integer(
+ name="datasets_num_private_threads",
+ default=None,
+ help=help_wrap(
+ "Number of threads for a private threadpool created for all"
+ "datasets computation..")
+ )
+
+ if datasets_num_parallel_batches:
+ flags.DEFINE_integer(
+ name="datasets_num_parallel_batches",
+ default=None,
+ help=help_wrap(
+ "Determines how many batches to process in parallel when using "
+ "map and batch from tf.data.")
+ )
+
+ if training_dataset_cache:
+ flags.DEFINE_boolean(
+ name="training_dataset_cache",
+ default=False,
+ help=help_wrap(
+ "Determines whether to cache the training dataset on workers. "
+ "Typically used to improve training performance when training "
+ "data is in remote storage and can fit into worker memory.")
+ )
+
+ if tf_data_experimental_slack:
+ flags.DEFINE_boolean(
+ name="tf_data_experimental_slack",
+ default=False,
+ help=help_wrap(
+ "Whether to enable tf.data's `experimental_slack` option.")
+ )
+
+ if enable_xla:
+ flags.DEFINE_boolean(
+ name="enable_xla", default=False,
+ help="Whether to enable XLA auto jit compilation")
+
+ return key_flags
diff --git a/models/official/utils/flags/core.py b/models/official/utils/flags/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa36944893a579fe5d4a65af9262651db0abc1ba
--- /dev/null
+++ b/models/official/utils/flags/core.py
@@ -0,0 +1,133 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Public interface for flag definition.
+
+See _example.py for detailed instructions on defining flags.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+from six.moves import shlex_quote
+
+from absl import app as absl_app
+from absl import flags
+
+from official.utils.flags import _base
+from official.utils.flags import _benchmark
+from official.utils.flags import _conventions
+from official.utils.flags import _device
+from official.utils.flags import _distribution
+from official.utils.flags import _misc
+from official.utils.flags import _performance
+
+
+def set_defaults(**kwargs):
+ for key, value in kwargs.items():
+ flags.FLAGS.set_default(name=key, value=value)
+
+
+def parse_flags(argv=None):
+ """Reset flags and reparse. Currently only used in testing."""
+ flags.FLAGS.unparse_flags()
+ absl_app.parse_flags_with_usage(argv or sys.argv)
+
+
+def register_key_flags_in_core(f):
+ """Defines a function in core.py, and registers its key flags.
+
+ absl uses the location of a flags.declare_key_flag() to determine the context
+ in which a flag is key. By making all declares in core, this allows model
+ main functions to call flags.adopt_module_key_flags() on core and correctly
+ chain key flags.
+
+ Args:
+ f: The function to be wrapped
+
+ Returns:
+ The "core-defined" version of the input function.
+ """
+
+ def core_fn(*args, **kwargs):
+ key_flags = f(*args, **kwargs)
+ [flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned
+ return core_fn
+
+
+define_base = register_key_flags_in_core(_base.define_base)
+# We have define_base_eager for compatibility, since it used to be a separate
+# function from define_base.
+define_base_eager = define_base
+define_log_steps = register_key_flags_in_core(_benchmark.define_log_steps)
+define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
+define_device = register_key_flags_in_core(_device.define_device)
+define_image = register_key_flags_in_core(_misc.define_image)
+define_performance = register_key_flags_in_core(_performance.define_performance)
+define_distribution = register_key_flags_in_core(
+ _distribution.define_distribution)
+
+
+help_wrap = _conventions.help_wrap
+
+
+get_num_gpus = _base.get_num_gpus
+get_tf_dtype = _performance.get_tf_dtype
+get_loss_scale = _performance.get_loss_scale
+DTYPE_MAP = _performance.DTYPE_MAP
+require_cloud_storage = _device.require_cloud_storage
+
+def _get_nondefault_flags_as_dict():
+ """Returns the nondefault flags as a dict from flag name to value."""
+ nondefault_flags = {}
+ for flag_name in flags.FLAGS:
+ flag_value = getattr(flags.FLAGS, flag_name)
+ if (flag_name != flags.FLAGS[flag_name].short_name and
+ flag_value != flags.FLAGS[flag_name].default):
+ nondefault_flags[flag_name] = flag_value
+ return nondefault_flags
+
+
+def get_nondefault_flags_as_str():
+ """Returns flags as a string that can be passed as command line arguments.
+
+ E.g., returns: "--batch_size=256 --use_synthetic_data" for the following code
+ block:
+
+ ```
+ flags.FLAGS.batch_size = 256
+ flags.FLAGS.use_synthetic_data = True
+ print(get_nondefault_flags_as_str())
+ ```
+
+ Only flags with nondefault values are returned, as passing default flags as
+ command line arguments has no effect.
+
+ Returns:
+ A string with the flags, that can be passed as command line arguments to a
+ program to use the flags.
+ """
+ nondefault_flags = _get_nondefault_flags_as_dict()
+ flag_strings = []
+ for name, value in sorted(nondefault_flags.items()):
+ if isinstance(value, bool):
+ flag_str = '--{}'.format(name) if value else '--no{}'.format(name)
+ elif isinstance(value, list):
+ flag_str = '--{}={}'.format(name, ','.join(value))
+ else:
+ flag_str = '--{}={}'.format(name, value)
+ flag_strings.append(flag_str)
+ return ' '.join(shlex_quote(flag_str) for flag_str in flag_strings)
diff --git a/models/official/utils/flags/flags_test.py b/models/official/utils/flags/flags_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e11a1642242bf134f3a9f1df0908f29b00cecf74
--- /dev/null
+++ b/models/official/utils/flags/flags_test.py
@@ -0,0 +1,162 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import unittest
+
+from absl import flags
+import tensorflow as tf
+
+from official.utils.flags import core as flags_core # pylint: disable=g-bad-import-order
+
+
+def define_flags():
+ flags_core.define_base(clean=True, num_gpu=False, stop_threshold=True,
+ hooks=True, train_epochs=True,
+ epochs_between_evals=True)
+ flags_core.define_performance(
+ num_parallel_calls=True, inter_op=True, intra_op=True,
+ dynamic_loss_scale=True, loss_scale=True, synthetic_data=True,
+ dtype=True)
+ flags_core.define_image()
+ flags_core.define_benchmark()
+
+
+class BaseTester(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(BaseTester, cls).setUpClass()
+ define_flags()
+
+ def test_default_setting(self):
+ """Test to ensure fields exist and defaults can be set.
+ """
+
+ defaults = dict(
+ data_dir="dfgasf",
+ model_dir="dfsdkjgbs",
+ train_epochs=534,
+ epochs_between_evals=15,
+ batch_size=256,
+ hooks=["LoggingTensorHook"],
+ num_parallel_calls=18,
+ inter_op_parallelism_threads=5,
+ intra_op_parallelism_threads=10,
+ data_format="channels_first"
+ )
+
+ flags_core.set_defaults(**defaults)
+ flags_core.parse_flags()
+
+ for key, value in defaults.items():
+ assert flags.FLAGS.get_flag_value(name=key, default=None) == value
+
+ def test_benchmark_setting(self):
+ defaults = dict(
+ hooks=["LoggingMetricHook"],
+ benchmark_log_dir="/tmp/12345",
+ gcp_project="project_abc",
+ )
+
+ flags_core.set_defaults(**defaults)
+ flags_core.parse_flags()
+
+ for key, value in defaults.items():
+ assert flags.FLAGS.get_flag_value(name=key, default=None) == value
+
+ def test_booleans(self):
+ """Test to ensure boolean flags trigger as expected.
+ """
+
+ flags_core.parse_flags([__file__, "--use_synthetic_data"])
+
+ assert flags.FLAGS.use_synthetic_data
+
+ def test_parse_dtype_info(self):
+ flags_core.parse_flags([__file__, "--dtype", "fp16"])
+ self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16)
+ self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
+ default_for_fp16=2), 2)
+
+ flags_core.parse_flags(
+ [__file__, "--dtype", "fp16", "--loss_scale", "5"])
+ self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
+ default_for_fp16=2), 5)
+
+ flags_core.parse_flags(
+ [__file__, "--dtype", "fp16", "--loss_scale", "dynamic"])
+ self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
+ default_for_fp16=2), "dynamic")
+
+ flags_core.parse_flags([__file__, "--dtype", "fp32"])
+ self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32)
+ self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
+ default_for_fp16=2), 1)
+
+ flags_core.parse_flags([__file__, "--dtype", "fp32", "--loss_scale", "5"])
+ self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
+ default_for_fp16=2), 5)
+
+
+ with self.assertRaises(SystemExit):
+ flags_core.parse_flags([__file__, "--dtype", "int8"])
+
+ with self.assertRaises(SystemExit):
+ flags_core.parse_flags([__file__, "--dtype", "fp16",
+ "--loss_scale", "abc"])
+
+ def test_get_nondefault_flags_as_str(self):
+ defaults = dict(
+ clean=True,
+ data_dir="abc",
+ hooks=["LoggingTensorHook"],
+ stop_threshold=1.5,
+ use_synthetic_data=False
+ )
+ flags_core.set_defaults(**defaults)
+ flags_core.parse_flags()
+
+ expected_flags = ""
+ self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
+
+ flags.FLAGS.clean = False
+ expected_flags += "--noclean"
+ self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
+
+ flags.FLAGS.data_dir = "xyz"
+ expected_flags += " --data_dir=xyz"
+ self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
+
+ flags.FLAGS.hooks = ["aaa", "bbb", "ccc"]
+ expected_flags += " --hooks=aaa,bbb,ccc"
+ self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
+
+ flags.FLAGS.stop_threshold = 3.
+ expected_flags += " --stop_threshold=3.0"
+ self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
+
+ flags.FLAGS.use_synthetic_data = True
+ expected_flags += " --use_synthetic_data"
+ self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
+
+ # Assert that explicit setting a flag to its default value does not cause it
+ # to appear in the string
+ flags.FLAGS.use_synthetic_data = False
+ expected_flags = expected_flags[:-len(" --use_synthetic_data")]
+ self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/models/official/utils/flags/guidelines.md b/models/official/utils/flags/guidelines.md
new file mode 100644
index 0000000000000000000000000000000000000000..db963aabebccad8614a1b59ea7ff9b828bcee3b4
--- /dev/null
+++ b/models/official/utils/flags/guidelines.md
@@ -0,0 +1,65 @@
+# Using flags in official models
+
+1. **All common flags must be incorporated in the models.**
+
+ Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
+ and channeled through `official.utils.flags.core`. For instance to define common supervised
+ learning parameters one could use the following code:
+
+ ```$xslt
+ from absl import app as absl_app
+ from absl import flags
+
+ from official.utils.flags import core as flags_core
+
+
+ def define_flags():
+ flags_core.define_base()
+ flags.adopt_key_flags(flags_core)
+
+
+ def main(_):
+ flags_obj = flags.FLAGS
+ print(flags_obj)
+
+
+ if __name__ == "__main__"
+ absl_app.run(main)
+ ```
+2. **Validate flag values.**
+
+ See the [Validators](#validators) section for implementation details.
+
+ Validators in the official model repo should not access the file system, such as verifying
+ that files exist, due to the strict ordering requirements.
+
+3. **Flag values should not be mutated.**
+
+ Instead of mutating flag values, use getter functions to return the desired values. An example
+ getter function is `get_tf_dtype` function below:
+
+ ```
+ # Map string to TensorFlow dtype
+ DTYPE_MAP = {
+ "fp16": tf.float16,
+ "fp32": tf.float32,
+ }
+
+ def get_tf_dtype(flags_obj):
+ if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
+ # If the graph_rewrite is used, we build the graph with fp32, and let the
+ # graph rewrite change ops to fp16.
+ return tf.float32
+ return DTYPE_MAP[flags_obj.dtype]
+
+
+ def main(_):
+ flags_obj = flags.FLAGS()
+
+ # Do not mutate flags_obj
+ # if flags_obj.fp16_implementation == "graph_rewrite":
+ # flags_obj.dtype = "float32" # Don't do this
+
+ print(get_tf_dtype(flags_obj))
+ ...
+ ```
\ No newline at end of file
diff --git a/models/official/utils/hyperparams_flags.py b/models/official/utils/hyperparams_flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b8150677e43b68a68b9234dd852f6df894ea849
--- /dev/null
+++ b/models/official/utils/hyperparams_flags.py
@@ -0,0 +1,128 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common flags for importing hyperparameters."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import flags
+from official.utils.flags import core as flags_core
+
+FLAGS = flags.FLAGS
+
+
+def define_gin_flags():
+ """Define common gin configurable flags."""
+ flags.DEFINE_multi_string('gin_file', None,
+ 'List of paths to the config files.')
+ flags.DEFINE_multi_string(
+ 'gin_param', None, 'Newline separated list of Gin parameter bindings.')
+
+
+def define_common_hparams_flags():
+ """Define the common flags across models."""
+
+ flags.DEFINE_string(
+ 'model_dir',
+ default=None,
+ help=('The directory where the model and training/evaluation summaries'
+ 'are stored.'))
+
+ flags.DEFINE_integer(
+ 'train_batch_size', default=None, help='Batch size for training.')
+
+ flags.DEFINE_integer(
+ 'eval_batch_size', default=None, help='Batch size for evaluation.')
+
+ flags.DEFINE_string(
+ 'precision',
+ default=None,
+ help=('Precision to use; one of: {bfloat16, float32}'))
+
+ flags.DEFINE_string(
+ 'config_file',
+ default=None,
+ help=('A YAML file which specifies overrides. Note that this file can be '
+ 'used as an override template to override the default parameters '
+ 'specified in Python. If the same parameter is specified in both '
+ '`--config_file` and `--params_override`, the one in '
+ '`--params_override` will be used finally.'))
+
+ flags.DEFINE_string(
+ 'params_override',
+ default=None,
+ help=('a YAML/JSON string or a YAML file which specifies additional '
+ 'overrides over the default parameters and those specified in '
+ '`--config_file`. Note that this is supposed to be used only to '
+ 'override the model parameters, but not the parameters like TPU '
+ 'specific flags. One canonical use case of `--config_file` and '
+ '`--params_override` is users first define a template config file '
+ 'using `--config_file`, then use `--params_override` to adjust the '
+ 'minimal set of tuning parameters, for example setting up different'
+ ' `train_batch_size`. '
+ 'The final override order of parameters: default_model_params --> '
+ 'params from config_file --> params in params_override.'
+ 'See also the help message of `--config_file`.'))
+ flags.DEFINE_integer('save_checkpoint_freq', None,
+ 'Number of steps to save checkpoint.')
+
+
+def initialize_common_flags():
+ """Define the common flags across models."""
+ define_common_hparams_flags()
+
+ flags_core.define_device(tpu=True)
+ flags_core.define_base(
+ num_gpu=True, model_dir=False, data_dir=False, batch_size=False)
+ flags_core.define_distribution(worker_hosts=True, task_index=True)
+ flags_core.define_performance(all_reduce_alg=True, num_packs=True)
+
+ # Reset the default value of num_gpus to zero.
+ FLAGS.num_gpus = 0
+
+ flags.DEFINE_string(
+ 'strategy_type', 'mirrored', 'Type of distribute strategy.'
+ 'One of mirrored, tpu and multiworker.')
+
+
+def strategy_flags_dict():
+ """Returns TPU and/or GPU related flags in a dictionary."""
+ return {
+ 'distribution_strategy': FLAGS.strategy_type,
+ # TPUStrategy related flags.
+ 'tpu': FLAGS.tpu,
+ # MultiWorkerMirroredStrategy related flags.
+ 'all_reduce_alg': FLAGS.all_reduce_alg,
+ 'worker_hosts': FLAGS.worker_hosts,
+ 'task_index': FLAGS.task_index,
+ # MirroredStrategy and OneDeviceStrategy
+ 'num_gpus': FLAGS.num_gpus,
+ 'num_packs': FLAGS.num_packs,
+ }
+
+
+def hparam_flags_dict():
+ """Returns model params related flags in a dictionary."""
+ return {
+ 'data_dir': FLAGS.data_dir,
+ 'model_dir': FLAGS.model_dir,
+ 'train_batch_size': FLAGS.train_batch_size,
+ 'eval_batch_size': FLAGS.eval_batch_size,
+ 'precision': FLAGS.precision,
+ 'config_file': FLAGS.config_file,
+ 'params_override': FLAGS.params_override,
+ }
diff --git a/models/official/utils/misc/__init__.py b/models/official/utils/misc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/utils/misc/callstack_sampler.py b/models/official/utils/misc/callstack_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..984f133e9c68a73569717bff47154110c718e3ce
--- /dev/null
+++ b/models/official/utils/misc/callstack_sampler.py
@@ -0,0 +1,62 @@
+"""A simple Python callstack sampler."""
+
+import contextlib
+import datetime
+import signal
+import traceback
+
+
+class CallstackSampler(object):
+ """A simple signal-based Python callstack sampler.
+ """
+
+ def __init__(self, interval=None):
+ self.stacks = []
+ self.interval = 0.001 if interval is None else interval
+
+ def _sample(self, signum, frame):
+ """Samples the current stack."""
+ del signum
+ stack = traceback.extract_stack(frame)
+ formatted_stack = []
+ formatted_stack.append(datetime.datetime.utcnow())
+ for filename, lineno, function_name, text in stack:
+ formatted_frame = '{}:{}({})({})'.format(filename, lineno, function_name,
+ text)
+ formatted_stack.append(formatted_frame)
+ self.stacks.append(formatted_stack)
+ signal.setitimer(signal.ITIMER_VIRTUAL, self.interval, 0)
+
+ @contextlib.contextmanager
+ def profile(self):
+ signal.signal(signal.SIGVTALRM, self._sample)
+ signal.setitimer(signal.ITIMER_VIRTUAL, self.interval, 0)
+ try:
+ yield
+ finally:
+ signal.setitimer(signal.ITIMER_VIRTUAL, 0)
+
+ def save(self, fname):
+ with open(fname, 'w') as f:
+ for s in self.stacks:
+ for l in s:
+ f.write('%s\n' % l)
+ f.write('\n')
+
+
+@contextlib.contextmanager
+def callstack_sampling(filename, interval=None):
+ """Periodically samples the Python callstack.
+
+ Args:
+ filename: the filename
+ interval: the sampling interval, in seconds. Defaults to 0.001.
+
+ Yields:
+ nothing
+ """
+ sampler = CallstackSampler(interval=interval)
+ with sampler.profile():
+ yield
+ sampler.save(filename)
+
diff --git a/models/official/utils/misc/distribution_utils.py b/models/official/utils/misc/distribution_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4823a9b1e6f5cb8d1ff4d7d86340d8656934a6e
--- /dev/null
+++ b/models/official/utils/misc/distribution_utils.py
@@ -0,0 +1,205 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper functions for running models in a distributed setting."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import random
+import string
+
+from absl import logging
+import tensorflow.compat.v2 as tf
+
+from official.utils.misc import tpu_lib
+
+
+def _collective_communication(all_reduce_alg):
+ """Return a CollectiveCommunication based on all_reduce_alg.
+
+ Args:
+ all_reduce_alg: a string specifying which collective communication to pick,
+ or None.
+
+ Returns:
+ tf.distribute.experimental.CollectiveCommunication object
+
+ Raises:
+ ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
+ """
+ collective_communication_options = {
+ None: tf.distribute.experimental.CollectiveCommunication.AUTO,
+ "ring": tf.distribute.experimental.CollectiveCommunication.RING,
+ "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
+ }
+ if all_reduce_alg not in collective_communication_options:
+ raise ValueError(
+ "When used with `multi_worker_mirrored`, valid values for "
+ "all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
+ all_reduce_alg))
+ return collective_communication_options[all_reduce_alg]
+
+
+def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
+ """Return a CrossDeviceOps based on all_reduce_alg and num_packs.
+
+ Args:
+ all_reduce_alg: a string specifying which cross device op to pick, or None.
+ num_packs: an integer specifying number of packs for the cross device op.
+
+ Returns:
+ tf.distribute.CrossDeviceOps object or None.
+
+ Raises:
+ ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
+ """
+ if all_reduce_alg is None:
+ return None
+ mirrored_all_reduce_options = {
+ "nccl": tf.distribute.NcclAllReduce,
+ "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
+ }
+ if all_reduce_alg not in mirrored_all_reduce_options:
+ raise ValueError(
+ "When used with `mirrored`, valid values for all_reduce_alg are "
+ "[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
+ all_reduce_alg))
+ cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
+ return cross_device_ops_class(num_packs=num_packs)
+
+
+def get_distribution_strategy(distribution_strategy="mirrored",
+ num_gpus=0,
+ all_reduce_alg=None,
+ num_packs=1,
+ tpu_address=None):
+ """Return a DistributionStrategy for running the model.
+
+ Args:
+ distribution_strategy: a string specifying which distribution strategy to
+ use. Accepted values are "off", "one_device", "mirrored",
+ "parameter_server", "multi_worker_mirrored", and "tpu" -- case insensitive.
+ "off" means not to use Distribution Strategy; "tpu" means to use
+ TPUStrategy using `tpu_address`.
+ num_gpus: Number of GPUs to run this model.
+ all_reduce_alg: Optional. Specifies which algorithm to use when performing
+ all-reduce. For `MirroredStrategy`, valid values are "nccl" and
+ "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
+ "ring" and "nccl". If None, DistributionStrategy will choose based on
+ device topology.
+ num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
+ or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
+ tpu_address: Optional. String that represents TPU to connect to. Must not
+ be None if `distribution_strategy` is set to `tpu`.
+ Returns:
+ tf.distribute.DistibutionStrategy object.
+ Raises:
+ ValueError: if `distribution_strategy` is "off" or "one_device" and
+ `num_gpus` is larger than 1; or `num_gpus` is negative or if
+ `distribution_strategy` is `tpu` but `tpu_address` is not specified.
+ """
+ if num_gpus < 0:
+ raise ValueError("`num_gpus` can not be negative.")
+
+ distribution_strategy = distribution_strategy.lower()
+ if distribution_strategy == "off":
+ if num_gpus > 1:
+ raise ValueError(
+ "When {} GPUs are specified, distribution_strategy "
+ "flag cannot be set to `off`.".format(num_gpus))
+ return None
+
+ if distribution_strategy == "tpu":
+ # When tpu_address is an empty string, we communicate with local TPUs.
+ cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
+ return tf.distribute.experimental.TPUStrategy(cluster_resolver)
+
+ if distribution_strategy == "multi_worker_mirrored":
+ return tf.distribute.experimental.MultiWorkerMirroredStrategy(
+ communication=_collective_communication(all_reduce_alg))
+
+ if distribution_strategy == "one_device":
+ if num_gpus == 0:
+ return tf.distribute.OneDeviceStrategy("device:CPU:0")
+ if num_gpus > 1:
+ raise ValueError("`OneDeviceStrategy` can not be used for more than "
+ "one device.")
+ return tf.distribute.OneDeviceStrategy("device:GPU:0")
+
+ if distribution_strategy == "mirrored":
+ if num_gpus == 0:
+ devices = ["device:CPU:0"]
+ else:
+ devices = ["device:GPU:%d" % i for i in range(num_gpus)]
+ return tf.distribute.MirroredStrategy(
+ devices=devices,
+ cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
+
+ if distribution_strategy == "parameter_server":
+ return tf.distribute.experimental.ParameterServerStrategy()
+
+ raise ValueError(
+ "Unrecognized Distribution Strategy: %r" % distribution_strategy)
+
+
+def configure_cluster(worker_hosts=None, task_index=-1):
+ """Set multi-worker cluster spec in TF_CONFIG environment variable.
+
+ Args:
+ worker_hosts: comma-separated list of worker ip:port pairs.
+
+ Returns:
+ Number of workers in the cluster.
+ """
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ if tf_config:
+ num_workers = (len(tf_config["cluster"].get("chief", [])) +
+ len(tf_config["cluster"].get("worker", [])))
+ elif worker_hosts:
+ workers = worker_hosts.split(",")
+ num_workers = len(workers)
+ if num_workers > 1 and task_index < 0:
+ raise ValueError("Must specify task_index when number of workers > 1")
+ task_index = 0 if num_workers == 1 else task_index
+ os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": workers
+ },
+ "task": {"type": "worker", "index": task_index}
+ })
+ else:
+ num_workers = 1
+ return num_workers
+
+
+def get_strategy_scope(strategy):
+ if strategy:
+ strategy_scope = strategy.scope()
+ else:
+ strategy_scope = DummyContextManager()
+
+ return strategy_scope
+
+
+class DummyContextManager(object):
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ pass
diff --git a/models/official/utils/misc/distribution_utils_test.py b/models/official/utils/misc/distribution_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fd7bff09daaf2f5c85af2a0e7b7efbd00dc42c1
--- /dev/null
+++ b/models/official/utils/misc/distribution_utils_test.py
@@ -0,0 +1,49 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+""" Tests for distribution util functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.compat.v2 as tf
+
+from official.utils.misc import distribution_utils
+
+
+class GetDistributionStrategyTest(tf.test.TestCase):
+ """Tests for get_distribution_strategy."""
+ def test_one_device_strategy_cpu(self):
+ ds = distribution_utils.get_distribution_strategy(num_gpus=0)
+ self.assertEquals(ds.num_replicas_in_sync, 1)
+ self.assertEquals(len(ds.extended.worker_devices), 1)
+ self.assertIn('CPU', ds.extended.worker_devices[0])
+
+ def test_one_device_strategy_gpu(self):
+ ds = distribution_utils.get_distribution_strategy(num_gpus=1)
+ self.assertEquals(ds.num_replicas_in_sync, 1)
+ self.assertEquals(len(ds.extended.worker_devices), 1)
+ self.assertIn('GPU', ds.extended.worker_devices[0])
+
+ def test_mirrored_strategy(self):
+ ds = distribution_utils.get_distribution_strategy(num_gpus=5)
+ self.assertEquals(ds.num_replicas_in_sync, 5)
+ self.assertEquals(len(ds.extended.worker_devices), 5)
+ for device in ds.extended.worker_devices:
+ self.assertIn('GPU', device)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/utils/misc/keras_utils.py b/models/official/utils/misc/keras_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cca51f1d24701802b0fd7cfc62a84306eedded2
--- /dev/null
+++ b/models/official/utils/misc/keras_utils.py
@@ -0,0 +1,199 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper functions for the Keras implementations of models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import multiprocessing
+import os
+import time
+
+from absl import logging
+import tensorflow as tf
+
+
+class BatchTimestamp(object):
+ """A structure to store batch time stamp."""
+
+ def __init__(self, batch_index, timestamp):
+ self.batch_index = batch_index
+ self.timestamp = timestamp
+
+ def __repr__(self):
+ return "'BatchTimestamp'".format(
+ self.batch_index, self.timestamp)
+
+
+class TimeHistory(tf.keras.callbacks.Callback):
+ """Callback for Keras models."""
+
+ def __init__(self, batch_size, log_steps, initial_step=0, logdir=None):
+ """Callback for logging performance.
+
+ Args:
+ batch_size: Total batch size.
+ log_steps: Interval of steps between logging of batch level stats.
+ initial_step: Optional, initial step.
+ logdir: Optional directory to write TensorBoard summaries.
+ """
+ # TODO(wcromar): remove this parameter and rely on `logs` parameter of
+ # on_train_batch_end()
+ self.batch_size = batch_size
+ super(TimeHistory, self).__init__()
+ self.log_steps = log_steps
+ self.last_log_step = initial_step
+ self.steps_before_epoch = initial_step
+ self.steps_in_epoch = 0
+ self.start_time = None
+
+ if logdir:
+ self.summary_writer = tf.summary.create_file_writer(logdir)
+ else:
+ self.summary_writer = None
+
+ # Logs start of step 1 then end of each step based on log_steps interval.
+ self.timestamp_log = []
+
+ # Records the time each epoch takes to run from start to finish of epoch.
+ self.epoch_runtime_log = []
+
+ @property
+ def global_steps(self):
+ """The current 1-indexed global step."""
+ return self.steps_before_epoch + self.steps_in_epoch
+
+ @property
+ def average_steps_per_second(self):
+ """The average training steps per second across all epochs."""
+ return self.global_steps / sum(self.epoch_runtime_log)
+
+ @property
+ def average_examples_per_second(self):
+ """The average number of training examples per second across all epochs."""
+ return self.average_steps_per_second * self.batch_size
+
+ def get_examples_per_sec(self, warmup=1):
+ """Calculates examples/sec through timestamp_log and skip warmup period."""
+ # First entry in timestamp_log is the start of the step 1. The rest of the
+ # entries are the end of each step recorded.
+ time_log = self.timestamp_log
+ seconds = time_log[-1].timestamp - time_log[warmup].timestamp
+ steps = time_log[-1].batch_index - time_log[warmup].batch_index
+ return self.batch_size * steps / seconds
+
+ def get_startup_time(self, start_time_sec):
+ return self.timestamp_log[0].timestamp - start_time_sec
+
+ def on_train_end(self, logs=None):
+ self.train_finish_time = time.time()
+
+ if self.summary_writer:
+ self.summary_writer.flush()
+
+ def on_epoch_begin(self, epoch, logs=None):
+ self.epoch_start = time.time()
+
+ def on_batch_begin(self, batch, logs=None):
+ if not self.start_time:
+ self.start_time = time.time()
+
+ # Record the timestamp of the first global step
+ if not self.timestamp_log:
+ self.timestamp_log.append(BatchTimestamp(self.global_steps,
+ self.start_time))
+
+ def on_batch_end(self, batch, logs=None):
+ """Records elapse time of the batch and calculates examples per second."""
+ self.steps_in_epoch = batch + 1
+ steps_since_last_log = self.global_steps - self.last_log_step
+ if steps_since_last_log >= self.log_steps:
+ now = time.time()
+ elapsed_time = now - self.start_time
+ steps_per_second = steps_since_last_log / elapsed_time
+ examples_per_second = steps_per_second * self.batch_size
+
+ self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
+ logging.info(
+ 'TimeHistory: %.2f seconds, %.2f examples/second between steps %d '
+ 'and %d', elapsed_time, examples_per_second, self.last_log_step,
+ self.global_steps)
+
+ if self.summary_writer:
+ with self.summary_writer.as_default():
+ tf.summary.scalar('steps_per_second', steps_per_second,
+ self.global_steps)
+ tf.summary.scalar('examples_per_second', examples_per_second,
+ self.global_steps)
+
+ self.last_log_step = self.global_steps
+ self.start_time = None
+
+ def on_epoch_end(self, epoch, logs=None):
+ epoch_run_time = time.time() - self.epoch_start
+ self.epoch_runtime_log.append(epoch_run_time)
+
+ self.steps_before_epoch += self.steps_in_epoch
+ self.steps_in_epoch = 0
+
+
+class SimpleCheckpoint(tf.keras.callbacks.Callback):
+ """Keras callback to save tf.train.Checkpoints."""
+
+ def __init__(self, checkpoint_manager):
+ super(SimpleCheckpoint, self).__init__()
+ self.checkpoint_manager = checkpoint_manager
+
+ def on_epoch_end(self, epoch, logs=None):
+ step_counter = self.checkpoint_manager._step_counter.numpy() # pylint: disable=protected-access
+ self.checkpoint_manager.save(checkpoint_number=step_counter)
+
+
+def set_session_config(enable_xla=False):
+ """Sets the session config."""
+ if enable_xla:
+ tf.config.optimizer.set_jit(True)
+
+# TODO(hongkuny): remove set_config_v2 globally.
+set_config_v2 = set_session_config
+
+
+def set_gpu_thread_mode_and_count(gpu_thread_mode,
+ datasets_num_private_threads,
+ num_gpus, per_gpu_thread_count):
+ """Set GPU thread mode and count, and adjust dataset threads count."""
+ cpu_count = multiprocessing.cpu_count()
+ logging.info('Logical CPU cores: %s', cpu_count)
+
+ # Allocate private thread pool for each GPU to schedule and launch kernels
+ per_gpu_thread_count = per_gpu_thread_count or 2
+ os.environ['TF_GPU_THREAD_MODE'] = gpu_thread_mode
+ os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count)
+ logging.info('TF_GPU_THREAD_COUNT: %s',
+ os.environ['TF_GPU_THREAD_COUNT'])
+ logging.info('TF_GPU_THREAD_MODE: %s',
+ os.environ['TF_GPU_THREAD_MODE'])
+
+ # Limit data preprocessing threadpool to CPU cores minus number of total GPU
+ # private threads and memory copy threads.
+ total_gpu_thread_count = per_gpu_thread_count * num_gpus
+ num_runtime_threads = num_gpus
+ if not datasets_num_private_threads:
+ datasets_num_private_threads = min(
+ cpu_count - total_gpu_thread_count - num_runtime_threads,
+ num_gpus * 8)
+ logging.info('Set datasets_num_private_threads to %s',
+ datasets_num_private_threads)
diff --git a/models/official/utils/misc/model_helpers.py b/models/official/utils/misc/model_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a44e50ac46162821dcbfacc55b5b1e5c30eba8f
--- /dev/null
+++ b/models/official/utils/misc/model_helpers.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Miscellaneous functions that can be called by models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numbers
+
+from absl import logging
+import tensorflow as tf
+
+from tensorflow.python.util import nest
+# pylint:disable=logging-format-interpolation
+
+
+def past_stop_threshold(stop_threshold, eval_metric):
+ """Return a boolean representing whether a model should be stopped.
+
+ Args:
+ stop_threshold: float, the threshold above which a model should stop
+ training.
+ eval_metric: float, the current value of the relevant metric to check.
+
+ Returns:
+ True if training should stop, False otherwise.
+
+ Raises:
+ ValueError: if either stop_threshold or eval_metric is not a number
+ """
+ if stop_threshold is None:
+ return False
+
+ if not isinstance(stop_threshold, numbers.Number):
+ raise ValueError("Threshold for checking stop conditions must be a number.")
+ if not isinstance(eval_metric, numbers.Number):
+ raise ValueError("Eval metric being checked against stop conditions "
+ "must be a number.")
+
+ if eval_metric >= stop_threshold:
+ logging.info("Stop threshold of {} was passed with metric value {}.".format(
+ stop_threshold, eval_metric))
+ return True
+
+ return False
+
+
+def generate_synthetic_data(
+ input_shape, input_value=0, input_dtype=None, label_shape=None,
+ label_value=0, label_dtype=None):
+ """Create a repeating dataset with constant values.
+
+ Args:
+ input_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of
+ the input data.
+ input_value: Value of each input element.
+ input_dtype: Input dtype. If None, will be inferred by the input value.
+ label_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of
+ the label data.
+ label_value: Value of each input element.
+ label_dtype: Input dtype. If None, will be inferred by the target value.
+
+ Returns:
+ Dataset of tensors or tuples of tensors (if label_shape is set).
+ """
+ # TODO(kathywu): Replace with SyntheticDataset once it is in contrib.
+ element = input_element = nest.map_structure(
+ lambda s: tf.constant(input_value, input_dtype, s), input_shape)
+
+ if label_shape:
+ label_element = nest.map_structure(
+ lambda s: tf.constant(label_value, label_dtype, s), label_shape)
+ element = (input_element, label_element)
+
+ return tf.data.Dataset.from_tensors(element).repeat()
+
+
+def apply_clean(flags_obj):
+ if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
+ logging.info("--clean flag set. Removing existing model dir:"
+ " {}".format(flags_obj.model_dir))
+ tf.io.gfile.rmtree(flags_obj.model_dir)
diff --git a/models/official/utils/misc/model_helpers_test.py b/models/official/utils/misc/model_helpers_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f2487e4223e7b46854db918114d2507fc891155
--- /dev/null
+++ b/models/official/utils/misc/model_helpers_test.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Model Helper functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf # pylint: disable=g-bad-import-order
+
+from official.utils.misc import model_helpers
+
+
+class PastStopThresholdTest(tf.test.TestCase):
+ """Tests for past_stop_threshold."""
+
+ def setUp(self):
+ super(PastStopThresholdTest, self).setUp()
+ tf.compat.v1.disable_eager_execution()
+
+ def test_past_stop_threshold(self):
+ """Tests for normal operating conditions."""
+ self.assertTrue(model_helpers.past_stop_threshold(0.54, 1))
+ self.assertTrue(model_helpers.past_stop_threshold(54, 100))
+ self.assertFalse(model_helpers.past_stop_threshold(0.54, 0.1))
+ self.assertFalse(model_helpers.past_stop_threshold(-0.54, -1.5))
+ self.assertTrue(model_helpers.past_stop_threshold(-0.54, 0))
+ self.assertTrue(model_helpers.past_stop_threshold(0, 0))
+ self.assertTrue(model_helpers.past_stop_threshold(0.54, 0.54))
+
+ def test_past_stop_threshold_none_false(self):
+ """Tests that check None returns false."""
+ self.assertFalse(model_helpers.past_stop_threshold(None, -1.5))
+ self.assertFalse(model_helpers.past_stop_threshold(None, None))
+ self.assertFalse(model_helpers.past_stop_threshold(None, 1.5))
+ # Zero should be okay, though.
+ self.assertTrue(model_helpers.past_stop_threshold(0, 1.5))
+
+ def test_past_stop_threshold_not_number(self):
+ """Tests for error conditions."""
+ with self.assertRaises(ValueError):
+ model_helpers.past_stop_threshold("str", 1)
+
+ with self.assertRaises(ValueError):
+ model_helpers.past_stop_threshold("str", tf.constant(5))
+
+ with self.assertRaises(ValueError):
+ model_helpers.past_stop_threshold("str", "another")
+
+ with self.assertRaises(ValueError):
+ model_helpers.past_stop_threshold(0, None)
+
+ with self.assertRaises(ValueError):
+ model_helpers.past_stop_threshold(0.7, "str")
+
+ with self.assertRaises(ValueError):
+ model_helpers.past_stop_threshold(tf.constant(4), None)
+
+
+class SyntheticDataTest(tf.test.TestCase):
+ """Tests for generate_synthetic_data."""
+
+ def test_generate_synethetic_data(self):
+ input_element, label_element = tf.compat.v1.data.make_one_shot_iterator(
+ model_helpers.generate_synthetic_data(input_shape=tf.TensorShape([5]),
+ input_value=123,
+ input_dtype=tf.float32,
+ label_shape=tf.TensorShape([]),
+ label_value=456,
+ label_dtype=tf.int32)).get_next()
+
+ with self.session() as sess:
+ for n in range(5):
+ inp, lab = sess.run((input_element, label_element))
+ self.assertAllClose(inp, [123., 123., 123., 123., 123.])
+ self.assertEquals(lab, 456)
+
+ def test_generate_only_input_data(self):
+ d = model_helpers.generate_synthetic_data(
+ input_shape=tf.TensorShape([4]),
+ input_value=43.5,
+ input_dtype=tf.float32)
+
+ element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
+ self.assertFalse(isinstance(element, tuple))
+
+ with self.session() as sess:
+ inp = sess.run(element)
+ self.assertAllClose(inp, [43.5, 43.5, 43.5, 43.5])
+
+ def test_generate_nested_data(self):
+ d = model_helpers.generate_synthetic_data(
+ input_shape={'a': tf.TensorShape([2]),
+ 'b': {'c': tf.TensorShape([3]), 'd': tf.TensorShape([])}},
+ input_value=1.1)
+
+ element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
+ self.assertIn('a', element)
+ self.assertIn('b', element)
+ self.assertEquals(len(element['b']), 2)
+ self.assertIn('c', element['b'])
+ self.assertIn('d', element['b'])
+ self.assertNotIn('c', element)
+
+ with self.session() as sess:
+ inp = sess.run(element)
+ self.assertAllClose(inp['a'], [1.1, 1.1])
+ self.assertAllClose(inp['b']['c'], [1.1, 1.1, 1.1])
+ self.assertAllClose(inp['b']['d'], 1.1)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/utils/misc/tpu_lib.py b/models/official/utils/misc/tpu_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d4cddb1c6b015091ed2da57df49277e3008c252
--- /dev/null
+++ b/models/official/utils/misc/tpu_lib.py
@@ -0,0 +1,34 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Initializes TPU system for TF 2.0."""
+
+import tensorflow as tf
+
+
+def tpu_initialize(tpu_address):
+ """Initializes TPU for TF 2.0 training.
+
+ Args:
+ tpu_address: string, bns address of master TPU worker.
+
+ Returns:
+ A TPUClusterResolver.
+ """
+ cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
+ tpu=tpu_address)
+ if tpu_address not in ('', 'local'):
+ tf.config.experimental_connect_to_cluster(cluster_resolver)
+ tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
+ return cluster_resolver
diff --git a/models/official/utils/registry.py b/models/official/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aff59813f11b1085860faac8c62ca8ce9e0a1f1
--- /dev/null
+++ b/models/official/utils/registry.py
@@ -0,0 +1,98 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Registry utility."""
+
+
+def register(registered_collection, reg_key):
+ """Register decorated function or class to collection.
+
+ Register decorated function or class into registered_collection, in a
+ hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
+ the decorated function or class is stored under
+ registered_collection["my_model"]["my_exp"]["my_config_0"].
+ This decorator is supposed to be used together with the lookup() function in
+ this file.
+
+ Args:
+ registered_collection: a dictionary. The decorated function or class will be
+ put into this collection.
+ reg_key: The key for retrieving the registered function or class. If reg_key
+ is a string, it can be hierarchical like my_model/my_exp/my_config_0
+ Returns:
+ A decorator function
+ Raises:
+ KeyError: when function or class to register already exists.
+ """
+ def decorator(fn_or_cls):
+ """Put fn_or_cls in the dictionary."""
+ if isinstance(reg_key, str):
+ hierarchy = reg_key.split("/")
+ collection = registered_collection
+ for h_idx, entry_name in enumerate(hierarchy[:-1]):
+ if entry_name not in collection:
+ collection[entry_name] = {}
+ collection = collection[entry_name]
+ if not isinstance(collection, dict):
+ raise KeyError(
+ "Collection path {} at position {} already registered as "
+ "a function or class.".format(entry_name, h_idx))
+ leaf_reg_key = hierarchy[-1]
+ else:
+ collection = registered_collection
+ leaf_reg_key = reg_key
+
+ if leaf_reg_key in collection:
+ raise KeyError("Function or class {} registered multiple times.".format(
+ leaf_reg_key))
+
+ collection[leaf_reg_key] = fn_or_cls
+ return fn_or_cls
+ return decorator
+
+
+def lookup(registered_collection, reg_key):
+ """Lookup and return decorated function or class in the collection.
+
+ Lookup decorated function or class in registered_collection, in a
+ hierarchical order. For example, when
+ reg_key="my_model/my_exp/my_config_0",
+ this function will return
+ registered_collection["my_model"]["my_exp"]["my_config_0"].
+
+ Args:
+ registered_collection: a dictionary. The decorated function or class will be
+ retrieved from this collection.
+ reg_key: The key for retrieving the registered function or class. If reg_key
+ is a string, it can be hierarchical like my_model/my_exp/my_config_0
+ Returns:
+ The registered function or class.
+ Raises:
+ LookupError: when reg_key cannot be found.
+ """
+ if isinstance(reg_key, str):
+ hierarchy = reg_key.split("/")
+ collection = registered_collection
+ for h_idx, entry_name in enumerate(hierarchy):
+ if entry_name not in collection:
+ raise LookupError(
+ "collection path {} at position {} never registered.".format(
+ entry_name, h_idx))
+ collection = collection[entry_name]
+ return collection
+ else:
+ if reg_key not in registered_collection:
+ raise LookupError("registration key {} never registered.".format(reg_key))
+ return registered_collection[reg_key]
diff --git a/models/official/utils/registry_test.py b/models/official/utils/registry_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cb230c75891aaebb8306bb84a235e2d2ecd70e5
--- /dev/null
+++ b/models/official/utils/registry_test.py
@@ -0,0 +1,85 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for registry."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from official.utils import registry
+
+
+class RegistryTest(tf.test.TestCase):
+
+ def test_register(self):
+ collection = {}
+
+ @registry.register(collection, 'functions/func_0')
+ def func_test():
+ pass
+ self.assertEqual(
+ registry.lookup(collection, 'functions/func_0'), func_test)
+
+ @registry.register(collection, 'classes/cls_0')
+ class ClassRegistryKey:
+ pass
+ self.assertEqual(
+ registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey)
+
+ @registry.register(collection, ClassRegistryKey)
+ class ClassRegistryValue:
+ pass
+ self.assertEqual(
+ registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
+
+ def test_register_hierarchy(self):
+ collection = {}
+
+ @registry.register(collection, 'functions/func_0')
+ def func_test0():
+ pass
+ @registry.register(collection, 'func_1')
+ def func_test1():
+ pass
+ @registry.register(collection, func_test1)
+ def func_test2():
+ pass
+ expected_collection = {
+ 'functions': {
+ 'func_0': func_test0,
+ },
+ 'func_1': func_test1,
+ func_test1: func_test2,
+ }
+ self.assertEqual(collection, expected_collection)
+
+ def test_register_error(self):
+ collection = {}
+
+ @registry.register(collection, 'functions/func_0')
+ def func_test0(): # pylint: disable=unused-variable
+ pass
+ with self.assertRaises(KeyError):
+ @registry.register(collection, 'functions/func_0/sub_func')
+ def func_test1(): # pylint: disable=unused-variable
+ pass
+ with self.assertRaises(LookupError):
+ registry.lookup(collection, 'non-exist')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/utils/testing/__init__.py b/models/official/utils/testing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/utils/testing/integration.py b/models/official/utils/testing/integration.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4809a4815cd76c637e2b319352a1d15ab89b87b
--- /dev/null
+++ b/models/official/utils/testing/integration.py
@@ -0,0 +1,71 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper code to run complete models from within python.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import sys
+import tempfile
+
+from absl import flags
+from absl.testing import flagsaver
+
+from official.utils.flags import core as flags_core
+
+
+@flagsaver.flagsaver
+def run_synthetic(main, tmp_root, extra_flags=None, synth=True, train_epochs=1,
+ epochs_between_evals=1):
+ """Performs a minimal run of a model.
+
+ This function is intended to test for syntax errors throughout a model. A
+ very limited run is performed using synthetic data.
+
+ Args:
+ main: The primary function used to exercise a code path. Generally this
+ function is ".main(argv)".
+ tmp_root: Root path for the temp directory created by the test class.
+ extra_flags: Additional flags passed by the caller of this function.
+ synth: Use synthetic data.
+ train_epochs: Value of the --train_epochs flag.
+ epochs_between_evals: Value of the --epochs_between_evals flag.
+ """
+
+ extra_flags = [] if extra_flags is None else extra_flags
+
+ model_dir = tempfile.mkdtemp(dir=tmp_root)
+
+ args = [sys.argv[0], "--model_dir", model_dir] + extra_flags
+
+ if synth:
+ args.append("--use_synthetic_data")
+
+ if train_epochs is not None:
+ args.extend(["--train_epochs", str(train_epochs)])
+
+ if epochs_between_evals is not None:
+ args.extend(["--epochs_between_evals", str(epochs_between_evals)])
+
+ try:
+ flags_core.parse_flags(argv=args)
+ main(flags.FLAGS)
+ finally:
+ if os.path.exists(model_dir):
+ shutil.rmtree(model_dir)
diff --git a/models/official/utils/testing/pylint.rcfile b/models/official/utils/testing/pylint.rcfile
new file mode 100644
index 0000000000000000000000000000000000000000..b872802a81187b63e82ead282dd38fad1d1b5ded
--- /dev/null
+++ b/models/official/utils/testing/pylint.rcfile
@@ -0,0 +1,168 @@
+[MESSAGES CONTROL]
+disable=R,W,bad-option-value,trailing-newlines,no-name-in-module
+
+[REPORTS]
+# Tells whether to display a full report or only the messages
+reports=no
+
+# Activate the evaluation score.
+score=no
+
+[BASIC]
+
+# Regular expression matching correct argument names
+argument-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression matching correct attribute names
+attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
+
+# Regular expression matching correct class attribute names
+class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
+
+# Regular expression matching correct class names
+class-rgx=^_?[A-Z][a-zA-Z0-9]*$
+
+# Regular expression matching correct constant names
+const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=10
+
+# Regular expression matching correct function names
+function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
+
+# Good variable names which should always be accepted, separated by a comma
+good-names=main,_
+
+# Regular expression matching correct inline iteration names
+inlinevar-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression matching correct method names
+method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*)|(setUp|tearDown))$
+
+# Regular expression matching correct module names
+module-rgx=^(_?[a-z][a-z0-9_]*)|__init__|PRESUBMIT|PRESUBMIT_unittest$
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=(__.*__|main|.*ArgParser)
+
+# Naming hint for variable names
+variable-name-hint=[a-z_][a-z0-9_]{2,30}$
+
+# Regular expression matching correct variable names
+variable-rgx=^[a-z][a-z0-9_]*$
+
+[TYPECHECK]
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis. It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=absl, absl.*, official, official.*, tensorflow, tensorflow.*, LazyLoader, google, google.cloud.*
+
+
+[CLASSES]
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,__new__,setUp
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,_fields,_replace,_source,_make
+
+# This is deprecated, because it is not used anymore.
+#ignore-iface-methods=
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls,class_
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[DESIGN]
+
+# Argument names that match this expression will be ignored. Default to name
+# with leading underscore
+ignored-argument-names=_.*
+
+# Maximum number of arguments for function / method
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of branch for function / method body
+max-branches=12
+
+# Maximum number of locals for function / method body
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body
+max-returns=6
+
+# Maximum number of statements in function / method body
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "Exception"
+overgeneral-exceptions=StandardError,Exception,BaseException
+
+
+[FORMAT]
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Maximum number of characters on a single line.
+max-line-length=80
+
+# Maximum number of lines in a module
+max-module-lines=99999
+
+# List of optional constructs for which whitespace checking is disabled
+no-space-check=
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=yes
+
+# Allow URLs and comment type annotations to exceed the max line length as neither can be easily
+# split across lines.
+ignore-long-lines=^\s*(?:(# )??$|# type:)
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid to define new builtins when possible.
+additional-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,_cb
+
+# A regular expression matching the name of dummy variables (i.e. expectedly
+# not used).
+dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
diff --git a/models/official/utils/testing/scripts/builds_common.sh b/models/official/utils/testing/scripts/builds_common.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3cf08bb510d2a8ba0b06b1d38ccd1294b159ce15
--- /dev/null
+++ b/models/official/utils/testing/scripts/builds_common.sh
@@ -0,0 +1,64 @@
+#!/usr/bin/env bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Common Bash functions used by build scripts
+
+COLOR_NC='\033[0m'
+COLOR_BOLD='\033[1m'
+COLOR_LIGHT_GRAY='\033[0;37m'
+COLOR_GREEN='\033[0;32m'
+COLOR_RED='\033[0;31m'
+
+die() {
+ # Print a message and exit with code 1.
+ #
+ # Usage: die
+ # e.g., die "Something bad happened."
+
+ echo $@
+ exit 1
+}
+
+num_cpus() {
+ # Get the number of CPUs
+ N_CPUS=$(grep -c ^processor /proc/cpuinfo)
+ if [[ -z ${N_CPUS} ]]; then
+ die "ERROR: Unable to determine the number of CPUs"
+ fi
+
+ echo ${N_CPUS}
+}
+
+# List files changed (i.e., added, or revised) from
+# the common ancestor of HEAD and the latest master branch.
+# Usage: get_changed_files_from_master_branch
+get_changed_files_from_master_branch() {
+ ANCESTOR=$(git merge-base HEAD master origin/master)
+ git diff ${ANCESTOR} --diff-filter=d --name-only "$@"
+}
+
+# List python files changed that still exist,
+# i.e., not removed.
+# Usage: get_py_files_to_check [--incremental]
+get_py_files_to_check() {
+ if [[ "$1" == "--incremental" ]]; then
+ get_changed_files_from_master_branch -- '*.py'
+ elif [[ -z "$1" ]]; then
+ find official/ -name '*.py'
+ else
+ die "Found unsupported args: $@ for get_py_files_to_check."
+ fi
+}
diff --git a/models/official/utils/testing/scripts/ci_sanity.sh b/models/official/utils/testing/scripts/ci_sanity.sh
new file mode 100644
index 0000000000000000000000000000000000000000..97d6bc290eff327f340088b960f910af2afa626b
--- /dev/null
+++ b/models/official/utils/testing/scripts/ci_sanity.sh
@@ -0,0 +1,132 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Sanity check script that runs tests and lint under local environment.
+# Make sure that tensorflow and pylint is installed.
+# usage: models >: ./official/utils/testing/scripts/ci_sanity.sh do_pylint --incremental
+set +x
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+source "${SCRIPT_DIR}/builds_common.sh"
+cd "$SCRIPT_DIR/../../../.."
+MODEL_ROOT="$(pwd)"
+
+export PYTHONPATH="$PYTHONPATH:${MODEL_ROOT}"
+
+# Run pylint
+do_pylint() {
+ # Usage: do_pylint [--incremental]
+ #
+ # Options:
+ # --incremental Performs check on only the python files changed in the
+ # last non-merge git commit.
+
+ # Use this list to whitelist pylint errors
+ ERROR_WHITELIST=""
+
+ echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
+
+ PYLINT_BIN="python3 -m pylint"
+
+ PYTHON_SRC_FILES=$(get_py_files_to_check $1)
+ if [[ -z ${PYTHON_SRC_FILES} ]]; then
+ echo "do_pylint found no Python files to check. Returning."
+ return 0
+ fi
+
+ PYLINTRC_FILE="official/utils/testing/pylint.rcfile"
+
+ if [[ ! -f "${PYLINTRC_FILE}" ]]; then
+ die "ERROR: Cannot find pylint rc file at ${PYLINTRC_FILE}"
+ fi
+
+ NUM_SRC_FILES=$(echo ${PYTHON_SRC_FILES} | wc -w)
+ NUM_CPUS=$(num_cpus)
+
+ echo "Running pylint on ${NUM_SRC_FILES} files with ${NUM_CPUS} "\
+ "parallel jobs..."
+ echo ""
+
+ PYLINT_START_TIME=$(date +'%s')
+ OUTPUT_FILE="$(mktemp)_pylint_output.log"
+ ERRORS_FILE="$(mktemp)_pylint_errors.log"
+ NONWL_ERRORS_FILE="$(mktemp)_pylint_nonwl_errors.log"
+
+ rm -rf ${OUTPUT_FILE}
+ rm -rf ${ERRORS_FILE}
+ rm -rf ${NONWL_ERRORS_FILE}
+ touch ${NONWL_ERRORS_FILE}
+
+ ${PYLINT_BIN} --rcfile="${PYLINTRC_FILE}" --output-format=parseable \
+ --jobs=${NUM_CPUS} ${PYTHON_SRC_FILES} > ${OUTPUT_FILE} 2>&1
+ PYLINT_END_TIME=$(date +'%s')
+
+ echo ""
+ echo "pylint took $((PYLINT_END_TIME - PYLINT_START_TIME)) s"
+ echo ""
+
+ # Report only what we care about
+ # Ref https://pylint.readthedocs.io/en/latest/technical_reference/features.html
+ # E: all errors
+ # W0311 bad-indentation
+ # W0312 mixed-indentation
+ # C0330 bad-continuation
+ # C0301 line-too-long
+ # C0326 bad-whitespace
+ # W0611 unused-import
+ # W0622 redefined-builtin
+ grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326|\[W0611|\[W0622)' ${OUTPUT_FILE} > ${ERRORS_FILE}
+
+ N_ERRORS=0
+ while read -r LINE; do
+ IS_WHITELISTED=0
+ for WL_REGEX in ${ERROR_WHITELIST}; do
+ if echo ${LINE} | grep -q "${WL_REGEX}"; then
+ echo "Found a whitelisted error:"
+ echo " ${LINE}"
+ IS_WHITELISTED=1
+ fi
+ done
+
+ if [[ ${IS_WHITELISTED} == "0" ]]; then
+ echo "${LINE}" >> ${NONWL_ERRORS_FILE}
+ echo "" >> ${NONWL_ERRORS_FILE}
+ ((N_ERRORS++))
+ fi
+ done <${ERRORS_FILE}
+
+ echo "Raw lint output file: ${OUTPUT_FILE}"
+
+ echo ""
+ if [[ ${N_ERRORS} != 0 ]]; then
+ echo "FAIL: Found ${N_ERRORS} non-whitelited pylint errors:"
+ cat "${NONWL_ERRORS_FILE}"
+ return 1
+ else
+ echo "PASS: No non-whitelisted pylint errors were found."
+ return 0
+ fi
+}
+
+test_result=0
+
+TESTS="$@"
+
+for t in "${TESTS}"; do
+ ${t} || test_result=$?
+done
+
+exit "${test_result}"
diff --git a/models/official/utils/testing/scripts/presubmit.sh b/models/official/utils/testing/scripts/presubmit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..954d96df7f8c5f95546fb642ce6f9597f935cb3c
--- /dev/null
+++ b/models/official/utils/testing/scripts/presubmit.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Presubmit script that runs tests and lint under local environment.
+# Make sure that tensorflow and pylint is installed.
+# usage: models >: ./official/utils/testing/scripts/presubmit.sh
+# usage: models >: ./official/utils/testing/scripts/presubmit.sh lint py2_test py3_test
+set +x
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "$SCRIPT_DIR/../../../.."
+MODEL_ROOT="$(pwd)"
+
+export PYTHONPATH="$PYTHONPATH:${MODEL_ROOT}"
+
+py_test() {
+ local PY_BINARY="$1"
+ local exit_code=0
+
+ echo "===========Running Python test============"
+
+ for test_file in `find official/ -name '*test.py' -print`
+ do
+ echo "####=======Testing ${test_file}=======####"
+ ${PY_BINARY} "${test_file}"
+ _exit_code=$?
+ if [[ $_exit_code != 0 ]]; then
+ exit_code=$_exit_code
+ echo "FAIL: ${test_file}"
+ fi
+ done
+
+ return "${exit_code}"
+}
+
+py2_test() {
+ local PY_BINARY=$(which python2)
+ py_test "$PY_BINARY"
+ return $?
+}
+
+py3_test() {
+ local PY_BINARY=$(which python3)
+ py_test "$PY_BINARY"
+ return $?
+}
+
+test_result=0
+
+if [ "$#" -eq 0 ]; then
+ TESTS="lint py2_test py3_test"
+else
+ TESTS="$@"
+fi
+
+for t in "${TESTS}"; do
+ ${t} || test_result=$?
+done
+
+exit "${test_result}"
diff --git a/models/official/vision/__init__.py b/models/official/vision/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/vision/detection/README.md b/models/official/vision/detection/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..53134ec553f8bb5bbd4d299a69f0e8fbb4176083
--- /dev/null
+++ b/models/official/vision/detection/README.md
@@ -0,0 +1,395 @@
+# Object Detection Models on TensorFlow 2
+
+**Note**: This repository is still under construction.
+More features and instructions will be added soon.
+
+## Prerequsite
+To get started, download the code from TensorFlow models GitHub repository or
+use the pre-installed Google Cloud VM.
+
+```bash
+git clone https://github.com/tensorflow/models.git
+```
+
+Next, make sure to use TensorFlow 2.1+ on Google Cloud. Also here are
+a few package you need to install to get started:
+
+```bash
+sudo apt-get install -y python-tk && \
+pip3 install -r ~/models/official/requirements.txt
+```
+
+## Train RetinaNet on TPU
+
+### Train a vanilla ResNet-50 based RetinaNet.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+RESNET_CHECKPOINT=""
+TRAIN_FILE_PATTERN=""
+EVAL_FILE_PATTERN=""
+VAL_JSON_FILE=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=tpu \
+ --tpu="${TPU_NAME?}" \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --params_override="{ type: retinanet, train: { checkpoint: { path: ${RESNET_CHECKPOINT?}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
+```
+
+The pre-trained ResNet-50 checkpoint can be downloaded [here](https://storage.cloud.google.com/cloud-tpu-checkpoints/model-garden-vision/detection/resnet50-2018-02-07.tar.gz).
+
+Note: The ResNet implementation under
+[detection/](https://github.com/tensorflow/models/tree/master/official/vision/detection)
+is currently different from the one under
+[classification/](https://github.com/tensorflow/models/tree/master/official/vision/image_classification),
+so the checkpoints are not compatible.
+We will unify the implementation soon.
+
+
+
+### Train a custom RetinaNet using the config file.
+
+First, create a YAML config file, e.g. *my_retinanet.yaml*. This file specifies
+the parameters to be overridden, which should at least include the following
+fields.
+
+```YAML
+# my_retinanet.yaml
+type: 'retinanet'
+train:
+ train_file_pattern:
+eval:
+ eval_file_pattern:
+ val_json_file:
+```
+
+Once the YAML config file is created, you can launch the training using the
+following command.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=tpu \
+ --tpu="${TPU_NAME?}" \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --config_file="my_retinanet.yaml"
+```
+
+## Train RetinaNet on GPU
+
+Training on GPU is similar to that on TPU. The major change is the strategy
+type (use "[mirrored](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)" for multiple GPU and
+"[one_device](https://www.tensorflow.org/api_docs/python/tf/distribute/OneDeviceStrategy)" for single GPU).
+
+Multi-GPUs example (assuming there are 8GPU connected to the host):
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=mirrored \
+ --num_gpus=8 \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --config_file="my_retinanet.yaml"
+```
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --config_file="my_retinanet.yaml"
+```
+
+An example with inline configuration (YAML or JSON format):
+
+```
+python3 ~/models/official/vision/detection/main.py \
+ --model_dir= \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --mode=train \
+ --params_override="eval:
+ eval_file_pattern:
+ batch_size: 8
+ val_json_file:
+predict:
+ predict_batch_size: 8
+architecture:
+ use_bfloat16: False
+train:
+ total_steps: 1
+ batch_size: 8
+ train_file_pattern:
+use_tpu: False
+"
+```
+
+---
+
+## Train Mask R-CNN on TPU
+
+### Train a vanilla ResNet-50 based Mask R-CNN.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+RESNET_CHECKPOINT=""
+TRAIN_FILE_PATTERN=""
+EVAL_FILE_PATTERN=""
+VAL_JSON_FILE=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=mask_rcnn \
+ --params_override="{train: { checkpoint: { path: ${RESNET_CHECKPOINT}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN} }, eval: { val_json_file: ${VAL_JSON_FILE}, eval_file_pattern: ${EVAL_FILE_PATTERN} } }"
+```
+
+The pre-trained ResNet-50 checkpoint can be downloaded [here](https://storage.cloud.google.com/cloud-tpu-checkpoints/model-garden-vision/detection/resnet50-2018-02-07.tar.gz).
+
+Note: The ResNet implementation under
+[detection/](https://github.com/tensorflow/models/tree/master/official/vision/detection)
+is currently different from the one under
+[classification/](https://github.com/tensorflow/models/tree/master/official/vision/image_classification),
+so the checkpoints are not compatible.
+We will unify the implementation soon.
+
+
+### Train a custom Mask R-CNN using the config file.
+
+First, create a YAML config file, e.g. *my_maskrcnn.yaml*.
+This file specifies the parameters to be overridden,
+which should at least include the following fields.
+
+```YAML
+# my_maskrcnn.yaml
+train:
+ train_file_pattern:
+eval:
+ eval_file_pattern:
+ val_json_file:
+```
+
+Once the YAML config file is created, you can launch the training using the
+following command.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=mask_rcnn \
+ --config_file="my_maskrcnn.yaml"
+```
+
+## Train Mask R-CNN on GPU
+
+Training on GPU is similar to that on TPU. The major change is the strategy type
+(use
+"[mirrored](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)"
+for multiple GPU and
+"[one_device](https://www.tensorflow.org/api_docs/python/tf/distribute/OneDeviceStrategy)"
+for single GPU).
+
+Multi-GPUs example (assuming there are 8GPU connected to the host):
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=mirrored \
+ --num_gpus=8 \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=mask_rcnn \
+ --config_file="my_maskrcnn.yaml"
+```
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=mask_rcnn \
+ --config_file="my_maskrcnn.yaml"
+```
+
+An example with inline configuration (YAML or JSON format):
+
+```
+python3 ~/models/official/vision/detection/main.py \
+ --model_dir= \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --mode=train \
+ --model=mask_rcnn \
+ --params_override="eval:
+ eval_file_pattern:
+ batch_size: 8
+ val_json_file:
+predict:
+ predict_batch_size: 8
+architecture:
+ use_bfloat16: False
+train:
+ total_steps: 1000
+ batch_size: 8
+ train_file_pattern:
+use_tpu: False
+"
+```
+
+## Train ShapeMask on TPU
+
+### Train a ResNet-50 based ShapeMask.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+RESNET_CHECKPOINT=""
+TRAIN_FILE_PATTERN=""
+EVAL_FILE_PATTERN=""
+VAL_JSON_FILE=""
+SHAPE_PRIOR_PATH=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=shapemask \
+ --params_override="{train: { checkpoint: { path: ${RESNET_CHECKPOINT}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN} }, eval: { val_json_file: ${VAL_JSON_FILE}, eval_file_pattern: ${EVAL_FILE_PATTERN} } shapemask_head: {use_category_for_mask: true, shape_prior_path: ${SHAPE_PRIOR_PATH}} }"
+```
+
+The pre-trained ResNet-50 checkpoint can be downloaded [here](https://storage.cloud.google.com/cloud-tpu-checkpoints/model-garden-vision/detection/resnet50-2018-02-07.tar.gz).
+
+The shape priors can be downloaded [here]
+(https://storage.googleapis.com/cloud-tpu-checkpoints/shapemask/kmeans_class_priors_91x20x32x32.npy)
+
+
+### Train a custom ShapeMask using the config file.
+
+First, create a YAML config file, e.g. *my_shapemask.yaml*.
+This file specifies the parameters to be overridden:
+
+```YAML
+# my_shapemask.yaml
+train:
+ train_file_pattern:
+ total_steps:
+ batch_size:
+eval:
+ eval_file_pattern:
+ val_json_file:
+ batch_size:
+shapemask_head:
+ shape_prior_path:
+```
+
+Once the YAML config file is created, you can launch the training using the
+following command.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=shapemask \
+ --config_file="my_shapemask.yaml"
+```
+
+## Train ShapeMask on GPU
+
+Training on GPU is similar to that on TPU. The major change is the strategy type
+(use
+"[mirrored](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)"
+for multiple GPU and
+"[one_device](https://www.tensorflow.org/api_docs/python/tf/distribute/OneDeviceStrategy)"
+for single GPU).
+
+Multi-GPUs example (assuming there are 8GPU connected to the host):
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=mirrored \
+ --num_gpus=8 \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=shapemask \
+ --config_file="my_shapemask.yaml"
+```
+
+A single GPU example
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/vision/detection/main.py \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=shapemask \
+ --config_file="my_shapemask.yaml"
+```
+
+
+An example with inline configuration (YAML or JSON format):
+
+```
+python3 ~/models/official/vision/detection/main.py \
+ --model_dir= \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --mode=train \
+ --model=shapemask \
+ --params_override="eval:
+ eval_file_pattern:
+ batch_size: 8
+ val_json_file:
+train:
+ total_steps: 1000
+ batch_size: 8
+ train_file_pattern:
+use_tpu: False
+"
+```
+
+
+### Run the evaluation (after training)
+
+```
+python3 /usr/share/models/official/vision/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=eval \
+ --model=shapemask \
+ --params_override="{eval: { val_json_file: ${VAL_JSON_FILE}, eval_file_pattern: ${EVAL_FILE_PATTERN}, eval_samples: 5000 } }"
+```
+
+`MODEL_DIR` needs to point to the trained path of ShapeMask model.
+Change `strategy_type=mirrored` and `num_gpus=1` to run on a GPU.
+
+Note: The JSON groundtruth file is useful for [COCO dataset](http://cocodataset.org/#home) and can be
+downloaded from the [COCO website](http://cocodataset.org/#download). For custom dataset, it is unncessary because the groundtruth can be included in the TFRecord files.
+
+## References
+
+1. [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002).
+ Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, and Piotr Dollár. IEEE
+ International Conference on Computer Vision (ICCV), 2017.
diff --git a/models/official/vision/detection/__init__.py b/models/official/vision/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/vision/detection/configs/__init__.py b/models/official/vision/detection/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/detection/configs/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/configs/base_config.py b/models/official/vision/detection/configs/base_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a4e2f5fbf001039a88bed6d834835348807719c
--- /dev/null
+++ b/models/official/vision/detection/configs/base_config.py
@@ -0,0 +1,135 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base config template."""
+
+
+BACKBONES = [
+ 'resnet',
+]
+
+MULTILEVEL_FEATURES = [
+ 'fpn',
+]
+
+# pylint: disable=line-too-long
+# For ResNet, this freezes the variables of the first conv1 and conv2_x
+# layers [1], which leads to higher training speed and slightly better testing
+# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
+# is able to capture low-level features such as edges; therefore, it does not
+# need to be fine-tuned for the detection task.
+# Note that we need to trailing `/` to avoid the incorrect match.
+# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
+RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
+REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
+
+BASE_CFG = {
+ 'model_dir': '',
+ 'use_tpu': True,
+ 'strategy_type': 'tpu',
+ 'isolate_session_state': False,
+ 'train': {
+ 'iterations_per_loop': 100,
+ 'batch_size': 64,
+ 'total_steps': 22500,
+ 'num_cores_per_replica': None,
+ 'input_partition_dims': None,
+ 'optimizer': {
+ 'type': 'momentum',
+ 'momentum': 0.9,
+ 'nesterov': True, # `False` is better for TPU v3-128.
+ },
+ 'learning_rate': {
+ 'type': 'step',
+ 'warmup_learning_rate': 0.0067,
+ 'warmup_steps': 500,
+ 'init_learning_rate': 0.08,
+ 'learning_rate_levels': [0.008, 0.0008],
+ 'learning_rate_steps': [15000, 20000],
+ },
+ 'checkpoint': {
+ 'path': '',
+ 'prefix': '',
+ },
+ # One can use 'RESNET_FROZEN_VAR_PREFIX' to speed up ResNet training
+ # when loading from the checkpoint.
+ 'frozen_variable_prefix': '',
+ 'train_file_pattern': '',
+ 'train_dataset_type': 'tfrecord',
+ # TODO(b/142174042): Support transpose_input option.
+ 'transpose_input': False,
+ 'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
+ 'l2_weight_decay': 0.0001,
+ 'gradient_clip_norm': 0.0,
+ 'input_sharding': False,
+ },
+ 'eval': {
+ 'input_sharding': True,
+ 'batch_size': 8,
+ 'eval_samples': 5000,
+ 'min_eval_interval': 180,
+ 'eval_timeout': None,
+ 'num_steps_per_eval': 1000,
+ 'type': 'box',
+ 'use_json_file': True,
+ 'val_json_file': '',
+ 'eval_file_pattern': '',
+ 'eval_dataset_type': 'tfrecord',
+ # When visualizing images, set evaluation batch size to 40 to avoid
+ # potential OOM.
+ 'num_images_to_visualize': 0,
+ },
+ 'predict': {
+ 'batch_size': 8,
+ },
+ 'architecture': {
+ 'backbone': 'resnet',
+ 'min_level': 3,
+ 'max_level': 7,
+ 'multilevel_features': 'fpn',
+ 'use_bfloat16': True,
+ # Note that `num_classes` is the total number of classes including
+ # one background classes whose index is 0.
+ 'num_classes': 91,
+ },
+ 'anchor': {
+ 'num_scales': 3,
+ 'aspect_ratios': [1.0, 2.0, 0.5],
+ 'anchor_size': 4.0,
+ },
+ 'norm_activation': {
+ 'activation': 'relu',
+ 'batch_norm_momentum': 0.997,
+ 'batch_norm_epsilon': 1e-4,
+ 'batch_norm_trainable': True,
+ 'use_sync_bn': False,
+ },
+ 'resnet': {
+ 'resnet_depth': 50,
+ },
+ 'fpn': {
+ 'fpn_feat_dims': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': True,
+ },
+ 'postprocess': {
+ 'use_batched_nms': False,
+ 'max_total_size': 100,
+ 'nms_iou_threshold': 0.5,
+ 'score_threshold': 0.05,
+ 'pre_nms_num_boxes': 5000,
+ },
+ 'enable_summary': False,
+}
+# pylint: enable=line-too-long
diff --git a/models/official/vision/detection/configs/factory.py b/models/official/vision/detection/configs/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60ea1e01133fdfffd76ad54daf4ee20ed1e46e0
--- /dev/null
+++ b/models/official/vision/detection/configs/factory.py
@@ -0,0 +1,37 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Factory to provide model configs."""
+
+from official.modeling.hyperparams import params_dict
+from official.vision.detection.configs import maskrcnn_config
+from official.vision.detection.configs import retinanet_config
+from official.vision.detection.configs import shapemask_config
+
+
+def config_generator(model):
+ """Model function generator."""
+ if model == 'retinanet':
+ default_config = retinanet_config.RETINANET_CFG
+ restrictions = retinanet_config.RETINANET_RESTRICTIONS
+ elif model == 'mask_rcnn':
+ default_config = maskrcnn_config.MASKRCNN_CFG
+ restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS
+ elif model == 'shapemask':
+ default_config = shapemask_config.SHAPEMASK_CFG
+ restrictions = shapemask_config.SHAPEMASK_RESTRICTIONS
+ else:
+ raise ValueError('Model %s is not supported.' % model)
+
+ return params_dict.ParamsDict(default_config, restrictions)
diff --git a/models/official/vision/detection/configs/maskrcnn_config.py b/models/official/vision/detection/configs/maskrcnn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..70c9b31448d3d83754c439c87ce9f0d0a04f88c9
--- /dev/null
+++ b/models/official/vision/detection/configs/maskrcnn_config.py
@@ -0,0 +1,116 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Config template to train Mask R-CNN."""
+
+from official.modeling.hyperparams import params_dict
+from official.vision.detection.configs import base_config
+
+
+# pylint: disable=line-too-long
+MASKRCNN_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
+MASKRCNN_CFG.override({
+ 'type': 'mask_rcnn',
+ 'eval': {
+ 'type': 'box_and_mask',
+ 'num_images_to_visualize': 0,
+ },
+ 'architecture': {
+ 'parser': 'maskrcnn_parser',
+ 'min_level': 2,
+ 'max_level': 6,
+ 'include_mask': True,
+ 'mask_target_size': 28,
+ },
+ 'maskrcnn_parser': {
+ 'output_size': [1024, 1024],
+ 'num_channels': 3,
+ 'rpn_match_threshold': 0.7,
+ 'rpn_unmatched_threshold': 0.3,
+ 'rpn_batch_size_per_im': 256,
+ 'rpn_fg_fraction': 0.5,
+ 'aug_rand_hflip': True,
+ 'aug_scale_min': 1.0,
+ 'aug_scale_max': 1.0,
+ 'skip_crowd_during_training': True,
+ 'max_num_instances': 100,
+ 'mask_crop_size': 112,
+ },
+ 'anchor': {
+ 'num_scales': 1,
+ 'anchor_size': 8,
+ },
+ 'rpn_head': {
+ 'anchors_per_location': 3,
+ 'num_convs': 2,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': False,
+ },
+ 'frcnn_head': {
+ 'num_convs': 0,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'num_fcs': 2,
+ 'fc_dims': 1024,
+ 'use_batch_norm': False,
+ },
+ 'mrcnn_head': {
+ 'num_convs': 4,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': False,
+ },
+ 'rpn_score_loss': {
+ 'rpn_batch_size_per_im': 256,
+ },
+ 'rpn_box_loss': {
+ 'huber_loss_delta': 1.0 / 9.0,
+ },
+ 'frcnn_box_loss': {
+ 'huber_loss_delta': 1.0,
+ },
+ 'roi_proposal': {
+ 'rpn_pre_nms_top_k': 2000,
+ 'rpn_post_nms_top_k': 1000,
+ 'rpn_nms_threshold': 0.7,
+ 'rpn_score_threshold': 0.0,
+ 'rpn_min_size_threshold': 0.0,
+ 'test_rpn_pre_nms_top_k': 1000,
+ 'test_rpn_post_nms_top_k': 1000,
+ 'test_rpn_nms_threshold': 0.7,
+ 'test_rpn_score_threshold': 0.0,
+ 'test_rpn_min_size_threshold': 0.0,
+ 'use_batched_nms': False,
+ },
+ 'roi_sampling': {
+ 'num_samples_per_image': 512,
+ 'fg_fraction': 0.25,
+ 'fg_iou_thresh': 0.5,
+ 'bg_iou_thresh_hi': 0.5,
+ 'bg_iou_thresh_lo': 0.0,
+ 'mix_gt_boxes': True,
+ },
+ 'mask_sampling': {
+ 'num_mask_samples_per_image': 128, # Typically = `num_samples_per_image` * `fg_fraction`.
+ },
+ 'postprocess': {
+ 'pre_nms_num_boxes': 1000,
+ },
+}, is_strict=False)
+
+
+MASKRCNN_RESTRICTIONS = [
+]
+# pylint: enable=line-too-long
diff --git a/models/official/vision/detection/configs/retinanet_config.py b/models/official/vision/detection/configs/retinanet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..579e30d083aacf138a2f9baffe1be7713ad21583
--- /dev/null
+++ b/models/official/vision/detection/configs/retinanet_config.py
@@ -0,0 +1,59 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Config template to train Retinanet."""
+
+from official.modeling.hyperparams import params_dict
+from official.vision.detection.configs import base_config
+
+
+# pylint: disable=line-too-long
+RETINANET_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
+RETINANET_CFG.override({
+ 'type': 'retinanet',
+ 'architecture': {
+ 'parser': 'retinanet_parser',
+ },
+ 'retinanet_parser': {
+ 'output_size': [640, 640],
+ 'num_channels': 3,
+ 'match_threshold': 0.5,
+ 'unmatched_threshold': 0.5,
+ 'aug_rand_hflip': True,
+ 'aug_scale_min': 1.0,
+ 'aug_scale_max': 1.0,
+ 'use_autoaugment': False,
+ 'autoaugment_policy_name': 'v0',
+ 'skip_crowd_during_training': True,
+ 'max_num_instances': 100,
+ },
+ 'retinanet_head': {
+ 'anchors_per_location': 9,
+ 'num_convs': 4,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ },
+ 'retinanet_loss': {
+ 'focal_loss_alpha': 0.25,
+ 'focal_loss_gamma': 1.5,
+ 'huber_loss_delta': 0.1,
+ 'box_loss_weight': 50,
+ },
+ 'enable_summary': True,
+}, is_strict=False)
+
+RETINANET_RESTRICTIONS = [
+]
+
+# pylint: enable=line-too-long
diff --git a/models/official/vision/detection/configs/shapemask_config.py b/models/official/vision/detection/configs/shapemask_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0914c492e15f65e5ba66701f27ca0d88d13698ff
--- /dev/null
+++ b/models/official/vision/detection/configs/shapemask_config.py
@@ -0,0 +1,98 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Config to train shapemask on COCO."""
+
+from official.modeling.hyperparams import params_dict
+from official.vision.detection.configs import base_config
+
+SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
+
+SHAPEMASK_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
+SHAPEMASK_CFG.override({
+ 'type': 'shapemask',
+ 'architecture': {
+ 'parser': 'shapemask_parser',
+ 'backbone': 'resnet',
+ 'multilevel_features': 'fpn',
+ 'outer_box_scale': 1.25,
+ },
+ 'train': {
+ 'total_steps': 45000,
+ 'learning_rate': {
+ 'learning_rate_steps': [30000, 40000],
+ },
+ 'frozen_variable_prefix': SHAPEMASK_RESNET_FROZEN_VAR_PREFIX,
+ 'regularization_variable_regex': None,
+ },
+ 'eval': {
+ 'type': 'shapemask_box_and_mask',
+ 'mask_eval_class': 'all', # 'all', 'voc', or 'nonvoc'.
+ },
+ 'shapemask_parser': {
+ 'output_size': [640, 640],
+ 'num_channels': 3,
+ 'match_threshold': 0.5,
+ 'unmatched_threshold': 0.5,
+ 'aug_rand_hflip': True,
+ 'aug_scale_min': 0.8,
+ 'aug_scale_max': 1.2,
+ 'skip_crowd_during_training': True,
+ 'max_num_instances': 100,
+ # Shapemask specific parameters
+ 'mask_train_class': 'all', # 'all', 'voc', or 'nonvoc'.
+ 'use_category': True,
+ 'outer_box_scale': 1.25,
+ 'num_sampled_masks': 8,
+ 'mask_crop_size': 32,
+ 'mask_min_level': 3,
+ 'mask_max_level': 5,
+ 'box_jitter_scale': 0.025,
+ 'upsample_factor': 4,
+ },
+ 'retinanet_head': {
+ 'anchors_per_location': 9,
+ 'num_convs': 4,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': True,
+ },
+ 'shapemask_head': {
+ 'num_downsample_channels': 128,
+ 'mask_crop_size': 32,
+ 'use_category_for_mask': True,
+ 'num_convs': 4,
+ 'upsample_factor': 4,
+ 'shape_prior_path': '',
+ },
+ 'retinanet_loss': {
+ 'focal_loss_alpha': 0.4,
+ 'focal_loss_gamma': 1.5,
+ 'huber_loss_delta': 0.15,
+ 'box_loss_weight': 50,
+ },
+ 'shapemask_loss': {
+ 'shape_prior_loss_weight': 0.1,
+ 'coarse_mask_loss_weight': 1.0,
+ 'fine_mask_loss_weight': 1.0,
+ },
+}, is_strict=False)
+
+SHAPEMASK_RESTRICTIONS = [
+ 'shapemask_head.mask_crop_size == shapemask_parser.mask_crop_size',
+ 'shapemask_head.upsample_factor == shapemask_parser.upsample_factor',
+ 'shapemask_parser.outer_box_scale == architecture.outer_box_scale',
+]
+
+# pylint: enable=line-too-long
diff --git a/models/official/vision/detection/dataloader/__init__.py b/models/official/vision/detection/dataloader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/detection/dataloader/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/dataloader/anchor.py b/models/official/vision/detection/dataloader/anchor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f46f7480062e75cec55d48ff683dcad8301e4994
--- /dev/null
+++ b/models/official/vision/detection/dataloader/anchor.py
@@ -0,0 +1,292 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Anchor box and labeler definition."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import tensorflow as tf
+from official.vision.detection.utils.object_detection import argmax_matcher
+from official.vision.detection.utils.object_detection import balanced_positive_negative_sampler
+from official.vision.detection.utils.object_detection import box_list
+from official.vision.detection.utils.object_detection import faster_rcnn_box_coder
+from official.vision.detection.utils.object_detection import region_similarity_calculator
+from official.vision.detection.utils.object_detection import target_assigner
+
+
+class Anchor(object):
+ """Anchor class for anchor-based object detectors."""
+
+ def __init__(self,
+ min_level,
+ max_level,
+ num_scales,
+ aspect_ratios,
+ anchor_size,
+ image_size):
+ """Constructs multiscale anchors.
+
+ Args:
+ min_level: integer number of minimum level of the output feature pyramid.
+ max_level: integer number of maximum level of the output feature pyramid.
+ num_scales: integer number representing intermediate scales added
+ on each level. For instances, num_scales=2 adds one additional
+ intermediate anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: list of float numbers representing the aspect raito anchors
+ added on each level. The number indicates the ratio of width to height.
+ For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each
+ scale level.
+ anchor_size: float number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ image_size: a list of integer numbers or Tensors representing
+ [height, width] of the input image size.The image_size should be divided
+ by the largest feature stride 2^max_level.
+ """
+ self.min_level = min_level
+ self.max_level = max_level
+ self.num_scales = num_scales
+ self.aspect_ratios = aspect_ratios
+ self.anchor_size = anchor_size
+ self.image_size = image_size
+ self.boxes = self._generate_boxes()
+
+ def _generate_boxes(self):
+ """Generates multiscale anchor boxes.
+
+ Returns:
+ a Tensor of shape [N, 4], represneting anchor boxes of all levels
+ concatenated together.
+ """
+ boxes_all = []
+ for level in range(self.min_level, self.max_level + 1):
+ boxes_l = []
+ for scale in range(self.num_scales):
+ for aspect_ratio in self.aspect_ratios:
+ stride = 2 ** level
+ intermidate_scale = 2 ** (scale / float(self.num_scales))
+ base_anchor_size = self.anchor_size * stride * intermidate_scale
+ aspect_x = aspect_ratio ** 0.5
+ aspect_y = aspect_ratio ** -0.5
+ half_anchor_size_x = base_anchor_size * aspect_x / 2.0
+ half_anchor_size_y = base_anchor_size * aspect_y / 2.0
+ x = tf.range(stride / 2, self.image_size[1], stride)
+ y = tf.range(stride / 2, self.image_size[0], stride)
+ xv, yv = tf.meshgrid(x, y)
+ xv = tf.cast(tf.reshape(xv, [-1]), dtype=tf.float32)
+ yv = tf.cast(tf.reshape(yv, [-1]), dtype=tf.float32)
+ # Tensor shape Nx4.
+ boxes = tf.stack([yv - half_anchor_size_y, xv - half_anchor_size_x,
+ yv + half_anchor_size_y, xv + half_anchor_size_x],
+ axis=1)
+ boxes_l.append(boxes)
+ # Concat anchors on the same level to tensor shape NxAx4.
+ boxes_l = tf.stack(boxes_l, axis=1)
+ boxes_l = tf.reshape(boxes_l, [-1, 4])
+ boxes_all.append(boxes_l)
+ return tf.concat(boxes_all, axis=0)
+
+ def unpack_labels(self, labels):
+ """Unpacks an array of labels into multiscales labels."""
+ unpacked_labels = collections.OrderedDict()
+ count = 0
+ for level in range(self.min_level, self.max_level + 1):
+ feat_size_y = tf.cast(self.image_size[0] / 2 ** level, tf.int32)
+ feat_size_x = tf.cast(self.image_size[1] / 2 ** level, tf.int32)
+ steps = feat_size_y * feat_size_x * self.anchors_per_location
+ unpacked_labels[level] = tf.reshape(
+ labels[count:count + steps], [feat_size_y, feat_size_x, -1])
+ count += steps
+ return unpacked_labels
+
+ @property
+ def anchors_per_location(self):
+ return self.num_scales * len(self.aspect_ratios)
+
+ @property
+ def multilevel_boxes(self):
+ return self.unpack_labels(self.boxes)
+
+
+class AnchorLabeler(object):
+ """Labeler for dense object detector."""
+
+ def __init__(self,
+ anchor,
+ match_threshold=0.5,
+ unmatched_threshold=0.5):
+ """Constructs anchor labeler to assign labels to anchors.
+
+ Args:
+ anchor: an instance of class Anchors.
+ match_threshold: a float number between 0 and 1 representing the
+ lower-bound threshold to assign positive labels for anchors. An anchor
+ with a score over the threshold is labeled positive.
+ unmatched_threshold: a float number between 0 and 1 representing the
+ upper-bound threshold to assign negative labels for anchors. An anchor
+ with a score below the threshold is labeled negative.
+ """
+ similarity_calc = region_similarity_calculator.IouSimilarity()
+ matcher = argmax_matcher.ArgMaxMatcher(
+ match_threshold,
+ unmatched_threshold=unmatched_threshold,
+ negatives_lower_than_unmatched=True,
+ force_match_for_each_row=True)
+ box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
+
+ self._target_assigner = target_assigner.TargetAssigner(
+ similarity_calc, matcher, box_coder)
+ self._anchor = anchor
+ self._match_threshold = match_threshold
+ self._unmatched_threshold = unmatched_threshold
+
+ def label_anchors(self, gt_boxes, gt_labels):
+ """Labels anchors with ground truth inputs.
+
+ Args:
+ gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
+ For each row, it stores [y0, x0, y1, x1] for four corners of a box.
+ gt_labels: A integer tensor with shape [N, 1] representing groundtruth
+ classes.
+ Returns:
+ cls_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ box_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors_per_location * 4]. The height_l
+ and width_l represent the dimension of bounding box regression output at
+ l-th level.
+ num_positives: scalar tensor storing number of positives in an image.
+ """
+ gt_box_list = box_list.BoxList(gt_boxes)
+ anchor_box_list = box_list.BoxList(self._anchor.boxes)
+
+ # The cls_weights, box_weights are not used.
+ cls_targets, _, box_targets, _, matches = self._target_assigner.assign(
+ anchor_box_list, gt_box_list, gt_labels)
+
+ # Labels definition in matches.match_results:
+ # (1) match_results[i]>=0, meaning that column i is matched with row
+ # match_results[i].
+ # (2) match_results[i]=-1, meaning that column i is not matched.
+ # (3) match_results[i]=-2, meaning that column i is ignored.
+ match_results = tf.expand_dims(matches.match_results, axis=1)
+ cls_targets = tf.cast(cls_targets, tf.int32)
+ cls_targets = tf.where(
+ tf.equal(match_results, -1), -tf.ones_like(cls_targets), cls_targets)
+ cls_targets = tf.where(
+ tf.equal(match_results, -2), -2 * tf.ones_like(cls_targets),
+ cls_targets)
+
+ # Unpacks labels into multi-level representations.
+ cls_targets_dict = self._anchor.unpack_labels(cls_targets)
+ box_targets_dict = self._anchor.unpack_labels(box_targets)
+ num_positives = tf.reduce_sum(
+ input_tensor=tf.cast(tf.greater(matches.match_results, -1), tf.float32))
+
+ return cls_targets_dict, box_targets_dict, num_positives
+
+
+class RpnAnchorLabeler(AnchorLabeler):
+ """Labeler for Region Proposal Network."""
+
+ def __init__(self, anchor, match_threshold=0.7,
+ unmatched_threshold=0.3, rpn_batch_size_per_im=256,
+ rpn_fg_fraction=0.5):
+ AnchorLabeler.__init__(self, anchor, match_threshold=0.7,
+ unmatched_threshold=0.3)
+ self._rpn_batch_size_per_im = rpn_batch_size_per_im
+ self._rpn_fg_fraction = rpn_fg_fraction
+
+ def _get_rpn_samples(self, match_results):
+ """Computes anchor labels.
+
+ This function performs subsampling for foreground (fg) and background (bg)
+ anchors.
+ Args:
+ match_results: A integer tensor with shape [N] representing the
+ matching results of anchors. (1) match_results[i]>=0,
+ meaning that column i is matched with row match_results[i].
+ (2) match_results[i]=-1, meaning that column i is not matched.
+ (3) match_results[i]=-2, meaning that column i is ignored.
+ Returns:
+ score_targets: a integer tensor with the a shape of [N].
+ (1) score_targets[i]=1, the anchor is a positive sample.
+ (2) score_targets[i]=0, negative. (3) score_targets[i]=-1, the anchor is
+ don't care (ignore).
+ """
+ sampler = (
+ balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
+ positive_fraction=self._rpn_fg_fraction, is_static=False))
+ # indicator includes both positive and negative labels.
+ # labels includes only positives labels.
+ # positives = indicator & labels.
+ # negatives = indicator & !labels.
+ # ignore = !indicator.
+ indicator = tf.greater(match_results, -2)
+ labels = tf.greater(match_results, -1)
+
+ samples = sampler.subsample(
+ indicator, self._rpn_batch_size_per_im, labels)
+ positive_labels = tf.where(
+ tf.logical_and(samples, labels),
+ tf.constant(2, dtype=tf.int32, shape=match_results.shape),
+ tf.constant(0, dtype=tf.int32, shape=match_results.shape))
+ negative_labels = tf.where(
+ tf.logical_and(samples, tf.logical_not(labels)),
+ tf.constant(1, dtype=tf.int32, shape=match_results.shape),
+ tf.constant(0, dtype=tf.int32, shape=match_results.shape))
+ ignore_labels = tf.fill(match_results.shape, -1)
+
+ return (ignore_labels + positive_labels + negative_labels,
+ positive_labels, negative_labels)
+
+ def label_anchors(self, gt_boxes, gt_labels):
+ """Labels anchors with ground truth inputs.
+
+ Args:
+ gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
+ For each row, it stores [y0, x0, y1, x1] for four corners of a box.
+ gt_labels: A integer tensor with shape [N, 1] representing groundtruth
+ classes.
+ Returns:
+ score_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors]. The height_l and width_l
+ represent the dimension of class logits at l-th level.
+ box_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ """
+ gt_box_list = box_list.BoxList(gt_boxes)
+ anchor_box_list = box_list.BoxList(self._anchor.boxes)
+
+ # cls_targets, cls_weights, box_weights are not used.
+ _, _, box_targets, _, matches = self._target_assigner.assign(
+ anchor_box_list, gt_box_list, gt_labels)
+
+ # score_targets contains the subsampled positive and negative anchors.
+ score_targets, _, _ = self._get_rpn_samples(matches.match_results)
+
+ # Unpacks labels.
+ score_targets_dict = self._anchor.unpack_labels(score_targets)
+ box_targets_dict = self._anchor.unpack_labels(box_targets)
+
+ return score_targets_dict, box_targets_dict
diff --git a/models/official/vision/detection/dataloader/factory.py b/models/official/vision/detection/dataloader/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e13aec222f529d97ee9c502d408648b9d091e5b
--- /dev/null
+++ b/models/official/vision/detection/dataloader/factory.py
@@ -0,0 +1,103 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model architecture factory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from official.vision.detection.dataloader import maskrcnn_parser
+from official.vision.detection.dataloader import retinanet_parser
+from official.vision.detection.dataloader import shapemask_parser
+
+
+def parser_generator(params, mode):
+ """Generator function for various dataset parser."""
+ if params.architecture.parser == 'retinanet_parser':
+ anchor_params = params.anchor
+ parser_params = params.retinanet_parser
+ parser_fn = retinanet_parser.Parser(
+ output_size=parser_params.output_size,
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ num_scales=anchor_params.num_scales,
+ aspect_ratios=anchor_params.aspect_ratios,
+ anchor_size=anchor_params.anchor_size,
+ match_threshold=parser_params.match_threshold,
+ unmatched_threshold=parser_params.unmatched_threshold,
+ aug_rand_hflip=parser_params.aug_rand_hflip,
+ aug_scale_min=parser_params.aug_scale_min,
+ aug_scale_max=parser_params.aug_scale_max,
+ use_autoaugment=parser_params.use_autoaugment,
+ autoaugment_policy_name=parser_params.autoaugment_policy_name,
+ skip_crowd_during_training=parser_params.skip_crowd_during_training,
+ max_num_instances=parser_params.max_num_instances,
+ use_bfloat16=params.architecture.use_bfloat16,
+ mode=mode)
+ elif params.architecture.parser == 'maskrcnn_parser':
+ anchor_params = params.anchor
+ parser_params = params.maskrcnn_parser
+ parser_fn = maskrcnn_parser.Parser(
+ output_size=parser_params.output_size,
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ num_scales=anchor_params.num_scales,
+ aspect_ratios=anchor_params.aspect_ratios,
+ anchor_size=anchor_params.anchor_size,
+ rpn_match_threshold=parser_params.rpn_match_threshold,
+ rpn_unmatched_threshold=parser_params.rpn_unmatched_threshold,
+ rpn_batch_size_per_im=parser_params.rpn_batch_size_per_im,
+ rpn_fg_fraction=parser_params.rpn_fg_fraction,
+ aug_rand_hflip=parser_params.aug_rand_hflip,
+ aug_scale_min=parser_params.aug_scale_min,
+ aug_scale_max=parser_params.aug_scale_max,
+ skip_crowd_during_training=parser_params.skip_crowd_during_training,
+ max_num_instances=parser_params.max_num_instances,
+ include_mask=params.architecture.include_mask,
+ mask_crop_size=parser_params.mask_crop_size,
+ use_bfloat16=params.architecture.use_bfloat16,
+ mode=mode)
+ elif params.architecture.parser == 'shapemask_parser':
+ anchor_params = params.anchor
+ parser_params = params.shapemask_parser
+ parser_fn = shapemask_parser.Parser(
+ output_size=parser_params.output_size,
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ num_scales=anchor_params.num_scales,
+ aspect_ratios=anchor_params.aspect_ratios,
+ anchor_size=anchor_params.anchor_size,
+ use_category=parser_params.use_category,
+ outer_box_scale=parser_params.outer_box_scale,
+ box_jitter_scale=parser_params.box_jitter_scale,
+ num_sampled_masks=parser_params.num_sampled_masks,
+ mask_crop_size=parser_params.mask_crop_size,
+ mask_min_level=parser_params.mask_min_level,
+ mask_max_level=parser_params.mask_max_level,
+ upsample_factor=parser_params.upsample_factor,
+ match_threshold=parser_params.match_threshold,
+ unmatched_threshold=parser_params.unmatched_threshold,
+ aug_rand_hflip=parser_params.aug_rand_hflip,
+ aug_scale_min=parser_params.aug_scale_min,
+ aug_scale_max=parser_params.aug_scale_max,
+ skip_crowd_during_training=parser_params.skip_crowd_during_training,
+ max_num_instances=parser_params.max_num_instances,
+ use_bfloat16=params.architecture.use_bfloat16,
+ mask_train_class=parser_params.mask_train_class,
+ mode=mode)
+ else:
+ raise ValueError('Parser %s is not supported.' % params.architecture.parser)
+
+ return parser_fn
diff --git a/models/official/vision/detection/dataloader/input_reader.py b/models/official/vision/detection/dataloader/input_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e65243f6863ccadb45704b3ed487aec3b8ab21a
--- /dev/null
+++ b/models/official/vision/detection/dataloader/input_reader.py
@@ -0,0 +1,107 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Data loader and input processing."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from typing import Text, Optional
+from official.modeling.hyperparams import params_dict
+from official.vision.detection.dataloader import factory
+from official.vision.detection.dataloader import mode_keys as ModeKeys
+
+
+class InputFn(object):
+ """Input function that creates dataset from files."""
+
+ def __init__(self,
+ file_pattern: Text,
+ params: params_dict.ParamsDict,
+ mode: Text,
+ batch_size: int,
+ num_examples: Optional[int] = -1):
+ """Initialize.
+
+ Args:
+ file_pattern: the file pattern for the data example (TFRecords).
+ params: the parameter object for constructing example parser and model.
+ mode: ModeKeys.TRAIN or ModeKeys.Eval
+ batch_size: the data batch size.
+ num_examples: If positive, only takes this number of examples and raise
+ tf.errors.OutOfRangeError after that. If non-positive, it will be
+ ignored.
+ """
+ assert file_pattern is not None
+ assert mode is not None
+ assert batch_size is not None
+ self._file_pattern = file_pattern
+ self._mode = mode
+ self._is_training = (mode == ModeKeys.TRAIN)
+ self._batch_size = batch_size
+ self._num_examples = num_examples
+ self._parser_fn = factory.parser_generator(params, mode)
+ self._dataset_fn = tf.data.TFRecordDataset
+
+ self._input_sharding = (not self._is_training)
+ try:
+ if self._is_training:
+ self._input_sharding = params.train.input_sharding
+ else:
+ self._input_sharding = params.eval.input_sharding
+ except AttributeError:
+ pass
+
+ def __call__(self, ctx=None, batch_size: int = None):
+ """Provides tf.data.Dataset object.
+
+ Args:
+ ctx: context object.
+ batch_size: expected batch size input data.
+
+ Returns:
+ tf.data.Dataset object.
+ """
+ if not batch_size:
+ batch_size = self._batch_size
+ assert batch_size is not None
+ dataset = tf.data.Dataset.list_files(
+ self._file_pattern, shuffle=self._is_training)
+
+ if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
+ dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
+ dataset = dataset.cache()
+
+ if self._is_training:
+ dataset = dataset.repeat()
+
+ dataset = dataset.interleave(
+ map_func=self._dataset_fn, cycle_length=32,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if self._is_training:
+ dataset = dataset.shuffle(1000)
+ if self._num_examples > 0:
+ dataset = dataset.take(self._num_examples)
+
+ # Parses the fetched records to input tensors for model function.
+ dataset = dataset.map(
+ self._parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
diff --git a/models/official/vision/detection/dataloader/maskrcnn_parser.py b/models/official/vision/detection/dataloader/maskrcnn_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..933e1b75c04ee04e4fbb60eaeb1ac9a48412a970
--- /dev/null
+++ b/models/official/vision/detection/dataloader/maskrcnn_parser.py
@@ -0,0 +1,385 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Data parser and processing for Mask R-CNN."""
+
+import tensorflow as tf
+
+from official.vision.detection.dataloader import anchor
+from official.vision.detection.dataloader import mode_keys as ModeKeys
+from official.vision.detection.dataloader import tf_example_decoder
+from official.vision.detection.utils import box_utils
+from official.vision.detection.utils import dataloader_utils
+from official.vision.detection.utils import input_utils
+
+
+class Parser(object):
+ """Parser to parse an image and its annotations into a dictionary of tensors."""
+
+ def __init__(self,
+ output_size,
+ min_level,
+ max_level,
+ num_scales,
+ aspect_ratios,
+ anchor_size,
+ rpn_match_threshold=0.7,
+ rpn_unmatched_threshold=0.3,
+ rpn_batch_size_per_im=256,
+ rpn_fg_fraction=0.5,
+ aug_rand_hflip=False,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ skip_crowd_during_training=True,
+ max_num_instances=100,
+ include_mask=False,
+ mask_crop_size=112,
+ use_bfloat16=True,
+ mode=None):
+ """Initializes parameters for parsing annotations in the dataset.
+
+ Args:
+ output_size: `Tensor` or `list` for [height, width] of output image. The
+ output_size should be divided by the largest feature stride 2^max_level.
+ min_level: `int` number of minimum level of the output feature pyramid.
+ max_level: `int` number of maximum level of the output feature pyramid.
+ num_scales: `int` number representing intermediate scales added
+ on each level. For instances, num_scales=2 adds one additional
+ intermediate anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: `list` of float numbers representing the aspect raito
+ anchors added on each level. The number indicates the ratio of width to
+ height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
+ on each scale level.
+ anchor_size: `float` number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ rpn_match_threshold:
+ rpn_unmatched_threshold:
+ rpn_batch_size_per_im:
+ rpn_fg_fraction:
+ aug_rand_hflip: `bool`, if True, augment training with random
+ horizontal flip.
+ aug_scale_min: `float`, the minimum scale applied to `output_size` for
+ data augmentation during training.
+ aug_scale_max: `float`, the maximum scale applied to `output_size` for
+ data augmentation during training.
+ skip_crowd_during_training: `bool`, if True, skip annotations labeled with
+ `is_crowd` equals to 1.
+ max_num_instances: `int` number of maximum number of instances in an
+ image. The groundtruth data will be padded to `max_num_instances`.
+ include_mask: a bool to indicate whether parse mask groundtruth.
+ mask_crop_size: the size which groundtruth mask is cropped to.
+ use_bfloat16: `bool`, if True, cast output image to tf.bfloat16.
+ mode: a ModeKeys. Specifies if this is training, evaluation, prediction
+ or prediction with groundtruths in the outputs.
+ """
+ self._mode = mode
+ self._max_num_instances = max_num_instances
+ self._skip_crowd_during_training = skip_crowd_during_training
+ self._is_training = (mode == ModeKeys.TRAIN)
+
+ self._example_decoder = tf_example_decoder.TfExampleDecoder(
+ include_mask=include_mask)
+
+ # Anchor.
+ self._output_size = output_size
+ self._min_level = min_level
+ self._max_level = max_level
+ self._num_scales = num_scales
+ self._aspect_ratios = aspect_ratios
+ self._anchor_size = anchor_size
+
+ # Target assigning.
+ self._rpn_match_threshold = rpn_match_threshold
+ self._rpn_unmatched_threshold = rpn_unmatched_threshold
+ self._rpn_batch_size_per_im = rpn_batch_size_per_im
+ self._rpn_fg_fraction = rpn_fg_fraction
+
+ # Data augmentation.
+ self._aug_rand_hflip = aug_rand_hflip
+ self._aug_scale_min = aug_scale_min
+ self._aug_scale_max = aug_scale_max
+
+ # Mask.
+ self._include_mask = include_mask
+ self._mask_crop_size = mask_crop_size
+
+ # Device.
+ self._use_bfloat16 = use_bfloat16
+
+ # Data is parsed depending on the model Modekey.
+ if mode == ModeKeys.TRAIN:
+ self._parse_fn = self._parse_train_data
+ elif mode == ModeKeys.EVAL:
+ self._parse_fn = self._parse_eval_data
+ elif mode == ModeKeys.PREDICT or mode == ModeKeys.PREDICT_WITH_GT:
+ self._parse_fn = self._parse_predict_data
+ else:
+ raise ValueError('mode is not defined.')
+
+ def __call__(self, value):
+ """Parses data to an image and associated training labels.
+
+ Args:
+ value: a string tensor holding a serialized tf.Example proto.
+
+ Returns:
+ image, labels: if mode == ModeKeys.TRAIN. see _parse_train_data.
+ {'images': image, 'labels': labels}: if mode == ModeKeys.PREDICT
+ or ModeKeys.PREDICT_WITH_GT.
+ """
+ with tf.name_scope('parser'):
+ data = self._example_decoder.decode(value)
+ return self._parse_fn(data)
+
+ def _parse_train_data(self, data):
+ """Parses data for training.
+
+ Args:
+ data: the decoded tensor dictionary from TfExampleDecoder.
+
+ Returns:
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ labels: a dictionary of tensors used for training. The following describes
+ {key: value} pairs in the dictionary.
+ image_info: a 2D `Tensor` that encodes the information of the image and
+ the applied preprocessing. It is in the format of
+ [[original_height, original_width], [scaled_height, scaled_width],
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each level.
+ rpn_score_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ rpn_box_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ gt_boxes: Groundtruth bounding box annotations. The box is represented
+ in [y1, x1, y2, x2] format. The coordinates are w.r.t the scaled
+ image that is fed to the network. The tennsor is padded with -1 to
+ the fixed dimension [self._max_num_instances, 4].
+ gt_classes: Groundtruth classes annotations. The tennsor is padded
+ with -1 to the fixed dimension [self._max_num_instances].
+ gt_masks: groundtrugh masks cropped by the bounding box and
+ resized to a fixed size determined by mask_crop_size.
+ """
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ if self._include_mask:
+ masks = data['groundtruth_instance_masks']
+
+ is_crowds = data['groundtruth_is_crowd']
+ # Skips annotations with `is_crowd` = True.
+ if self._skip_crowd_during_training and self._is_training:
+ num_groundtrtuhs = tf.shape(classes)[0]
+ with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
+ indices = tf.cond(
+ tf.greater(tf.size(is_crowds), 0),
+ lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
+ lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
+ classes = tf.gather(classes, indices)
+ boxes = tf.gather(boxes, indices)
+ if self._include_mask:
+ masks = tf.gather(masks, indices)
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Flips image randomly during training.
+ if self._aug_rand_hflip:
+ if self._include_mask:
+ image, boxes, masks = input_utils.random_horizontal_flip(
+ image, boxes, masks)
+ else:
+ image, boxes = input_utils.random_horizontal_flip(
+ image, boxes)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ # Now the coordinates of boxes are w.r.t. the original image.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(
+ self._output_size, 2 ** self._max_level),
+ aug_scale_min=self._aug_scale_min,
+ aug_scale_max=self._aug_scale_max)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # Resizes and crops boxes.
+ # Now the coordinates of boxes are w.r.t the scaled image.
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+ if self._include_mask:
+ masks = tf.gather(masks, indices)
+ # Transfer boxes to the original image space and do normalization.
+ cropped_boxes = boxes + tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
+ cropped_boxes /= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
+ cropped_boxes = box_utils.normalize_boxes(cropped_boxes, image_shape)
+ num_masks = tf.shape(masks)[0]
+ masks = tf.image.crop_and_resize(
+ tf.expand_dims(masks, axis=-1),
+ cropped_boxes,
+ box_indices=tf.range(num_masks, dtype=tf.int32),
+ crop_size=[self._mask_crop_size, self._mask_crop_size],
+ method='bilinear')
+ masks = tf.squeeze(masks, axis=-1)
+
+ # Assigns anchor targets.
+ # Note that after the target assignment, box targets are absolute pixel
+ # offsets w.r.t. the scaled image.
+ input_anchor = anchor.Anchor(
+ self._min_level,
+ self._max_level,
+ self._num_scales,
+ self._aspect_ratios,
+ self._anchor_size,
+ (image_height, image_width))
+ anchor_labeler = anchor.RpnAnchorLabeler(
+ input_anchor,
+ self._rpn_match_threshold,
+ self._rpn_unmatched_threshold,
+ self._rpn_batch_size_per_im,
+ self._rpn_fg_fraction)
+ rpn_score_targets, rpn_box_targets = anchor_labeler.label_anchors(
+ boxes, tf.cast(tf.expand_dims(classes, axis=-1), dtype=tf.float32))
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ }
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'image_info': image_info,
+ 'rpn_score_targets': rpn_score_targets,
+ 'rpn_box_targets': rpn_box_targets,
+ }
+ inputs['gt_boxes'] = input_utils.pad_to_fixed_size(boxes,
+ self._max_num_instances,
+ -1)
+ inputs['gt_classes'] = input_utils.pad_to_fixed_size(
+ classes, self._max_num_instances, -1)
+ if self._include_mask:
+ inputs['gt_masks'] = input_utils.pad_to_fixed_size(
+ masks, self._max_num_instances, -1)
+
+ return inputs, labels
+
+ def _parse_eval_data(self, data):
+ """Parses data for evaluation."""
+ raise NotImplementedError('Not implemented!')
+
+ def _parse_predict_data(self, data):
+ """Parses data for prediction.
+
+ Args:
+ data: the decoded tensor dictionary from TfExampleDecoder.
+
+ Returns:
+ A dictionary of {'images': image, 'labels': labels} where
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ labels: a dictionary of tensors used for training. The following
+ describes {key: value} pairs in the dictionary.
+ source_ids: Source image id. Default value -1 if the source id is
+ empty in the groundtruth annotation.
+ image_info: a 2D `Tensor` that encodes the information of the image
+ and the applied preprocessing. It is in the format of
+ [[original_height, original_width], [scaled_height, scaled_width],
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each
+ level.
+ """
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(
+ self._output_size, 2 ** self._max_level),
+ aug_scale_min=1.0,
+ aug_scale_max=1.0)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ # Compute Anchor boxes.
+ input_anchor = anchor.Anchor(
+ self._min_level,
+ self._max_level,
+ self._num_scales,
+ self._aspect_ratios,
+ self._anchor_size,
+ (image_height, image_width))
+
+ labels = {
+ 'image_info': image_info,
+ }
+
+ if self._mode == ModeKeys.PREDICT_WITH_GT:
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(
+ data['groundtruth_boxes'], image_shape)
+ groundtruths = {
+ 'source_id': data['source_id'],
+ 'height': data['height'],
+ 'width': data['width'],
+ 'num_detections': tf.shape(data['groundtruth_classes']),
+ 'boxes': boxes,
+ 'classes': data['groundtruth_classes'],
+ 'areas': data['groundtruth_area'],
+ 'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
+ }
+ groundtruths['source_id'] = dataloader_utils.process_source_id(
+ groundtruths['source_id'])
+ groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
+ groundtruths, self._max_num_instances)
+ # TODO(yeqing): Remove the `groundtrtuh` layer key (no longer needed).
+ labels['groundtruths'] = groundtruths
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ }
+
+ return inputs, labels
diff --git a/models/official/vision/detection/dataloader/mode_keys.py b/models/official/vision/detection/dataloader/mode_keys.py
new file mode 100644
index 0000000000000000000000000000000000000000..020382b2486ca25a41f0c3eb88b1f2038c538e7e
--- /dev/null
+++ b/models/official/vision/detection/dataloader/mode_keys.py
@@ -0,0 +1,33 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Standard names for input dataloader modes.
+
+The following standard keys are defined:
+
+* `TRAIN`: training mode.
+* `EVAL`: evaluation mode.
+* `PREDICT`: prediction mode.
+* `PREDICT_WITH_GT`: prediction mode with groundtruths in returned variables.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+TRAIN = 'train'
+EVAL = 'eval'
+PREDICT = 'predict'
+PREDICT_WITH_GT = 'predict_with_gt'
diff --git a/models/official/vision/detection/dataloader/retinanet_parser.py b/models/official/vision/detection/dataloader/retinanet_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..d226a6da7e2fc2e650ad6ecdfb5a431d13df97a3
--- /dev/null
+++ b/models/official/vision/detection/dataloader/retinanet_parser.py
@@ -0,0 +1,422 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Data parser and processing.
+
+Parse image and ground truths in a dataset to training targets and package them
+into (image, labels) tuple for RetinaNet.
+
+T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar
+Focal Loss for Dense Object Detection. arXiv:1708.02002
+"""
+
+import tensorflow as tf
+
+from official.vision.detection.dataloader import anchor
+from official.vision.detection.dataloader import mode_keys as ModeKeys
+from official.vision.detection.dataloader import tf_example_decoder
+from official.vision.detection.utils import box_utils
+from official.vision.detection.utils import input_utils
+
+
+def process_source_id(source_id):
+ """Processes source_id to the right format."""
+ if source_id.dtype == tf.string:
+ source_id = tf.cast(tf.strings.to_number(source_id), tf.int32)
+ with tf.control_dependencies([source_id]):
+ source_id = tf.cond(
+ pred=tf.equal(tf.size(input=source_id), 0),
+ true_fn=lambda: tf.cast(tf.constant(-1), tf.int32),
+ false_fn=lambda: tf.identity(source_id))
+ return source_id
+
+
+def pad_groundtruths_to_fixed_size(gt, n):
+ """Pads the first dimension of groundtruths labels to the fixed size."""
+ gt['boxes'] = input_utils.pad_to_fixed_size(gt['boxes'], n, -1)
+ gt['is_crowds'] = input_utils.pad_to_fixed_size(gt['is_crowds'], n, 0)
+ gt['areas'] = input_utils.pad_to_fixed_size(gt['areas'], n, -1)
+ gt['classes'] = input_utils.pad_to_fixed_size(gt['classes'], n, -1)
+ return gt
+
+
+class Parser(object):
+ """Parser to parse an image and its annotations into a dictionary of tensors."""
+
+ def __init__(self,
+ output_size,
+ min_level,
+ max_level,
+ num_scales,
+ aspect_ratios,
+ anchor_size,
+ match_threshold=0.5,
+ unmatched_threshold=0.5,
+ aug_rand_hflip=False,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ use_autoaugment=False,
+ autoaugment_policy_name='v0',
+ skip_crowd_during_training=True,
+ max_num_instances=100,
+ use_bfloat16=True,
+ mode=None):
+ """Initializes parameters for parsing annotations in the dataset.
+
+ Args:
+ output_size: `Tensor` or `list` for [height, width] of output image. The
+ output_size should be divided by the largest feature stride 2^max_level.
+ min_level: `int` number of minimum level of the output feature pyramid.
+ max_level: `int` number of maximum level of the output feature pyramid.
+ num_scales: `int` number representing intermediate scales added
+ on each level. For instances, num_scales=2 adds one additional
+ intermediate anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: `list` of float numbers representing the aspect raito
+ anchors added on each level. The number indicates the ratio of width to
+ height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
+ on each scale level.
+ anchor_size: `float` number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ match_threshold: `float` number between 0 and 1 representing the
+ lower-bound threshold to assign positive labels for anchors. An anchor
+ with a score over the threshold is labeled positive.
+ unmatched_threshold: `float` number between 0 and 1 representing the
+ upper-bound threshold to assign negative labels for anchors. An anchor
+ with a score below the threshold is labeled negative.
+ aug_rand_hflip: `bool`, if True, augment training with random
+ horizontal flip.
+ aug_scale_min: `float`, the minimum scale applied to `output_size` for
+ data augmentation during training.
+ aug_scale_max: `float`, the maximum scale applied to `output_size` for
+ data augmentation during training.
+ use_autoaugment: `bool`, if True, use the AutoAugment augmentation policy
+ during training.
+ autoaugment_policy_name: `string` that specifies the name of the
+ AutoAugment policy that will be used during training.
+ skip_crowd_during_training: `bool`, if True, skip annotations labeled with
+ `is_crowd` equals to 1.
+ max_num_instances: `int` number of maximum number of instances in an
+ image. The groundtruth data will be padded to `max_num_instances`.
+ use_bfloat16: `bool`, if True, cast output image to tf.bfloat16.
+ mode: a ModeKeys. Specifies if this is training, evaluation, prediction
+ or prediction with groundtruths in the outputs.
+ """
+ self._mode = mode
+ self._max_num_instances = max_num_instances
+ self._skip_crowd_during_training = skip_crowd_during_training
+ self._is_training = (mode == ModeKeys.TRAIN)
+
+ self._example_decoder = tf_example_decoder.TfExampleDecoder(
+ include_mask=False)
+
+ # Anchor.
+ self._output_size = output_size
+ self._min_level = min_level
+ self._max_level = max_level
+ self._num_scales = num_scales
+ self._aspect_ratios = aspect_ratios
+ self._anchor_size = anchor_size
+ self._match_threshold = match_threshold
+ self._unmatched_threshold = unmatched_threshold
+
+ # Data augmentation.
+ self._aug_rand_hflip = aug_rand_hflip
+ self._aug_scale_min = aug_scale_min
+ self._aug_scale_max = aug_scale_max
+
+ # Data Augmentation with AutoAugment.
+ self._use_autoaugment = use_autoaugment
+ self._autoaugment_policy_name = autoaugment_policy_name
+
+ # Device.
+ self._use_bfloat16 = use_bfloat16
+
+ # Data is parsed depending on the model Modekey.
+ if mode == ModeKeys.TRAIN:
+ self._parse_fn = self._parse_train_data
+ elif mode == ModeKeys.EVAL:
+ self._parse_fn = self._parse_eval_data
+ elif mode == ModeKeys.PREDICT or mode == ModeKeys.PREDICT_WITH_GT:
+ self._parse_fn = self._parse_predict_data
+ else:
+ raise ValueError('mode is not defined.')
+
+ def __call__(self, value):
+ """Parses data to an image and associated training labels.
+
+ Args:
+ value: a string tensor holding a serialized tf.Example proto.
+
+ Returns:
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ labels:
+ cls_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ box_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ num_positives: number of positive anchors in the image.
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each level.
+ image_info: a 2D `Tensor` that encodes the information of the image and
+ the applied preprocessing. It is in the format of
+ [[original_height, original_width], [scaled_height, scaled_width],
+ [y_scale, x_scale], [y_offset, x_offset]].
+ groundtruths:
+ source_id: source image id. Default value -1 if the source id is empty
+ in the groundtruth annotation.
+ boxes: groundtruth bounding box annotations. The box is represented in
+ [y1, x1, y2, x2] format. The tennsor is padded with -1 to the fixed
+ dimension [self._max_num_instances, 4].
+ classes: groundtruth classes annotations. The tennsor is padded with
+ -1 to the fixed dimension [self._max_num_instances].
+ areas: groundtruth areas annotations. The tennsor is padded with -1
+ to the fixed dimension [self._max_num_instances].
+ is_crowds: groundtruth annotations to indicate if an annotation
+ represents a group of instances by value {0, 1}. The tennsor is
+ padded with 0 to the fixed dimension [self._max_num_instances].
+ """
+ with tf.name_scope('parser'):
+ data = self._example_decoder.decode(value)
+ return self._parse_fn(data)
+
+ def _parse_train_data(self, data):
+ """Parses data for training and evaluation."""
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ is_crowds = data['groundtruth_is_crowd']
+ # Skips annotations with `is_crowd` = True.
+ if self._skip_crowd_during_training and self._is_training:
+ num_groundtrtuhs = tf.shape(input=classes)[0]
+ with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
+ indices = tf.cond(
+ pred=tf.greater(tf.size(input=is_crowds), 0),
+ true_fn=lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
+ false_fn=lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
+ classes = tf.gather(classes, indices)
+ boxes = tf.gather(boxes, indices)
+
+ # Gets original image and its size.
+ image = data['image']
+
+ image_shape = tf.shape(input=image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Flips image randomly during training.
+ if self._aug_rand_hflip:
+ image, boxes = input_utils.random_horizontal_flip(image, boxes)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(
+ self._output_size, 2 ** self._max_level),
+ aug_scale_min=self._aug_scale_min,
+ aug_scale_max=self._aug_scale_max)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # Resizes and crops boxes.
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+
+ # Assigns anchors.
+ input_anchor = anchor.Anchor(
+ self._min_level, self._max_level, self._num_scales,
+ self._aspect_ratios, self._anchor_size, (image_height, image_width))
+ anchor_labeler = anchor.AnchorLabeler(
+ input_anchor, self._match_threshold, self._unmatched_threshold)
+ (cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
+ boxes,
+ tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'cls_targets': cls_targets,
+ 'box_targets': box_targets,
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'num_positives': num_positives,
+ 'image_info': image_info,
+ }
+ return image, labels
+
+ def _parse_eval_data(self, data):
+ """Parses data for training and evaluation."""
+ groundtruths = {}
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(input=image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(
+ self._output_size, 2 ** self._max_level),
+ aug_scale_min=1.0,
+ aug_scale_max=1.0)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # Resizes and crops boxes.
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+
+ # Assigns anchors.
+ input_anchor = anchor.Anchor(
+ self._min_level, self._max_level, self._num_scales,
+ self._aspect_ratios, self._anchor_size, (image_height, image_width))
+ anchor_labeler = anchor.AnchorLabeler(
+ input_anchor, self._match_threshold, self._unmatched_threshold)
+ (cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
+ boxes,
+ tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ # Sets up groundtruth data for evaluation.
+ groundtruths = {
+ 'source_id': data['source_id'],
+ 'num_groundtrtuhs': tf.shape(data['groundtruth_classes']),
+ 'image_info': image_info,
+ 'boxes': box_utils.denormalize_boxes(
+ data['groundtruth_boxes'], image_shape),
+ 'classes': data['groundtruth_classes'],
+ 'areas': data['groundtruth_area'],
+ 'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
+ }
+ groundtruths['source_id'] = process_source_id(groundtruths['source_id'])
+ groundtruths = pad_groundtruths_to_fixed_size(
+ groundtruths, self._max_num_instances)
+
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'cls_targets': cls_targets,
+ 'box_targets': box_targets,
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'num_positives': num_positives,
+ 'image_info': image_info,
+ 'groundtruths': groundtruths,
+ }
+ return image, labels
+
+ def _parse_predict_data(self, data):
+ """Parses data for prediction."""
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(input=image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(
+ self._output_size, 2 ** self._max_level),
+ aug_scale_min=1.0,
+ aug_scale_max=1.0)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ # Compute Anchor boxes.
+ input_anchor = anchor.Anchor(
+ self._min_level, self._max_level, self._num_scales,
+ self._aspect_ratios, self._anchor_size, (image_height, image_width))
+
+ labels = {
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'image_info': image_info,
+ }
+ # If mode is PREDICT_WITH_GT, returns groundtruths and training targets
+ # in labels.
+ if self._mode == ModeKeys.PREDICT_WITH_GT:
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(
+ data['groundtruth_boxes'], image_shape)
+ groundtruths = {
+ 'source_id': data['source_id'],
+ 'num_detections': tf.shape(data['groundtruth_classes']),
+ 'boxes': boxes,
+ 'classes': data['groundtruth_classes'],
+ 'areas': data['groundtruth_area'],
+ 'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
+ }
+ groundtruths['source_id'] = process_source_id(groundtruths['source_id'])
+ groundtruths = pad_groundtruths_to_fixed_size(
+ groundtruths, self._max_num_instances)
+ labels['groundtruths'] = groundtruths
+
+ # Computes training objective for evaluation loss.
+ classes = data['groundtruth_classes']
+
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+
+ # Assigns anchors.
+ anchor_labeler = anchor.AnchorLabeler(
+ input_anchor, self._match_threshold, self._unmatched_threshold)
+ (cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
+ boxes,
+ tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+ labels['cls_targets'] = cls_targets
+ labels['box_targets'] = box_targets
+ labels['num_positives'] = num_positives
+ return image, labels
diff --git a/models/official/vision/detection/dataloader/shapemask_parser.py b/models/official/vision/detection/dataloader/shapemask_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bc368c0ef290291405157b772ed523f3725e0a3
--- /dev/null
+++ b/models/official/vision/detection/dataloader/shapemask_parser.py
@@ -0,0 +1,522 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Data parser and processing.
+
+Parse image and ground truths in a dataset to training targets and package them
+into (image, labels) tuple for ShapeMask.
+
+Weicheng Kuo, Anelia Angelova, Jitendra Malik, Tsung-Yi Lin
+ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors.
+arXiv:1904.03239.
+"""
+
+import tensorflow as tf
+
+from official.vision.detection.dataloader import anchor
+from official.vision.detection.dataloader import mode_keys as ModeKeys
+from official.vision.detection.dataloader import tf_example_decoder
+from official.vision.detection.utils import box_utils
+from official.vision.detection.utils import class_utils
+from official.vision.detection.utils import dataloader_utils
+from official.vision.detection.utils import input_utils
+
+
+def pad_to_size(input_tensor, size):
+ """Pads data with zeros to a given length at the first dimension if needed.
+
+ Args:
+ input_tensor: `Tensor` with any dimension.
+ size: `int` number for the first dimension of output Tensor.
+
+ Returns:
+ `Tensor` with the first dimension padded to `size` if the first diemsion
+ is less than `size`, otherwise no padding.
+ """
+ input_shape = tf.shape(input_tensor)
+ padding_shape = []
+
+ # Computes the padding length on the first dimension.
+ padding_length = tf.maximum(0, size - tf.shape(input_tensor)[0])
+ assert_length = tf.Assert(
+ tf.greater_equal(padding_length, 0), [padding_length])
+ with tf.control_dependencies([assert_length]):
+ padding_shape.append(padding_length)
+
+ # Copies shapes of the rest of input shape dimensions.
+ for i in range(1, len(input_shape)):
+ padding_shape.append(tf.shape(input=input_tensor)[i])
+
+ # Pads input tensor to the fixed first dimension.
+ paddings = tf.cast(tf.zeros(padding_shape), input_tensor.dtype)
+ padded_tensor = tf.concat([input_tensor, paddings], axis=0)
+ return padded_tensor
+
+
+class Parser(object):
+ """ShapeMask Parser to parse an image and its annotations into a dictionary of tensors."""
+
+ def __init__(self,
+ output_size,
+ min_level,
+ max_level,
+ num_scales,
+ aspect_ratios,
+ anchor_size,
+ use_category=True,
+ outer_box_scale=1.0,
+ box_jitter_scale=0.025,
+ num_sampled_masks=8,
+ mask_crop_size=32,
+ mask_min_level=3,
+ mask_max_level=5,
+ upsample_factor=4,
+ match_threshold=0.5,
+ unmatched_threshold=0.5,
+ aug_rand_hflip=False,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ skip_crowd_during_training=True,
+ max_num_instances=100,
+ use_bfloat16=True,
+ mask_train_class='all',
+ mode=None):
+ """Initializes parameters for parsing annotations in the dataset.
+
+ Args:
+ output_size: `Tensor` or `list` for [height, width] of output image. The
+ output_size should be divided by the largest feature stride 2^max_level.
+ min_level: `int` number of minimum level of the output feature pyramid.
+ max_level: `int` number of maximum level of the output feature pyramid.
+ num_scales: `int` number representing intermediate scales added
+ on each level. For instances, num_scales=2 adds one additional
+ intermediate anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: `list` of float numbers representing the aspect raito
+ anchors added on each level. The number indicates the ratio of width to
+ height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
+ on each scale level.
+ anchor_size: `float` number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ use_category: if `False`, treat all object in all classes in one
+ foreground category.
+ outer_box_scale: `float` number in a range of [1.0, inf) representing
+ the scale from object box to outer box. The mask branch predicts
+ instance mask enclosed in outer box.
+ box_jitter_scale: `float` number representing the noise magnitude to
+ jitter the training groundtruth boxes for mask branch.
+ num_sampled_masks: `int` number of sampled masks for training.
+ mask_crop_size: `list` for [height, width] of output training masks.
+ mask_min_level: `int` number indicating the minimum feature level to
+ obtain instance features.
+ mask_max_level: `int` number indicating the maximum feature level to
+ obtain instance features.
+ upsample_factor: `int` factor of upsampling the fine mask predictions.
+ match_threshold: `float` number between 0 and 1 representing the
+ lower-bound threshold to assign positive labels for anchors. An anchor
+ with a score over the threshold is labeled positive.
+ unmatched_threshold: `float` number between 0 and 1 representing the
+ upper-bound threshold to assign negative labels for anchors. An anchor
+ with a score below the threshold is labeled negative.
+ aug_rand_hflip: `bool`, if True, augment training with random
+ horizontal flip.
+ aug_scale_min: `float`, the minimum scale applied to `output_size` for
+ data augmentation during training.
+ aug_scale_max: `float`, the maximum scale applied to `output_size` for
+ data augmentation during training.
+ skip_crowd_during_training: `bool`, if True, skip annotations labeled with
+ `is_crowd` equals to 1.
+ max_num_instances: `int` number of maximum number of instances in an
+ image. The groundtruth data will be padded to `max_num_instances`.
+ use_bfloat16: `bool`, if True, cast output image to tf.bfloat16.
+ mask_train_class: a string of experiment mode: `all`, `voc` or `nonvoc`.
+ mode: a ModeKeys. Specifies if this is training, evaluation, prediction
+ or prediction with groundtruths in the outputs.
+ """
+ self._mode = mode
+ self._mask_train_class = mask_train_class
+ self._max_num_instances = max_num_instances
+ self._skip_crowd_during_training = skip_crowd_during_training
+ self._is_training = (mode == ModeKeys.TRAIN)
+
+ self._example_decoder = tf_example_decoder.TfExampleDecoder(
+ include_mask=True)
+
+ # Anchor.
+ self._output_size = output_size
+ self._min_level = min_level
+ self._max_level = max_level
+ self._num_scales = num_scales
+ self._aspect_ratios = aspect_ratios
+ self._anchor_size = anchor_size
+ self._match_threshold = match_threshold
+ self._unmatched_threshold = unmatched_threshold
+
+ # Data augmentation.
+ self._aug_rand_hflip = aug_rand_hflip
+ self._aug_scale_min = aug_scale_min
+ self._aug_scale_max = aug_scale_max
+
+ # Device.
+ self._use_bfloat16 = use_bfloat16
+
+ # ShapeMask specific.
+ # Control of which category to use.
+ self._use_category = use_category
+ self._num_sampled_masks = num_sampled_masks
+ self._mask_crop_size = mask_crop_size
+ self._mask_min_level = mask_min_level
+ self._mask_max_level = mask_max_level
+ self._outer_box_scale = outer_box_scale
+ self._box_jitter_scale = box_jitter_scale
+ self._up_sample_factor = upsample_factor
+
+ # Data is parsed depending on the model Modekey.
+ if mode == ModeKeys.TRAIN:
+ self._parse_fn = self._parse_train_data
+ elif mode == ModeKeys.EVAL:
+ self._parse_fn = self._parse_eval_data
+ elif mode == ModeKeys.PREDICT or mode == ModeKeys.PREDICT_WITH_GT:
+ self._parse_fn = self._parse_predict_data
+ else:
+ raise ValueError('mode is not defined.')
+
+ def __call__(self, value):
+ """Parses data to an image and associated training labels.
+
+ Args:
+ value: a string tensor holding a serialized tf.Example proto.
+
+ Returns:
+ inputs:
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ mask_boxes: sampled boxes that tightly enclose the training masks. The
+ box is represented in [y1, x1, y2, x2] format. The tensor is sampled
+ to the fixed dimension [self._num_sampled_masks, 4].
+ mask_outer_boxes: loose box that enclose sampled tight box. The
+ box is represented in [y1, x1, y2, x2] format. The tensor is sampled
+ to the fixed dimension [self._num_sampled_masks, 4].
+ mask_classes: the class ids of sampled training masks. The tensor has
+ shape [self._num_sampled_masks].
+ labels:
+ cls_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ box_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ num_positives: number of positive anchors in the image.
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each level.
+ image_scale: 2D float `Tensor` representing scale factors that apply
+ to [height, width] of input image.
+ mask_targets: training binary mask targets. The tensor has shape
+ [self._num_sampled_masks, self._mask_crop_size, self._mask_crop_size].
+ mask_is_valid: the binary tensor to indicate if the sampled masks are
+ valide. The sampled masks are invalid when no mask annotations are
+ included in the image. The tensor has shape [1].
+ groundtruths:
+ source_id: source image id. Default value -1 if the source id is empty
+ in the groundtruth annotation.
+ boxes: groundtruth bounding box annotations. The box is represented in
+ [y1, x1, y2, x2] format. The tensor is padded with -1 to the fixed
+ dimension [self._max_num_instances, 4].
+ classes: groundtruth classes annotations. The tensor is padded with
+ -1 to the fixed dimension [self._max_num_instances].
+ areas: groundtruth areas annotations. The tensor is padded with -1
+ to the fixed dimension [self._max_num_instances].
+ is_crowds: groundtruth annotations to indicate if an annotation
+ represents a group of instances by value {0, 1}. The tensor is
+ padded with 0 to the fixed dimension [self._max_num_instances].
+ """
+ with tf.name_scope('parser'):
+ data = self._example_decoder.decode(value)
+ return self._parse_fn(data)
+
+ def _parse_train_data(self, data):
+ """Parse data for ShapeMask training."""
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ masks = data['groundtruth_instance_masks']
+ is_crowds = data['groundtruth_is_crowd']
+ # Skips annotations with `is_crowd` = True.
+ if self._skip_crowd_during_training and self._is_training:
+ num_groundtrtuhs = tf.shape(classes)[0]
+ with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
+ indices = tf.cond(
+ tf.greater(tf.size(is_crowds), 0),
+ lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
+ lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
+ classes = tf.gather(classes, indices)
+ boxes = tf.gather(boxes, indices)
+ masks = tf.gather(masks, indices)
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # If not using category, makes all categories with id = 0.
+ if not self._use_category:
+ classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Flips image randomly during training.
+ if self._aug_rand_hflip:
+ image, boxes, masks = input_utils.random_horizontal_flip(
+ image, boxes, masks)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ self._output_size,
+ aug_scale_min=self._aug_scale_min,
+ aug_scale_max=self._aug_scale_max)
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+
+ # Resizes and crops boxes and masks.
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+ masks = tf.gather(masks, indices)
+
+ # Assigns anchors.
+ input_anchor = anchor.Anchor(
+ self._min_level, self._max_level, self._num_scales,
+ self._aspect_ratios, self._anchor_size, self._output_size)
+ anchor_labeler = anchor.AnchorLabeler(
+ input_anchor, self._match_threshold, self._unmatched_threshold)
+ (cls_targets,
+ box_targets,
+ num_positives) = anchor_labeler.label_anchors(
+ boxes,
+ tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+
+ # Sample groundtruth masks/boxes/classes for mask branch.
+ num_masks = tf.shape(masks)[0]
+ mask_shape = tf.shape(masks)[1:3]
+
+ # Pad sampled boxes/masks/classes to a constant batch size.
+ padded_boxes = pad_to_size(boxes, self._num_sampled_masks)
+ padded_classes = pad_to_size(classes, self._num_sampled_masks)
+ padded_masks = pad_to_size(masks, self._num_sampled_masks)
+
+ # Randomly sample groundtruth masks for mask branch training. For the image
+ # without groundtruth masks, it will sample the dummy padded tensors.
+ rand_indices = tf.random.shuffle(
+ tf.range(tf.maximum(num_masks, self._num_sampled_masks)))
+ rand_indices = tf.math.mod(rand_indices, tf.maximum(num_masks, 1))
+ rand_indices = rand_indices[0:self._num_sampled_masks]
+ rand_indices = tf.reshape(rand_indices, [self._num_sampled_masks])
+
+ sampled_boxes = tf.gather(padded_boxes, rand_indices)
+ sampled_classes = tf.gather(padded_classes, rand_indices)
+ sampled_masks = tf.gather(padded_masks, rand_indices)
+ # Jitter the sampled boxes to mimic the noisy detections.
+ sampled_boxes = box_utils.jitter_boxes(
+ sampled_boxes, noise_scale=self._box_jitter_scale)
+ sampled_boxes = box_utils.clip_boxes(sampled_boxes, self._output_size)
+ # Compute mask targets in feature crop. A feature crop fully contains a
+ # sampled box.
+ mask_outer_boxes = box_utils.compute_outer_boxes(
+ sampled_boxes, tf.shape(image)[0:2], scale=self._outer_box_scale)
+ mask_outer_boxes = box_utils.clip_boxes(mask_outer_boxes, self._output_size)
+ # Compensate the offset of mask_outer_boxes to map it back to original image
+ # scale.
+ mask_outer_boxes_ori = mask_outer_boxes
+ mask_outer_boxes_ori += tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
+ mask_outer_boxes_ori /= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
+ norm_mask_outer_boxes_ori = box_utils.normalize_boxes(
+ mask_outer_boxes_ori, mask_shape)
+
+ # Set sampled_masks shape to [batch_size, height, width, 1].
+ sampled_masks = tf.cast(tf.expand_dims(sampled_masks, axis=-1), tf.float32)
+ mask_targets = tf.image.crop_and_resize(
+ sampled_masks,
+ norm_mask_outer_boxes_ori,
+ box_indices=tf.range(self._num_sampled_masks),
+ crop_size=[self._mask_crop_size, self._mask_crop_size],
+ method='bilinear',
+ extrapolation_value=0,
+ name='train_mask_targets')
+ mask_targets = tf.where(tf.greater_equal(mask_targets, 0.5),
+ tf.ones_like(mask_targets),
+ tf.zeros_like(mask_targets))
+ mask_targets = tf.squeeze(mask_targets, axis=-1)
+ if self._up_sample_factor > 1:
+ fine_mask_targets = tf.image.crop_and_resize(
+ sampled_masks,
+ norm_mask_outer_boxes_ori,
+ box_indices=tf.range(self._num_sampled_masks),
+ crop_size=[
+ self._mask_crop_size * self._up_sample_factor,
+ self._mask_crop_size * self._up_sample_factor
+ ],
+ method='bilinear',
+ extrapolation_value=0,
+ name='train_mask_targets')
+ fine_mask_targets = tf.where(
+ tf.greater_equal(fine_mask_targets, 0.5),
+ tf.ones_like(fine_mask_targets), tf.zeros_like(fine_mask_targets))
+ fine_mask_targets = tf.squeeze(fine_mask_targets, axis=-1)
+ else:
+ fine_mask_targets = mask_targets
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ valid_image = tf.cast(tf.not_equal(num_masks, 0), tf.int32)
+ if self._mask_train_class == 'all':
+ mask_is_valid = valid_image * tf.ones_like(sampled_classes, tf.int32)
+ else:
+ # Get the intersection of sampled classes with training splits.
+ mask_valid_classes = tf.cast(
+ tf.expand_dims(
+ class_utils.coco_split_class_ids(self._mask_train_class), 1),
+ sampled_classes.dtype)
+ match = tf.reduce_any(
+ tf.equal(tf.expand_dims(sampled_classes, 0), mask_valid_classes), 0)
+ mask_is_valid = valid_image * tf.cast(match, tf.int32)
+
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'cls_targets': cls_targets,
+ 'box_targets': box_targets,
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'num_positives': num_positives,
+ 'image_info': image_info,
+ # For ShapeMask.
+ 'mask_targets': mask_targets,
+ 'fine_mask_targets': fine_mask_targets,
+ 'mask_is_valid': mask_is_valid,
+ }
+
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ 'mask_boxes': sampled_boxes,
+ 'mask_outer_boxes': mask_outer_boxes,
+ 'mask_classes': sampled_classes,
+ }
+ return inputs, labels
+
+ def _parse_predict_data(self, data):
+ """Parse data for ShapeMask training."""
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ masks = data['groundtruth_instance_masks']
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # If not using category, makes all categories with id = 0.
+ if not self._use_category:
+ classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ self._output_size,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0)
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+
+ # Resizes and crops boxes and masks.
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+ masks = input_utils.resize_and_crop_masks(
+ tf.expand_dims(masks, axis=-1), image_scale, self._output_size, offset)
+
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+
+ # Assigns anchors.
+ input_anchor = anchor.Anchor(
+ self._min_level, self._max_level, self._num_scales,
+ self._aspect_ratios, self._anchor_size, self._output_size)
+ anchor_labeler = anchor.AnchorLabeler(
+ input_anchor, self._match_threshold, self._unmatched_threshold)
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ labels = {
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'image_info': image_info,
+ }
+ if self._mode == ModeKeys.PREDICT_WITH_GT:
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ groundtruths = {
+ 'source_id': data['source_id'],
+ 'height': data['height'],
+ 'width': data['width'],
+ 'num_detections': tf.shape(data['groundtruth_classes']),
+ 'boxes': box_utils.denormalize_boxes(
+ data['groundtruth_boxes'], image_shape),
+ 'classes': data['groundtruth_classes'],
+ # 'masks': tf.squeeze(masks, axis=-1),
+ 'areas': data['groundtruth_area'],
+ 'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
+ }
+ groundtruths['source_id'] = dataloader_utils.process_source_id(
+ groundtruths['source_id'])
+ groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
+ groundtruths, self._max_num_instances)
+ # Computes training labels.
+ (cls_targets,
+ box_targets,
+ num_positives) = anchor_labeler.label_anchors(
+ boxes,
+ tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+ # Packs labels for model_fn outputs.
+ labels.update({
+ 'cls_targets': cls_targets,
+ 'box_targets': box_targets,
+ 'num_positives': num_positives,
+ 'groundtruths': groundtruths,
+ })
+
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ }
+
+ return inputs, labels
diff --git a/models/official/vision/detection/dataloader/tf_example_decoder.py b/models/official/vision/detection/dataloader/tf_example_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f719a9168a4d3106600fffcc47c14cc90f3cadc7
--- /dev/null
+++ b/models/official/vision/detection/dataloader/tf_example_decoder.py
@@ -0,0 +1,156 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tensorflow Example proto decoder for object detection.
+
+A decoder to decode string tensors containing serialized tensorflow.Example
+protos for object detection.
+"""
+import tensorflow as tf
+
+
+class TfExampleDecoder(object):
+ """Tensorflow Example proto decoder."""
+
+ def __init__(self, include_mask=False):
+ self._include_mask = include_mask
+ self._keys_to_features = {
+ 'image/encoded':
+ tf.io.FixedLenFeature((), tf.string),
+ 'image/source_id':
+ tf.io.FixedLenFeature((), tf.string),
+ 'image/height':
+ tf.io.FixedLenFeature((), tf.int64),
+ 'image/width':
+ tf.io.FixedLenFeature((), tf.int64),
+ 'image/object/bbox/xmin':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/bbox/xmax':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/bbox/ymin':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/bbox/ymax':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/class/label':
+ tf.io.VarLenFeature(tf.int64),
+ 'image/object/area':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/is_crowd':
+ tf.io.VarLenFeature(tf.int64),
+ }
+ if include_mask:
+ self._keys_to_features.update({
+ 'image/object/mask':
+ tf.io.VarLenFeature(tf.string),
+ })
+
+ def _decode_image(self, parsed_tensors):
+ """Decodes the image and set its static shape."""
+ image = tf.io.decode_image(parsed_tensors['image/encoded'], channels=3)
+ image.set_shape([None, None, 3])
+ return image
+
+ def _decode_boxes(self, parsed_tensors):
+ """Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
+ xmin = parsed_tensors['image/object/bbox/xmin']
+ xmax = parsed_tensors['image/object/bbox/xmax']
+ ymin = parsed_tensors['image/object/bbox/ymin']
+ ymax = parsed_tensors['image/object/bbox/ymax']
+ return tf.stack([ymin, xmin, ymax, xmax], axis=-1)
+
+ def _decode_masks(self, parsed_tensors):
+ """Decode a set of PNG masks to the tf.float32 tensors."""
+ def _decode_png_mask(png_bytes):
+ mask = tf.squeeze(
+ tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1)
+ mask = tf.cast(mask, dtype=tf.float32)
+ mask.set_shape([None, None])
+ return mask
+
+ height = parsed_tensors['image/height']
+ width = parsed_tensors['image/width']
+ masks = parsed_tensors['image/object/mask']
+ return tf.cond(
+ pred=tf.greater(tf.size(input=masks), 0),
+ true_fn=lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32),
+ false_fn=lambda: tf.zeros([0, height, width], dtype=tf.float32))
+
+ def _decode_areas(self, parsed_tensors):
+ xmin = parsed_tensors['image/object/bbox/xmin']
+ xmax = parsed_tensors['image/object/bbox/xmax']
+ ymin = parsed_tensors['image/object/bbox/ymin']
+ ymax = parsed_tensors['image/object/bbox/ymax']
+ return tf.cond(
+ tf.greater(tf.shape(parsed_tensors['image/object/area'])[0], 0),
+ lambda: parsed_tensors['image/object/area'],
+ lambda: (xmax - xmin) * (ymax - ymin))
+
+ def decode(self, serialized_example):
+ """Decode the serialized example.
+
+ Args:
+ serialized_example: a single serialized tf.Example string.
+
+ Returns:
+ decoded_tensors: a dictionary of tensors with the following fields:
+ - image: a uint8 tensor of shape [None, None, 3].
+ - source_id: a string scalar tensor.
+ - height: an integer scalar tensor.
+ - width: an integer scalar tensor.
+ - groundtruth_classes: a int64 tensor of shape [None].
+ - groundtruth_is_crowd: a bool tensor of shape [None].
+ - groundtruth_area: a float32 tensor of shape [None].
+ - groundtruth_boxes: a float32 tensor of shape [None, 4].
+ - groundtruth_instance_masks: a float32 tensor of shape
+ [None, None, None].
+ - groundtruth_instance_masks_png: a string tensor of shape [None].
+ """
+ parsed_tensors = tf.io.parse_single_example(
+ serialized=serialized_example, features=self._keys_to_features)
+ for k in parsed_tensors:
+ if isinstance(parsed_tensors[k], tf.SparseTensor):
+ if parsed_tensors[k].dtype == tf.string:
+ parsed_tensors[k] = tf.sparse.to_dense(
+ parsed_tensors[k], default_value='')
+ else:
+ parsed_tensors[k] = tf.sparse.to_dense(
+ parsed_tensors[k], default_value=0)
+
+ image = self._decode_image(parsed_tensors)
+ boxes = self._decode_boxes(parsed_tensors)
+ areas = self._decode_areas(parsed_tensors)
+ is_crowds = tf.cond(
+ tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0),
+ lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool),
+ lambda: tf.zeros_like(parsed_tensors['image/object/class/label'], dtype=tf.bool)) # pylint: disable=line-too-long
+ if self._include_mask:
+ masks = self._decode_masks(parsed_tensors)
+
+ decoded_tensors = {
+ 'image': image,
+ 'source_id': parsed_tensors['image/source_id'],
+ 'height': parsed_tensors['image/height'],
+ 'width': parsed_tensors['image/width'],
+ 'groundtruth_classes': parsed_tensors['image/object/class/label'],
+ 'groundtruth_is_crowd': is_crowds,
+ 'groundtruth_area': areas,
+ 'groundtruth_boxes': boxes,
+ }
+ if self._include_mask:
+ decoded_tensors.update({
+ 'groundtruth_instance_masks': masks,
+ 'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'],
+ })
+ return decoded_tensors
diff --git a/models/official/vision/detection/evaluation/__init__.py b/models/official/vision/detection/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/detection/evaluation/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/evaluation/coco_evaluator.py b/models/official/vision/detection/evaluation/coco_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc56a9332784dd66d5393bbf0d4c996fe5141c6d
--- /dev/null
+++ b/models/official/vision/detection/evaluation/coco_evaluator.py
@@ -0,0 +1,343 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The COCO-style evaluator.
+
+The following snippet demonstrates the use of interfaces:
+
+ evaluator = COCOEvaluator(...)
+ for _ in range(num_evals):
+ for _ in range(num_batches_per_eval):
+ predictions, groundtruth = predictor.predict(...) # pop a batch.
+ evaluator.update(predictions, groundtruths) # aggregate internal stats.
+ evaluator.evaluate() # finish one full eval.
+
+See also: https://github.com/cocodataset/cocoapi/
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import atexit
+import tempfile
+import numpy as np
+from absl import logging
+from pycocotools import cocoeval
+import six
+import tensorflow as tf
+
+from official.vision.detection.evaluation import coco_utils
+from official.vision.detection.utils import class_utils
+
+
+class MetricWrapper(object):
+ # This is only a wrapper for COCO metric and works on for numpy array. So it
+ # doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric.
+
+ def __init__(self, evaluator):
+ self._evaluator = evaluator
+
+ def update_state(self, y_true, y_pred):
+ labels = tf.nest.map_structure(lambda x: x.numpy(), y_true)
+ outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred)
+ groundtruths = {}
+ predictions = {}
+ for key, val in outputs.items():
+ if isinstance(val, tuple):
+ val = np.concatenate(val)
+ predictions[key] = val
+ for key, val in labels.items():
+ if isinstance(val, tuple):
+ val = np.concatenate(val)
+ groundtruths[key] = val
+ self._evaluator.update(predictions, groundtruths)
+
+ def result(self):
+ return self._evaluator.evaluate()
+
+ def reset_states(self):
+ return self._evaluator.reset()
+
+
+class COCOEvaluator(object):
+ """COCO evaluation metric class."""
+
+ def __init__(self, annotation_file, include_mask, need_rescale_bboxes=True):
+ """Constructs COCO evaluation class.
+
+ The class provides the interface to metrics_fn in TPUEstimator. The
+ _update_op() takes detections from each image and push them to
+ self.detections. The _evaluate() loads a JSON file in COCO annotation format
+ as the groundtruths and runs COCO evaluation.
+
+ Args:
+ annotation_file: a JSON file that stores annotations of the eval dataset.
+ If `annotation_file` is None, groundtruth annotations will be loaded
+ from the dataloader.
+ include_mask: a boolean to indicate whether or not to include the mask
+ eval.
+ need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
+ to absolute values (`image_info` is needed in this case).
+ """
+ if annotation_file:
+ if annotation_file.startswith('gs://'):
+ _, local_val_json = tempfile.mkstemp(suffix='.json')
+ tf.io.gfile.remove(local_val_json)
+
+ tf.io.gfile.copy(annotation_file, local_val_json)
+ atexit.register(tf.io.gfile.remove, local_val_json)
+ else:
+ local_val_json = annotation_file
+ self._coco_gt = coco_utils.COCOWrapper(
+ eval_type=('mask' if include_mask else 'box'),
+ annotation_file=local_val_json)
+ self._annotation_file = annotation_file
+ self._include_mask = include_mask
+ self._metric_names = [
+ 'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'ARmax1', 'ARmax10',
+ 'ARmax100', 'ARs', 'ARm', 'ARl'
+ ]
+ self._required_prediction_fields = [
+ 'source_id', 'num_detections', 'detection_classes', 'detection_scores',
+ 'detection_boxes'
+ ]
+ self._need_rescale_bboxes = need_rescale_bboxes
+ if self._need_rescale_bboxes:
+ self._required_prediction_fields.append('image_info')
+ self._required_groundtruth_fields = [
+ 'source_id', 'height', 'width', 'classes', 'boxes'
+ ]
+ if self._include_mask:
+ mask_metric_names = ['mask_' + x for x in self._metric_names]
+ self._metric_names.extend(mask_metric_names)
+ self._required_prediction_fields.extend(['detection_masks'])
+ self._required_groundtruth_fields.extend(['masks'])
+
+ self.reset()
+
+ def reset(self):
+ """Resets internal states for a fresh run."""
+ self._predictions = {}
+ if not self._annotation_file:
+ self._groundtruths = {}
+
+ def evaluate(self):
+ """Evaluates with detections from all images with COCO API.
+
+ Returns:
+ coco_metric: float numpy array with shape [24] representing the
+ coco-style evaluation metrics (box and mask).
+ """
+ if not self._annotation_file:
+ logging.info('Thre is no annotation_file in COCOEvaluator.')
+ gt_dataset = coco_utils.convert_groundtruths_to_coco_dataset(
+ self._groundtruths)
+ coco_gt = coco_utils.COCOWrapper(
+ eval_type=('mask' if self._include_mask else 'box'),
+ gt_dataset=gt_dataset)
+ else:
+ logging.info('Using annotation file: %s', self._annotation_file)
+ coco_gt = self._coco_gt
+ coco_predictions = coco_utils.convert_predictions_to_coco_annotations(
+ self._predictions)
+ coco_dt = coco_gt.loadRes(predictions=coco_predictions)
+ image_ids = [ann['image_id'] for ann in coco_predictions]
+
+ coco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='bbox')
+ coco_eval.params.imgIds = image_ids
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ coco_metrics = coco_eval.stats
+
+ if self._include_mask:
+ mcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='segm')
+ mcoco_eval.params.imgIds = image_ids
+ mcoco_eval.evaluate()
+ mcoco_eval.accumulate()
+ mcoco_eval.summarize()
+ mask_coco_metrics = mcoco_eval.stats
+
+ if self._include_mask:
+ metrics = np.hstack((coco_metrics, mask_coco_metrics))
+ else:
+ metrics = coco_metrics
+
+ # Cleans up the internal variables in order for a fresh eval next time.
+ self.reset()
+
+ metrics_dict = {}
+ for i, name in enumerate(self._metric_names):
+ metrics_dict[name] = metrics[i].astype(np.float32)
+ return metrics_dict
+
+ def _process_predictions(self, predictions):
+ image_scale = np.tile(predictions['image_info'][:, 2:3, :], (1, 1, 2))
+ predictions['detection_boxes'] = (
+ predictions['detection_boxes'].astype(np.float32))
+ predictions['detection_boxes'] /= image_scale
+ if 'detection_outer_boxes' in predictions:
+ predictions['detection_outer_boxes'] = (
+ predictions['detection_outer_boxes'].astype(np.float32))
+ predictions['detection_outer_boxes'] /= image_scale
+
+ def update(self, predictions, groundtruths=None):
+ """Update and aggregate detection results and groundtruth data.
+
+ Args:
+ predictions: a dictionary of numpy arrays including the fields below.
+ See different parsers under `../dataloader` for more details.
+ Required fields:
+ - source_id: a numpy array of int or string of shape [batch_size].
+ - image_info [if `need_rescale_bboxes` is True]: a numpy array of
+ float of shape [batch_size, 4, 2].
+ - num_detections: a numpy array of
+ int of shape [batch_size].
+ - detection_boxes: a numpy array of float of shape [batch_size, K, 4].
+ - detection_classes: a numpy array of int of shape [batch_size, K].
+ - detection_scores: a numpy array of float of shape [batch_size, K].
+ Optional fields:
+ - detection_masks: a numpy array of float of shape
+ [batch_size, K, mask_height, mask_width].
+ groundtruths: a dictionary of numpy arrays including the fields below.
+ See also different parsers under `../dataloader` for more details.
+ Required fields:
+ - source_id: a numpy array of int or string of shape [batch_size].
+ - height: a numpy array of int of shape [batch_size].
+ - width: a numpy array of int of shape [batch_size].
+ - num_detections: a numpy array of int of shape [batch_size].
+ - boxes: a numpy array of float of shape [batch_size, K, 4].
+ - classes: a numpy array of int of shape [batch_size, K].
+ Optional fields:
+ - is_crowds: a numpy array of int of shape [batch_size, K]. If the
+ field is absent, it is assumed that this instance is not crowd.
+ - areas: a numy array of float of shape [batch_size, K]. If the
+ field is absent, the area is calculated using either boxes or
+ masks depending on which one is available.
+ - masks: a numpy array of float of shape
+ [batch_size, K, mask_height, mask_width],
+
+ Raises:
+ ValueError: if the required prediction or groundtruth fields are not
+ present in the incoming `predictions` or `groundtruths`.
+ """
+ for k in self._required_prediction_fields:
+ if k not in predictions:
+ raise ValueError(
+ 'Missing the required key `{}` in predictions!'.format(k))
+ if self._need_rescale_bboxes:
+ self._process_predictions(predictions)
+ for k, v in six.iteritems(predictions):
+ if k not in self._predictions:
+ self._predictions[k] = [v]
+ else:
+ self._predictions[k].append(v)
+
+ if not self._annotation_file:
+ assert groundtruths
+ for k in self._required_groundtruth_fields:
+ if k not in groundtruths:
+ raise ValueError(
+ 'Missing the required key `{}` in groundtruths!'.format(k))
+ for k, v in six.iteritems(groundtruths):
+ if k not in self._groundtruths:
+ self._groundtruths[k] = [v]
+ else:
+ self._groundtruths[k].append(v)
+
+
+class ShapeMaskCOCOEvaluator(COCOEvaluator):
+ """COCO evaluation metric class for ShapeMask."""
+
+ def __init__(self, mask_eval_class, **kwargs):
+ """Constructs COCO evaluation class.
+
+ The class provides the interface to metrics_fn in TPUEstimator. The
+ _update_op() takes detections from each image and push them to
+ self.detections. The _evaluate() loads a JSON file in COCO annotation format
+ as the groundtruths and runs COCO evaluation.
+
+ Args:
+ mask_eval_class: the set of classes for mask evaluation.
+ **kwargs: other keyword arguments passed to the parent class initializer.
+ """
+ super(ShapeMaskCOCOEvaluator, self).__init__(**kwargs)
+ self._mask_eval_class = mask_eval_class
+ self._eval_categories = class_utils.coco_split_class_ids(mask_eval_class)
+ if mask_eval_class != 'all':
+ self._metric_names = [
+ x.replace('mask', 'novel_mask') for x in self._metric_names
+ ]
+
+ def evaluate(self):
+ """Evaluates with detections from all images with COCO API.
+
+ Returns:
+ coco_metric: float numpy array with shape [24] representing the
+ coco-style evaluation metrics (box and mask).
+ """
+ if not self._annotation_file:
+ gt_dataset = coco_utils.convert_groundtruths_to_coco_dataset(
+ self._groundtruths)
+ coco_gt = coco_utils.COCOWrapper(
+ eval_type=('mask' if self._include_mask else 'box'),
+ gt_dataset=gt_dataset)
+ else:
+ coco_gt = self._coco_gt
+ coco_predictions = coco_utils.convert_predictions_to_coco_annotations(
+ self._predictions)
+ coco_dt = coco_gt.loadRes(predictions=coco_predictions)
+ image_ids = [ann['image_id'] for ann in coco_predictions]
+
+ coco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='bbox')
+ coco_eval.params.imgIds = image_ids
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ coco_metrics = coco_eval.stats
+
+ if self._include_mask:
+ mcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='segm')
+ mcoco_eval.params.imgIds = image_ids
+ mcoco_eval.evaluate()
+ mcoco_eval.accumulate()
+ mcoco_eval.summarize()
+ if self._mask_eval_class == 'all':
+ metrics = np.hstack((coco_metrics, mcoco_eval.stats))
+ else:
+ mask_coco_metrics = mcoco_eval.category_stats
+ val_catg_idx = np.isin(mcoco_eval.params.catIds,
+ self._eval_categories)
+ # Gather the valid evaluation of the eval categories.
+ if np.any(val_catg_idx):
+ mean_val_metrics = []
+ for mid in range(len(self._metric_names) // 2):
+ mean_val_metrics.append(
+ np.nanmean(mask_coco_metrics[mid][val_catg_idx]))
+
+ mean_val_metrics = np.array(mean_val_metrics)
+ else:
+ mean_val_metrics = np.zeros(len(self._metric_names) // 2)
+ metrics = np.hstack((coco_metrics, mean_val_metrics))
+ else:
+ metrics = coco_metrics
+
+ # Cleans up the internal variables in order for a fresh eval next time.
+ self.reset()
+
+ metrics_dict = {}
+ for i, name in enumerate(self._metric_names):
+ metrics_dict[name] = metrics[i].astype(np.float32)
+ return metrics_dict
diff --git a/models/official/vision/detection/evaluation/coco_utils.py b/models/official/vision/detection/evaluation/coco_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8155d1fbb89ac143eb7cc03457a6645a5b5ab505
--- /dev/null
+++ b/models/official/vision/detection/evaluation/coco_utils.py
@@ -0,0 +1,374 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Util functions related to pycocotools and COCO eval."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import json
+
+from absl import logging
+import numpy as np
+from PIL import Image
+from pycocotools import coco
+from pycocotools import mask as mask_api
+import six
+import tensorflow as tf
+
+from official.vision.detection.dataloader import tf_example_decoder
+from official.vision.detection.utils import box_utils
+from official.vision.detection.utils import mask_utils
+
+
+class COCOWrapper(coco.COCO):
+ """COCO wrapper class.
+
+ This class wraps COCO API object, which provides the following additional
+ functionalities:
+ 1. Support string type image id.
+ 2. Support loading the groundtruth dataset using the external annotation
+ dictionary.
+ 3. Support loading the prediction results using the external annotation
+ dictionary.
+ """
+
+ def __init__(self, eval_type='box', annotation_file=None, gt_dataset=None):
+ """Instantiates a COCO-style API object.
+
+ Args:
+ eval_type: either 'box' or 'mask'.
+ annotation_file: a JSON file that stores annotations of the eval dataset.
+ This is required if `gt_dataset` is not provided.
+ gt_dataset: the groundtruth eval datatset in COCO API format.
+ """
+ if ((annotation_file and gt_dataset) or
+ ((not annotation_file) and (not gt_dataset))):
+ raise ValueError('One and only one of `annotation_file` and `gt_dataset` '
+ 'needs to be specified.')
+
+ if eval_type not in ['box', 'mask']:
+ raise ValueError('The `eval_type` can only be either `box` or `mask`.')
+
+ coco.COCO.__init__(self, annotation_file=annotation_file)
+ self._eval_type = eval_type
+ if gt_dataset:
+ self.dataset = gt_dataset
+ self.createIndex()
+
+ def loadRes(self, predictions):
+ """Loads result file and return a result api object.
+
+ Args:
+ predictions: a list of dictionary each representing an annotation in COCO
+ format. The required fields are `image_id`, `category_id`, `score`,
+ `bbox`, `segmentation`.
+
+ Returns:
+ res: result COCO api object.
+
+ Raises:
+ ValueError: if the set of image id from predctions is not the subset of
+ the set of image id of the groundtruth dataset.
+ """
+ res = coco.COCO()
+ res.dataset['images'] = copy.deepcopy(self.dataset['images'])
+ res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+
+ image_ids = [ann['image_id'] for ann in predictions]
+ if set(image_ids) != (set(image_ids) & set(self.getImgIds())):
+ raise ValueError('Results do not correspond to the current dataset!')
+ for ann in predictions:
+ x1, x2, y1, y2 = [ann['bbox'][0], ann['bbox'][0] + ann['bbox'][2],
+ ann['bbox'][1], ann['bbox'][1] + ann['bbox'][3]]
+ if self._eval_type == 'box':
+ ann['area'] = ann['bbox'][2] * ann['bbox'][3]
+ ann['segmentation'] = [
+ [x1, y1, x1, y2, x2, y2, x2, y1]]
+ elif self._eval_type == 'mask':
+ ann['area'] = mask_api.area(ann['segmentation'])
+
+ res.dataset['annotations'] = copy.deepcopy(predictions)
+ res.createIndex()
+ return res
+
+
+def convert_predictions_to_coco_annotations(predictions):
+ """Converts a batch of predictions to annotations in COCO format.
+
+ Args:
+ predictions: a dictionary of lists of numpy arrays including the following
+ fields. K below denotes the maximum number of instances per image.
+ Required fields:
+ - source_id: a list of numpy arrays of int or string of shape
+ [batch_size].
+ - num_detections: a list of numpy arrays of int of shape [batch_size].
+ - detection_boxes: a list of numpy arrays of float of shape
+ [batch_size, K, 4], where coordinates are in the original image
+ space (not the scaled image space).
+ - detection_classes: a list of numpy arrays of int of shape
+ [batch_size, K].
+ - detection_scores: a list of numpy arrays of float of shape
+ [batch_size, K].
+ Optional fields:
+ - detection_masks: a list of numpy arrays of float of shape
+ [batch_size, K, mask_height, mask_width].
+
+ Returns:
+ coco_predictions: prediction in COCO annotation format.
+ """
+ coco_predictions = []
+ num_batches = len(predictions['source_id'])
+ batch_size = predictions['source_id'][0].shape[0]
+ max_num_detections = predictions['detection_classes'][0].shape[1]
+ use_outer_box = 'detection_outer_boxes' in predictions
+ for i in range(num_batches):
+ predictions['detection_boxes'][i] = box_utils.yxyx_to_xywh(
+ predictions['detection_boxes'][i])
+ if use_outer_box:
+ predictions['detection_outer_boxes'][i] = box_utils.yxyx_to_xywh(
+ predictions['detection_outer_boxes'][i])
+ mask_boxes = predictions['detection_outer_boxes']
+ else:
+ mask_boxes = predictions['detection_boxes']
+
+ for j in range(batch_size):
+ if 'detection_masks' in predictions:
+ image_masks = mask_utils.paste_instance_masks(
+ predictions['detection_masks'][i][j],
+ mask_boxes[i][j],
+ int(predictions['image_info'][i][j, 0, 0]),
+ int(predictions['image_info'][i][j, 0, 1]))
+ binary_masks = (image_masks > 0.0).astype(np.uint8)
+ encoded_masks = [
+ mask_api.encode(np.asfortranarray(binary_mask))
+ for binary_mask in list(binary_masks)]
+ for k in range(max_num_detections):
+ ann = {}
+ ann['image_id'] = predictions['source_id'][i][j]
+ ann['category_id'] = predictions['detection_classes'][i][j, k]
+ ann['bbox'] = predictions['detection_boxes'][i][j, k]
+ ann['score'] = predictions['detection_scores'][i][j, k]
+ if 'detection_masks' in predictions:
+ ann['segmentation'] = encoded_masks[k]
+ coco_predictions.append(ann)
+
+ for i, ann in enumerate(coco_predictions):
+ ann['id'] = i + 1
+
+ return coco_predictions
+
+
+def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
+ """Converts groundtruths to the dataset in COCO format.
+
+ Args:
+ groundtruths: a dictionary of numpy arrays including the fields below.
+ Note that each element in the list represent the number for a single
+ example without batch dimension. K below denotes the actual number of
+ instances for each image.
+ Required fields:
+ - source_id: a list of numpy arrays of int or string of shape
+ [batch_size].
+ - height: a list of numpy arrays of int of shape [batch_size].
+ - width: a list of numpy arrays of int of shape [batch_size].
+ - num_detections: a list of numpy arrays of int of shape [batch_size].
+ - boxes: a list of numpy arrays of float of shape [batch_size, K, 4],
+ where coordinates are in the original image space (not the
+ normalized coordinates).
+ - classes: a list of numpy arrays of int of shape [batch_size, K].
+ Optional fields:
+ - is_crowds: a list of numpy arrays of int of shape [batch_size, K]. If
+ th field is absent, it is assumed that this instance is not crowd.
+ - areas: a list of numy arrays of float of shape [batch_size, K]. If the
+ field is absent, the area is calculated using either boxes or
+ masks depending on which one is available.
+ - masks: a list of numpy arrays of string of shape [batch_size, K],
+ label_map: (optional) a dictionary that defines items from the category id
+ to the category name. If `None`, collect the category mappping from the
+ `groundtruths`.
+
+ Returns:
+ coco_groundtruths: the groundtruth dataset in COCO format.
+ """
+ source_ids = np.concatenate(groundtruths['source_id'], axis=0)
+ heights = np.concatenate(groundtruths['height'], axis=0)
+ widths = np.concatenate(groundtruths['width'], axis=0)
+ gt_images = [{'id': int(i), 'height': int(h), 'width': int(w)} for i, h, w
+ in zip(source_ids, heights, widths)]
+
+ gt_annotations = []
+ num_batches = len(groundtruths['source_id'])
+ batch_size = groundtruths['source_id'][0].shape[0]
+ for i in range(num_batches):
+ for j in range(batch_size):
+ num_instances = groundtruths['num_detections'][i][j]
+ for k in range(num_instances):
+ ann = {}
+ ann['image_id'] = int(groundtruths['source_id'][i][j])
+ if 'is_crowds' in groundtruths:
+ ann['iscrowd'] = int(groundtruths['is_crowds'][i][j, k])
+ else:
+ ann['iscrowd'] = 0
+ ann['category_id'] = int(groundtruths['classes'][i][j, k])
+ boxes = groundtruths['boxes'][i]
+ ann['bbox'] = [
+ float(boxes[j, k, 1]),
+ float(boxes[j, k, 0]),
+ float(boxes[j, k, 3] - boxes[j, k, 1]),
+ float(boxes[j, k, 2] - boxes[j, k, 0])]
+ if 'areas' in groundtruths:
+ ann['area'] = float(groundtruths['areas'][i][j, k])
+ else:
+ ann['area'] = float(
+ (boxes[j, k, 3] - boxes[j, k, 1]) *
+ (boxes[j, k, 2] - boxes[j, k, 0]))
+ if 'masks' in groundtruths:
+ mask = Image.open(six.StringIO(groundtruths['masks'][i][j, k]))
+ width, height = mask.size
+ np_mask = (
+ np.array(mask.getdata()).reshape(height, width).astype(np.uint8))
+ np_mask[np_mask > 0] = 255
+ encoded_mask = mask_api.encode(np.asfortranarray(np_mask))
+ ann['segmentation'] = encoded_mask
+ if 'areas' not in groundtruths:
+ ann['area'] = mask_api.area(encoded_mask)
+ gt_annotations.append(ann)
+
+ for i, ann in enumerate(gt_annotations):
+ ann['id'] = i + 1
+
+ if label_map:
+ gt_categories = [{'id': i, 'name': label_map[i]} for i in label_map]
+ else:
+ category_ids = [gt['category_id'] for gt in gt_annotations]
+ gt_categories = [{'id': i} for i in set(category_ids)]
+
+ gt_dataset = {
+ 'images': gt_images,
+ 'categories': gt_categories,
+ 'annotations': copy.deepcopy(gt_annotations),
+ }
+ return gt_dataset
+
+
+class COCOGroundtruthGenerator(object):
+ """Generates the groundtruth annotations from a single example."""
+
+ def __init__(self, file_pattern, num_examples, include_mask):
+ self._file_pattern = file_pattern
+ self._num_examples = num_examples
+ self._include_mask = include_mask
+ self._dataset_fn = tf.data.TFRecordDataset
+
+ def _parse_single_example(self, example):
+ """Parses a single serialized tf.Example proto.
+
+ Args:
+ example: a serialized tf.Example proto string.
+
+ Returns:
+ A dictionary of groundtruth with the following fields:
+ source_id: a scalar tensor of int64 representing the image source_id.
+ height: a scalar tensor of int64 representing the image height.
+ width: a scalar tensor of int64 representing the image width.
+ boxes: a float tensor of shape [K, 4], representing the groundtruth
+ boxes in absolute coordinates with respect to the original image size.
+ classes: a int64 tensor of shape [K], representing the class labels of
+ each instances.
+ is_crowds: a bool tensor of shape [K], indicating whether the instance
+ is crowd.
+ areas: a float tensor of shape [K], indicating the area of each
+ instance.
+ masks: a string tensor of shape [K], containing the bytes of the png
+ mask of each instance.
+ """
+ decoder = tf_example_decoder.TfExampleDecoder(
+ include_mask=self._include_mask)
+ decoded_tensors = decoder.decode(example)
+
+ image = decoded_tensors['image']
+ image_size = tf.shape(image)[0:2]
+ boxes = box_utils.denormalize_boxes(
+ decoded_tensors['groundtruth_boxes'], image_size)
+ groundtruths = {
+ 'source_id': tf.string_to_number(
+ decoded_tensors['source_id'], out_type=tf.int64),
+ 'height': decoded_tensors['height'],
+ 'width': decoded_tensors['width'],
+ 'num_detections': tf.shape(decoded_tensors['groundtruth_classes'])[0],
+ 'boxes': boxes,
+ 'classes': decoded_tensors['groundtruth_classes'],
+ 'is_crowds': decoded_tensors['groundtruth_is_crowd'],
+ 'areas': decoded_tensors['groundtruth_area'],
+ }
+ if self._include_mask:
+ groundtruths.update({
+ 'masks': decoded_tensors['groundtruth_instance_masks_png'],
+ })
+ return groundtruths
+
+ def _build_pipeline(self):
+ """Builds data pipeline to generate groundtruth annotations."""
+ dataset = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)
+ dataset = dataset.apply(
+ tf.data.experimental.parallel_interleave(
+ lambda filename: self._dataset_fn(filename).prefetch(1),
+ cycle_length=32,
+ sloppy=False))
+ dataset = dataset.map(self._parse_single_example, num_parallel_calls=64)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(1, drop_remainder=False)
+ return dataset
+
+ def __call__(self):
+ with tf.Graph().as_default():
+ dataset = self._build_pipeline()
+ groundtruth = dataset.make_one_shot_iterator().get_next()
+
+ with tf.Session() as sess:
+ for _ in range(self._num_examples):
+ groundtruth_result = sess.run(groundtruth)
+ yield groundtruth_result
+
+
+def scan_and_generator_annotation_file(file_pattern,
+ num_samples,
+ include_mask,
+ annotation_file):
+ """Scans and generate the COCO-style annotation JSON file given a dataset."""
+ groundtruth_generator = COCOGroundtruthGenerator(
+ file_pattern, num_samples, include_mask)
+ generate_annotation_file(groundtruth_generator, annotation_file)
+
+
+def generate_annotation_file(groundtruth_generator,
+ annotation_file):
+ """Generates COCO-style annotation JSON file given a groundtruth generator."""
+ groundtruths = {}
+ logging.info('Loading groundtruth annotations from dataset to memory...')
+ for groundtruth in groundtruth_generator():
+ for k, v in six.iteritems(groundtruth):
+ if k not in groundtruths:
+ groundtruths[k] = [v]
+ else:
+ groundtruths[k].append(v)
+ gt_dataset = convert_groundtruths_to_coco_dataset(groundtruths)
+
+ logging.info('Saving groundtruth annotations to the JSON file...')
+ with tf.io.gfile.GFile(annotation_file, 'w') as f:
+ f.write(json.dumps(gt_dataset))
+ logging.info('Done saving the JSON file...')
diff --git a/models/official/vision/detection/evaluation/factory.py b/models/official/vision/detection/evaluation/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d44bf177071a97b663b41410a05d59d59f04456
--- /dev/null
+++ b/models/official/vision/detection/evaluation/factory.py
@@ -0,0 +1,40 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Evaluator factory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from official.vision.detection.evaluation import coco_evaluator
+
+
+def evaluator_generator(params):
+ """Generator function for various evaluators."""
+ if params.type == 'box':
+ evaluator = coco_evaluator.COCOEvaluator(
+ annotation_file=params.val_json_file, include_mask=False)
+ elif params.type == 'box_and_mask':
+ evaluator = coco_evaluator.COCOEvaluator(
+ annotation_file=params.val_json_file, include_mask=True)
+ elif params.type == 'shapemask_box_and_mask':
+ evaluator = coco_evaluator.ShapeMaskCOCOEvaluator(
+ mask_eval_class=params.mask_eval_class,
+ annotation_file=params.val_json_file, include_mask=True)
+
+ else:
+ raise ValueError('Evaluator %s is not supported.' % params.type)
+
+ return coco_evaluator.MetricWrapper(evaluator)
diff --git a/models/official/vision/detection/executor/__init__.py b/models/official/vision/detection/executor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/detection/executor/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/executor/detection_executor.py b/models/official/vision/detection/executor/detection_executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ff028cf67d6df37e5a0af31bc2e54844231fcd
--- /dev/null
+++ b/models/official/vision/detection/executor/detection_executor.py
@@ -0,0 +1,160 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An executor class for running model on TensorFlow 2.0."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import logging
+
+import tensorflow as tf
+from official.modeling.training import distributed_executor as executor
+from official.vision.detection.utils.object_detection import visualization_utils
+
+
+class DetectionDistributedExecutor(executor.DistributedExecutor):
+ """Detection specific customer training loop executor.
+
+ Subclasses the DistributedExecutor and adds support for numpy based metrics.
+ """
+
+ def __init__(self,
+ predict_post_process_fn=None,
+ trainable_variables_filter=None,
+ **kwargs):
+ super(DetectionDistributedExecutor, self).__init__(**kwargs)
+ if predict_post_process_fn:
+ assert callable(predict_post_process_fn)
+ if trainable_variables_filter:
+ assert callable(trainable_variables_filter)
+ self._predict_post_process_fn = predict_post_process_fn
+ self._trainable_variables_filter = trainable_variables_filter
+ self.eval_steps = tf.Variable(
+ 0,
+ trainable=False,
+ dtype=tf.int32,
+ synchronization=tf.VariableSynchronization.ON_READ,
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+ shape=[])
+
+ def _create_replicated_step(self,
+ strategy,
+ model,
+ loss_fn,
+ optimizer,
+ metric=None):
+ trainable_variables = model.trainable_variables
+ if self._trainable_variables_filter:
+ trainable_variables = self._trainable_variables_filter(
+ trainable_variables)
+ logging.info('Filter trainable variables from %d to %d',
+ len(model.trainable_variables), len(trainable_variables))
+ _update_state = lambda labels, outputs: None
+ if isinstance(metric, tf.keras.metrics.Metric):
+ _update_state = lambda labels, outputs: metric.update_state(
+ labels, outputs)
+ else:
+ logging.error('Detection: train metric is not an instance of '
+ 'tf.keras.metrics.Metric.')
+
+ def _replicated_step(inputs):
+ """Replicated training step."""
+ inputs, labels = inputs
+
+ with tf.GradientTape() as tape:
+ outputs = model(inputs, training=True)
+ all_losses = loss_fn(labels, outputs)
+ losses = {}
+ for k, v in all_losses.items():
+ losses[k] = tf.reduce_mean(v)
+ per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync
+ _update_state(labels, outputs)
+
+ grads = tape.gradient(per_replica_loss, trainable_variables)
+ optimizer.apply_gradients(zip(grads, trainable_variables))
+ return losses
+
+ return _replicated_step
+
+ def _create_test_step(self, strategy, model, metric):
+ """Creates a distributed test step."""
+
+ @tf.function
+ def test_step(iterator, eval_steps):
+ """Calculates evaluation metrics on distributed devices."""
+
+ def _test_step_fn(inputs, eval_steps):
+ """Replicated accuracy calculation."""
+ inputs, labels = inputs
+ model_outputs = model(inputs, training=False)
+ if self._predict_post_process_fn:
+ labels, prediction_outputs = self._predict_post_process_fn(
+ labels, model_outputs)
+ num_remaining_visualizations = (
+ self._params.eval.num_images_to_visualize - eval_steps)
+ # If there are remaining number of visualizations that needs to be
+ # done, add next batch outputs for visualization.
+ #
+ # TODO(hongjunchoi): Once dynamic slicing is supported on TPU, only
+ # write correct slice of outputs to summary file.
+ if num_remaining_visualizations > 0:
+ visualization_utils.visualize_images_with_bounding_boxes(
+ inputs, prediction_outputs['detection_boxes'],
+ self.global_train_step, self.eval_summary_writer)
+
+ return labels, prediction_outputs
+
+ labels, outputs = strategy.run(
+ _test_step_fn, args=(
+ next(iterator),
+ eval_steps,
+ ))
+ outputs = tf.nest.map_structure(strategy.experimental_local_results,
+ outputs)
+ labels = tf.nest.map_structure(strategy.experimental_local_results,
+ labels)
+
+ eval_steps.assign_add(self._params.eval.batch_size)
+ return labels, outputs
+
+ return test_step
+
+ def _run_evaluation(self, test_step, current_training_step, metric,
+ test_iterator):
+ """Runs validation steps and aggregate metrics."""
+ self.eval_steps.assign(0)
+ if not test_iterator or not metric:
+ logging.warning(
+ 'Both test_iterator (%s) and metrics (%s) must not be None.',
+ test_iterator, metric)
+ return None
+ logging.info('Running evaluation after step: %s.', current_training_step)
+ while True:
+ try:
+ labels, outputs = test_step(test_iterator, self.eval_steps)
+ if metric:
+ metric.update_state(labels, outputs)
+ except (StopIteration, tf.errors.OutOfRangeError):
+ break
+
+ metric_result = metric.result()
+ if isinstance(metric, tf.keras.metrics.Metric):
+ metric_result = tf.nest.map_structure(lambda x: x.numpy().astype(float),
+ metric_result)
+ logging.info('Step: [%d] Validation metric = %s', current_training_step,
+ metric_result)
+ return metric_result
diff --git a/models/official/vision/detection/main.py b/models/official/vision/detection/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..542be3a1dcc73f82719af2d60dc9abd210787931
--- /dev/null
+++ b/models/official/vision/detection/main.py
@@ -0,0 +1,271 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Main function to train various object detection models."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import functools
+import pprint
+
+# pylint: disable=g-bad-import-order
+import tensorflow as tf
+
+from absl import app
+from absl import flags
+from absl import logging
+# pylint: enable=g-bad-import-order
+
+from official.modeling.hyperparams import params_dict
+from official.modeling.training import distributed_executor as executor
+from official.utils import hyperparams_flags
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+from official.vision.detection.configs import factory as config_factory
+from official.vision.detection.dataloader import input_reader
+from official.vision.detection.dataloader import mode_keys as ModeKeys
+from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
+from official.vision.detection.modeling import factory as model_factory
+
+hyperparams_flags.initialize_common_flags()
+flags_core.define_log_steps()
+
+flags.DEFINE_bool('enable_xla', default=False, help='Enable XLA for GPU')
+
+flags.DEFINE_string(
+ 'mode', default='train', help='Mode to run: `train` or `eval`.')
+
+flags.DEFINE_string(
+ 'model', default='retinanet',
+ help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.')
+
+flags.DEFINE_string('training_file_pattern', None,
+ 'Location of the train data.')
+
+flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
+
+flags.DEFINE_string(
+ 'checkpoint_path', None,
+ 'The checkpoint path to eval. Only used in eval_once mode.')
+
+FLAGS = flags.FLAGS
+
+
+def run_executor(params,
+ mode,
+ checkpoint_path=None,
+ train_input_fn=None,
+ eval_input_fn=None,
+ callbacks=None,
+ prebuilt_strategy=None):
+ """Runs the object detection model on distribution strategy defined by the user."""
+
+ if params.architecture.use_bfloat16:
+ policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
+ 'mixed_bfloat16')
+ tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
+
+ model_builder = model_factory.model_generator(params)
+
+ if prebuilt_strategy is not None:
+ strategy = prebuilt_strategy
+ else:
+ strategy_config = params.strategy_config
+ distribution_utils.configure_cluster(strategy_config.worker_hosts,
+ strategy_config.task_index)
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=params.strategy_type,
+ num_gpus=strategy_config.num_gpus,
+ all_reduce_alg=strategy_config.all_reduce_alg,
+ num_packs=strategy_config.num_packs,
+ tpu_address=strategy_config.tpu)
+
+ num_workers = int(strategy.num_replicas_in_sync + 7) // 8
+ is_multi_host = (int(num_workers) >= 2)
+
+ if mode == 'train':
+
+ def _model_fn(params):
+ return model_builder.build_model(params, mode=ModeKeys.TRAIN)
+
+ logging.info(
+ 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
+ strategy.num_replicas_in_sync, num_workers, is_multi_host)
+
+ dist_executor = DetectionDistributedExecutor(
+ strategy=strategy,
+ params=params,
+ model_fn=_model_fn,
+ loss_fn=model_builder.build_loss_fn,
+ is_multi_host=is_multi_host,
+ predict_post_process_fn=model_builder.post_processing,
+ trainable_variables_filter=model_builder
+ .make_filter_trainable_variables_fn())
+
+ if is_multi_host:
+ train_input_fn = functools.partial(
+ train_input_fn,
+ batch_size=params.train.batch_size // strategy.num_replicas_in_sync)
+
+ return dist_executor.train(
+ train_input_fn=train_input_fn,
+ model_dir=params.model_dir,
+ iterations_per_loop=params.train.iterations_per_loop,
+ total_steps=params.train.total_steps,
+ init_checkpoint=model_builder.make_restore_checkpoint_fn(),
+ custom_callbacks=callbacks,
+ save_config=True)
+ elif mode == 'eval' or mode == 'eval_once':
+
+ def _model_fn(params):
+ return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
+
+ logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
+ strategy.num_replicas_in_sync, num_workers, is_multi_host)
+
+ if is_multi_host:
+ eval_input_fn = functools.partial(
+ eval_input_fn,
+ batch_size=params.eval.batch_size // strategy.num_replicas_in_sync)
+
+ dist_executor = DetectionDistributedExecutor(
+ strategy=strategy,
+ params=params,
+ model_fn=_model_fn,
+ loss_fn=model_builder.build_loss_fn,
+ is_multi_host=is_multi_host,
+ predict_post_process_fn=model_builder.post_processing,
+ trainable_variables_filter=model_builder
+ .make_filter_trainable_variables_fn())
+
+ if mode == 'eval':
+ results = dist_executor.evaluate_from_model_dir(
+ model_dir=params.model_dir,
+ eval_input_fn=eval_input_fn,
+ eval_metric_fn=model_builder.eval_metrics,
+ eval_timeout=params.eval.eval_timeout,
+ min_eval_interval=params.eval.min_eval_interval,
+ total_steps=params.train.total_steps)
+ else:
+ # Run evaluation once for a single checkpoint.
+ if not checkpoint_path:
+ raise ValueError('checkpoint_path cannot be empty.')
+ if tf.io.gfile.isdir(checkpoint_path):
+ checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
+ summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
+ results, _ = dist_executor.evaluate_checkpoint(
+ checkpoint_path=checkpoint_path,
+ eval_input_fn=eval_input_fn,
+ eval_metric_fn=model_builder.eval_metrics,
+ summary_writer=summary_writer)
+ for k, v in results.items():
+ logging.info('Final eval metric %s: %f', k, v)
+ return results
+ else:
+ raise ValueError('Mode not found: %s.' % mode)
+
+
+def run(callbacks=None):
+ keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)
+
+ params = config_factory.config_generator(FLAGS.model)
+
+ params = params_dict.override_params_dict(
+ params, FLAGS.config_file, is_strict=True)
+
+ params = params_dict.override_params_dict(
+ params, FLAGS.params_override, is_strict=True)
+ params.override(
+ {
+ 'strategy_type': FLAGS.strategy_type,
+ 'model_dir': FLAGS.model_dir,
+ 'strategy_config': executor.strategy_flags_dict(),
+ },
+ is_strict=False)
+
+ # Make sure use_tpu and strategy_type are in sync.
+ params.use_tpu = (params.strategy_type == 'tpu')
+
+ if not params.use_tpu:
+ params.override({
+ 'architecture': {
+ 'use_bfloat16': False,
+ },
+ 'norm_activation': {
+ 'use_sync_bn': False,
+ },
+ }, is_strict=True)
+
+ params.validate()
+ params.lock()
+ pp = pprint.PrettyPrinter()
+ params_str = pp.pformat(params.as_dict())
+ logging.info('Model Parameters: %s', params_str)
+
+ train_input_fn = None
+ eval_input_fn = None
+ training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
+ eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
+ if not training_file_pattern and not eval_file_pattern:
+ raise ValueError('Must provide at least one of training_file_pattern and '
+ 'eval_file_pattern.')
+
+ if training_file_pattern:
+ # Use global batch size for single host.
+ train_input_fn = input_reader.InputFn(
+ file_pattern=training_file_pattern,
+ params=params,
+ mode=input_reader.ModeKeys.TRAIN,
+ batch_size=params.train.batch_size)
+
+ if eval_file_pattern:
+ eval_input_fn = input_reader.InputFn(
+ file_pattern=eval_file_pattern,
+ params=params,
+ mode=input_reader.ModeKeys.PREDICT_WITH_GT,
+ batch_size=params.eval.batch_size,
+ num_examples=params.eval.eval_samples)
+
+ if callbacks is None:
+ callbacks = []
+
+ if FLAGS.log_steps:
+ callbacks.append(
+ keras_utils.TimeHistory(
+ batch_size=params.train.batch_size,
+ log_steps=FLAGS.log_steps,
+ ))
+
+ return run_executor(
+ params,
+ FLAGS.mode,
+ checkpoint_path=FLAGS.checkpoint_path,
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ callbacks=callbacks)
+
+
+def main(argv):
+ del argv # Unused.
+
+ run()
+
+
+if __name__ == '__main__':
+ tf.config.set_soft_device_placement(True)
+ app.run(main)
diff --git a/models/official/vision/detection/modeling/__init__.py b/models/official/vision/detection/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/detection/modeling/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/modeling/architecture/__init__.py b/models/official/vision/detection/modeling/architecture/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/detection/modeling/architecture/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/modeling/architecture/factory.py b/models/official/vision/detection/modeling/architecture/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed5647d6fb83fbd7c404a4573ff247acb8999b8c
--- /dev/null
+++ b/models/official/vision/detection/modeling/architecture/factory.py
@@ -0,0 +1,163 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model architecture factory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from official.vision.detection.modeling.architecture import fpn
+from official.vision.detection.modeling.architecture import heads
+from official.vision.detection.modeling.architecture import identity
+from official.vision.detection.modeling.architecture import nn_ops
+from official.vision.detection.modeling.architecture import resnet
+
+
+def norm_activation_generator(params):
+ return nn_ops.norm_activation_builder(
+ momentum=params.batch_norm_momentum,
+ epsilon=params.batch_norm_epsilon,
+ trainable=params.batch_norm_trainable,
+ activation=params.activation)
+
+
+def backbone_generator(params):
+ """Generator function for various backbone models."""
+ if params.architecture.backbone == 'resnet':
+ resnet_params = params.resnet
+ backbone_fn = resnet.Resnet(
+ resnet_depth=resnet_params.resnet_depth,
+ activation=params.norm_activation.activation,
+ norm_activation=norm_activation_generator(
+ params.norm_activation))
+ else:
+ raise ValueError('Backbone model `{}` is not supported.'
+ .format(params.architecture.backbone))
+
+ return backbone_fn
+
+
+def multilevel_features_generator(params):
+ """Generator function for various FPN models."""
+ if params.architecture.multilevel_features == 'fpn':
+ fpn_params = params.fpn
+ fpn_fn = fpn.Fpn(
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ fpn_feat_dims=fpn_params.fpn_feat_dims,
+ use_separable_conv=fpn_params.use_separable_conv,
+ activation=params.norm_activation.activation,
+ use_batch_norm=fpn_params.use_batch_norm,
+ norm_activation=norm_activation_generator(
+ params.norm_activation))
+ elif params.architecture.multilevel_features == 'identity':
+ fpn_fn = identity.Identity()
+ else:
+ raise ValueError('The multi-level feature model `{}` is not supported.'
+ .format(params.architecture.multilevel_features))
+ return fpn_fn
+
+
+def retinanet_head_generator(params):
+ """Generator function for RetinaNet head architecture."""
+ head_params = params.retinanet_head
+ return heads.RetinanetHead(
+ params.architecture.min_level,
+ params.architecture.max_level,
+ params.architecture.num_classes,
+ head_params.anchors_per_location,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def rpn_head_generator(params):
+ """Generator function for RPN head architecture."""
+ head_params = params.rpn_head
+ return heads.RpnHead(
+ params.architecture.min_level,
+ params.architecture.max_level,
+ head_params.anchors_per_location,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def fast_rcnn_head_generator(params):
+ """Generator function for Fast R-CNN head architecture."""
+ head_params = params.frcnn_head
+ return heads.FastrcnnHead(
+ params.architecture.num_classes,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ head_params.num_fcs,
+ head_params.fc_dims,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def mask_rcnn_head_generator(params):
+ """Generator function for Mask R-CNN head architecture."""
+ head_params = params.mrcnn_head
+ return heads.MaskrcnnHead(
+ params.architecture.num_classes,
+ params.architecture.mask_target_size,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def shapeprior_head_generator(params):
+ """Generator function for shape prior head architecture."""
+ head_params = params.shapemask_head
+ return heads.ShapemaskPriorHead(
+ params.architecture.num_classes,
+ head_params.num_downsample_channels,
+ head_params.mask_crop_size,
+ head_params.use_category_for_mask,
+ head_params.shape_prior_path)
+
+
+def coarsemask_head_generator(params):
+ """Generator function for ShapeMask coarse mask head architecture."""
+ head_params = params.shapemask_head
+ return heads.ShapemaskCoarsemaskHead(
+ params.architecture.num_classes,
+ head_params.num_downsample_channels,
+ head_params.mask_crop_size,
+ head_params.use_category_for_mask,
+ head_params.num_convs,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def finemask_head_generator(params):
+ """Generator function for Shapemask fine mask head architecture."""
+ head_params = params.shapemask_head
+ return heads.ShapemaskFinemaskHead(
+ params.architecture.num_classes,
+ head_params.num_downsample_channels,
+ head_params.mask_crop_size,
+ head_params.use_category_for_mask,
+ head_params.num_convs,
+ head_params.upsample_factor)
diff --git a/models/official/vision/detection/modeling/architecture/fpn.py b/models/official/vision/detection/modeling/architecture/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b968dc2e152eb66e2df7ca7673b506c123b59d0f
--- /dev/null
+++ b/models/official/vision/detection/modeling/architecture/fpn.py
@@ -0,0 +1,151 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Feature Pyramid Networks.
+
+Feature Pyramid Networks were proposed in:
+[1] Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan,
+ , and Serge Belongie
+ Feature Pyramid Networks for Object Detection. CVPR 2017.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf
+
+from tensorflow.python.keras import backend
+from official.vision.detection.modeling.architecture import nn_ops
+from official.vision.detection.ops import spatial_transform_ops
+
+
+class Fpn(object):
+ """Feature pyramid networks."""
+
+ def __init__(self,
+ min_level=3,
+ max_level=7,
+ fpn_feat_dims=256,
+ use_separable_conv=False,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(
+ activation='relu')):
+ """FPN initialization function.
+
+ Args:
+ min_level: `int` minimum level in FPN output feature maps.
+ max_level: `int` maximum level in FPN output feature maps.
+ fpn_feat_dims: `int` number of filters in FPN layers.
+ use_separable_conv: `bool`, if True use separable convolution for
+ convolution in FPN layers.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer
+ followed by an optional activation layer.
+ """
+ self._min_level = min_level
+ self._max_level = max_level
+ self._fpn_feat_dims = fpn_feat_dims
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf.keras.layers.SeparableConv2D, depth_multiplier=1)
+ else:
+ self._conv2d_op = tf.keras.layers.Conv2D
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+ self._norm_activation = norm_activation
+
+ self._norm_activations = {}
+ self._lateral_conv2d_op = {}
+ self._post_hoc_conv2d_op = {}
+ self._coarse_conv2d_op = {}
+ for level in range(self._min_level, self._max_level + 1):
+ if self._use_batch_norm:
+ self._norm_activations[level] = norm_activation(
+ use_activation=False, name='p%d-bn' % level)
+ self._lateral_conv2d_op[level] = self._conv2d_op(
+ filters=self._fpn_feat_dims,
+ kernel_size=(1, 1),
+ padding='same',
+ name='l%d' % level)
+ self._post_hoc_conv2d_op[level] = self._conv2d_op(
+ filters=self._fpn_feat_dims,
+ strides=(1, 1),
+ kernel_size=(3, 3),
+ padding='same',
+ name='post_hoc_d%d' % level)
+ self._coarse_conv2d_op[level] = self._conv2d_op(
+ filters=self._fpn_feat_dims,
+ strides=(2, 2),
+ kernel_size=(3, 3),
+ padding='same',
+ name='p%d' % level)
+
+ def __call__(self, multilevel_features, is_training=None):
+ """Returns the FPN features for a given multilevel features.
+
+ Args:
+ multilevel_features: a `dict` containing `int` keys for continuous feature
+ levels, e.g., [2, 3, 4, 5]. The values are corresponding features with
+ shape [batch_size, height_l, width_l, num_filters].
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ a `dict` containing `int` keys for continuous feature levels
+ [min_level, min_level + 1, ..., max_level]. The values are corresponding
+ FPN features with shape [batch_size, height_l, width_l, fpn_feat_dims].
+ """
+ input_levels = list(multilevel_features.keys())
+ if min(input_levels) > self._min_level:
+ raise ValueError(
+ 'The minimum backbone level %d should be '%(min(input_levels)) +
+ 'less or equal to FPN minimum level %d.:'%(self._min_level))
+ backbone_max_level = min(max(input_levels), self._max_level)
+ with backend.get_graph().as_default(), tf.name_scope('fpn'):
+ # Adds lateral connections.
+ feats_lateral = {}
+ for level in range(self._min_level, backbone_max_level + 1):
+ feats_lateral[level] = self._lateral_conv2d_op[level](
+ multilevel_features[level])
+
+ # Adds top-down path.
+ feats = {backbone_max_level: feats_lateral[backbone_max_level]}
+ for level in range(backbone_max_level - 1, self._min_level - 1, -1):
+ feats[level] = spatial_transform_ops.nearest_upsampling(
+ feats[level + 1], 2) + feats_lateral[level]
+
+ # Adds post-hoc 3x3 convolution kernel.
+ for level in range(self._min_level, backbone_max_level + 1):
+ feats[level] = self._post_hoc_conv2d_op[level](feats[level])
+
+ # Adds coarser FPN levels introduced for RetinaNet.
+ for level in range(backbone_max_level + 1, self._max_level + 1):
+ feats_in = feats[level - 1]
+ if level > backbone_max_level + 1:
+ feats_in = self._activation_op(feats_in)
+ feats[level] = self._coarse_conv2d_op[level](feats_in)
+ if self._use_batch_norm:
+ # Adds batch_norm layer.
+ for level in range(self._min_level, self._max_level + 1):
+ feats[level] = self._norm_activations[level](
+ feats[level], is_training=is_training)
+ return feats
diff --git a/models/official/vision/detection/modeling/architecture/heads.py b/models/official/vision/detection/modeling/architecture/heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f6954aecbbef8e8807345e643555ba222b0e1b9
--- /dev/null
+++ b/models/official/vision/detection/modeling/architecture/heads.py
@@ -0,0 +1,999 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes to build various prediction heads in all supported models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.keras import backend
+from official.vision.detection.modeling.architecture import nn_ops
+from official.vision.detection.ops import spatial_transform_ops
+
+
+class RpnHead(tf.keras.layers.Layer):
+ """Region Proposal Network head."""
+
+ def __init__(self,
+ min_level,
+ max_level,
+ anchors_per_location,
+ num_convs=2,
+ num_filters=256,
+ use_separable_conv=False,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(
+ activation='relu')):
+ """Initialize params to build Region Proposal Network head.
+
+ Args:
+ min_level: `int` number of minimum feature level.
+ max_level: `int` number of maximum feature level.
+ anchors_per_location: `int` number of number of anchors per pixel
+ location.
+ num_convs: `int` number that represents the number of the intermediate
+ conv layers before the prediction.
+ num_filters: `int` number that represents the number of filters of the
+ intermediate conv layers.
+ use_separable_conv: `bool`, indicating whether the separable conv layers
+ is used.
+ activation: activation function. Support 'relu' and 'swish'.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer
+ followed by an optional activation layer.
+ """
+ self._min_level = min_level
+ self._max_level = max_level
+ self._anchors_per_location = anchors_per_location
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf.keras.layers.SeparableConv2D,
+ depth_multiplier=1,
+ bias_initializer=tf.zeros_initializer())
+ else:
+ self._conv2d_op = functools.partial(
+ tf.keras.layers.Conv2D,
+ kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
+ bias_initializer=tf.zeros_initializer())
+
+ self._rpn_conv = self._conv2d_op(
+ num_filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ activation=(None if self._use_batch_norm else self._activation_op),
+ padding='same',
+ name='rpn')
+ self._rpn_class_conv = self._conv2d_op(
+ anchors_per_location,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='rpn-class')
+ self._rpn_box_conv = self._conv2d_op(
+ 4 * anchors_per_location,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='rpn-box')
+
+ self._norm_activations = {}
+ if self._use_batch_norm:
+ for level in range(self._min_level, self._max_level + 1):
+ self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' %
+ level)
+
+ def _shared_rpn_heads(self, features, anchors_per_location, level,
+ is_training):
+ """Shared RPN heads."""
+ features = self._rpn_conv(features)
+ if self._use_batch_norm:
+ # The batch normalization layers are not shared between levels.
+ features = self._norm_activations[level](
+ features, is_training=is_training)
+ # Proposal classification scores
+ scores = self._rpn_class_conv(features)
+ # Proposal bbox regression deltas
+ bboxes = self._rpn_box_conv(features)
+
+ return scores, bboxes
+
+ def __call__(self, features, is_training=None):
+
+ scores_outputs = {}
+ box_outputs = {}
+
+ with backend.get_graph().as_default(), tf.name_scope('rpn_head'):
+ for level in range(self._min_level, self._max_level + 1):
+ scores_output, box_output = self._shared_rpn_heads(
+ features[level], self._anchors_per_location, level, is_training)
+ scores_outputs[level] = scores_output
+ box_outputs[level] = box_output
+ return scores_outputs, box_outputs
+
+
+class FastrcnnHead(tf.keras.layers.Layer):
+ """Fast R-CNN box head."""
+
+ def __init__(self,
+ num_classes,
+ num_convs=0,
+ num_filters=256,
+ use_separable_conv=False,
+ num_fcs=2,
+ fc_dims=1024,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(
+ activation='relu')):
+ """Initialize params to build Fast R-CNN box head.
+
+ Args:
+ num_classes: a integer for the number of classes.
+ num_convs: `int` number that represents the number of the intermediate
+ conv layers before the FC layers.
+ num_filters: `int` number that represents the number of filters of the
+ intermediate conv layers.
+ use_separable_conv: `bool`, indicating whether the separable conv layers
+ is used.
+ num_fcs: `int` number that represents the number of FC layers before the
+ predictions.
+ fc_dims: `int` number that represents the number of dimension of the FC
+ layers.
+ activation: activation function. Support 'relu' and 'swish'.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer
+ followed by an optional activation layer.
+ """
+ self._num_classes = num_classes
+
+ self._num_convs = num_convs
+ self._num_filters = num_filters
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf.keras.layers.SeparableConv2D,
+ depth_multiplier=1,
+ bias_initializer=tf.zeros_initializer())
+ else:
+ self._conv2d_op = functools.partial(
+ tf.keras.layers.Conv2D,
+ kernel_initializer=tf.keras.initializers.VarianceScaling(
+ scale=2, mode='fan_out', distribution='untruncated_normal'),
+ bias_initializer=tf.zeros_initializer())
+
+ self._num_fcs = num_fcs
+ self._fc_dims = fc_dims
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+ self._norm_activation = norm_activation
+
+ self._conv_ops = []
+ self._conv_bn_ops = []
+ for i in range(self._num_convs):
+ self._conv_ops.append(
+ self._conv2d_op(
+ self._num_filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding='same',
+ dilation_rate=(1, 1),
+ activation=(None if self._use_batch_norm else self._activation_op),
+ name='conv_{}'.format(i)))
+ if self._use_batch_norm:
+ self._conv_bn_ops.append(self._norm_activation())
+
+ self._fc_ops = []
+ self._fc_bn_ops = []
+ for i in range(self._num_fcs):
+ self._fc_ops.append(
+ tf.keras.layers.Dense(
+ units=self._fc_dims,
+ activation=(None if self._use_batch_norm else self._activation_op),
+ name='fc{}'.format(i)))
+ if self._use_batch_norm:
+ self._fc_bn_ops.append(self._norm_activation(fused=False))
+
+ self._class_predict = tf.keras.layers.Dense(
+ self._num_classes,
+ kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
+ bias_initializer=tf.zeros_initializer(),
+ name='class-predict')
+ self._box_predict = tf.keras.layers.Dense(
+ self._num_classes * 4,
+ kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.001),
+ bias_initializer=tf.zeros_initializer(),
+ name='box-predict')
+
+ def __call__(self, roi_features, is_training=None):
+ """Box and class branches for the Mask-RCNN model.
+
+ Args:
+ roi_features: A ROI feature tensor of shape
+ [batch_size, num_rois, height_l, width_l, num_filters].
+ is_training: `boolean`, if True if model is in training mode.
+
+ Returns:
+ class_outputs: a tensor with a shape of
+ [batch_size, num_rois, num_classes], representing the class predictions.
+ box_outputs: a tensor with a shape of
+ [batch_size, num_rois, num_classes * 4], representing the box
+ predictions.
+ """
+
+ with backend.get_graph().as_default(), tf.name_scope('fast_rcnn_head'):
+ # reshape inputs beofre FC.
+ _, num_rois, height, width, filters = roi_features.get_shape().as_list()
+
+ net = tf.reshape(roi_features, [-1, height, width, filters])
+ for i in range(self._num_convs):
+ net = self._conv_ops[i](net)
+ if self._use_batch_norm:
+ net = self._conv_bn_ops[i](net, is_training=is_training)
+
+ filters = self._num_filters if self._num_convs > 0 else filters
+ net = tf.reshape(net, [-1, num_rois, height * width * filters])
+
+ for i in range(self._num_fcs):
+ net = self._fc_ops[i](net)
+ if self._use_batch_norm:
+ net = self._fc_bn_ops[i](net, is_training=is_training)
+
+ class_outputs = self._class_predict(net)
+ box_outputs = self._box_predict(net)
+ return class_outputs, box_outputs
+
+
+class MaskrcnnHead(tf.keras.layers.Layer):
+ """Mask R-CNN head."""
+
+ def __init__(self,
+ num_classes,
+ mask_target_size,
+ num_convs=4,
+ num_filters=256,
+ use_separable_conv=False,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(
+ activation='relu')):
+ """Initialize params to build Fast R-CNN head.
+
+ Args:
+ num_classes: a integer for the number of classes.
+ mask_target_size: a integer that is the resolution of masks.
+ num_convs: `int` number that represents the number of the intermediate
+ conv layers before the prediction.
+ num_filters: `int` number that represents the number of filters of the
+ intermediate conv layers.
+ use_separable_conv: `bool`, indicating whether the separable conv layers
+ is used.
+ activation: activation function. Support 'relu' and 'swish'.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer
+ followed by an optional activation layer.
+ """
+ self._num_classes = num_classes
+ self._mask_target_size = mask_target_size
+
+ self._num_convs = num_convs
+ self._num_filters = num_filters
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf.keras.layers.SeparableConv2D,
+ depth_multiplier=1,
+ bias_initializer=tf.zeros_initializer())
+ else:
+ self._conv2d_op = functools.partial(
+ tf.keras.layers.Conv2D,
+ kernel_initializer=tf.keras.initializers.VarianceScaling(
+ scale=2, mode='fan_out', distribution='untruncated_normal'),
+ bias_initializer=tf.zeros_initializer())
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+ self._norm_activation = norm_activation
+ self._conv2d_ops = []
+ for i in range(self._num_convs):
+ self._conv2d_ops.append(
+ self._conv2d_op(
+ self._num_filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding='same',
+ dilation_rate=(1, 1),
+ activation=(None if self._use_batch_norm else self._activation_op),
+ name='mask-conv-l%d' % i))
+ self._mask_conv_transpose = tf.keras.layers.Conv2DTranspose(
+ self._num_filters,
+ kernel_size=(2, 2),
+ strides=(2, 2),
+ padding='valid',
+ activation=(None if self._use_batch_norm else self._activation_op),
+ kernel_initializer=tf.keras.initializers.VarianceScaling(
+ scale=2, mode='fan_out', distribution='untruncated_normal'),
+ bias_initializer=tf.zeros_initializer(),
+ name='conv5-mask')
+
+ def __call__(self, roi_features, class_indices, is_training=None):
+ """Mask branch for the Mask-RCNN model.
+
+ Args:
+ roi_features: A ROI feature tensor of shape
+ [batch_size, num_rois, height_l, width_l, num_filters].
+ class_indices: a Tensor of shape [batch_size, num_rois], indicating
+ which class the ROI is.
+ is_training: `boolean`, if True if model is in training mode.
+
+ Returns:
+ mask_outputs: a tensor with a shape of
+ [batch_size, num_masks, mask_height, mask_width, num_classes],
+ representing the mask predictions.
+ fg_gather_indices: a tensor with a shape of [batch_size, num_masks, 2],
+ representing the fg mask targets.
+ Raises:
+ ValueError: If boxes is not a rank-3 tensor or the last dimension of
+ boxes is not 4.
+ """
+
+ with backend.get_graph().as_default():
+ with tf.name_scope('mask_head'):
+ _, num_rois, height, width, filters = roi_features.get_shape().as_list()
+ net = tf.reshape(roi_features, [-1, height, width, filters])
+
+ for i in range(self._num_convs):
+ net = self._conv2d_ops[i](net)
+ if self._use_batch_norm:
+ net = self._norm_activation()(net, is_training=is_training)
+
+ net = self._mask_conv_transpose(net)
+ if self._use_batch_norm:
+ net = self._norm_activation()(net, is_training=is_training)
+
+ mask_outputs = self._conv2d_op(
+ self._num_classes,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='mask_fcn_logits')(
+ net)
+ mask_outputs = tf.reshape(mask_outputs, [
+ -1, num_rois, self._mask_target_size, self._mask_target_size,
+ self._num_classes
+ ])
+
+ with tf.name_scope('masks_post_processing'):
+ # TODO(pengchong): Figure out the way not to use the static inferred
+ # batch size.
+ batch_size, num_masks = class_indices.get_shape().as_list()
+ mask_outputs = tf.transpose(a=mask_outputs, perm=[0, 1, 4, 2, 3])
+ # Contructs indices for gather.
+ batch_indices = tf.tile(
+ tf.expand_dims(tf.range(batch_size), axis=1), [1, num_masks])
+ mask_indices = tf.tile(
+ tf.expand_dims(tf.range(num_masks), axis=0), [batch_size, 1])
+ gather_indices = tf.stack(
+ [batch_indices, mask_indices, class_indices], axis=2)
+ mask_outputs = tf.gather_nd(mask_outputs, gather_indices)
+ return mask_outputs
+
+
+class RetinanetHead(object):
+ """RetinaNet head."""
+
+ def __init__(self,
+ min_level,
+ max_level,
+ num_classes,
+ anchors_per_location,
+ num_convs=4,
+ num_filters=256,
+ use_separable_conv=False,
+ norm_activation=nn_ops.norm_activation_builder(
+ activation='relu')):
+ """Initialize params to build RetinaNet head.
+
+ Args:
+ min_level: `int` number of minimum feature level.
+ max_level: `int` number of maximum feature level.
+ num_classes: `int` number of classification categories.
+ anchors_per_location: `int` number of anchors per pixel location.
+ num_convs: `int` number of stacked convolution before the last prediction
+ layer.
+ num_filters: `int` number of filters used in the head architecture.
+ use_separable_conv: `bool` to indicate whether to use separable
+ convoluation.
+ norm_activation: an operation that includes a normalization layer
+ followed by an optional activation layer.
+ """
+ self._min_level = min_level
+ self._max_level = max_level
+
+ self._num_classes = num_classes
+ self._anchors_per_location = anchors_per_location
+
+ self._num_convs = num_convs
+ self._num_filters = num_filters
+ self._use_separable_conv = use_separable_conv
+ with tf.name_scope('class_net') as scope_name:
+ self._class_name_scope = tf.name_scope(scope_name)
+ with tf.name_scope('box_net') as scope_name:
+ self._box_name_scope = tf.name_scope(scope_name)
+ self._build_class_net_layers(norm_activation)
+ self._build_box_net_layers(norm_activation)
+
+ def _class_net_batch_norm_name(self, i, level):
+ return 'class-%d-%d' % (i, level)
+
+ def _box_net_batch_norm_name(self, i, level):
+ return 'box-%d-%d' % (i, level)
+
+ def _build_class_net_layers(self, norm_activation):
+ """Build re-usable layers for class prediction network."""
+ if self._use_separable_conv:
+ self._class_predict = tf.keras.layers.SeparableConv2D(
+ self._num_classes * self._anchors_per_location,
+ kernel_size=(3, 3),
+ bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
+ padding='same',
+ name='class-predict')
+ else:
+ self._class_predict = tf.keras.layers.Conv2D(
+ self._num_classes * self._anchors_per_location,
+ kernel_size=(3, 3),
+ bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
+ kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-5),
+ padding='same',
+ name='class-predict')
+ self._class_conv = []
+ self._class_norm_activation = {}
+ for i in range(self._num_convs):
+ if self._use_separable_conv:
+ self._class_conv.append(
+ tf.keras.layers.SeparableConv2D(
+ self._num_filters,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ activation=None,
+ padding='same',
+ name='class-' + str(i)))
+ else:
+ self._class_conv.append(
+ tf.keras.layers.Conv2D(
+ self._num_filters,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf.keras.initializers.RandomNormal(
+ stddev=0.01),
+ activation=None,
+ padding='same',
+ name='class-' + str(i)))
+ for level in range(self._min_level, self._max_level + 1):
+ name = self._class_net_batch_norm_name(i, level)
+ self._class_norm_activation[name] = norm_activation(name=name)
+
+ def _build_box_net_layers(self, norm_activation):
+ """Build re-usable layers for box prediction network."""
+ if self._use_separable_conv:
+ self._box_predict = tf.keras.layers.SeparableConv2D(
+ 4 * self._anchors_per_location,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ padding='same',
+ name='box-predict')
+ else:
+ self._box_predict = tf.keras.layers.Conv2D(
+ 4 * self._anchors_per_location,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-5),
+ padding='same',
+ name='box-predict')
+ self._box_conv = []
+ self._box_norm_activation = {}
+ for i in range(self._num_convs):
+ if self._use_separable_conv:
+ self._box_conv.append(
+ tf.keras.layers.SeparableConv2D(
+ self._num_filters,
+ kernel_size=(3, 3),
+ activation=None,
+ bias_initializer=tf.zeros_initializer(),
+ padding='same',
+ name='box-' + str(i)))
+ else:
+ self._box_conv.append(
+ tf.keras.layers.Conv2D(
+ self._num_filters,
+ kernel_size=(3, 3),
+ activation=None,
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf.keras.initializers.RandomNormal(
+ stddev=0.01),
+ padding='same',
+ name='box-' + str(i)))
+ for level in range(self._min_level, self._max_level + 1):
+ name = self._box_net_batch_norm_name(i, level)
+ self._box_norm_activation[name] = norm_activation(name=name)
+
+ def __call__(self, fpn_features, is_training=None):
+ """Returns outputs of RetinaNet head."""
+ class_outputs = {}
+ box_outputs = {}
+ with backend.get_graph().as_default(), tf.name_scope('retinanet_head'):
+ for level in range(self._min_level, self._max_level + 1):
+ features = fpn_features[level]
+
+ class_outputs[level] = self.class_net(
+ features, level, is_training=is_training)
+ box_outputs[level] = self.box_net(
+ features, level, is_training=is_training)
+ return class_outputs, box_outputs
+
+ def class_net(self, features, level, is_training):
+ """Class prediction network for RetinaNet."""
+ with self._class_name_scope:
+ for i in range(self._num_convs):
+ features = self._class_conv[i](features)
+ # The convolution layers in the class net are shared among all levels,
+ # but each level has its batch normlization to capture the statistical
+ # difference among different levels.
+ name = self._class_net_batch_norm_name(i, level)
+ features = self._class_norm_activation[name](
+ features, is_training=is_training)
+
+ classes = self._class_predict(features)
+ return classes
+
+ def box_net(self, features, level, is_training=None):
+ """Box regression network for RetinaNet."""
+ with self._box_name_scope:
+ for i in range(self._num_convs):
+ features = self._box_conv[i](features)
+ # The convolution layers in the box net are shared among all levels, but
+ # each level has its batch normlization to capture the statistical
+ # difference among different levels.
+ name = self._box_net_batch_norm_name(i, level)
+ features = self._box_norm_activation[name](
+ features, is_training=is_training)
+
+ boxes = self._box_predict(features)
+ return boxes
+
+
+# TODO(yeqing): Refactor this class when it is ready for var_scope reuse.
+class ShapemaskPriorHead(object):
+ """ShapeMask Prior head."""
+
+ def __init__(self,
+ num_classes,
+ num_downsample_channels,
+ mask_crop_size,
+ use_category_for_mask,
+ shape_prior_path):
+ """Initialize params to build RetinaNet head.
+
+ Args:
+ num_classes: Number of output classes.
+ num_downsample_channels: number of channels in mask branch.
+ mask_crop_size: feature crop size.
+ use_category_for_mask: use class information in mask branch.
+ shape_prior_path: the path to load shape priors.
+ """
+ self._mask_num_classes = num_classes if use_category_for_mask else 1
+ self._num_downsample_channels = num_downsample_channels
+ self._mask_crop_size = mask_crop_size
+ self._shape_prior_path = shape_prior_path
+ self._use_category_for_mask = use_category_for_mask
+
+ self._shape_prior_fc = tf.keras.layers.Dense(
+ self._num_downsample_channels, name='shape-prior-fc')
+
+ def __call__(self, fpn_features, boxes, outer_boxes, classes, is_training):
+ """Generate the detection priors from the box detections and FPN features.
+
+ This corresponds to the Fig. 4 of the ShapeMask paper at
+ https://arxiv.org/pdf/1904.03239.pdf
+
+ Args:
+ fpn_features: a dictionary of FPN features.
+ boxes: a float tensor of shape [batch_size, num_instances, 4]
+ representing the tight gt boxes from dataloader/detection.
+ outer_boxes: a float tensor of shape [batch_size, num_instances, 4]
+ representing the loose gt boxes from dataloader/detection.
+ classes: a int Tensor of shape [batch_size, num_instances]
+ of instance classes.
+ is_training: training mode or not.
+
+ Returns:
+ instance_features: a float Tensor of shape [batch_size * num_instances,
+ mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
+ instance feature crop.
+ detection_priors: A float Tensor of shape [batch_size * num_instances,
+ mask_size, mask_size, 1].
+ """
+ with backend.get_graph().as_default(), tf.name_scope('prior_mask'):
+ batch_size, num_instances, _ = boxes.get_shape().as_list()
+ outer_boxes = tf.cast(outer_boxes, tf.float32)
+ boxes = tf.cast(boxes, tf.float32)
+ instance_features = spatial_transform_ops.multilevel_crop_and_resize(
+ fpn_features, outer_boxes, output_size=self._mask_crop_size)
+ instance_features = self._shape_prior_fc(instance_features)
+
+ shape_priors = self._get_priors()
+
+ # Get uniform priors for each outer box.
+ uniform_priors = tf.ones([batch_size, num_instances, self._mask_crop_size,
+ self._mask_crop_size])
+ uniform_priors = spatial_transform_ops.crop_mask_in_target_box(
+ uniform_priors, boxes, outer_boxes, self._mask_crop_size)
+
+ # Classify shape priors using uniform priors + instance features.
+ prior_distribution = self._classify_shape_priors(
+ tf.cast(instance_features, tf.float32), uniform_priors, classes)
+
+ instance_priors = tf.gather(shape_priors, classes)
+ instance_priors *= tf.expand_dims(tf.expand_dims(
+ tf.cast(prior_distribution, tf.float32), axis=-1), axis=-1)
+ instance_priors = tf.reduce_sum(instance_priors, axis=2)
+ detection_priors = spatial_transform_ops.crop_mask_in_target_box(
+ instance_priors, boxes, outer_boxes, self._mask_crop_size)
+
+ return instance_features, detection_priors
+
+ def _get_priors(self):
+ """Load shape priors from file."""
+ # loads class specific or agnostic shape priors
+ if self._shape_prior_path:
+ # Priors are loaded into shape [mask_num_classes, num_clusters, 32, 32].
+ priors = np.load(tf.io.gfile.GFile(self._shape_prior_path, 'rb'))
+ priors = tf.convert_to_tensor(priors, dtype=tf.float32)
+ self._num_clusters = priors.get_shape().as_list()[1]
+ else:
+ # If prior path does not exist, do not use priors, i.e., pirors equal to
+ # uniform empty 32x32 patch.
+ self._num_clusters = 1
+ priors = tf.zeros([self._mask_num_classes, self._num_clusters,
+ self._mask_crop_size, self._mask_crop_size])
+ return priors
+
+ def _classify_shape_priors(self, features, uniform_priors, classes):
+ """Classify the uniform prior by predicting the shape modes.
+
+ Classify the object crop features into K modes of the clusters for each
+ category.
+
+ Args:
+ features: A float Tensor of shape [batch_size, num_instances,
+ mask_size, mask_size, num_channels].
+ uniform_priors: A float Tensor of shape [batch_size, num_instances,
+ mask_size, mask_size] representing the uniform detection priors.
+ classes: A int Tensor of shape [batch_size, num_instances]
+ of detection class ids.
+
+ Returns:
+ prior_distribution: A float Tensor of shape
+ [batch_size, num_instances, num_clusters] representing the classifier
+ output probability over all possible shapes.
+ """
+
+ batch_size, num_instances, _, _, _ = features.get_shape().as_list()
+ features *= tf.expand_dims(uniform_priors, axis=-1)
+ # Reduce spatial dimension of features. The features have shape
+ # [batch_size, num_instances, num_channels].
+ features = tf.reduce_mean(features, axis=(2, 3))
+ logits = tf.keras.layers.Dense(
+ self._mask_num_classes * self._num_clusters,
+ kernel_initializer=tf.random_normal_initializer(stddev=0.01))(features)
+ logits = tf.reshape(logits,
+ [batch_size, num_instances,
+ self._mask_num_classes, self._num_clusters])
+ if self._use_category_for_mask:
+ logits = tf.gather(logits, tf.expand_dims(classes, axis=-1), batch_dims=2)
+ logits = tf.squeeze(logits, axis=2)
+ else:
+ logits = logits[:, :, 0, :]
+
+ distribution = tf.nn.softmax(logits, name='shape_prior_weights')
+ return distribution
+
+
+class ShapemaskCoarsemaskHead(object):
+ """ShapemaskCoarsemaskHead head."""
+
+ def __init__(self,
+ num_classes,
+ num_downsample_channels,
+ mask_crop_size,
+ use_category_for_mask,
+ num_convs,
+ norm_activation=nn_ops.norm_activation_builder()):
+ """Initialize params to build ShapeMask coarse and fine prediction head.
+
+ Args:
+ num_classes: `int` number of mask classification categories.
+ num_downsample_channels: `int` number of filters at mask head.
+ mask_crop_size: feature crop size.
+ use_category_for_mask: use class information in mask branch.
+ num_convs: `int` number of stacked convolution before the last prediction
+ layer.
+ norm_activation: an operation that includes a normalization layer
+ followed by an optional activation layer.
+ """
+ self._mask_num_classes = num_classes if use_category_for_mask else 1
+ self._use_category_for_mask = use_category_for_mask
+ self._num_downsample_channels = num_downsample_channels
+ self._mask_crop_size = mask_crop_size
+ self._num_convs = num_convs
+ self._norm_activation = norm_activation
+
+ self._coarse_mask_fc = tf.keras.layers.Dense(
+ self._num_downsample_channels, name='coarse-mask-fc')
+
+ self._class_conv = []
+ self._class_norm_activation = []
+
+ for i in range(self._num_convs):
+ self._class_conv.append(tf.keras.layers.Conv2D(
+ self._num_downsample_channels,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
+ padding='same',
+ name='coarse-mask-class-%d' % i))
+
+ self._class_norm_activation.append(
+ norm_activation(name='coarse-mask-class-%d-bn' % i))
+
+ self._class_predict = tf.keras.layers.Conv2D(
+ self._mask_num_classes,
+ kernel_size=(1, 1),
+ # Focal loss bias initialization to have foreground 0.01 probability.
+ bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
+ kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
+ padding='same',
+ name='coarse-mask-class-predict')
+
+ def __call__(self, features, detection_priors, classes, is_training):
+ """Generate instance masks from FPN features and detection priors.
+
+ This corresponds to the Fig. 5-6 of the ShapeMask paper at
+ https://arxiv.org/pdf/1904.03239.pdf
+
+ Args:
+ features: a float Tensor of shape [batch_size, num_instances,
+ mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
+ instance feature crop.
+ detection_priors: a float Tensor of shape [batch_size, num_instances,
+ mask_crop_size, mask_crop_size, 1]. This is the detection prior for
+ the instance.
+ classes: a int Tensor of shape [batch_size, num_instances]
+ of instance classes.
+ is_training: a bool indicating whether in training mode.
+
+ Returns:
+ mask_outputs: instance mask prediction as a float Tensor of shape
+ [batch_size, num_instances, mask_size, mask_size].
+ """
+ with backend.get_graph().as_default(), tf.name_scope('coarse_mask'):
+ # Transform detection priors to have the same dimension as features.
+ detection_priors = tf.expand_dims(detection_priors, axis=-1)
+ detection_priors = self._coarse_mask_fc(detection_priors)
+
+ features += detection_priors
+ mask_logits = self.decoder_net(features, is_training)
+ # Gather the logits with right input class.
+ if self._use_category_for_mask:
+ mask_logits = tf.transpose(mask_logits, [0, 1, 4, 2, 3])
+ mask_logits = tf.gather(mask_logits, tf.expand_dims(classes, -1),
+ batch_dims=2)
+ mask_logits = tf.squeeze(mask_logits, axis=2)
+ else:
+ mask_logits = mask_logits[..., 0]
+
+ return mask_logits
+
+ def decoder_net(self, features, is_training=False):
+ """Coarse mask decoder network architecture.
+
+ Args:
+ features: A tensor of size [batch, height_in, width_in, channels_in].
+ is_training: Whether batch_norm layers are in training mode.
+
+ Returns:
+ images: A feature tensor of size [batch, output_size, output_size,
+ num_channels]
+ """
+ (batch_size, num_instances, height, width,
+ num_channels) = features.get_shape().as_list()
+ features = tf.reshape(features, [batch_size * num_instances, height, width,
+ num_channels])
+ for i in range(self._num_convs):
+ features = self._class_conv[i](features)
+ features = self._class_norm_activation[i](features,
+ is_training=is_training)
+
+ mask_logits = self._class_predict(features)
+ mask_logits = tf.reshape(mask_logits, [batch_size, num_instances, height,
+ width, self._mask_num_classes])
+ return mask_logits
+
+
+class ShapemaskFinemaskHead(object):
+ """ShapemaskFinemaskHead head."""
+
+ def __init__(self,
+ num_classes,
+ num_downsample_channels,
+ mask_crop_size,
+ use_category_for_mask,
+ num_convs,
+ upsample_factor,
+ norm_activation=nn_ops.norm_activation_builder()):
+ """Initialize params to build ShapeMask coarse and fine prediction head.
+
+ Args:
+ num_classes: `int` number of mask classification categories.
+ num_downsample_channels: `int` number of filters at mask head.
+ mask_crop_size: feature crop size.
+ use_category_for_mask: use class information in mask branch.
+ num_convs: `int` number of stacked convolution before the last prediction
+ layer.
+ upsample_factor: `int` number of fine mask upsampling factor.
+ norm_activation: an operation that includes a batch normalization layer
+ followed by a relu layer(optional).
+ """
+ self._use_category_for_mask = use_category_for_mask
+ self._mask_num_classes = num_classes if use_category_for_mask else 1
+ self._num_downsample_channels = num_downsample_channels
+ self._mask_crop_size = mask_crop_size
+ self._num_convs = num_convs
+ self.up_sample_factor = upsample_factor
+
+ self._fine_mask_fc = tf.keras.layers.Dense(
+ self._num_downsample_channels, name='fine-mask-fc')
+
+ self._upsample_conv = tf.keras.layers.Conv2DTranspose(
+ self._num_downsample_channels,
+ (self.up_sample_factor, self.up_sample_factor),
+ (self.up_sample_factor, self.up_sample_factor),
+ name='fine-mask-conv2d-tran')
+
+ self._fine_class_conv = []
+ self._fine_class_bn = []
+ for i in range(self._num_convs):
+ self._fine_class_conv.append(
+ tf.keras.layers.Conv2D(
+ self._num_downsample_channels,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf.keras.initializers.RandomNormal(
+ stddev=0.01),
+ activation=None,
+ padding='same',
+ name='fine-mask-class-%d' % i))
+ self._fine_class_bn.append(norm_activation(
+ name='fine-mask-class-%d-bn' % i))
+
+ self._class_predict_conv = tf.keras.layers.Conv2D(
+ self._mask_num_classes,
+ kernel_size=(1, 1),
+ # Focal loss bias initialization to have foreground 0.01 probability.
+ bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
+ kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
+ padding='same',
+ name='fine-mask-class-predict')
+
+ def __call__(self, features, mask_logits, classes, is_training):
+ """Generate instance masks from FPN features and detection priors.
+
+ This corresponds to the Fig. 5-6 of the ShapeMask paper at
+ https://arxiv.org/pdf/1904.03239.pdf
+
+ Args:
+ features: a float Tensor of shape
+ [batch_size, num_instances, mask_crop_size, mask_crop_size,
+ num_downsample_channels]. This is the instance feature crop.
+ mask_logits: a float Tensor of shape
+ [batch_size, num_instances, mask_crop_size, mask_crop_size] indicating
+ predicted mask logits.
+ classes: a int Tensor of shape [batch_size, num_instances]
+ of instance classes.
+ is_training: a bool indicating whether in training mode.
+
+ Returns:
+ mask_outputs: instance mask prediction as a float Tensor of shape
+ [batch_size, num_instances, mask_size, mask_size].
+ """
+ # Extract the foreground mean features
+ # with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE):
+ with backend.get_graph().as_default(), tf.name_scope('fine_mask'):
+ mask_probs = tf.nn.sigmoid(mask_logits)
+ # Compute instance embedding for hard average.
+ binary_mask = tf.cast(tf.greater(mask_probs, 0.5), features.dtype)
+ instance_embedding = tf.reduce_sum(
+ features * tf.expand_dims(binary_mask, axis=-1), axis=(2, 3))
+ instance_embedding /= tf.expand_dims(
+ tf.reduce_sum(binary_mask, axis=(2, 3)) + 1e-20, axis=-1)
+ # Take the difference between crop features and mean instance features.
+ features -= tf.expand_dims(
+ tf.expand_dims(instance_embedding, axis=2), axis=2)
+
+ features += self._fine_mask_fc(tf.expand_dims(mask_probs, axis=-1))
+
+ # Decoder to generate upsampled segmentation mask.
+ mask_logits = self.decoder_net(features, is_training)
+ if self._use_category_for_mask:
+ mask_logits = tf.transpose(mask_logits, [0, 1, 4, 2, 3])
+ mask_logits = tf.gather(mask_logits,
+ tf.expand_dims(classes, -1), batch_dims=2)
+ mask_logits = tf.squeeze(mask_logits, axis=2)
+ else:
+ mask_logits = mask_logits[..., 0]
+
+ return mask_logits
+
+ def decoder_net(self, features, is_training=False):
+ """Fine mask decoder network architecture.
+
+ Args:
+ features: A tensor of size [batch, height_in, width_in, channels_in].
+ is_training: Whether batch_norm layers are in training mode.
+
+ Returns:
+ images: A feature tensor of size [batch, output_size, output_size,
+ num_channels], where output size is self._gt_upsample_scale times
+ that of input.
+ """
+ (batch_size, num_instances, height, width,
+ num_channels) = features.get_shape().as_list()
+ features = tf.reshape(features, [batch_size * num_instances, height, width,
+ num_channels])
+ for i in range(self._num_convs):
+ features = self._fine_class_conv[i](features)
+ features = self._fine_class_bn[i](features, is_training=is_training)
+
+ if self.up_sample_factor > 1:
+ features = self._upsample_conv(features)
+
+ # Predict per-class instance masks.
+ mask_logits = self._class_predict_conv(features)
+
+ mask_logits = tf.reshape(mask_logits,
+ [batch_size, num_instances,
+ height * self.up_sample_factor,
+ width * self.up_sample_factor,
+ self._mask_num_classes])
+ return mask_logits
diff --git a/models/official/vision/detection/modeling/architecture/identity.py b/models/official/vision/detection/modeling/architecture/identity.py
new file mode 100644
index 0000000000000000000000000000000000000000..acc90c4d5efddcac50eb95b1229c3c5500917445
--- /dev/null
+++ b/models/official/vision/detection/modeling/architecture/identity.py
@@ -0,0 +1,28 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Identity Fn that forwards the input features."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class Identity(object):
+ """Identity function that forwards the input features."""
+
+ def __call__(self, features, is_training=False):
+ """Only forwards the input features."""
+ return features
+
diff --git a/models/official/vision/detection/modeling/architecture/nn_ops.py b/models/official/vision/detection/modeling/architecture/nn_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b3617d6c5b23dd31a9f891985dcf8361ff1e177
--- /dev/null
+++ b/models/official/vision/detection/modeling/architecture/nn_ops.py
@@ -0,0 +1,108 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Neural network operations commonly shared by the architectures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import tensorflow as tf
+
+
+class NormActivation(tf.keras.layers.Layer):
+ """Combined Normalization and Activation layers."""
+
+ def __init__(self,
+ momentum=0.997,
+ epsilon=1e-4,
+ trainable=True,
+ init_zero=False,
+ use_activation=True,
+ activation='relu',
+ fused=True,
+ name=None):
+ """A class to construct layers for a batch normalization followed by a ReLU.
+
+ Args:
+ momentum: momentum for the moving average.
+ epsilon: small float added to variance to avoid dividing by zero.
+ trainable: `bool`, if True also add variables to the graph collection
+ GraphKeys.TRAINABLE_VARIABLES. If False, freeze batch normalization
+ layer.
+ init_zero: `bool` if True, initializes scale parameter of batch
+ normalization with 0. If False, initialize it with 1.
+ fused: `bool` fused option in batch normalziation.
+ use_actiation: `bool`, whether to add the optional activation layer after
+ the batch normalization layer.
+ activation: 'string', the type of the activation layer. Currently support
+ `relu` and `swish`.
+ name: `str` name for the operation.
+ """
+ super(NormActivation, self).__init__(trainable=trainable)
+ if init_zero:
+ gamma_initializer = tf.keras.initializers.Zeros()
+ else:
+ gamma_initializer = tf.keras.initializers.Ones()
+ self._normalization_op = tf.keras.layers.BatchNormalization(
+ momentum=momentum,
+ epsilon=epsilon,
+ center=True,
+ scale=True,
+ trainable=trainable,
+ fused=fused,
+ gamma_initializer=gamma_initializer,
+ name=name)
+ self._use_activation = use_activation
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+
+ def __call__(self, inputs, is_training=None):
+ """Builds the normalization layer followed by an optional activation layer.
+
+ Args:
+ inputs: `Tensor` of shape `[batch, channels, ...]`.
+ is_training: `boolean`, if True if model is in training mode.
+
+ Returns:
+ A normalized `Tensor` with the same `data_format`.
+ """
+ # We will need to keep training=None by default, so that it can be inherit
+ # from keras.Model.training
+ if is_training and self.trainable:
+ is_training = True
+ inputs = self._normalization_op(inputs, training=is_training)
+
+ if self._use_activation:
+ inputs = self._activation_op(inputs)
+ return inputs
+
+
+def norm_activation_builder(momentum=0.997,
+ epsilon=1e-4,
+ trainable=True,
+ activation='relu',
+ **kwargs):
+ return functools.partial(
+ NormActivation,
+ momentum=momentum,
+ epsilon=epsilon,
+ trainable=trainable,
+ activation=activation,
+ **kwargs)
diff --git a/models/official/vision/detection/modeling/architecture/resnet.py b/models/official/vision/detection/modeling/architecture/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..abbc7213ea971f0cb014d770e7e0c1707855fb08
--- /dev/null
+++ b/models/official/vision/detection/modeling/architecture/resnet.py
@@ -0,0 +1,309 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains definitions for the post-activation form of Residual Networks.
+
+Residual networks (ResNets) were proposed in:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+import tensorflow as tf
+from tensorflow.python.keras import backend
+from official.vision.detection.modeling.architecture import nn_ops
+
+# TODO(b/140112644): Refactor the code with Keras style, i.e. build and call.
+class Resnet(object):
+ """Class to build ResNet family model."""
+
+ def __init__(self,
+ resnet_depth,
+ activation='relu',
+ norm_activation=nn_ops.norm_activation_builder(
+ activation='relu'),
+ data_format='channels_last'):
+ """ResNet initialization function.
+
+ Args:
+ resnet_depth: `int` depth of ResNet backbone model.
+ norm_activation: an operation that includes a normalization layer
+ followed by an optional activation layer.
+ data_format: `str` either "channels_first" for `[batch, channels, height,
+ width]` or "channels_last for `[batch, height, width, channels]`.
+ """
+ self._resnet_depth = resnet_depth
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._norm_activation = norm_activation
+ self._data_format = data_format
+
+ model_params = {
+ 10: {'block': self.residual_block, 'layers': [1, 1, 1, 1]},
+ 18: {'block': self.residual_block, 'layers': [2, 2, 2, 2]},
+ 34: {'block': self.residual_block, 'layers': [3, 4, 6, 3]},
+ 50: {'block': self.bottleneck_block, 'layers': [3, 4, 6, 3]},
+ 101: {'block': self.bottleneck_block, 'layers': [3, 4, 23, 3]},
+ 152: {'block': self.bottleneck_block, 'layers': [3, 8, 36, 3]},
+ 200: {'block': self.bottleneck_block, 'layers': [3, 24, 36, 3]}
+ }
+
+ if resnet_depth not in model_params:
+ valid_resnet_depths = ', '.join(
+ [str(depth) for depth in sorted(model_params.keys())])
+ raise ValueError(
+ 'The resnet_depth should be in [%s]. Not a valid resnet_depth:'%(
+ valid_resnet_depths), self._resnet_depth)
+ params = model_params[resnet_depth]
+ self._resnet_fn = self.resnet_v1_generator(
+ params['block'], params['layers'])
+
+ def __call__(self, inputs, is_training=None):
+ """Returns the ResNet model for a given size and number of output classes.
+
+ Args:
+ inputs: a `Tesnor` with shape [batch_size, height, width, 3] representing
+ a batch of images.
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ a `dict` containing `int` keys for continuous feature levels [2, 3, 4, 5].
+ The values are corresponding feature hierarchy in ResNet with shape
+ [batch_size, height_l, width_l, num_filters].
+ """
+ with backend.get_graph().as_default():
+ with tf.name_scope('resnet%s' % self._resnet_depth):
+ return self._resnet_fn(inputs, is_training)
+
+ def fixed_padding(self, inputs, kernel_size):
+ """Pads the input along the spatial dimensions independently of input size.
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height, width]` or
+ `[batch, height, width, channels]` depending on `data_format`.
+ kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d`
+ operations. Should be a positive integer.
+
+ Returns:
+ A padded `Tensor` of the same `data_format` with size either intact
+ (if `kernel_size == 1`) or padded (if `kernel_size > 1`).
+ """
+ pad_total = kernel_size - 1
+ pad_beg = pad_total // 2
+ pad_end = pad_total - pad_beg
+ if self._data_format == 'channels_first':
+ padded_inputs = tf.pad(
+ tensor=inputs,
+ paddings=[[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
+ else:
+ padded_inputs = tf.pad(
+ tensor=inputs,
+ paddings=[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
+
+ return padded_inputs
+
+ def conv2d_fixed_padding(self, inputs, filters, kernel_size, strides):
+ """Strided 2-D convolution with explicit padding.
+
+ The padding is consistent and is based only on `kernel_size`, not on the
+ dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height_in, width_in]`.
+ filters: `int` number of filters in the convolution.
+ kernel_size: `int` size of the kernel to be used in the convolution.
+ strides: `int` strides of the convolution.
+
+ Returns:
+ A `Tensor` of shape `[batch, filters, height_out, width_out]`.
+ """
+ if strides > 1:
+ inputs = self.fixed_padding(inputs, kernel_size)
+
+ return tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=('SAME' if strides == 1 else 'VALID'),
+ use_bias=False,
+ kernel_initializer=tf.initializers.VarianceScaling(),
+ data_format=self._data_format)(
+ inputs=inputs)
+
+ def residual_block(self,
+ inputs,
+ filters,
+ strides,
+ use_projection=False,
+ is_training=None):
+ """Standard building block for residual networks with BN after convolutions.
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height, width]`.
+ filters: `int` number of filters for the first two convolutions. Note that
+ the third and final convolution will use 4 times as many filters.
+ strides: `int` block stride. If greater than 1, this block will ultimately
+ downsample the input.
+ use_projection: `bool` for whether this block should use a projection
+ shortcut (versus the default identity shortcut). This is usually
+ `True` for the first block of a block group, which may change the
+ number of filters and the resolution.
+ is_training: `bool` if True, the model is in training mode.
+ Returns:
+ The output `Tensor` of the block.
+ """
+ shortcut = inputs
+ if use_projection:
+ # Projection shortcut in first layer to match filters and strides
+ shortcut = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=1, strides=strides)
+ shortcut = self._norm_activation(use_activation=False)(
+ shortcut, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=3, strides=strides)
+ inputs = self._norm_activation()(inputs, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=3, strides=1)
+ inputs = self._norm_activation(use_activation=False, init_zero=True)(
+ inputs, is_training=is_training)
+
+ return self._activation_op(inputs + shortcut)
+
+ def bottleneck_block(self,
+ inputs,
+ filters,
+ strides,
+ use_projection=False,
+ is_training=None):
+ """Bottleneck block variant for residual networks with BN after convolutions.
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height, width]`.
+ filters: `int` number of filters for the first two convolutions. Note that
+ the third and final convolution will use 4 times as many filters.
+ strides: `int` block stride. If greater than 1, this block will ultimately
+ downsample the input.
+ use_projection: `bool` for whether this block should use a projection
+ shortcut (versus the default identity shortcut). This is usually
+ `True` for the first block of a block group, which may change the
+ number of filters and the resolution.
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ The output `Tensor` of the block.
+ """
+ shortcut = inputs
+ if use_projection:
+ # Projection shortcut only in first block within a group. Bottleneck
+ # blocks end with 4 times the number of filters.
+ filters_out = 4 * filters
+ shortcut = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters_out, kernel_size=1, strides=strides)
+ shortcut = self._norm_activation(use_activation=False)(
+ shortcut, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=1, strides=1)
+ inputs = self._norm_activation()(inputs, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=3, strides=strides)
+ inputs = self._norm_activation()(inputs, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=4 * filters, kernel_size=1, strides=1)
+ inputs = self._norm_activation(use_activation=False, init_zero=True)(
+ inputs, is_training=is_training)
+
+ return self._activation_op(inputs + shortcut)
+
+ def block_group(self, inputs, filters, block_fn, blocks, strides, name,
+ is_training):
+ """Creates one group of blocks for the ResNet model.
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height, width]`.
+ filters: `int` number of filters for the first convolution of the layer.
+ block_fn: `function` for the block to use within the model
+ blocks: `int` number of blocks contained in the layer.
+ strides: `int` stride to use for the first convolution of the layer. If
+ greater than 1, this layer will downsample the input.
+ name: `str`name for the Tensor output of the block layer.
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ The output `Tensor` of the block layer.
+ """
+ # Only the first block per block_group uses projection shortcut and strides.
+ inputs = block_fn(inputs, filters, strides, use_projection=True,
+ is_training=is_training)
+
+ for _ in range(1, blocks):
+ inputs = block_fn(inputs, filters, 1, is_training=is_training)
+
+ return tf.identity(inputs, name)
+
+ def resnet_v1_generator(self, block_fn, layers):
+ """Generator for ResNet v1 models.
+
+ Args:
+ block_fn: `function` for the block to use within the model. Either
+ `residual_block` or `bottleneck_block`.
+ layers: list of 4 `int`s denoting the number of blocks to include in each
+ of the 4 block groups. Each group consists of blocks that take inputs of
+ the same resolution.
+
+ Returns:
+ Model `function` that takes in `inputs` and `is_training` and returns the
+ output `Tensor` of the ResNet model.
+ """
+
+ def model(inputs, is_training=None):
+ """Creation of the model graph."""
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=64, kernel_size=7, strides=2)
+ inputs = tf.identity(inputs, 'initial_conv')
+ inputs = self._norm_activation()(inputs, is_training=is_training)
+
+ inputs = tf.keras.layers.MaxPool2D(
+ pool_size=3, strides=2, padding='SAME',
+ data_format=self._data_format)(
+ inputs)
+ inputs = tf.identity(inputs, 'initial_max_pool')
+
+ c2 = self.block_group(
+ inputs=inputs, filters=64, block_fn=block_fn, blocks=layers[0],
+ strides=1, name='block_group1', is_training=is_training)
+ c3 = self.block_group(
+ inputs=c2, filters=128, block_fn=block_fn, blocks=layers[1],
+ strides=2, name='block_group2', is_training=is_training)
+ c4 = self.block_group(
+ inputs=c3, filters=256, block_fn=block_fn, blocks=layers[2],
+ strides=2, name='block_group3', is_training=is_training)
+ c5 = self.block_group(
+ inputs=c4, filters=512, block_fn=block_fn, blocks=layers[3],
+ strides=2, name='block_group4', is_training=is_training)
+ return {2: c2, 3: c3, 4: c4, 5: c5}
+
+ return model
diff --git a/models/official/vision/detection/modeling/base_model.py b/models/official/vision/detection/modeling/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d18f12f5b7c52ca02334c4c685b70d353de83c5
--- /dev/null
+++ b/models/official/vision/detection/modeling/base_model.py
@@ -0,0 +1,138 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base Model definition."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import functools
+import re
+import tensorflow as tf
+from official.vision.detection.modeling import checkpoint_utils
+from official.vision.detection.modeling import learning_rates
+from official.vision.detection.modeling import optimizers
+
+
+def _make_filter_trainable_variables_fn(frozen_variable_prefix):
+ """Creates a function for filtering trainable varialbes."""
+
+ def _filter_trainable_variables(variables):
+ """Filters trainable varialbes.
+
+ Args:
+ variables: a list of tf.Variable to be filtered.
+
+ Returns:
+ filtered_variables: a list of tf.Variable filtered out the frozen ones.
+ """
+ # frozen_variable_prefix: a regex string specifing the prefix pattern of
+ # the frozen variables' names.
+ filtered_variables = [
+ v for v in variables
+ if not frozen_variable_prefix or
+ not re.match(frozen_variable_prefix, v.name)
+ ]
+ return filtered_variables
+
+ return _filter_trainable_variables
+
+
+class Model(object):
+ """Base class for model function."""
+
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, params):
+ self._use_bfloat16 = params.architecture.use_bfloat16
+
+ if params.architecture.use_bfloat16:
+ policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
+ 'mixed_bfloat16')
+ tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
+
+ # Optimization.
+ self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer)
+ self._learning_rate = learning_rates.learning_rate_generator(
+ params.train.total_steps, params.train.learning_rate)
+
+ self._frozen_variable_prefix = params.train.frozen_variable_prefix
+ self._regularization_var_regex = params.train.regularization_variable_regex
+ self._l2_weight_decay = params.train.l2_weight_decay
+
+ # Checkpoint restoration.
+ self._checkpoint = params.train.checkpoint.as_dict()
+
+ # Summary.
+ self._enable_summary = params.enable_summary
+ self._model_dir = params.model_dir
+
+ @abc.abstractmethod
+ def build_outputs(self, inputs, mode):
+ """Build the graph of the forward path."""
+ pass
+
+ @abc.abstractmethod
+ def build_model(self, params, mode):
+ """Build the model object."""
+ pass
+
+ @abc.abstractmethod
+ def build_loss_fn(self):
+ """Build the model object."""
+ pass
+
+ def post_processing(self, labels, outputs):
+ """Post-processing function."""
+ return labels, outputs
+
+ def model_outputs(self, inputs, mode):
+ """Build the model outputs."""
+ return self.build_outputs(inputs, mode)
+
+ def build_optimizer(self):
+ """Returns train_op to optimize total loss."""
+ # Sets up the optimizer.
+ return self._optimizer_fn(self._learning_rate)
+
+ def make_filter_trainable_variables_fn(self):
+ """Creates a function for filtering trainable varialbes."""
+ return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
+
+ def weight_decay_loss(self, trainable_variables):
+ reg_variables = [
+ v for v in trainable_variables
+ if self._regularization_var_regex is None
+ or re.match(self._regularization_var_regex, v.name)
+ ]
+
+ return self._l2_weight_decay * tf.add_n(
+ [tf.nn.l2_loss(v) for v in reg_variables])
+
+ def make_restore_checkpoint_fn(self):
+ """Returns scaffold function to restore parameters from v1 checkpoint."""
+ if 'skip_checkpoint_variables' in self._checkpoint:
+ skip_regex = self._checkpoint['skip_checkpoint_variables']
+ else:
+ skip_regex = None
+ return checkpoint_utils.make_restore_checkpoint_fn(
+ self._checkpoint['path'],
+ prefix=self._checkpoint['prefix'],
+ skip_regex=skip_regex)
+
+ def eval_metrics(self):
+ """Returns tuple of metric function and its inputs for evaluation."""
+ raise NotImplementedError('Unimplemented eval_metrics')
diff --git a/models/official/vision/detection/modeling/checkpoint_utils.py b/models/official/vision/detection/modeling/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb798396a714cbbc1a36309c99ceaa636a30354
--- /dev/null
+++ b/models/official/vision/detection/modeling/checkpoint_utils.py
@@ -0,0 +1,131 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Util functions for loading checkpoints. Especially for loading Tensorflow 1.x
+checkpoint to Tensorflow 2.x (keras) model.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import re
+from absl import logging
+
+import tensorflow as tf
+
+
+def _build_assignment_map(keras_model,
+ prefix='',
+ skip_variables_regex=None,
+ var_to_shape_map=None):
+ """Compute an assignment mapping for loading older checkpoints into a Keras
+ model. Variable names are remapped from the original TPUEstimator model to
+ the new Keras name.
+
+ Args:
+ keras_model: tf.keras.Model object to provide variables to assign.
+ prefix: prefix in the variable name to be remove for alignment with names in
+ the checkpoint.
+ skip_variables_regex: regular expression to math the names of variables that
+ do not need to be assign.
+ var_to_shape_map: variable name to shape mapping from the checkpoint.
+
+ Returns:
+ The variable assignment map.
+ """
+ assignment_map = {}
+
+
+ checkpoint_names = None
+ if var_to_shape_map:
+ checkpoint_names = list(filter(
+ lambda x: not x.endswith('Momentum') and not x.endswith(
+ 'global_step'), var_to_shape_map.keys()))
+
+ for var in keras_model.variables:
+ var_name = var.name
+
+ if skip_variables_regex and re.match(skip_variables_regex, var_name):
+ continue
+ # Trim the index of the variable.
+ if ':' in var_name:
+ var_name = var_name[:var_name.rindex(':')]
+ if var_name.startswith(prefix):
+ var_name = var_name[len(prefix):]
+
+ if not var_to_shape_map:
+ assignment_map[var_name] = var
+ continue
+
+ # Match name with variables in the checkpoint.
+ match_names = list(filter(lambda x: x.endswith(var_name), checkpoint_names))
+ try:
+ if match_names:
+ assert len(match_names) == 1, 'more then on matches for {}: {}'.format(
+ var_name, match_names)
+ checkpoint_names.remove(match_names[0])
+ assignment_map[match_names[0]] = var
+ else:
+ logging.info('Error not found var name: %s', var_name)
+ except Exception as e:
+ logging.info('Error removing the match_name: %s', match_names)
+ logging.info('Exception: %s', e)
+ raise
+ logging.info('Found variable in checkpoint: %d', len(assignment_map))
+ return assignment_map
+
+
+def _get_checkpoint_map(checkpoint_path):
+ reader = tf.train.load_checkpoint(checkpoint_path)
+ return reader.get_variable_to_shape_map()
+
+
+def make_restore_checkpoint_fn(checkpoint_path, prefix='', skip_regex=None):
+ """Returns scaffold function to restore parameters from v1 checkpoint.
+ Args:
+ checkpoint_path: path of the checkpoint folder or file.
+ Example 1: '/path/to/model_dir/'
+ Example 2: '/path/to/model.ckpt-22500'
+ prefix: prefix in the variable name to be remove for alignment with names in
+ the checkpoint.
+ skip_regex: regular expression to math the names of variables that
+ do not need to be assign.
+
+ Returns:
+ Callable[tf.kears.Model] -> void. Fn to load v1 checkpoint to keras model.
+ """
+
+ def _restore_checkpoint_fn(keras_model):
+ """Loads pretrained model through scaffold function."""
+ if not checkpoint_path:
+ logging.info('checkpoint_path is empty')
+ return
+ var_prefix = prefix
+ if prefix and not prefix.endswith('/'):
+ var_prefix += '/'
+ var_to_shape_map = _get_checkpoint_map(checkpoint_path)
+ assert var_to_shape_map, 'var_to_shape_map should not be empty'
+ vars_to_load = _build_assignment_map(
+ keras_model,
+ prefix=var_prefix,
+ skip_variables_regex=skip_regex,
+ var_to_shape_map=var_to_shape_map)
+ if not vars_to_load:
+ raise ValueError('Variables to load is empty.')
+ tf.compat.v1.train.init_from_checkpoint(checkpoint_path,
+ vars_to_load)
+
+ return _restore_checkpoint_fn
diff --git a/models/official/vision/detection/modeling/factory.py b/models/official/vision/detection/modeling/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..b140416dfdba90420f99a8bcb3b07cc04a63cc3e
--- /dev/null
+++ b/models/official/vision/detection/modeling/factory.py
@@ -0,0 +1,34 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Factory to build detection model."""
+
+
+from official.vision.detection.modeling import maskrcnn_model
+from official.vision.detection.modeling import retinanet_model
+from official.vision.detection.modeling import shapemask_model
+
+
+def model_generator(params):
+ """Model function generator."""
+ if params.type == 'retinanet':
+ model_fn = retinanet_model.RetinanetModel(params)
+ elif params.type == 'mask_rcnn':
+ model_fn = maskrcnn_model.MaskrcnnModel(params)
+ elif params.type == 'shapemask':
+ model_fn = shapemask_model.ShapeMaskModel(params)
+ else:
+ raise ValueError('Model %s is not supported.'% params.type)
+
+ return model_fn
diff --git a/models/official/vision/detection/modeling/learning_rates.py b/models/official/vision/detection/modeling/learning_rates.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecc24ffadb073c79f71725b1adcb61cbd83127cd
--- /dev/null
+++ b/models/official/vision/detection/modeling/learning_rates.py
@@ -0,0 +1,98 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Learning rate schedule."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy as np
+import tensorflow as tf
+from official.modeling.hyperparams import params_dict
+
+
+class StepLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Class to generate learning rate tensor."""
+
+ def __init__(self, total_steps, params):
+ """Creates the step learning rate tensor with linear warmup."""
+ super(StepLearningRateWithLinearWarmup, self).__init__()
+ self._total_steps = total_steps
+ assert isinstance(params, (dict, params_dict.ParamsDict))
+ if isinstance(params, dict):
+ params = params_dict.ParamsDict(params)
+ self._params = params
+
+ def __call__(self, global_step):
+ warmup_lr = self._params.warmup_learning_rate
+ warmup_steps = self._params.warmup_steps
+ init_lr = self._params.init_learning_rate
+ lr_levels = self._params.learning_rate_levels
+ lr_steps = self._params.learning_rate_steps
+ linear_warmup = (
+ warmup_lr + tf.cast(global_step, dtype=tf.float32) / warmup_steps *
+ (init_lr - warmup_lr))
+ learning_rate = tf.where(global_step < warmup_steps, linear_warmup, init_lr)
+
+ for next_learning_rate, start_step in zip(lr_levels, lr_steps):
+ learning_rate = tf.where(global_step >= start_step, next_learning_rate,
+ learning_rate)
+ return learning_rate
+
+ def get_config(self):
+ return {'_params': self._params.as_dict()}
+
+
+class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Class to generate learning rate tensor."""
+
+ def __init__(self, total_steps, params):
+ """Creates the consine learning rate tensor with linear warmup."""
+ super(CosineLearningRateWithLinearWarmup, self).__init__()
+ self._total_steps = total_steps
+ assert isinstance(params, (dict, params_dict.ParamsDict))
+ if isinstance(params, dict):
+ params = params_dict.ParamsDict(params)
+ self._params = params
+
+ def __call__(self, global_step):
+ global_step = tf.cast(global_step, dtype=tf.float32)
+ warmup_lr = self._params.warmup_learning_rate
+ warmup_steps = self._params.warmup_steps
+ init_lr = self._params.init_learning_rate
+ total_steps = self._total_steps
+ linear_warmup = (
+ warmup_lr + global_step / warmup_steps * (init_lr - warmup_lr))
+ cosine_learning_rate = (
+ init_lr * (tf.cos(np.pi * (global_step - warmup_steps) /
+ (total_steps - warmup_steps)) + 1.0) / 2.0)
+ learning_rate = tf.where(global_step < warmup_steps, linear_warmup,
+ cosine_learning_rate)
+ return learning_rate
+
+ def get_config(self):
+ return {'_params': self._params.as_dict()}
+
+
+def learning_rate_generator(total_steps, params):
+ """The learning rate function generator."""
+ if params.type == 'step':
+ return StepLearningRateWithLinearWarmup(total_steps, params)
+ elif params.type == 'cosine':
+ return CosineLearningRateWithLinearWarmup(total_steps, params)
+ else:
+ raise ValueError('Unsupported learning rate type: {}.'.format(params.type))
diff --git a/models/official/vision/detection/modeling/losses.py b/models/official/vision/detection/modeling/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b993061b3c51c9ae6456d84a79f7fea5d74c77e
--- /dev/null
+++ b/models/official/vision/detection/modeling/losses.py
@@ -0,0 +1,542 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Losses used for detection models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+import tensorflow as tf
+
+
+def focal_loss(logits, targets, alpha, gamma, normalizer):
+ """Compute the focal loss between `logits` and the golden `target` values.
+
+ Focal loss = -(1-pt)^gamma * log(pt)
+ where pt is the probability of being classified to the true class.
+
+ Args:
+ logits: A float32 tensor of size
+ [batch, height_in, width_in, num_predictions].
+ targets: A float32 tensor of size
+ [batch, height_in, width_in, num_predictions].
+ alpha: A float32 scalar multiplying alpha to the loss from positive examples
+ and (1-alpha) to the loss from negative examples.
+ gamma: A float32 scalar modulating loss from hard and easy examples.
+ normalizer: A float32 scalar normalizes the total loss from all examples.
+
+ Returns:
+ loss: A float32 Tensor of size [batch, height_in, width_in, num_predictions]
+ representing normalized loss on the prediction map.
+ """
+ with tf.name_scope('focal_loss'):
+ positive_label_mask = tf.math.equal(targets, 1.0)
+ cross_entropy = (
+ tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))
+ # Below are comments/derivations for computing modulator.
+ # For brevity, let x = logits, z = targets, r = gamma, and p_t = sigmod(x)
+ # for positive samples and 1 - sigmoid(x) for negative examples.
+ #
+ # The modulator, defined as (1 - P_t)^r, is a critical part in focal loss
+ # computation. For r > 0, it puts more weights on hard examples, and less
+ # weights on easier ones. However if it is directly computed as (1 - P_t)^r,
+ # its back-propagation is not stable when r < 1. The implementation here
+ # resolves the issue.
+ #
+ # For positive samples (labels being 1),
+ # (1 - p_t)^r
+ # = (1 - sigmoid(x))^r
+ # = (1 - (1 / (1 + exp(-x))))^r
+ # = (exp(-x) / (1 + exp(-x)))^r
+ # = exp(log((exp(-x) / (1 + exp(-x)))^r))
+ # = exp(r * log(exp(-x)) - r * log(1 + exp(-x)))
+ # = exp(- r * x - r * log(1 + exp(-x)))
+ #
+ # For negative samples (labels being 0),
+ # (1 - p_t)^r
+ # = (sigmoid(x))^r
+ # = (1 / (1 + exp(-x)))^r
+ # = exp(log((1 / (1 + exp(-x)))^r))
+ # = exp(-r * log(1 + exp(-x)))
+ #
+ # Therefore one unified form for positive (z = 1) and negative (z = 0)
+ # samples is:
+ # (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
+ neg_logits = -1.0 * logits
+ modulator = tf.math.exp(gamma * targets * neg_logits -
+ gamma * tf.math.log1p(tf.math.exp(neg_logits)))
+ loss = modulator * cross_entropy
+ weighted_loss = tf.where(positive_label_mask, alpha * loss,
+ (1.0 - alpha) * loss)
+ weighted_loss /= normalizer
+ return weighted_loss
+
+
+class RpnScoreLoss(object):
+ """Region Proposal Network score loss function."""
+
+ def __init__(self, params):
+ self._rpn_batch_size_per_im = params.rpn_batch_size_per_im
+ self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
+ reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
+
+ def __call__(self, score_outputs, labels):
+ """Computes total RPN detection loss.
+
+ Computes total RPN detection loss including box and score from all levels.
+
+ Args:
+ score_outputs: an OrderDict with keys representing levels and values
+ representing scores in [batch_size, height, width, num_anchors].
+ labels: the dictionary that returned from dataloader that includes
+ groundturth targets.
+
+ Returns:
+ rpn_score_loss: a scalar tensor representing total score loss.
+ """
+ with tf.name_scope('rpn_loss'):
+ levels = sorted(score_outputs.keys())
+
+ score_losses = []
+ for level in levels:
+ score_losses.append(
+ self._rpn_score_loss(
+ score_outputs[level],
+ labels[level],
+ normalizer=tf.cast(
+ tf.shape(score_outputs[level])[0] *
+ self._rpn_batch_size_per_im, dtype=tf.float32)))
+
+ # Sums per level losses to total loss.
+ return tf.math.add_n(score_losses)
+
+ def _rpn_score_loss(self, score_outputs, score_targets, normalizer=1.0):
+ """Computes score loss."""
+ # score_targets has three values:
+ # (1) score_targets[i]=1, the anchor is a positive sample.
+ # (2) score_targets[i]=0, negative.
+ # (3) score_targets[i]=-1, the anchor is don't care (ignore).
+ with tf.name_scope('rpn_score_loss'):
+ mask = tf.math.logical_or(tf.math.equal(score_targets, 1),
+ tf.math.equal(score_targets, 0))
+
+ score_targets = tf.math.maximum(score_targets,
+ tf.zeros_like(score_targets))
+
+ score_targets = tf.expand_dims(score_targets, axis=-1)
+ score_outputs = tf.expand_dims(score_outputs, axis=-1)
+ score_loss = self._binary_crossentropy(
+ score_targets, score_outputs, sample_weight=mask)
+
+ score_loss /= normalizer
+ return score_loss
+
+
+class RpnBoxLoss(object):
+ """Region Proposal Network box regression loss function."""
+
+ def __init__(self, params):
+ logging.info('RpnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
+ # The delta is typically around the mean value of regression target.
+ # for instances, the regression targets of 512x512 input with 6 anchors on
+ # P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
+ self._huber_loss = tf.keras.losses.Huber(
+ delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
+
+ def __call__(self, box_outputs, labels):
+ """Computes total RPN detection loss.
+
+ Computes total RPN detection loss including box and score from all levels.
+
+ Args:
+ box_outputs: an OrderDict with keys representing levels and values
+ representing box regression targets in
+ [batch_size, height, width, num_anchors * 4].
+ labels: the dictionary that returned from dataloader that includes
+ groundturth targets.
+
+ Returns:
+ rpn_box_loss: a scalar tensor representing total box regression loss.
+ """
+ with tf.name_scope('rpn_loss'):
+ levels = sorted(box_outputs.keys())
+
+ box_losses = []
+ for level in levels:
+ box_losses.append(self._rpn_box_loss(box_outputs[level], labels[level]))
+
+ # Sum per level losses to total loss.
+ return tf.add_n(box_losses)
+
+ def _rpn_box_loss(self, box_outputs, box_targets, normalizer=1.0):
+ """Computes box regression loss."""
+ with tf.name_scope('rpn_box_loss'):
+ mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
+ box_targets = tf.expand_dims(box_targets, axis=-1)
+ box_outputs = tf.expand_dims(box_outputs, axis=-1)
+ box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
+ # The loss is normalized by the sum of non-zero weights and additional
+ # normalizer provided by the function caller. Using + 0.01 here to avoid
+ # division by zero.
+ box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
+ return box_loss
+
+
+class FastrcnnClassLoss(object):
+ """Fast R-CNN classification loss function."""
+
+ def __init__(self):
+ self._categorical_crossentropy = tf.keras.losses.CategoricalCrossentropy(
+ reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
+
+ def __call__(self, class_outputs, class_targets):
+ """Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
+
+ This function implements the classification loss of the Fast-RCNN.
+
+ The classification loss is softmax on all RoIs.
+ Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
+
+ Args:
+ class_outputs: a float tensor representing the class prediction for each box
+ with a shape of [batch_size, num_boxes, num_classes].
+ class_targets: a float tensor representing the class label for each box
+ with a shape of [batch_size, num_boxes].
+
+ Returns:
+ a scalar tensor representing total class loss.
+ """
+ with tf.name_scope('fast_rcnn_loss'):
+ batch_size, num_boxes, num_classes = class_outputs.get_shape().as_list()
+ class_targets = tf.cast(class_targets, dtype=tf.int32)
+ class_targets_one_hot = tf.one_hot(class_targets, num_classes)
+ return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot,
+ normalizer=batch_size * num_boxes / 2.0)
+
+ def _fast_rcnn_class_loss(self, class_outputs, class_targets_one_hot,
+ normalizer):
+ """Computes classification loss."""
+ with tf.name_scope('fast_rcnn_class_loss'):
+ class_loss = self._categorical_crossentropy(class_targets_one_hot,
+ class_outputs)
+
+ class_loss /= normalizer
+ return class_loss
+
+
+class FastrcnnBoxLoss(object):
+ """Fast R-CNN box regression loss function."""
+
+ def __init__(self, params):
+ logging.info('FastrcnnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
+ # The delta is typically around the mean value of regression target.
+ # for instances, the regression targets of 512x512 input with 6 anchors on
+ # P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
+ self._huber_loss = tf.keras.losses.Huber(
+ delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
+
+ def __call__(self, box_outputs, class_targets, box_targets):
+ """Computes the box loss (Fast-RCNN branch) of Mask-RCNN.
+
+ This function implements the box regression loss of the Fast-RCNN. As the
+ `box_outputs` produces `num_classes` boxes for each RoI, the reference model
+ expands `box_targets` to match the shape of `box_outputs` and selects only
+ the target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/fast_rcnn.py) # pylint: disable=line-too-long
+ Instead, this function selects the `box_outputs` by the `class_targets` so
+ that it doesn't expand `box_targets`.
+
+ The box loss is smooth L1-loss on only positive samples of RoIs.
+ Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
+
+ Args:
+ box_outputs: a float tensor representing the box prediction for each box
+ with a shape of [batch_size, num_boxes, num_classes * 4].
+ class_targets: a float tensor representing the class label for each box
+ with a shape of [batch_size, num_boxes].
+ box_targets: a float tensor representing the box label for each box
+ with a shape of [batch_size, num_boxes, 4].
+
+ Returns:
+ box_loss: a scalar tensor representing total box regression loss.
+ """
+ with tf.name_scope('fast_rcnn_loss'):
+ class_targets = tf.cast(class_targets, dtype=tf.int32)
+
+ # Selects the box from `box_outputs` based on `class_targets`, with which
+ # the box has the maximum overlap.
+ (batch_size, num_rois,
+ num_class_specific_boxes) = box_outputs.get_shape().as_list()
+ num_classes = num_class_specific_boxes // 4
+ box_outputs = tf.reshape(box_outputs,
+ [batch_size, num_rois, num_classes, 4])
+
+ box_indices = tf.reshape(
+ class_targets + tf.tile(
+ tf.expand_dims(
+ tf.range(batch_size) * num_rois * num_classes, 1),
+ [1, num_rois]) + tf.tile(
+ tf.expand_dims(tf.range(num_rois) * num_classes, 0),
+ [batch_size, 1]), [-1])
+
+ box_outputs = tf.matmul(
+ tf.one_hot(
+ box_indices,
+ batch_size * num_rois * num_classes,
+ dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
+ box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
+
+ return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets)
+
+ def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
+ normalizer=1.0):
+ """Computes box regression loss."""
+ with tf.name_scope('fast_rcnn_box_loss'):
+ mask = tf.tile(tf.expand_dims(tf.greater(class_targets, 0), axis=2),
+ [1, 1, 4])
+ mask = tf.cast(mask, dtype=tf.float32)
+ box_targets = tf.expand_dims(box_targets, axis=-1)
+ box_outputs = tf.expand_dims(box_outputs, axis=-1)
+ box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
+ # The loss is normalized by the number of ones in mask,
+ # additianal normalizer provided by the user and using 0.01 here to avoid
+ # division by 0.
+ box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
+ return box_loss
+
+
+class MaskrcnnLoss(object):
+ """Mask R-CNN instance segmentation mask loss function."""
+
+ def __init__(self):
+ self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
+ reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
+
+ def __call__(self, mask_outputs, mask_targets, select_class_targets):
+ """Computes the mask loss of Mask-RCNN.
+
+ This function implements the mask loss of Mask-RCNN. As the `mask_outputs`
+ produces `num_classes` masks for each RoI, the reference model expands
+ `mask_targets` to match the shape of `mask_outputs` and selects only the
+ target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/mask_rcnn.py) # pylint: disable=line-too-long
+ Instead, this implementation selects the `mask_outputs` by the `class_targets`
+ so that it doesn't expand `mask_targets`. Note that the selection logic is
+ done in the post-processing of mask_rcnn_fn in mask_rcnn_architecture.py.
+
+ Args:
+ mask_outputs: a float tensor representing the prediction for each mask,
+ with a shape of
+ [batch_size, num_masks, mask_height, mask_width].
+ mask_targets: a float tensor representing the binary mask of ground truth
+ labels for each mask with a shape of
+ [batch_size, num_masks, mask_height, mask_width].
+ select_class_targets: a tensor with a shape of [batch_size, num_masks],
+ representing the foreground mask targets.
+
+ Returns:
+ mask_loss: a float tensor representing total mask loss.
+ """
+ with tf.name_scope('mask_rcnn_loss'):
+ (batch_size, num_masks, mask_height,
+ mask_width) = mask_outputs.get_shape().as_list()
+
+ weights = tf.tile(
+ tf.reshape(tf.greater(select_class_targets, 0),
+ [batch_size, num_masks, 1, 1]),
+ [1, 1, mask_height, mask_width])
+ weights = tf.cast(weights, dtype=tf.float32)
+
+ mask_targets = tf.expand_dims(mask_targets, axis=-1)
+ mask_outputs = tf.expand_dims(mask_outputs, axis=-1)
+ mask_loss = self._binary_crossentropy(mask_targets, mask_outputs,
+ sample_weight=weights)
+
+ # The loss is normalized by the number of 1's in weights and
+ # + 0.01 is used to avoid division by zero.
+ return mask_loss / (tf.reduce_sum(weights) + 0.01)
+
+
+class RetinanetClassLoss(object):
+ """RetinaNet class loss."""
+
+ def __init__(self, params, num_classes):
+ self._num_classes = num_classes
+ self._focal_loss_alpha = params.focal_loss_alpha
+ self._focal_loss_gamma = params.focal_loss_gamma
+
+ def __call__(self, cls_outputs, labels, num_positives):
+ """Computes total detection loss.
+
+ Computes total detection loss including box and class loss from all levels.
+
+ Args:
+ cls_outputs: an OrderDict with keys representing levels and values
+ representing logits in [batch_size, height, width,
+ num_anchors * num_classes].
+ labels: the dictionary that returned from dataloader that includes
+ class groundturth targets.
+ num_positives: number of positive examples in the minibatch.
+
+ Returns:
+ an integar tensor representing total class loss.
+ """
+ # Sums all positives in a batch for normalization and avoids zero
+ # num_positives_sum, which would lead to inf loss during training
+ num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0
+
+ cls_losses = []
+ for level in cls_outputs.keys():
+ cls_losses.append(self.class_loss(
+ cls_outputs[level], labels[level], num_positives_sum))
+ # Sums per level losses to total loss.
+ return tf.add_n(cls_losses)
+
+ def class_loss(self, cls_outputs, cls_targets, num_positives,
+ ignore_label=-2):
+ """Computes RetinaNet classification loss."""
+ # Onehot encoding for classification labels.
+ cls_targets_one_hot = tf.one_hot(cls_targets, self._num_classes)
+ bs, height, width, _, _ = cls_targets_one_hot.get_shape().as_list()
+ cls_targets_one_hot = tf.reshape(cls_targets_one_hot,
+ [bs, height, width, -1])
+ loss = focal_loss(tf.cast(cls_outputs, dtype=tf.float32),
+ tf.cast(cls_targets_one_hot, dtype=tf.float32),
+ self._focal_loss_alpha,
+ self._focal_loss_gamma,
+ num_positives)
+
+ ignore_loss = tf.where(
+ tf.equal(cls_targets, ignore_label),
+ tf.zeros_like(cls_targets, dtype=tf.float32),
+ tf.ones_like(cls_targets, dtype=tf.float32),
+ )
+ ignore_loss = tf.expand_dims(ignore_loss, -1)
+ ignore_loss = tf.tile(ignore_loss, [1, 1, 1, 1, self._num_classes])
+ ignore_loss = tf.reshape(ignore_loss, tf.shape(input=loss))
+ return tf.reduce_sum(input_tensor=ignore_loss * loss)
+
+
+class RetinanetBoxLoss(object):
+ """RetinaNet box loss."""
+
+ def __init__(self, params):
+ self._huber_loss = tf.keras.losses.Huber(
+ delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
+
+ def __call__(self, box_outputs, labels, num_positives):
+ """Computes box detection loss.
+
+ Computes total detection loss including box and class loss from all levels.
+
+ Args:
+ box_outputs: an OrderDict with keys representing levels and values
+ representing box regression targets in [batch_size, height, width,
+ num_anchors * 4].
+ labels: the dictionary that returned from dataloader that includes
+ box groundturth targets.
+ num_positives: number of positive examples in the minibatch.
+
+ Returns:
+ an integar tensor representing total box regression loss.
+ """
+ # Sums all positives in a batch for normalization and avoids zero
+ # num_positives_sum, which would lead to inf loss during training
+ num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0
+
+ box_losses = []
+ for level in box_outputs.keys():
+ # Onehot encoding for classification labels.
+ box_targets_l = labels[level]
+ box_losses.append(
+ self.box_loss(box_outputs[level], box_targets_l, num_positives_sum))
+ # Sums per level losses to total loss.
+ return tf.add_n(box_losses)
+
+ def box_loss(self, box_outputs, box_targets, num_positives):
+ """Computes RetinaNet box regression loss."""
+ # The delta is typically around the mean value of regression target.
+ # for instances, the regression targets of 512x512 input with 6 anchors on
+ # P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2].
+ normalizer = num_positives * 4.0
+ mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
+ box_targets = tf.expand_dims(box_targets, axis=-1)
+ box_outputs = tf.expand_dims(box_outputs, axis=-1)
+ box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
+ box_loss /= normalizer
+ return box_loss
+
+
+class ShapemaskMseLoss(object):
+ """ShapeMask mask Mean Squared Error loss function wrapper."""
+
+ def __call__(self, probs, labels, valid_mask):
+ """Compute instance segmentation loss.
+
+ Args:
+ probs: A Tensor of shape [batch_size * num_points, height, width,
+ num_classes]. The logits are not necessarily between 0 and 1.
+ labels: A float32/float16 Tensor of shape [batch_size, num_instances,
+ mask_size, mask_size], where mask_size =
+ mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
+ for coarse masks and shape priors.
+ valid_mask: a binary mask indicating valid training masks.
+
+ Returns:
+ loss: an float tensor representing total mask classification loss.
+ """
+ with tf.name_scope('shapemask_prior_loss'):
+ batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
+ diff = (tf.cast(labels, dtype=tf.float32) -
+ tf.cast(probs, dtype=tf.float32))
+ diff *= tf.cast(
+ tf.reshape(valid_mask, [batch_size, num_instances, 1, 1]),
+ tf.float32)
+ # Adding 0.001 in the denominator to avoid division by zero.
+ loss = tf.nn.l2_loss(diff) / (tf.reduce_sum(labels) + 0.001)
+ return loss
+
+
+class ShapemaskLoss(object):
+ """ShapeMask mask loss function wrapper."""
+
+ def __init__(self):
+ self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
+ reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
+
+ def __call__(self, logits, labels, valid_mask):
+ """ShapeMask mask cross entropy loss function wrapper.
+
+ Args:
+ logits: A Tensor of shape [batch_size * num_instances, height, width,
+ num_classes]. The logits are not necessarily between 0 and 1.
+ labels: A float16/float32 Tensor of shape [batch_size, num_instances,
+ mask_size, mask_size], where mask_size =
+ mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
+ for coarse masks and shape priors.
+ valid_mask: a binary mask of shape [batch_size, num_instances]
+ indicating valid training masks.
+ Returns:
+ loss: an float tensor representing total mask classification loss.
+ """
+ with tf.name_scope('shapemask_loss'):
+ batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
+ labels = tf.cast(labels, tf.float32)
+ logits = tf.cast(logits, tf.float32)
+ loss = self._binary_crossentropy(labels, logits)
+ loss *= tf.cast(tf.reshape(
+ valid_mask, [batch_size, num_instances, 1, 1]), loss.dtype)
+ # Adding 0.001 in the denominator to avoid division by zero.
+ loss = tf.reduce_sum(loss) / (tf.reduce_sum(labels) + 0.001)
+ return loss
diff --git a/models/official/vision/detection/modeling/maskrcnn_model.py b/models/official/vision/detection/modeling/maskrcnn_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5cbe7d56ba7d82836ef58df201aa74779cb2f69
--- /dev/null
+++ b/models/official/vision/detection/modeling/maskrcnn_model.py
@@ -0,0 +1,344 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model defination for the Mask R-CNN Model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import backend
+from official.vision.detection.dataloader import anchor
+from official.vision.detection.dataloader import mode_keys
+from official.vision.detection.evaluation import factory as eval_factory
+from official.vision.detection.modeling import base_model
+from official.vision.detection.modeling import losses
+from official.vision.detection.modeling.architecture import factory
+from official.vision.detection.ops import postprocess_ops
+from official.vision.detection.ops import roi_ops
+from official.vision.detection.ops import spatial_transform_ops
+from official.vision.detection.ops import target_ops
+from official.vision.detection.utils import box_utils
+
+
+class MaskrcnnModel(base_model.Model):
+ """Mask R-CNN model function."""
+
+ def __init__(self, params):
+ super(MaskrcnnModel, self).__init__(params)
+
+ # For eval metrics.
+ self._params = params
+ self._keras_model = None
+
+ self._include_mask = params.architecture.include_mask
+
+ # Architecture generators.
+ self._backbone_fn = factory.backbone_generator(params)
+ self._fpn_fn = factory.multilevel_features_generator(params)
+ self._rpn_head_fn = factory.rpn_head_generator(params)
+ self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
+ self._sample_rois_fn = target_ops.ROISampler(params.roi_sampling)
+ self._sample_masks_fn = target_ops.MaskSampler(
+ params.architecture.mask_target_size,
+ params.mask_sampling.num_mask_samples_per_image)
+
+ self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
+ if self._include_mask:
+ self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)
+
+ # Loss function.
+ self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
+ self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
+ self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
+ self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
+ if self._include_mask:
+ self._mask_loss_fn = losses.MaskrcnnLoss()
+
+ self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
+ params.postprocess)
+
+ self._transpose_input = params.train.transpose_input
+ assert not self._transpose_input, 'Transpose input is not supportted.'
+
+ def build_outputs(self, inputs, mode):
+ is_training = mode == mode_keys.TRAIN
+ model_outputs = {}
+
+ image = inputs['image']
+ _, image_height, image_width, _ = image.get_shape().as_list()
+ backbone_features = self._backbone_fn(image, is_training)
+ fpn_features = self._fpn_fn(backbone_features, is_training)
+
+ rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
+ fpn_features, is_training)
+ model_outputs.update({
+ 'rpn_score_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ rpn_score_outputs),
+ 'rpn_box_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ rpn_box_outputs),
+ })
+ input_anchor = anchor.Anchor(self._params.architecture.min_level,
+ self._params.architecture.max_level,
+ self._params.anchor.num_scales,
+ self._params.anchor.aspect_ratios,
+ self._params.anchor.anchor_size,
+ (image_height, image_width))
+ rpn_rois, _ = self._generate_rois_fn(rpn_box_outputs, rpn_score_outputs,
+ input_anchor.multilevel_boxes,
+ inputs['image_info'][:, 1, :],
+ is_training)
+ if is_training:
+ rpn_rois = tf.stop_gradient(rpn_rois)
+
+ # Sample proposals.
+ rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
+ self._sample_rois_fn(rpn_rois, inputs['gt_boxes'],
+ inputs['gt_classes']))
+
+ # Create bounding box training targets.
+ box_targets = box_utils.encode_boxes(
+ matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
+ # If the target is background, the box target is set to all 0s.
+ box_targets = tf.where(
+ tf.tile(
+ tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
+ [1, 1, 4]),
+ tf.zeros_like(box_targets),
+ box_targets)
+ model_outputs.update({
+ 'class_targets': matched_gt_classes,
+ 'box_targets': box_targets,
+ })
+
+ roi_features = spatial_transform_ops.multilevel_crop_and_resize(
+ fpn_features, rpn_rois, output_size=7)
+
+ class_outputs, box_outputs = self._frcnn_head_fn(roi_features, is_training)
+
+ model_outputs.update({
+ 'class_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ class_outputs),
+ 'box_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ box_outputs),
+ })
+
+ # Add this output to train to make the checkpoint loadable in predict mode.
+ # If we skip it in train mode, the heads will be out-of-order and checkpoint
+ # loading will fail.
+ boxes, scores, classes, valid_detections = self._generate_detections_fn(
+ box_outputs, class_outputs, rpn_rois, inputs['image_info'][:, 1:2, :])
+ model_outputs.update({
+ 'num_detections': valid_detections,
+ 'detection_boxes': boxes,
+ 'detection_classes': classes,
+ 'detection_scores': scores,
+ })
+
+ if not self._include_mask:
+ return model_outputs
+
+ if is_training:
+ rpn_rois, classes, mask_targets = self._sample_masks_fn(
+ rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
+ inputs['gt_masks'])
+ mask_targets = tf.stop_gradient(mask_targets)
+
+ classes = tf.cast(classes, dtype=tf.int32)
+
+ model_outputs.update({
+ 'mask_targets': mask_targets,
+ 'sampled_class_targets': classes,
+ })
+ else:
+ rpn_rois = boxes
+ classes = tf.cast(classes, dtype=tf.int32)
+
+ mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
+ fpn_features, rpn_rois, output_size=14)
+
+ mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training)
+
+ if is_training:
+ model_outputs.update({
+ 'mask_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ mask_outputs),
+ })
+ else:
+ model_outputs.update({
+ 'detection_masks': tf.nn.sigmoid(mask_outputs)
+ })
+
+ return model_outputs
+
+ def build_loss_fn(self):
+ if self._keras_model is None:
+ raise ValueError('build_loss_fn() must be called after build_model().')
+
+ filter_fn = self.make_filter_trainable_variables_fn()
+ trainable_variables = filter_fn(self._keras_model.trainable_variables)
+
+ def _total_loss_fn(labels, outputs):
+ rpn_score_loss = self._rpn_score_loss_fn(outputs['rpn_score_outputs'],
+ labels['rpn_score_targets'])
+ rpn_box_loss = self._rpn_box_loss_fn(outputs['rpn_box_outputs'],
+ labels['rpn_box_targets'])
+
+ frcnn_class_loss = self._frcnn_class_loss_fn(outputs['class_outputs'],
+ outputs['class_targets'])
+ frcnn_box_loss = self._frcnn_box_loss_fn(outputs['box_outputs'],
+ outputs['class_targets'],
+ outputs['box_targets'])
+
+ if self._include_mask:
+ mask_loss = self._mask_loss_fn(outputs['mask_outputs'],
+ outputs['mask_targets'],
+ outputs['sampled_class_targets'])
+ else:
+ mask_loss = 0.0
+
+ model_loss = (
+ rpn_score_loss + rpn_box_loss + frcnn_class_loss + frcnn_box_loss +
+ mask_loss)
+
+ l2_regularization_loss = self.weight_decay_loss(trainable_variables)
+ total_loss = model_loss + l2_regularization_loss
+ return {
+ 'total_loss': total_loss,
+ 'loss': total_loss,
+ 'fast_rcnn_class_loss': frcnn_class_loss,
+ 'fast_rcnn_box_loss': frcnn_box_loss,
+ 'mask_loss': mask_loss,
+ 'model_loss': model_loss,
+ 'l2_regularization_loss': l2_regularization_loss,
+ 'rpn_score_loss': rpn_score_loss,
+ 'rpn_box_loss': rpn_box_loss,
+ }
+
+ return _total_loss_fn
+
+ def build_input_layers(self, params, mode):
+ is_training = mode == mode_keys.TRAIN
+ input_shape = (
+ params.maskrcnn_parser.output_size +
+ [params.maskrcnn_parser.num_channels])
+ if is_training:
+ batch_size = params.train.batch_size
+ input_layer = {
+ 'image':
+ tf.keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info':
+ tf.keras.layers.Input(
+ shape=[4, 2],
+ batch_size=batch_size,
+ name='image_info',
+ ),
+ 'gt_boxes':
+ tf.keras.layers.Input(
+ shape=[params.maskrcnn_parser.max_num_instances, 4],
+ batch_size=batch_size,
+ name='gt_boxes'),
+ 'gt_classes':
+ tf.keras.layers.Input(
+ shape=[params.maskrcnn_parser.max_num_instances],
+ batch_size=batch_size,
+ name='gt_classes',
+ dtype=tf.int64),
+ }
+ if self._include_mask:
+ input_layer['gt_masks'] = tf.keras.layers.Input(
+ shape=[
+ params.maskrcnn_parser.max_num_instances,
+ params.maskrcnn_parser.mask_crop_size,
+ params.maskrcnn_parser.mask_crop_size
+ ],
+ batch_size=batch_size,
+ name='gt_masks')
+ else:
+ batch_size = params.eval.batch_size
+ input_layer = {
+ 'image':
+ tf.keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info':
+ tf.keras.layers.Input(
+ shape=[4, 2],
+ batch_size=batch_size,
+ name='image_info',
+ ),
+ }
+ return input_layer
+
+ def build_model(self, params, mode):
+ if self._keras_model is None:
+ input_layers = self.build_input_layers(self._params, mode)
+ with backend.get_graph().as_default():
+ outputs = self.model_outputs(input_layers, mode)
+
+ model = tf.keras.models.Model(
+ inputs=input_layers, outputs=outputs, name='maskrcnn')
+ assert model is not None, 'Fail to build tf.keras.Model.'
+ model.optimizer = self.build_optimizer()
+ self._keras_model = model
+
+ return self._keras_model
+
+ def post_processing(self, labels, outputs):
+ required_output_fields = ['class_outputs', 'box_outputs']
+ for field in required_output_fields:
+ if field not in outputs:
+ raise ValueError('"%s" is missing in outputs, requried %s found %s'
+ %(field, required_output_fields, outputs.keys()))
+ predictions = {
+ 'image_info': labels['image_info'],
+ 'num_detections': outputs['num_detections'],
+ 'detection_boxes': outputs['detection_boxes'],
+ 'detection_classes': outputs['detection_classes'],
+ 'detection_scores': outputs['detection_scores'],
+ }
+ if self._include_mask:
+ predictions.update({
+ 'detection_masks': outputs['detection_masks'],
+ })
+
+ if 'groundtruths' in labels:
+ predictions['source_id'] = labels['groundtruths']['source_id']
+ predictions['gt_source_id'] = labels['groundtruths']['source_id']
+ predictions['gt_height'] = labels['groundtruths']['height']
+ predictions['gt_width'] = labels['groundtruths']['width']
+ predictions['gt_image_info'] = labels['image_info']
+ predictions['gt_num_detections'] = (
+ labels['groundtruths']['num_detections'])
+ predictions['gt_boxes'] = labels['groundtruths']['boxes']
+ predictions['gt_classes'] = labels['groundtruths']['classes']
+ predictions['gt_areas'] = labels['groundtruths']['areas']
+ predictions['gt_is_crowds'] = labels['groundtruths']['is_crowds']
+ return labels, predictions
+
+ def eval_metrics(self):
+ return eval_factory.evaluator_generator(self._params.eval)
diff --git a/models/official/vision/detection/modeling/optimizers.py b/models/official/vision/detection/modeling/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd51bb59f579b3de027cba26ef3bee0e67d0c74f
--- /dev/null
+++ b/models/official/vision/detection/modeling/optimizers.py
@@ -0,0 +1,50 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Optimizers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy as np
+import tensorflow as tf
+
+
+class OptimizerFactory(object):
+ """Class to generate optimizer function."""
+
+ def __init__(self, params):
+ """Creates optimized based on the specified flags."""
+ if params.type == 'momentum':
+ self._optimizer = functools.partial(
+ tf.keras.optimizers.SGD,
+ momentum=params.momentum,
+ nesterov=params.nesterov)
+ elif params.type == 'adam':
+ self._optimizer = tf.keras.optimizers.Adam
+ elif params.type == 'adadelta':
+ self._optimizer = tf.keras.optimizers.Adadelta
+ elif params.type == 'adagrad':
+ self._optimizer = tf.keras.optimizers.Adagrad
+ elif params.type == 'rmsprop':
+ self._optimizer = functools.partial(
+ tf.keras.optimizers.RMSprop, momentum=params.momentum)
+ else:
+ raise ValueError('Unsupported optimizer type `{}`.'.format(params.type))
+
+ def __call__(self, learning_rate):
+ return self._optimizer(learning_rate=learning_rate)
diff --git a/models/official/vision/detection/modeling/retinanet_model.py b/models/official/vision/detection/modeling/retinanet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff299674f0044cd208a1657a962d133744b78b77
--- /dev/null
+++ b/models/official/vision/detection/modeling/retinanet_model.py
@@ -0,0 +1,170 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model defination for the RetinaNet Model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import backend
+from official.vision.detection.dataloader import mode_keys
+from official.vision.detection.evaluation import factory as eval_factory
+from official.vision.detection.modeling import base_model
+from official.vision.detection.modeling import losses
+from official.vision.detection.modeling.architecture import factory
+from official.vision.detection.ops import postprocess_ops
+
+
+class RetinanetModel(base_model.Model):
+ """RetinaNet model function."""
+
+ def __init__(self, params):
+ super(RetinanetModel, self).__init__(params)
+
+ # For eval metrics.
+ self._params = params
+
+ # Architecture generators.
+ self._backbone_fn = factory.backbone_generator(params)
+ self._fpn_fn = factory.multilevel_features_generator(params)
+ self._head_fn = factory.retinanet_head_generator(params)
+
+ # Loss function.
+ self._cls_loss_fn = losses.RetinanetClassLoss(
+ params.retinanet_loss, params.architecture.num_classes)
+ self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
+ self._box_loss_weight = params.retinanet_loss.box_loss_weight
+ self._keras_model = None
+
+ # Predict function.
+ self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
+ params.architecture.min_level,
+ params.architecture.max_level,
+ params.postprocess)
+
+ self._transpose_input = params.train.transpose_input
+ assert not self._transpose_input, 'Transpose input is not supportted.'
+ # Input layer.
+ input_shape = (
+ params.retinanet_parser.output_size +
+ [params.retinanet_parser.num_channels])
+ self._input_layer = tf.keras.layers.Input(
+ shape=input_shape, name='',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
+
+ def build_outputs(self, inputs, mode):
+ # If the input image is transposed (from NHWC to HWCN), we need to revert it
+ # back to the original shape before it's used in the computation.
+ if self._transpose_input:
+ inputs = tf.transpose(inputs, [3, 0, 1, 2])
+
+ backbone_features = self._backbone_fn(
+ inputs, is_training=(mode == mode_keys.TRAIN))
+ fpn_features = self._fpn_fn(
+ backbone_features, is_training=(mode == mode_keys.TRAIN))
+ cls_outputs, box_outputs = self._head_fn(
+ fpn_features, is_training=(mode == mode_keys.TRAIN))
+
+ if self._use_bfloat16:
+ levels = cls_outputs.keys()
+ for level in levels:
+ cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
+ box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
+
+ model_outputs = {
+ 'cls_outputs': cls_outputs,
+ 'box_outputs': box_outputs,
+ }
+ return model_outputs
+
+ def build_loss_fn(self):
+ if self._keras_model is None:
+ raise ValueError('build_loss_fn() must be called after build_model().')
+
+ filter_fn = self.make_filter_trainable_variables_fn()
+ trainable_variables = filter_fn(self._keras_model.trainable_variables)
+
+ def _total_loss_fn(labels, outputs):
+ cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
+ labels['cls_targets'],
+ labels['num_positives'])
+ box_loss = self._box_loss_fn(outputs['box_outputs'],
+ labels['box_targets'],
+ labels['num_positives'])
+ model_loss = cls_loss + self._box_loss_weight * box_loss
+ l2_regularization_loss = self.weight_decay_loss(trainable_variables)
+ total_loss = model_loss + l2_regularization_loss
+ return {
+ 'total_loss': total_loss,
+ 'cls_loss': cls_loss,
+ 'box_loss': box_loss,
+ 'model_loss': model_loss,
+ 'l2_regularization_loss': l2_regularization_loss,
+ }
+
+ return _total_loss_fn
+
+ def build_model(self, params, mode=None):
+ if self._keras_model is None:
+ with backend.get_graph().as_default():
+ outputs = self.model_outputs(self._input_layer, mode)
+
+ model = tf.keras.models.Model(
+ inputs=self._input_layer, outputs=outputs, name='retinanet')
+ assert model is not None, 'Fail to build tf.keras.Model.'
+ model.optimizer = self.build_optimizer()
+ self._keras_model = model
+
+ return self._keras_model
+
+ def post_processing(self, labels, outputs):
+ # TODO(yeqing): Moves the output related part into build_outputs.
+ required_output_fields = ['cls_outputs', 'box_outputs']
+ for field in required_output_fields:
+ if field not in outputs:
+ raise ValueError('"%s" is missing in outputs, requried %s found %s',
+ field, required_output_fields, outputs.keys())
+ required_label_fields = ['image_info', 'groundtruths']
+ for field in required_label_fields:
+ if field not in labels:
+ raise ValueError('"%s" is missing in outputs, requried %s found %s',
+ field, required_label_fields, labels.keys())
+ boxes, scores, classes, valid_detections = self._generate_detections_fn(
+ outputs['box_outputs'], outputs['cls_outputs'],
+ labels['anchor_boxes'], labels['image_info'][:, 1:2, :])
+ # Discards the old output tensors to save memory. The `cls_outputs` and
+ # `box_outputs` are pretty big and could potentiall lead to memory issue.
+ outputs = {
+ 'source_id': labels['groundtruths']['source_id'],
+ 'image_info': labels['image_info'],
+ 'num_detections': valid_detections,
+ 'detection_boxes': boxes,
+ 'detection_classes': classes,
+ 'detection_scores': scores,
+ }
+
+ if 'groundtruths' in labels:
+ labels['source_id'] = labels['groundtruths']['source_id']
+ labels['boxes'] = labels['groundtruths']['boxes']
+ labels['classes'] = labels['groundtruths']['classes']
+ labels['areas'] = labels['groundtruths']['areas']
+ labels['is_crowds'] = labels['groundtruths']['is_crowds']
+
+ return labels, outputs
+
+ def eval_metrics(self):
+ return eval_factory.evaluator_generator(self._params.eval)
diff --git a/models/official/vision/detection/modeling/shapemask_model.py b/models/official/vision/detection/modeling/shapemask_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..174187ed02ae7a7617f259974d64b1906a3d16e0
--- /dev/null
+++ b/models/official/vision/detection/modeling/shapemask_model.py
@@ -0,0 +1,314 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model definition for the ShapeMask Model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import backend
+from official.vision.detection.dataloader import anchor
+from official.vision.detection.dataloader import mode_keys
+from official.vision.detection.evaluation import factory as eval_factory
+from official.vision.detection.modeling import base_model
+from official.vision.detection.modeling import losses
+from official.vision.detection.modeling.architecture import factory
+from official.vision.detection.ops import postprocess_ops
+from official.vision.detection.utils import box_utils
+
+
+class ShapeMaskModel(base_model.Model):
+ """ShapeMask model function."""
+
+ def __init__(self, params):
+ super(ShapeMaskModel, self).__init__(params)
+
+ self._params = params
+ self._keras_model = None
+
+ # Architecture generators.
+ self._backbone_fn = factory.backbone_generator(params)
+ self._fpn_fn = factory.multilevel_features_generator(params)
+ self._retinanet_head_fn = factory.retinanet_head_generator(params)
+ self._shape_prior_head_fn = factory.shapeprior_head_generator(params)
+ self._coarse_mask_fn = factory.coarsemask_head_generator(params)
+ self._fine_mask_fn = factory.finemask_head_generator(params)
+
+ # Loss functions.
+ self._cls_loss_fn = losses.RetinanetClassLoss(
+ params.retinanet_loss, params.architecture.num_classes)
+ self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
+ self._box_loss_weight = params.retinanet_loss.box_loss_weight
+
+ # Mask loss function.
+ self._shapemask_prior_loss_fn = losses.ShapemaskMseLoss()
+ self._shapemask_loss_fn = losses.ShapemaskLoss()
+ self._shape_prior_loss_weight = (
+ params.shapemask_loss.shape_prior_loss_weight)
+ self._coarse_mask_loss_weight = (
+ params.shapemask_loss.coarse_mask_loss_weight)
+ self._fine_mask_loss_weight = (
+ params.shapemask_loss.fine_mask_loss_weight)
+
+ # Predict function.
+ self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
+ params.architecture.min_level,
+ params.architecture.max_level,
+ params.postprocess)
+
+ def build_outputs(self, inputs, mode):
+ is_training = mode == mode_keys.TRAIN
+ images = inputs['image']
+
+ if 'anchor_boxes' in inputs:
+ anchor_boxes = inputs['anchor_boxes']
+ else:
+ anchor_boxes = anchor.Anchor(
+ self._params.architecture.min_level,
+ self._params.architecture.max_level,
+ self._params.anchor.num_scales,
+ self._params.anchor.aspect_ratios,
+ self._params.anchor.anchor_size,
+ images.get_shape().as_list()[1:3]).multilevel_boxes
+
+ batch_size = tf.shape(images)[0]
+ for level in anchor_boxes:
+ anchor_boxes[level] = tf.tile(
+ tf.expand_dims(anchor_boxes[level], 0), [batch_size, 1, 1, 1])
+
+ backbone_features = self._backbone_fn(images, is_training=is_training)
+ fpn_features = self._fpn_fn(backbone_features, is_training=is_training)
+ cls_outputs, box_outputs = self._retinanet_head_fn(
+ fpn_features, is_training=is_training)
+
+ valid_boxes, valid_scores, valid_classes, valid_detections = (
+ self._generate_detections_fn(box_outputs, cls_outputs,
+ anchor_boxes,
+ inputs['image_info'][:, 1:2, :]))
+
+ image_size = images.get_shape().as_list()[1:3]
+ valid_outer_boxes = box_utils.compute_outer_boxes(
+ tf.reshape(valid_boxes, [-1, 4]),
+ image_size,
+ scale=self._params.shapemask_parser.outer_box_scale)
+ valid_outer_boxes = tf.reshape(valid_outer_boxes, tf.shape(valid_boxes))
+
+ # Wrapping if else code paths into a layer to make the checkpoint loadable
+ # in prediction mode.
+ class SampledBoxesLayer(tf.keras.layers.Layer):
+ """ShapeMask model function."""
+
+ def call(self, inputs, val_boxes, val_classes, val_outer_boxes, training):
+ if training:
+ boxes = inputs['mask_boxes']
+ outer_boxes = inputs['mask_outer_boxes']
+ classes = inputs['mask_classes']
+ else:
+ boxes = val_boxes
+ classes = val_classes
+ outer_boxes = val_outer_boxes
+ return boxes, classes, outer_boxes
+
+ boxes, classes, outer_boxes = SampledBoxesLayer()(
+ inputs, valid_boxes, valid_classes,
+ valid_outer_boxes, training=is_training)
+
+ instance_features, prior_masks = self._shape_prior_head_fn(fpn_features,
+ boxes,
+ outer_boxes,
+ classes,
+ is_training)
+ coarse_mask_logits = self._coarse_mask_fn(instance_features,
+ prior_masks,
+ classes,
+ is_training)
+ fine_mask_logits = self._fine_mask_fn(instance_features,
+ coarse_mask_logits,
+ classes,
+ is_training)
+
+ model_outputs = {
+ 'cls_outputs': cls_outputs,
+ 'box_outputs': box_outputs,
+ 'fine_mask_logits': fine_mask_logits,
+ 'coarse_mask_logits': coarse_mask_logits,
+ 'prior_masks': prior_masks,
+ }
+
+ if not is_training:
+ model_outputs.update({
+ 'num_detections': valid_detections,
+ 'detection_boxes': valid_boxes,
+ 'detection_outer_boxes': valid_outer_boxes,
+ 'detection_masks': fine_mask_logits,
+ 'detection_classes': valid_classes,
+ 'detection_scores': valid_scores,
+ })
+
+ return model_outputs
+
+ def build_loss_fn(self):
+ if self._keras_model is None:
+ raise ValueError('build_loss_fn() must be called after build_model().')
+
+ filter_fn = self.make_filter_trainable_variables_fn()
+ trainable_variables = filter_fn(self._keras_model.trainable_variables)
+
+ def _total_loss_fn(labels, outputs):
+ cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
+ labels['cls_targets'],
+ labels['num_positives'])
+ box_loss = self._box_loss_fn(outputs['box_outputs'],
+ labels['box_targets'],
+ labels['num_positives'])
+
+ # Adds Shapemask model losses.
+ shape_prior_loss = self._shapemask_prior_loss_fn(
+ outputs['prior_masks'],
+ labels['mask_targets'],
+ labels['mask_is_valid'])
+ coarse_mask_loss = self._shapemask_loss_fn(
+ outputs['coarse_mask_logits'],
+ labels['mask_targets'],
+ labels['mask_is_valid'])
+ fine_mask_loss = self._shapemask_loss_fn(
+ outputs['fine_mask_logits'],
+ labels['fine_mask_targets'],
+ labels['mask_is_valid'])
+
+ model_loss = (
+ cls_loss + self._box_loss_weight * box_loss +
+ shape_prior_loss * self._shape_prior_loss_weight +
+ coarse_mask_loss * self._coarse_mask_loss_weight +
+ fine_mask_loss * self._fine_mask_loss_weight)
+
+ l2_regularization_loss = self.weight_decay_loss(trainable_variables)
+ total_loss = model_loss + l2_regularization_loss
+
+ shapemask_losses = {
+ 'total_loss': total_loss,
+ 'loss': total_loss,
+ 'retinanet_cls_loss': cls_loss,
+ 'l2_regularization_loss': l2_regularization_loss,
+ 'retinanet_box_loss': box_loss,
+ 'shapemask_prior_loss': shape_prior_loss,
+ 'shapemask_coarse_mask_loss': coarse_mask_loss,
+ 'shapemask_fine_mask_loss': fine_mask_loss,
+ 'model_loss': model_loss,
+ }
+ return shapemask_losses
+
+ return _total_loss_fn
+
+ def build_input_layers(self, params, mode):
+ is_training = mode == mode_keys.TRAIN
+ input_shape = (
+ params.shapemask_parser.output_size +
+ [params.shapemask_parser.num_channels])
+ if is_training:
+ batch_size = params.train.batch_size
+ input_layer = {
+ 'image': tf.keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info': tf.keras.layers.Input(
+ shape=[4, 2],
+ batch_size=batch_size,
+ name='image_info'),
+ 'mask_classes': tf.keras.layers.Input(
+ shape=[params.shapemask_parser.num_sampled_masks],
+ batch_size=batch_size,
+ name='mask_classes',
+ dtype=tf.int64),
+ 'mask_outer_boxes': tf.keras.layers.Input(
+ shape=[params.shapemask_parser.num_sampled_masks, 4],
+ batch_size=batch_size,
+ name='mask_outer_boxes',
+ dtype=tf.float32),
+ 'mask_boxes': tf.keras.layers.Input(
+ shape=[params.shapemask_parser.num_sampled_masks, 4],
+ batch_size=batch_size,
+ name='mask_boxes',
+ dtype=tf.float32),
+ }
+ else:
+ batch_size = params.eval.batch_size
+ input_layer = {
+ 'image': tf.keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info': tf.keras.layers.Input(
+ shape=[4, 2],
+ batch_size=batch_size,
+ name='image_info'),
+ }
+ return input_layer
+
+ def build_model(self, params, mode):
+ if self._keras_model is None:
+ input_layers = self.build_input_layers(self._params, mode)
+ with backend.get_graph().as_default():
+ outputs = self.model_outputs(input_layers, mode)
+
+ model = tf.keras.models.Model(
+ inputs=input_layers, outputs=outputs, name='shapemask')
+ assert model is not None, 'Fail to build tf.keras.Model.'
+ model.optimizer = self.build_optimizer()
+ self._keras_model = model
+
+ return self._keras_model
+
+ def post_processing(self, labels, outputs):
+ required_output_fields = ['num_detections', 'detection_boxes',
+ 'detection_classes', 'detection_masks',
+ 'detection_scores']
+
+ for field in required_output_fields:
+ if field not in outputs:
+ raise ValueError(
+ '"{}" is missing in outputs, requried {} found {}'.format(
+ field, required_output_fields, outputs.keys()))
+
+ required_label_fields = ['image_info']
+ for field in required_label_fields:
+ if field not in labels:
+ raise ValueError(
+ '"{}" is missing in labels, requried {} found {}'.format(
+ field, required_label_fields, labels.keys()))
+
+ predictions = {
+ 'image_info': labels['image_info'],
+ 'num_detections': outputs['num_detections'],
+ 'detection_boxes': outputs['detection_boxes'],
+ 'detection_outer_boxes': outputs['detection_outer_boxes'],
+ 'detection_classes': outputs['detection_classes'],
+ 'detection_scores': outputs['detection_scores'],
+ 'detection_masks': outputs['detection_masks'],
+ }
+
+ if 'groundtruths' in labels:
+ predictions['source_id'] = labels['groundtruths']['source_id']
+ labels = labels['groundtruths']
+
+ return labels, predictions
+
+ def eval_metrics(self):
+ return eval_factory.evaluator_generator(self._params.eval)
diff --git a/models/official/vision/detection/ops/__init__.py b/models/official/vision/detection/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/detection/ops/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/ops/nms.py b/models/official/vision/detection/ops/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc516e5991a824b1d2f8e0261750cde2481fda2f
--- /dev/null
+++ b/models/official/vision/detection/ops/nms.py
@@ -0,0 +1,205 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tensorflow implementation of non max suppression."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.vision.detection.utils import box_utils
+
+
+NMS_TILE_SIZE = 512
+
+
+def _self_suppression(iou, _, iou_sum):
+ batch_size = tf.shape(iou)[0]
+ can_suppress_others = tf.cast(
+ tf.reshape(tf.reduce_max(iou, 1) <= 0.5, [batch_size, -1, 1]), iou.dtype)
+ iou_suppressed = tf.reshape(
+ tf.cast(tf.reduce_max(can_suppress_others * iou, 1) <= 0.5, iou.dtype),
+ [batch_size, -1, 1]) * iou
+ iou_sum_new = tf.reduce_sum(iou_suppressed, [1, 2])
+ return [
+ iou_suppressed,
+ tf.reduce_any(iou_sum - iou_sum_new > 0.5), iou_sum_new
+ ]
+
+
+def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx):
+ batch_size = tf.shape(boxes)[0]
+ new_slice = tf.slice(boxes, [0, inner_idx * NMS_TILE_SIZE, 0],
+ [batch_size, NMS_TILE_SIZE, 4])
+ iou = box_utils.bbox_overlap(new_slice, box_slice)
+ ret_slice = tf.expand_dims(
+ tf.cast(tf.reduce_all(iou < iou_threshold, [1]), box_slice.dtype),
+ 2) * box_slice
+ return boxes, ret_slice, iou_threshold, inner_idx + 1
+
+
+def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
+ """Process boxes in the range [idx*NMS_TILE_SIZE, (idx+1)*NMS_TILE_SIZE).
+
+ Args:
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
+ iou_threshold: a float representing the threshold for deciding whether boxes
+ overlap too much with respect to IOU.
+ output_size: an int32 tensor of size [batch_size]. Representing the number
+ of selected boxes for each batch.
+ idx: an integer scalar representing induction variable.
+
+ Returns:
+ boxes: updated boxes.
+ iou_threshold: pass down iou_threshold to the next iteration.
+ output_size: the updated output_size.
+ idx: the updated induction variable.
+ """
+ num_tiles = tf.shape(boxes)[1] // NMS_TILE_SIZE
+ batch_size = tf.shape(boxes)[0]
+
+ # Iterates over tiles that can possibly suppress the current tile.
+ box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0],
+ [batch_size, NMS_TILE_SIZE, 4])
+ _, box_slice, _, _ = tf.while_loop(
+ lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
+ _cross_suppression, [boxes, box_slice, iou_threshold,
+ tf.constant(0)])
+
+ # Iterates over the current tile to compute self-suppression.
+ iou = box_utils.bbox_overlap(box_slice, box_slice)
+ mask = tf.expand_dims(
+ tf.reshape(tf.range(NMS_TILE_SIZE), [1, -1]) > tf.reshape(
+ tf.range(NMS_TILE_SIZE), [-1, 1]), 0)
+ iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype)
+ suppressed_iou, _, _ = tf.while_loop(
+ lambda _iou, loop_condition, _iou_sum: loop_condition, _self_suppression,
+ [iou, tf.constant(True),
+ tf.reduce_sum(iou, [1, 2])])
+ suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0
+ box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype), 2)
+
+ # Uses box_slice to update the input boxes.
+ mask = tf.reshape(
+ tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype), [1, -1, 1, 1])
+ boxes = tf.tile(tf.expand_dims(
+ box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape(
+ boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask)
+ boxes = tf.reshape(boxes, [batch_size, -1, 4])
+
+ # Updates output_size.
+ output_size += tf.reduce_sum(
+ tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1])
+ return boxes, iou_threshold, output_size, idx + 1
+
+
+def sorted_non_max_suppression_padded(scores,
+ boxes,
+ max_output_size,
+ iou_threshold):
+ """A wrapper that handles non-maximum suppression.
+
+ Assumption:
+ * The boxes are sorted by scores unless the box is a dot (all coordinates
+ are zero).
+ * Boxes with higher scores can be used to suppress boxes with lower scores.
+
+ The overal design of the algorithm is to handle boxes tile-by-tile:
+
+ boxes = boxes.pad_to_multiply_of(tile_size)
+ num_tiles = len(boxes) // tile_size
+ output_boxes = []
+ for i in range(num_tiles):
+ box_tile = boxes[i*tile_size : (i+1)*tile_size]
+ for j in range(i - 1):
+ suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
+ iou = bbox_overlap(box_tile, suppressing_tile)
+ # if the box is suppressed in iou, clear it to a dot
+ box_tile *= _update_boxes(iou)
+ # Iteratively handle the diagnal tile.
+ iou = _box_overlap(box_tile, box_tile)
+ iou_changed = True
+ while iou_changed:
+ # boxes that are not suppressed by anything else
+ suppressing_boxes = _get_suppressing_boxes(iou)
+ # boxes that are suppressed by suppressing_boxes
+ suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
+ # clear iou to 0 for boxes that are suppressed, as they cannot be used
+ # to suppress other boxes any more
+ new_iou = _clear_iou(iou, suppressed_boxes)
+ iou_changed = (new_iou != iou)
+ iou = new_iou
+ # remaining boxes that can still suppress others, are selected boxes.
+ output_boxes.append(_get_suppressing_boxes(iou))
+ if len(output_boxes) >= max_output_size:
+ break
+
+ Args:
+ scores: a tensor with a shape of [batch_size, anchors].
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
+ max_output_size: a scalar integer `Tensor` representing the maximum number
+ of boxes to be selected by non max suppression.
+ iou_threshold: a float representing the threshold for deciding whether boxes
+ overlap too much with respect to IOU.
+
+ Returns:
+ nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
+ dtype as input scores.
+ nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
+ same dtype as input boxes.
+ """
+ batch_size = tf.shape(boxes)[0]
+ num_boxes = tf.shape(boxes)[1]
+ pad = tf.cast(
+ tf.math.ceil(tf.cast(num_boxes, tf.float32) / NMS_TILE_SIZE),
+ tf.int32) * NMS_TILE_SIZE - num_boxes
+ boxes = tf.pad(tf.cast(boxes, tf.float32), [[0, 0], [0, pad], [0, 0]])
+ scores = tf.pad(
+ tf.cast(scores, tf.float32), [[0, 0], [0, pad]], constant_values=-1)
+ num_boxes += pad
+
+ def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
+ return tf.logical_and(
+ tf.reduce_min(output_size) < max_output_size,
+ idx < num_boxes // NMS_TILE_SIZE)
+
+ selected_boxes, _, output_size, _ = tf.while_loop(
+ _loop_cond, _suppression_loop_body, [
+ boxes, iou_threshold,
+ tf.zeros([batch_size], tf.int32),
+ tf.constant(0)
+ ])
+ idx = num_boxes - tf.cast(
+ tf.nn.top_k(
+ tf.cast(tf.reduce_any(selected_boxes > 0, [2]), tf.int32) *
+ tf.expand_dims(tf.range(num_boxes, 0, -1), 0), max_output_size)[0],
+ tf.int32)
+ idx = tf.minimum(idx, num_boxes - 1)
+ idx = tf.reshape(
+ idx + tf.reshape(tf.range(batch_size) * num_boxes, [-1, 1]), [-1])
+ boxes = tf.reshape(
+ tf.gather(tf.reshape(boxes, [-1, 4]), idx),
+ [batch_size, max_output_size, 4])
+ boxes = boxes * tf.cast(
+ tf.reshape(tf.range(max_output_size), [1, -1, 1]) < tf.reshape(
+ output_size, [-1, 1, 1]), boxes.dtype)
+ scores = tf.reshape(
+ tf.gather(tf.reshape(scores, [-1, 1]), idx),
+ [batch_size, max_output_size])
+ scores = scores * tf.cast(
+ tf.reshape(tf.range(max_output_size), [1, -1]) < tf.reshape(
+ output_size, [-1, 1]), scores.dtype)
+ return scores, boxes
diff --git a/models/official/vision/detection/ops/postprocess_ops.py b/models/official/vision/detection/ops/postprocess_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cb06c34ab114d171f30cb52e69d8dc73996e302
--- /dev/null
+++ b/models/official/vision/detection/ops/postprocess_ops.py
@@ -0,0 +1,413 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Post-processing model outputs to generate detection."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import tensorflow as tf
+
+from official.vision.detection.ops import nms
+from official.vision.detection.utils import box_utils
+
+
+def generate_detections_factory(params):
+ """Factory to select function to generate detection."""
+ if params.use_batched_nms:
+ func = functools.partial(
+ _generate_detections_batched,
+ max_total_size=params.max_total_size,
+ nms_iou_threshold=params.nms_iou_threshold,
+ score_threshold=params.score_threshold)
+ else:
+ func = functools.partial(
+ _generate_detections,
+ max_total_size=params.max_total_size,
+ nms_iou_threshold=params.nms_iou_threshold,
+ score_threshold=params.score_threshold,
+ pre_nms_num_boxes=params.pre_nms_num_boxes)
+ return func
+
+
+def _select_top_k_scores(scores_in, pre_nms_num_detections):
+ """Select top_k scores and indices for each class.
+
+ Args:
+ scores_in: a Tensor with shape [batch_size, N, num_classes], which stacks
+ class logit outputs on all feature levels. The N is the number of total
+ anchors on all levels. The num_classes is the number of classes predicted
+ by the model.
+ pre_nms_num_detections: Number of candidates before NMS.
+
+ Returns:
+ scores and indices: Tensors with shape [batch_size, pre_nms_num_detections,
+ num_classes].
+ """
+ batch_size, num_anchors, num_class = scores_in.get_shape().as_list()
+ scores_trans = tf.transpose(scores_in, perm=[0, 2, 1])
+ scores_trans = tf.reshape(scores_trans, [-1, num_anchors])
+
+ top_k_scores, top_k_indices = tf.nn.top_k(
+ scores_trans, k=pre_nms_num_detections, sorted=True)
+
+ top_k_scores = tf.reshape(top_k_scores,
+ [batch_size, num_class, pre_nms_num_detections])
+ top_k_indices = tf.reshape(top_k_indices,
+ [batch_size, num_class, pre_nms_num_detections])
+
+ return tf.transpose(top_k_scores,
+ [0, 2, 1]), tf.transpose(top_k_indices, [0, 2, 1])
+
+
+def _generate_detections(boxes,
+ scores,
+ max_total_size=100,
+ nms_iou_threshold=0.3,
+ score_threshold=0.05,
+ pre_nms_num_boxes=5000):
+ """Generate the final detections given the model outputs.
+
+ This uses classes unrolling with while loop based NMS, could be parralled
+ at batch dimension.
+
+ Args:
+ boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
+ N, 1, 4], which box predictions on all feature levels. The N is the number
+ of total anchors on all levels.
+ scores: a tensor with shape [batch_size, N, num_classes], which stacks class
+ probability on all feature levels. The N is the number of total anchors on
+ all levels. The num_classes is the number of classes predicted by the
+ model. Note that the class_outputs here is the raw score.
+ max_total_size: a scalar representing maximum number of boxes retained over
+ all classes.
+ nms_iou_threshold: a float representing the threshold for deciding whether
+ boxes overlap too much with respect to IOU.
+ score_threshold: a float representing the threshold for deciding when to
+ remove boxes based on score.
+ pre_nms_num_boxes: an int number of top candidate detections per class
+ before NMS.
+
+ Returns:
+ nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
+ representing top detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [batch_size, max_total_size]
+ representing sorted confidence scores for detected boxes. The values are
+ between [0, 1].
+ nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
+ classes for detected boxes.
+ valid_detections: `int` Tensor of shape [batch_size] only the top
+ `valid_detections` boxes are valid detections.
+ """
+ with tf.name_scope('generate_detections'):
+ nmsed_boxes = []
+ nmsed_classes = []
+ nmsed_scores = []
+ valid_detections = []
+ batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
+ _, total_anchors, num_classes = scores.get_shape().as_list()
+ # Selects top pre_nms_num scores and indices before NMS.
+ scores, indices = _select_top_k_scores(
+ scores, min(total_anchors, pre_nms_num_boxes))
+ for i in range(num_classes):
+ boxes_i = boxes[:, :, min(num_classes_for_box - 1, i), :]
+ scores_i = scores[:, :, i]
+ # Obtains pre_nms_num_boxes before running NMS.
+ boxes_i = tf.gather(boxes_i, indices[:, :, i], batch_dims=1, axis=1)
+
+ # Filter out scores.
+ boxes_i, scores_i = box_utils.filter_boxes_by_scores(
+ boxes_i, scores_i, min_score_threshold=score_threshold)
+
+ (nmsed_scores_i, nmsed_boxes_i) = nms.sorted_non_max_suppression_padded(
+ tf.cast(scores_i, tf.float32),
+ tf.cast(boxes_i, tf.float32),
+ max_total_size,
+ iou_threshold=nms_iou_threshold)
+ nmsed_classes_i = tf.fill([batch_size, max_total_size], i)
+ nmsed_boxes.append(nmsed_boxes_i)
+ nmsed_scores.append(nmsed_scores_i)
+ nmsed_classes.append(nmsed_classes_i)
+ nmsed_boxes = tf.concat(nmsed_boxes, axis=1)
+ nmsed_scores = tf.concat(nmsed_scores, axis=1)
+ nmsed_classes = tf.concat(nmsed_classes, axis=1)
+ nmsed_scores, indices = tf.nn.top_k(
+ nmsed_scores, k=max_total_size, sorted=True)
+ nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1)
+ nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1)
+ valid_detections = tf.reduce_sum(
+ input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32), axis=1)
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+def _generate_detections_per_image(boxes,
+ scores,
+ max_total_size=100,
+ nms_iou_threshold=0.3,
+ score_threshold=0.05,
+ pre_nms_num_boxes=5000):
+ """Generate the final detections per image given the model outputs.
+
+ Args:
+ boxes: a tensor with shape [N, num_classes, 4] or [N, 1, 4], which box
+ predictions on all feature levels. The N is the number of total anchors on
+ all levels.
+ scores: a tensor with shape [N, num_classes], which stacks class probability
+ on all feature levels. The N is the number of total anchors on all levels.
+ The num_classes is the number of classes predicted by the model. Note that
+ the class_outputs here is the raw score.
+ max_total_size: a scalar representing maximum number of boxes retained over
+ all classes.
+ nms_iou_threshold: a float representing the threshold for deciding whether
+ boxes overlap too much with respect to IOU.
+ score_threshold: a float representing the threshold for deciding when to
+ remove boxes based on score.
+ pre_nms_num_boxes: an int number of top candidate detections per class
+ before NMS.
+
+ Returns:
+ nms_boxes: `float` Tensor of shape [max_total_size, 4] representing top
+ detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [max_total_size] representing sorted
+ confidence scores for detected boxes. The values are between [0, 1].
+ nms_classes: `int` Tensor of shape [max_total_size] representing classes for
+ detected boxes.
+ valid_detections: `int` Tensor of shape [1] only the top `valid_detections`
+ boxes are valid detections.
+ """
+ nmsed_boxes = []
+ nmsed_scores = []
+ nmsed_classes = []
+ num_classes_for_box = boxes.get_shape().as_list()[1]
+ num_classes = scores.get_shape().as_list()[1]
+ for i in range(num_classes):
+ boxes_i = boxes[:, min(num_classes_for_box - 1, i)]
+ scores_i = scores[:, i]
+
+ # Obtains pre_nms_num_boxes before running NMS.
+ scores_i, indices = tf.nn.top_k(
+ scores_i, k=tf.minimum(tf.shape(input=scores_i)[-1], pre_nms_num_boxes))
+ boxes_i = tf.gather(boxes_i, indices)
+
+ (nmsed_indices_i,
+ nmsed_num_valid_i) = tf.image.non_max_suppression_padded(
+ tf.cast(boxes_i, tf.float32),
+ tf.cast(scores_i, tf.float32),
+ max_total_size,
+ iou_threshold=nms_iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True,
+ name='nms_detections_' + str(i))
+ nmsed_boxes_i = tf.gather(boxes_i, nmsed_indices_i)
+ nmsed_scores_i = tf.gather(scores_i, nmsed_indices_i)
+ # Sets scores of invalid boxes to -1.
+ nmsed_scores_i = tf.where(
+ tf.less(tf.range(max_total_size), [nmsed_num_valid_i]), nmsed_scores_i,
+ -tf.ones_like(nmsed_scores_i))
+ nmsed_classes_i = tf.fill([max_total_size], i)
+ nmsed_boxes.append(nmsed_boxes_i)
+ nmsed_scores.append(nmsed_scores_i)
+ nmsed_classes.append(nmsed_classes_i)
+
+ # Concats results from all classes and sort them.
+ nmsed_boxes = tf.concat(nmsed_boxes, axis=0)
+ nmsed_scores = tf.concat(nmsed_scores, axis=0)
+ nmsed_classes = tf.concat(nmsed_classes, axis=0)
+ nmsed_scores, indices = tf.nn.top_k(
+ nmsed_scores, k=max_total_size, sorted=True)
+ nmsed_boxes = tf.gather(nmsed_boxes, indices)
+ nmsed_classes = tf.gather(nmsed_classes, indices)
+ valid_detections = tf.reduce_sum(
+ input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32))
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+def _generate_detections_batched(boxes,
+ scores,
+ max_total_size,
+ nms_iou_threshold,
+ score_threshold):
+ """Generates detected boxes with scores and classes for one-stage detector.
+
+ The function takes output of multi-level ConvNets and anchor boxes and
+ generates detected boxes. Note that this used batched nms, which is not
+ supported on TPU currently.
+
+ Args:
+ boxes: a tensor with shape [batch_size, N, num_classes, 4] or
+ [batch_size, N, 1, 4], which box predictions on all feature levels. The N
+ is the number of total anchors on all levels.
+ scores: a tensor with shape [batch_size, N, num_classes], which
+ stacks class probability on all feature levels. The N is the number of
+ total anchors on all levels. The num_classes is the number of classes
+ predicted by the model. Note that the class_outputs here is the raw score.
+ max_total_size: a scalar representing maximum number of boxes retained over
+ all classes.
+ nms_iou_threshold: a float representing the threshold for deciding whether
+ boxes overlap too much with respect to IOU.
+ score_threshold: a float representing the threshold for deciding when to
+ remove boxes based on score.
+ Returns:
+ nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
+ representing top detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [batch_size, max_total_size]
+ representing sorted confidence scores for detected boxes. The values are
+ between [0, 1].
+ nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
+ classes for detected boxes.
+ valid_detections: `int` Tensor of shape [batch_size] only the top
+ `valid_detections` boxes are valid detections.
+ """
+ with tf.name_scope('generate_detections'):
+ # TODO(tsungyi): Removes normalization/denomalization once the
+ # tf.image.combined_non_max_suppression is coordinate system agnostic.
+ # Normalizes maximum box cooridinates to 1.
+ normalizer = tf.reduce_max(boxes)
+ boxes /= normalizer
+ (nmsed_boxes, nmsed_scores, nmsed_classes,
+ valid_detections) = tf.image.combined_non_max_suppression(
+ boxes,
+ scores,
+ max_output_size_per_class=max_total_size,
+ max_total_size=max_total_size,
+ iou_threshold=nms_iou_threshold,
+ score_threshold=score_threshold,
+ pad_per_class=False,)
+ # De-normalizes box cooridinates.
+ nmsed_boxes *= normalizer
+ nmsed_classes = tf.cast(nmsed_classes, tf.int32)
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+class MultilevelDetectionGenerator(object):
+ """Generates detected boxes with scores and classes for one-stage detector."""
+
+ def __init__(self, min_level, max_level, params):
+ self._min_level = min_level
+ self._max_level = max_level
+ self._generate_detections = generate_detections_factory(params)
+
+ def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
+ # Collects outputs from all levels into a list.
+ boxes = []
+ scores = []
+ for i in range(self._min_level, self._max_level + 1):
+ box_outputs_i_shape = tf.shape(box_outputs[i])
+ batch_size = box_outputs_i_shape[0]
+ num_anchors_per_locations = box_outputs_i_shape[-1] // 4
+ num_classes = tf.shape(class_outputs[i])[-1] // num_anchors_per_locations
+
+ # Applies score transformation and remove the implicit background class.
+ scores_i = tf.sigmoid(
+ tf.reshape(class_outputs[i], [batch_size, -1, num_classes]))
+ scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1])
+
+ # Box decoding.
+ # The anchor boxes are shared for all data in a batch.
+ # One stage detector only supports class agnostic box regression.
+ anchor_boxes_i = tf.reshape(anchor_boxes[i], [batch_size, -1, 4])
+ box_outputs_i = tf.reshape(box_outputs[i], [batch_size, -1, 4])
+ boxes_i = box_utils.decode_boxes(box_outputs_i, anchor_boxes_i)
+
+ # Box clipping.
+ boxes_i = box_utils.clip_boxes(boxes_i, image_shape)
+
+ boxes.append(boxes_i)
+ scores.append(scores_i)
+ boxes = tf.concat(boxes, axis=1)
+ scores = tf.concat(scores, axis=1)
+
+ nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
+ self._generate_detections(tf.expand_dims(boxes, axis=2), scores))
+
+ # Adds 1 to offset the background class which has index 0.
+ nmsed_classes += 1
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+class GenericDetectionGenerator(object):
+ """Generates the final detected boxes with scores and classes."""
+
+ def __init__(self, params):
+ self._generate_detections = generate_detections_factory(params)
+
+ def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
+ """Generate final detections.
+
+ Args:
+ box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
+ representing the class-specific box coordinates relative to anchors.
+ class_outputs: a tensor of shape of [batch_size, K, num_classes]
+ representing the class logits before applying score activiation.
+ anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
+ corresponding anchor boxes w.r.t `box_outputs`.
+ image_shape: a tensor of shape of [batch_size, 2] storing the image height
+ and width w.r.t. the scaled image, i.e. the same image space as
+ `box_outputs` and `anchor_boxes`.
+
+ Returns:
+ nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
+ representing top detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [batch_size, max_total_size]
+ representing sorted confidence scores for detected boxes. The values are
+ between [0, 1].
+ nms_classes: `int` Tensor of shape [batch_size, max_total_size]
+ representing classes for detected boxes.
+ valid_detections: `int` Tensor of shape [batch_size] only the top
+ `valid_detections` boxes are valid detections.
+ """
+ class_outputs = tf.nn.softmax(class_outputs, axis=-1)
+
+ # Removes the background class.
+ class_outputs_shape = tf.shape(class_outputs)
+ batch_size = class_outputs_shape[0]
+ num_locations = class_outputs_shape[1]
+ num_classes = class_outputs_shape[-1]
+ num_detections = num_locations * (num_classes - 1)
+
+ class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
+ box_outputs = tf.reshape(
+ box_outputs,
+ tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
+ box_outputs = tf.slice(
+ box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
+ anchor_boxes = tf.tile(
+ tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1])
+ box_outputs = tf.reshape(
+ box_outputs,
+ tf.stack([batch_size, num_detections, 4], axis=-1))
+ anchor_boxes = tf.reshape(
+ anchor_boxes,
+ tf.stack([batch_size, num_detections, 4], axis=-1))
+
+ # Box decoding.
+ decoded_boxes = box_utils.decode_boxes(
+ box_outputs, anchor_boxes, weights=[10.0, 10.0, 5.0, 5.0])
+
+ # Box clipping
+ decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)
+
+ decoded_boxes = tf.reshape(
+ decoded_boxes,
+ tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))
+
+ nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
+ self._generate_detections(decoded_boxes, class_outputs))
+
+ # Adds 1 to offset the background class which has index 0.
+ nmsed_classes += 1
+
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
diff --git a/models/official/vision/detection/ops/roi_ops.py b/models/official/vision/detection/ops/roi_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..a21bc7b2882de39b12bc76dacd37047fabac1766
--- /dev/null
+++ b/models/official/vision/detection/ops/roi_ops.py
@@ -0,0 +1,237 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ROI-related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.vision.detection.ops import nms
+from official.vision.detection.utils import box_utils
+
+
+def multilevel_propose_rois(rpn_boxes,
+ rpn_scores,
+ anchor_boxes,
+ image_shape,
+ rpn_pre_nms_top_k=2000,
+ rpn_post_nms_top_k=1000,
+ rpn_nms_threshold=0.7,
+ rpn_score_threshold=0.0,
+ rpn_min_size_threshold=0.0,
+ decode_boxes=True,
+ clip_boxes=True,
+ use_batched_nms=False,
+ apply_sigmoid_to_score=True):
+ """Proposes RoIs given a group of candidates from different FPN levels.
+
+ The following describes the steps:
+ 1. For each individual level:
+ a. Apply sigmoid transform if specified.
+ b. Decode boxes if specified.
+ c. Clip boxes if specified.
+ d. Filter small boxes and those fall outside image if specified.
+ e. Apply pre-NMS filtering including pre-NMS top k and score thresholding.
+ f. Apply NMS.
+ 2. Aggregate post-NMS boxes from each level.
+ 3. Apply an overall top k to generate the final selected RoIs.
+
+ Args:
+ rpn_boxes: a dict with keys representing FPN levels and values representing
+ box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
+ rpn_scores: a dict with keys representing FPN levels and values representing
+ logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
+ anchor_boxes: a dict with keys representing FPN levels and values
+ representing anchor box tensors of shape
+ [batch_size, feature_h, feature_w, num_anchors * 4].
+ image_shape: a tensor of shape [batch_size, 2] where the last dimension are
+ [height, width] of the scaled image.
+ rpn_pre_nms_top_k: an integer of top scoring RPN proposals *per level* to
+ keep before applying NMS. Default: 2000.
+ rpn_post_nms_top_k: an integer of top scoring RPN proposals *in total* to
+ keep after applying NMS. Default: 1000.
+ rpn_nms_threshold: a float between 0 and 1 representing the IoU threshold
+ used for NMS. If 0.0, no NMS is applied. Default: 0.7.
+ rpn_score_threshold: a float between 0 and 1 representing the minimal box
+ score to keep before applying NMS. This is often used as a pre-filtering
+ step for better performance. If 0, no filtering is applied. Default: 0.
+ rpn_min_size_threshold: a float representing the minimal box size in each
+ side (w.r.t. the scaled image) to keep before applying NMS. This is often
+ used as a pre-filtering step for better performance. If 0, no filtering is
+ applied. Default: 0.
+ decode_boxes: a boolean indicating whether `rpn_boxes` needs to be decoded
+ using `anchor_boxes`. If False, use `rpn_boxes` directly and ignore
+ `anchor_boxes`. Default: True.
+ clip_boxes: a boolean indicating whether boxes are first clipped to the
+ scaled image size before appliying NMS. If False, no clipping is applied
+ and `image_shape` is ignored. Default: True.
+ use_batched_nms: a boolean indicating whether NMS is applied in batch using
+ `tf.image.combined_non_max_suppression`. Currently only available in
+ CPU/GPU. Default: False.
+ apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to
+ `rpn_scores` before applying NMS. Default: True.
+
+ Returns:
+ selected_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
+ representing the box coordinates of the selected proposals w.r.t. the
+ scaled image.
+ selected_roi_scores: a tensor of shape [batch_size, rpn_post_nms_top_k, 1],
+ representing the scores of the selected proposals.
+ """
+ with tf.name_scope('multilevel_propose_rois'):
+ rois = []
+ roi_scores = []
+ image_shape = tf.expand_dims(image_shape, axis=1)
+ for level in sorted(rpn_scores.keys()):
+ with tf.name_scope('level_%d' % level):
+ _, feature_h, feature_w, num_anchors_per_location = (
+ rpn_scores[level].get_shape().as_list())
+
+ num_boxes = feature_h * feature_w * num_anchors_per_location
+ this_level_scores = tf.reshape(rpn_scores[level], [-1, num_boxes])
+ this_level_boxes = tf.reshape(rpn_boxes[level], [-1, num_boxes, 4])
+ this_level_anchors = tf.cast(
+ tf.reshape(anchor_boxes[level], [-1, num_boxes, 4]),
+ dtype=this_level_scores.dtype)
+
+ if apply_sigmoid_to_score:
+ this_level_scores = tf.sigmoid(this_level_scores)
+
+ if decode_boxes:
+ this_level_boxes = box_utils.decode_boxes(
+ this_level_boxes, this_level_anchors)
+ if clip_boxes:
+ this_level_boxes = box_utils.clip_boxes(
+ this_level_boxes, image_shape)
+
+ if rpn_min_size_threshold > 0.0:
+ this_level_boxes, this_level_scores = box_utils.filter_boxes(
+ this_level_boxes,
+ this_level_scores,
+ image_shape,
+ rpn_min_size_threshold)
+
+ this_level_pre_nms_top_k = min(num_boxes, rpn_pre_nms_top_k)
+ this_level_post_nms_top_k = min(num_boxes, rpn_post_nms_top_k)
+ if rpn_nms_threshold > 0.0:
+ if use_batched_nms:
+ this_level_rois, this_level_roi_scores, _, _ = (
+ tf.image.combined_non_max_suppression(
+ tf.expand_dims(this_level_boxes, axis=2),
+ tf.expand_dims(this_level_scores, axis=-1),
+ max_output_size_per_class=this_level_pre_nms_top_k,
+ max_total_size=this_level_post_nms_top_k,
+ iou_threshold=rpn_nms_threshold,
+ score_threshold=rpn_score_threshold,
+ pad_per_class=False,
+ clip_boxes=False))
+ else:
+ if rpn_score_threshold > 0.0:
+ this_level_boxes, this_level_scores = (
+ box_utils.filter_boxes_by_scores(
+ this_level_boxes, this_level_scores, rpn_score_threshold))
+ this_level_boxes, this_level_scores = box_utils.top_k_boxes(
+ this_level_boxes, this_level_scores, k=this_level_pre_nms_top_k)
+ this_level_roi_scores, this_level_rois = (
+ nms.sorted_non_max_suppression_padded(
+ this_level_scores,
+ this_level_boxes,
+ max_output_size=this_level_post_nms_top_k,
+ iou_threshold=rpn_nms_threshold))
+ else:
+ this_level_rois, this_level_roi_scores = box_utils.top_k_boxes(
+ this_level_rois,
+ this_level_scores,
+ k=this_level_post_nms_top_k)
+
+ rois.append(this_level_rois)
+ roi_scores.append(this_level_roi_scores)
+
+ all_rois = tf.concat(rois, axis=1)
+ all_roi_scores = tf.concat(roi_scores, axis=1)
+
+ with tf.name_scope('top_k_rois'):
+ _, num_valid_rois = all_roi_scores.get_shape().as_list()
+ overall_top_k = min(num_valid_rois, rpn_post_nms_top_k)
+
+ selected_rois, selected_roi_scores = box_utils.top_k_boxes(
+ all_rois, all_roi_scores, k=overall_top_k)
+
+ return selected_rois, selected_roi_scores
+
+
+class ROIGenerator(object):
+ """Proposes RoIs for the second stage processing."""
+
+ def __init__(self, params):
+ self._rpn_pre_nms_top_k = params.rpn_pre_nms_top_k
+ self._rpn_post_nms_top_k = params.rpn_post_nms_top_k
+ self._rpn_nms_threshold = params.rpn_nms_threshold
+ self._rpn_score_threshold = params.rpn_score_threshold
+ self._rpn_min_size_threshold = params.rpn_min_size_threshold
+ self._test_rpn_pre_nms_top_k = params.test_rpn_pre_nms_top_k
+ self._test_rpn_post_nms_top_k = params.test_rpn_post_nms_top_k
+ self._test_rpn_nms_threshold = params.test_rpn_nms_threshold
+ self._test_rpn_score_threshold = params.test_rpn_score_threshold
+ self._test_rpn_min_size_threshold = params.test_rpn_min_size_threshold
+ self._use_batched_nms = params.use_batched_nms
+
+ def __call__(self, boxes, scores, anchor_boxes, image_shape, is_training):
+ """Generates RoI proposals.
+
+ Args:
+ boxes: a dict with keys representing FPN levels and values representing
+ box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
+ scores: a dict with keys representing FPN levels and values representing
+ logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
+ anchor_boxes: a dict with keys representing FPN levels and values
+ representing anchor box tensors of shape
+ [batch_size, feature_h, feature_w, num_anchors * 4].
+ image_shape: a tensor of shape [batch_size, 2] where the last dimension
+ are [height, width] of the scaled image.
+ is_training: a bool indicating whether it is in training or inference
+ mode.
+
+ Returns:
+ proposed_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
+ representing the box coordinates of the proposed RoIs w.r.t. the
+ scaled image.
+ proposed_roi_scores: a tensor of shape
+ [batch_size, rpn_post_nms_top_k, 1], representing the scores of the
+ proposed RoIs.
+
+ """
+ proposed_rois, proposed_roi_scores = multilevel_propose_rois(
+ boxes,
+ scores,
+ anchor_boxes,
+ image_shape,
+ rpn_pre_nms_top_k=(self._rpn_pre_nms_top_k if is_training
+ else self._test_rpn_pre_nms_top_k),
+ rpn_post_nms_top_k=(self._rpn_post_nms_top_k if is_training
+ else self._test_rpn_post_nms_top_k),
+ rpn_nms_threshold=(self._rpn_nms_threshold if is_training
+ else self._test_rpn_nms_threshold),
+ rpn_score_threshold=(self._rpn_score_threshold if is_training
+ else self._test_rpn_score_threshold),
+ rpn_min_size_threshold=(self._rpn_min_size_threshold if is_training
+ else self._test_rpn_min_size_threshold),
+ decode_boxes=True,
+ clip_boxes=True,
+ use_batched_nms=self._use_batched_nms,
+ apply_sigmoid_to_score=True)
+ return proposed_rois, proposed_roi_scores
diff --git a/models/official/vision/detection/ops/spatial_transform_ops.py b/models/official/vision/detection/ops/spatial_transform_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae60d20f0e8c8454bd7972e851c33b6dca56ed90
--- /dev/null
+++ b/models/official/vision/detection/ops/spatial_transform_ops.py
@@ -0,0 +1,608 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions to performa spatial transformation for Tensor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+_EPSILON = 1e-8
+
+
+def nearest_upsampling(data, scale):
+ """Nearest neighbor upsampling implementation.
+
+ Args:
+ data: A tensor with a shape of [batch, height_in, width_in, channels].
+ scale: An integer multiple to scale resolution of input data.
+ Returns:
+ data_up: A tensor with a shape of
+ [batch, height_in*scale, width_in*scale, channels]. Same dtype as input
+ data.
+ """
+ with tf.name_scope('nearest_upsampling'):
+ bs, _, _, c = data.get_shape().as_list()
+ shape = tf.shape(input=data)
+ h = shape[1]
+ w = shape[2]
+ bs = -1 if bs is None else bs
+ # Uses reshape to quickly upsample the input. The nearest pixel is selected
+ # implicitly via broadcasting.
+ data = tf.reshape(data, [bs, h, 1, w, 1, c]) * tf.ones(
+ [1, 1, scale, 1, scale, 1], dtype=data.dtype)
+ return tf.reshape(data, [bs, h * scale, w * scale, c])
+
+
+def feature_bilinear_interpolation(features, kernel_y, kernel_x):
+ """Feature bilinear interpolation.
+
+ The RoIAlign feature f can be computed by bilinear interpolation
+ of four neighboring feature points f0, f1, f2, and f3.
+
+ f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
+ [f10, f11]]
+ f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11
+ f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11
+ kernel_y = [hy, ly]
+ kernel_x = [hx, lx]
+
+ Args:
+ features: The features are in shape of [batch_size, num_boxes, output_size *
+ 2, output_size * 2, num_filters].
+ kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1].
+ kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1].
+
+ Returns:
+ A 5-D tensor representing feature crop of shape
+ [batch_size, num_boxes, output_size, output_size, num_filters].
+
+ """
+ (batch_size, num_boxes, output_size, _,
+ num_filters) = features.get_shape().as_list()
+ output_size = output_size // 2
+ kernel_y = tf.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1])
+ kernel_x = tf.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2])
+ # Use implicit broadcast to generate the interpolation kernel. The
+ # multiplier `4` is for avg pooling.
+ interpolation_kernel = kernel_y * kernel_x * 4
+
+ # Interpolate the gathered features with computed interpolation kernels.
+ features *= tf.cast(
+ tf.expand_dims(interpolation_kernel, axis=-1), dtype=features.dtype)
+ features = tf.reshape(
+ features,
+ [batch_size * num_boxes, output_size * 2, output_size * 2, num_filters])
+ features = tf.nn.avg_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
+ features = tf.reshape(
+ features, [batch_size, num_boxes, output_size, output_size, num_filters])
+ return features
+
+
+def compute_grid_positions(boxes, boundaries, output_size, sample_offset):
+ """Compute the grid position w.r.t.
+
+ the corresponding feature map.
+
+ Args:
+ boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the
+ information of each box w.r.t. the corresponding feature map.
+ boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left
+ corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float)
+ in terms of the number of pixels of the corresponding feature map size.
+ boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing
+ the boundary (in (y, x)) of the corresponding feature map for each box.
+ Any resampled grid points that go beyond the bounary will be clipped.
+ output_size: a scalar indicating the output crop size.
+ sample_offset: a float number in [0, 1] indicates the subpixel sample offset
+ from grid point.
+
+ Returns:
+ kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1].
+ kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1].
+ box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2]
+ box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2]
+ """
+ batch_size, num_boxes, _ = boxes.get_shape().as_list()
+ box_grid_x = []
+ box_grid_y = []
+ for i in range(output_size):
+ box_grid_x.append(boxes[:, :, 1] +
+ (i + sample_offset) * boxes[:, :, 3] / output_size)
+ box_grid_y.append(boxes[:, :, 0] +
+ (i + sample_offset) * boxes[:, :, 2] / output_size)
+ box_grid_x = tf.stack(box_grid_x, axis=2)
+ box_grid_y = tf.stack(box_grid_y, axis=2)
+
+ box_grid_y0 = tf.floor(box_grid_y)
+ box_grid_x0 = tf.floor(box_grid_x)
+ box_grid_x0 = tf.maximum(0., box_grid_x0)
+ box_grid_y0 = tf.maximum(0., box_grid_y0)
+
+ box_grid_x0 = tf.minimum(box_grid_x0, tf.expand_dims(boundaries[:, :, 1], -1))
+ box_grid_x1 = tf.minimum(box_grid_x0 + 1,
+ tf.expand_dims(boundaries[:, :, 1], -1))
+ box_grid_y0 = tf.minimum(box_grid_y0, tf.expand_dims(boundaries[:, :, 0], -1))
+ box_grid_y1 = tf.minimum(box_grid_y0 + 1,
+ tf.expand_dims(boundaries[:, :, 0], -1))
+
+ box_gridx0x1 = tf.stack([box_grid_x0, box_grid_x1], axis=-1)
+ box_gridy0y1 = tf.stack([box_grid_y0, box_grid_y1], axis=-1)
+
+ # The RoIAlign feature f can be computed by bilinear interpolation of four
+ # neighboring feature points f0, f1, f2, and f3.
+ # f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
+ # [f10, f11]]
+ # f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11
+ # f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11
+ ly = box_grid_y - box_grid_y0
+ lx = box_grid_x - box_grid_x0
+ hy = 1.0 - ly
+ hx = 1.0 - lx
+ kernel_y = tf.reshape(
+ tf.stack([hy, ly], axis=3), [batch_size, num_boxes, output_size, 2, 1])
+ kernel_x = tf.reshape(
+ tf.stack([hx, lx], axis=3), [batch_size, num_boxes, output_size, 2, 1])
+ return kernel_y, kernel_x, box_gridy0y1, box_gridx0x1
+
+
+def get_grid_one_hot(box_gridy0y1, box_gridx0x1, feature_height, feature_width):
+ """Get grid_one_hot from indices and feature_size."""
+ (batch_size, num_boxes, output_size, _) = box_gridx0x1.get_shape().as_list()
+ y_indices = tf.cast(
+ tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size, 2]),
+ dtype=tf.int32)
+ x_indices = tf.cast(
+ tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size, 2]),
+ dtype=tf.int32)
+
+ # shape is [batch_size, num_boxes, output_size, 2, height]
+ grid_y_one_hot = tf.one_hot(tf.cast(y_indices, tf.int32), feature_height)
+ # shape is [batch_size, num_boxes, output_size, 2, width]
+ grid_x_one_hot = tf.one_hot(tf.cast(x_indices, tf.int32), feature_width)
+
+ return grid_y_one_hot, grid_x_one_hot
+
+
+def selective_crop_and_resize(features,
+ boxes,
+ box_levels,
+ boundaries,
+ output_size=7,
+ sample_offset=0.5,
+ use_einsum_gather=False):
+ """Crop and resize boxes on a set of feature maps.
+
+ Given multiple features maps indexed by different levels, and a set of boxes
+ where each box is mapped to a certain level, it selectively crops and resizes
+ boxes from the corresponding feature maps to generate the box features.
+
+ We follow the ROIAlign technique (see https://arxiv.org/pdf/1703.06870.pdf,
+ figure 3 for reference). Specifically, for each feature map, we select an
+ (output_size, output_size) set of pixels corresponding to the box location,
+ and then use bilinear interpolation to select the feature value for each
+ pixel.
+
+ For performance, we perform the gather and interpolation on all layers as a
+ single operation. In this op the multi-level features are first stacked and
+ gathered into [2*output_size, 2*output_size] feature points. Then bilinear
+ interpolation is performed on the gathered feature points to generate
+ [output_size, output_size] RoIAlign feature map.
+
+ Here is the step-by-step algorithm:
+ 1. The multi-level features are gathered into a
+ [batch_size, num_boxes, output_size*2, output_size*2, num_filters]
+ Tensor. The Tensor contains four neighboring feature points for each
+ vertice in the output grid.
+ 2. Compute the interpolation kernel of shape
+ [batch_size, num_boxes, output_size*2, output_size*2]. The last 2 axis
+ can be seen as stacking 2x2 interpolation kernels for all vertices in the
+ output grid.
+ 3. Element-wise multiply the gathered features and interpolation kernel.
+ Then apply 2x2 average pooling to reduce spatial dimension to
+ output_size.
+
+ Args:
+ features: a 5-D tensor of shape [batch_size, num_levels, max_height,
+ max_width, num_filters] where cropping and resizing are based.
+ boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the
+ information of each box w.r.t. the corresponding feature map.
+ boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left
+ corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float)
+ in terms of the number of pixels of the corresponding feature map size.
+ box_levels: a 3-D tensor of shape [batch_size, num_boxes, 1] representing
+ the 0-based corresponding feature level index of each box.
+ boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing
+ the boundary (in (y, x)) of the corresponding feature map for each box.
+ Any resampled grid points that go beyond the bounary will be clipped.
+ output_size: a scalar indicating the output crop size.
+ sample_offset: a float number in [0, 1] indicates the subpixel sample offset
+ from grid point.
+ use_einsum_gather: use einsum to replace gather or not. Replacing einsum
+ with gather can improve performance when feature size is not large, einsum
+ is friendly with model partition as well. Gather's performance is better
+ when feature size is very large and there are multiple box levels.
+
+ Returns:
+ features_per_box: a 5-D tensor of shape
+ [batch_size, num_boxes, output_size, output_size, num_filters]
+ representing the cropped features.
+ """
+ (batch_size, num_levels, max_feature_height, max_feature_width,
+ num_filters) = features.get_shape().as_list()
+ _, num_boxes, _ = boxes.get_shape().as_list()
+
+ kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = compute_grid_positions(
+ boxes, boundaries, output_size, sample_offset)
+ x_indices = tf.cast(
+ tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
+ dtype=tf.int32)
+ y_indices = tf.cast(
+ tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
+ dtype=tf.int32)
+
+ if use_einsum_gather:
+ # Blinear interpolation is done during the last two gathers:
+ # f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
+ # [f10, f11]]
+ # [[f00, f01],
+ # [f10, f11]] = tf.einsum(tf.einsum(features, y_one_hot), x_one_hot)
+ # where [hy, ly] and [hx, lx] are the bilinear interpolation kernel.
+
+ # shape is [batch_size, boxes, output_size, 2, 1]
+ grid_y_one_hot, grid_x_one_hot = get_grid_one_hot(box_gridy0y1,
+ box_gridx0x1,
+ max_feature_height,
+ max_feature_width)
+
+ # shape is [batch_size, num_boxes, output_size, height]
+ grid_y_weight = tf.reduce_sum(
+ tf.multiply(grid_y_one_hot, kernel_y), axis=-2)
+ # shape is [batch_size, num_boxes, output_size, width]
+ grid_x_weight = tf.reduce_sum(
+ tf.multiply(grid_x_one_hot, kernel_x), axis=-2)
+
+ # Gather for y_axis.
+ # shape is [batch_size, num_boxes, output_size, width, features]
+ features_per_box = tf.einsum('bmhwf,bmoh->bmowf', features,
+ tf.cast(grid_y_weight, features.dtype))
+ # Gather for x_axis.
+ # shape is [batch_size, num_boxes, output_size, output_size, features]
+ features_per_box = tf.einsum('bmhwf,bmow->bmhof', features_per_box,
+ tf.cast(grid_x_weight, features.dtype))
+ else:
+ height_dim_offset = max_feature_width
+ level_dim_offset = max_feature_height * height_dim_offset
+ batch_dim_offset = num_levels * level_dim_offset
+
+ batch_size_offset = tf.tile(
+ tf.reshape(
+ tf.range(batch_size) * batch_dim_offset, [batch_size, 1, 1, 1]),
+ [1, num_boxes, output_size * 2, output_size * 2])
+ box_levels_offset = tf.tile(
+ tf.reshape(box_levels * level_dim_offset,
+ [batch_size, num_boxes, 1, 1]),
+ [1, 1, output_size * 2, output_size * 2])
+ y_indices_offset = tf.tile(
+ tf.reshape(y_indices * height_dim_offset,
+ [batch_size, num_boxes, output_size * 2, 1]),
+ [1, 1, 1, output_size * 2])
+ x_indices_offset = tf.tile(
+ tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
+ [1, 1, output_size * 2, 1])
+
+ indices = tf.reshape(
+ batch_size_offset + box_levels_offset + y_indices_offset +
+ x_indices_offset, [-1])
+
+ features = tf.reshape(features, [-1, num_filters])
+ # TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
+ # performance.
+ features_per_box = tf.reshape(
+ tf.gather(features, indices),
+ [batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])
+ features_per_box = feature_bilinear_interpolation(features_per_box,
+ kernel_y, kernel_x)
+
+ return features_per_box
+
+
+def multilevel_crop_and_resize(features, boxes, output_size=7):
+ """Crop and resize on multilevel feature pyramid.
+
+ Generate the (output_size, output_size) set of pixels for each input box
+ by first locating the box into the correct feature level, and then cropping
+ and resizing it using the correspoding feature map of that level.
+
+ Args:
+ features: A dictionary with key as pyramid level and value as features. The
+ features are in shape of [batch_size, height_l, width_l, num_filters].
+ boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row represents
+ a box with [y1, x1, y2, x2] in un-normalized coordinates.
+ output_size: A scalar to indicate the output crop size.
+
+ Returns:
+ A 5-D tensor representing feature crop of shape
+ [batch_size, num_boxes, output_size, output_size, num_filters].
+ """
+
+ with tf.name_scope('multilevel_crop_and_resize'):
+ levels = list(features.keys())
+ min_level = min(levels)
+ max_level = max(levels)
+ batch_size, max_feature_height, max_feature_width, num_filters = (
+ features[min_level].get_shape().as_list())
+ _, num_boxes, _ = boxes.get_shape().as_list()
+
+ # Stack feature pyramid into a features_all of shape
+ # [batch_size, levels, height, width, num_filters].
+ features_all = []
+ feature_heights = []
+ feature_widths = []
+ for level in range(min_level, max_level + 1):
+ shape = features[level].get_shape().as_list()
+ feature_heights.append(shape[1])
+ feature_widths.append(shape[2])
+ # Concat tensor of [batch_size, height_l * width_l, num_filters] for each
+ # levels.
+ features_all.append(
+ tf.reshape(features[level], [batch_size, -1, num_filters]))
+ features_r2 = tf.reshape(tf.concat(features_all, 1), [-1, num_filters])
+
+ # Calculate height_l * width_l for each level.
+ level_dim_sizes = [
+ feature_widths[i] * feature_heights[i]
+ for i in range(len(feature_widths))
+ ]
+ # level_dim_offsets is accumulated sum of level_dim_size.
+ level_dim_offsets = [0]
+ for i in range(len(feature_widths) - 1):
+ level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
+ batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]
+ level_dim_offsets = tf.constant(level_dim_offsets, tf.int32)
+ height_dim_sizes = tf.constant(feature_widths, tf.int32)
+
+ # Assigns boxes to the right level.
+ box_width = boxes[:, :, 3] - boxes[:, :, 1]
+ box_height = boxes[:, :, 2] - boxes[:, :, 0]
+ areas_sqrt = tf.sqrt(box_height * box_width)
+ levels = tf.cast(
+ tf.math.floordiv(
+ tf.math.log(tf.divide(areas_sqrt, 224.0)), tf.math.log(2.0)) +
+ 4.0,
+ dtype=tf.int32)
+ # Maps levels between [min_level, max_level].
+ levels = tf.minimum(max_level, tf.maximum(levels, min_level))
+
+ # Projects box location and sizes to corresponding feature levels.
+ scale_to_level = tf.cast(
+ tf.pow(tf.constant(2.0), tf.cast(levels, tf.float32)),
+ dtype=boxes.dtype)
+ boxes /= tf.expand_dims(scale_to_level, axis=2)
+ box_width /= scale_to_level
+ box_height /= scale_to_level
+ boxes = tf.concat([boxes[:, :, 0:2],
+ tf.expand_dims(box_height, -1),
+ tf.expand_dims(box_width, -1)], axis=-1)
+
+ # Maps levels to [0, max_level-min_level].
+ levels -= min_level
+ level_strides = tf.pow([[2.0]], tf.cast(levels, tf.float32))
+ boundary = tf.cast(
+ tf.concat([
+ tf.expand_dims(
+ [[tf.cast(max_feature_height, tf.float32)]] / level_strides - 1,
+ axis=-1),
+ tf.expand_dims(
+ [[tf.cast(max_feature_width, tf.float32)]] / level_strides - 1,
+ axis=-1),
+ ],
+ axis=-1), boxes.dtype)
+
+ # Compute grid positions.
+ kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = compute_grid_positions(
+ boxes, boundary, output_size, sample_offset=0.5)
+
+ x_indices = tf.cast(
+ tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
+ dtype=tf.int32)
+ y_indices = tf.cast(
+ tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
+ dtype=tf.int32)
+
+ batch_size_offset = tf.tile(
+ tf.reshape(
+ tf.range(batch_size) * batch_dim_size, [batch_size, 1, 1, 1]),
+ [1, num_boxes, output_size * 2, output_size * 2])
+ # Get level offset for each box. Each box belongs to one level.
+ levels_offset = tf.tile(
+ tf.reshape(
+ tf.gather(level_dim_offsets, levels),
+ [batch_size, num_boxes, 1, 1]),
+ [1, 1, output_size * 2, output_size * 2])
+ y_indices_offset = tf.tile(
+ tf.reshape(
+ y_indices * tf.expand_dims(tf.gather(height_dim_sizes, levels), -1),
+ [batch_size, num_boxes, output_size * 2, 1]),
+ [1, 1, 1, output_size * 2])
+ x_indices_offset = tf.tile(
+ tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
+ [1, 1, output_size * 2, 1])
+ indices = tf.reshape(
+ batch_size_offset + levels_offset + y_indices_offset + x_indices_offset,
+ [-1])
+
+ # TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
+ # performance.
+ features_per_box = tf.reshape(
+ tf.gather(features_r2, indices),
+ [batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])
+
+ # Bilinear interpolation.
+ features_per_box = feature_bilinear_interpolation(features_per_box,
+ kernel_y, kernel_x)
+ return features_per_box
+
+
+def single_level_feature_crop(features, level_boxes, detection_prior_levels,
+ min_mask_level, mask_crop_size):
+ """Crop the FPN features at the appropriate levels for each detection.
+
+
+ Args:
+ features: a float tensor of shape [batch_size, num_levels,
+ max_feature_size, max_feature_size, num_downsample_channels].
+ level_boxes: a float Tensor of the level boxes to crop from.
+ [batch_size, num_instances, 4].
+ detection_prior_levels: an int Tensor of instance assigned level of shape
+ [batch_size, num_instances].
+ min_mask_level: minimum FPN level to crop mask feature from.
+ mask_crop_size: an int of mask crop size.
+
+ Returns:
+ crop_features: a float Tensor of shape [batch_size * num_instances,
+ mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
+ instance feature crop.
+ """
+ (batch_size, num_levels, max_feature_size,
+ _, num_downsample_channels) = features.get_shape().as_list()
+ _, num_of_instances, _ = level_boxes.get_shape().as_list()
+ level_boxes = tf.cast(level_boxes, tf.int32)
+ assert num_of_instances == detection_prior_levels.get_shape().as_list()[1]
+
+ x_start_indices = level_boxes[:, :, 1]
+ y_start_indices = level_boxes[:, :, 0]
+ # generate the full indices (not just the starting index)
+ x_idx_list = []
+ y_idx_list = []
+ for i in range(mask_crop_size):
+ x_idx_list.append(x_start_indices + i)
+ y_idx_list.append(y_start_indices + i)
+
+ x_indices = tf.stack(x_idx_list, axis=2)
+ y_indices = tf.stack(y_idx_list, axis=2)
+ levels = detection_prior_levels - min_mask_level
+ height_dim_size = max_feature_size
+ level_dim_size = max_feature_size * height_dim_size
+ batch_dim_size = num_levels * level_dim_size
+ # TODO(weicheng) change this to gather_nd for better readability.
+ indices = tf.reshape(
+ tf.tile(
+ tf.reshape(
+ tf.range(batch_size) * batch_dim_size,
+ [batch_size, 1, 1, 1]),
+ [1, num_of_instances,
+ mask_crop_size, mask_crop_size]) +
+ tf.tile(
+ tf.reshape(levels * level_dim_size,
+ [batch_size, num_of_instances, 1, 1]),
+ [1, 1, mask_crop_size, mask_crop_size]) +
+ tf.tile(
+ tf.reshape(y_indices * height_dim_size,
+ [batch_size, num_of_instances,
+ mask_crop_size, 1]),
+ [1, 1, 1, mask_crop_size]) +
+ tf.tile(
+ tf.reshape(x_indices,
+ [batch_size, num_of_instances,
+ 1, mask_crop_size]),
+ [1, 1, mask_crop_size, 1]), [-1])
+
+ features_r2 = tf.reshape(features,
+ [-1, num_downsample_channels])
+ crop_features = tf.reshape(
+ tf.gather(features_r2, indices),
+ [batch_size * num_of_instances,
+ mask_crop_size, mask_crop_size,
+ num_downsample_channels])
+
+ return crop_features
+
+
+def crop_mask_in_target_box(masks,
+ boxes,
+ target_boxes,
+ output_size,
+ sample_offset=0,
+ use_einsum=True):
+ """Crop masks in target boxes.
+
+ Args:
+ masks: A tensor with a shape of [batch_size, num_masks, height, width].
+ boxes: a float tensor representing box cooridnates that tightly enclose
+ masks with a shape of [batch_size, num_masks, 4] in un-normalized
+ coordinates. A box is represented by [ymin, xmin, ymax, xmax].
+ target_boxes: a float tensor representing target box cooridnates for
+ masks with a shape of [batch_size, num_masks, 4] in un-normalized
+ coordinates. A box is represented by [ymin, xmin, ymax, xmax].
+ output_size: A scalar to indicate the output crop size. It currently only
+ supports to output a square shape outputs.
+ sample_offset: a float number in [0, 1] indicates the subpixel sample offset
+ from grid point.
+ use_einsum: Use einsum to replace gather in selective_crop_and_resize.
+
+ Returns:
+ A 4-D tensor representing feature crop of shape
+ [batch_size, num_boxes, output_size, output_size].
+ """
+ with tf.name_scope('crop_mask_in_target_box'):
+ batch_size, num_masks, height, width = masks.get_shape().as_list()
+ masks = tf.reshape(masks, [batch_size*num_masks, height, width, 1])
+ # Pad zeros on the boundary of masks.
+ masks = tf.image.pad_to_bounding_box(masks, 2, 2, height + 4, width + 4)
+ masks = tf.reshape(masks, [batch_size, num_masks, height+4, width+4, 1])
+
+ # Projects target box locations and sizes to corresponding cropped
+ # mask coordinates.
+ gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(
+ value=boxes, num_or_size_splits=4, axis=2)
+ bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(
+ value=target_boxes, num_or_size_splits=4, axis=2)
+ y_transform = (bb_y_min - gt_y_min) * height / (
+ gt_y_max - gt_y_min + _EPSILON) + 2
+ x_transform = (bb_x_min - gt_x_min) * height / (
+ gt_x_max - gt_x_min + _EPSILON) + 2
+ h_transform = (bb_y_max - bb_y_min) * width / (
+ gt_y_max - gt_y_min + _EPSILON)
+ w_transform = (bb_x_max - bb_x_min) * width / (
+ gt_x_max - gt_x_min + _EPSILON)
+
+ boundaries = tf.concat([
+ tf.cast(
+ tf.ones_like(y_transform) * ((height + 4) - 1), dtype=tf.float32),
+ tf.cast(
+ tf.ones_like(x_transform) * ((width + 4) - 1), dtype=tf.float32)
+ ],
+ axis=-1)
+
+ # Reshape tensors to have the right shape for selective_crop_and_resize.
+ trasnformed_boxes = tf.concat(
+ [y_transform, x_transform, h_transform, w_transform], -1)
+ levels = tf.tile(tf.reshape(tf.range(num_masks), [1, num_masks]),
+ [batch_size, 1])
+
+ cropped_masks = selective_crop_and_resize(
+ masks,
+ trasnformed_boxes,
+ levels,
+ boundaries,
+ output_size,
+ sample_offset=sample_offset,
+ use_einsum_gather=use_einsum)
+ cropped_masks = tf.squeeze(cropped_masks, axis=-1)
+
+ return cropped_masks
diff --git a/models/official/vision/detection/ops/target_ops.py b/models/official/vision/detection/ops/target_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a7d6856511f846365041527f2532c8f2b376244
--- /dev/null
+++ b/models/official/vision/detection/ops/target_ops.py
@@ -0,0 +1,399 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Target and sampling related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.vision.detection.ops import spatial_transform_ops
+from official.vision.detection.utils import box_utils
+from official.vision.detection.utils.object_detection import balanced_positive_negative_sampler
+
+
+def box_matching(boxes, gt_boxes, gt_classes):
+ """Match boxes to groundtruth boxes.
+
+ Given the proposal boxes and the groundtruth boxes and classes, perform the
+ groundtruth matching by taking the argmax of the IoU between boxes and
+ groundtruth boxes.
+
+ Args:
+ boxes: a tensor of shape of [batch_size, N, 4] representing the box
+ coordiantes to be matched to groundtruth boxes.
+ gt_boxes: a tensor of shape of [batch_size, MAX_INSTANCES, 4] representing
+ the groundtruth box coordinates. It is padded with -1s to indicate the
+ invalid boxes.
+ gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
+ classes. It is padded with -1s to indicate the invalid classes.
+
+ Returns:
+ matched_gt_boxes: a tensor of shape of [batch_size, N, 4], representing
+ the matched groundtruth box coordinates for each input box. If the box
+ does not overlap with any groundtruth boxes, the matched boxes of it
+ will be set to all 0s.
+ matched_gt_classes: a tensor of shape of [batch_size, N], representing
+ the matched groundtruth classes for each input box. If the box does not
+ overlap with any groundtruth boxes, the matched box classes of it will
+ be set to 0, which corresponds to the background class.
+ matched_gt_indices: a tensor of shape of [batch_size, N], representing
+ the indices of the matched groundtruth boxes in the original gt_boxes
+ tensor. If the box does not overlap with any groundtruth boxes, the
+ index of the matched groundtruth will be set to -1.
+ matched_iou: a tensor of shape of [batch_size, N], representing the IoU
+ between the box and its matched groundtruth box. The matched IoU is the
+ maximum IoU of the box and all the groundtruth boxes.
+ iou: a tensor of shape of [batch_size, N, K], representing the IoU matrix
+ between boxes and the groundtruth boxes. The IoU between a box and the
+ invalid groundtruth boxes whose coordinates are [-1, -1, -1, -1] is -1.
+ """
+ # Compute IoU between boxes and gt_boxes.
+ # iou <- [batch_size, N, K]
+ iou = box_utils.bbox_overlap(boxes, gt_boxes)
+
+ # max_iou <- [batch_size, N]
+ # 0.0 -> no match to gt, or -1.0 match to no gt
+ matched_iou = tf.reduce_max(iou, axis=-1)
+
+ # background_box_mask <- bool, [batch_size, N]
+ background_box_mask = tf.less_equal(matched_iou, 0.0)
+
+ argmax_iou_indices = tf.argmax(iou, axis=-1, output_type=tf.int32)
+
+ argmax_iou_indices_shape = tf.shape(argmax_iou_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(argmax_iou_indices_shape[0]), axis=-1) *
+ tf.ones([1, argmax_iou_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_indices = tf.stack([batch_indices, argmax_iou_indices], axis=-1)
+
+ matched_gt_boxes = tf.gather_nd(gt_boxes, gather_nd_indices)
+ matched_gt_boxes = tf.where(
+ tf.tile(tf.expand_dims(background_box_mask, axis=-1), [1, 1, 4]),
+ tf.zeros_like(matched_gt_boxes, dtype=matched_gt_boxes.dtype),
+ matched_gt_boxes)
+
+ matched_gt_classes = tf.gather_nd(gt_classes, gather_nd_indices)
+ matched_gt_classes = tf.where(
+ background_box_mask,
+ tf.zeros_like(matched_gt_classes),
+ matched_gt_classes)
+
+ matched_gt_indices = tf.where(
+ background_box_mask,
+ -tf.ones_like(argmax_iou_indices),
+ argmax_iou_indices)
+
+ return (matched_gt_boxes, matched_gt_classes, matched_gt_indices,
+ matched_iou, iou)
+
+
+def assign_and_sample_proposals(proposed_boxes,
+ gt_boxes,
+ gt_classes,
+ num_samples_per_image=512,
+ mix_gt_boxes=True,
+ fg_fraction=0.25,
+ fg_iou_thresh=0.5,
+ bg_iou_thresh_hi=0.5,
+ bg_iou_thresh_lo=0.0):
+ """Assigns the proposals with groundtruth classes and performs subsmpling.
+
+ Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
+ following algorithm to generate the final `num_samples_per_image` RoIs.
+ 1. Calculates the IoU between each proposal box and each gt_boxes.
+ 2. Assigns each proposed box with a groundtruth class and box by choosing
+ the largest IoU overlap.
+ 3. Samples `num_samples_per_image` boxes from all proposed boxes, and
+ returns box_targets, class_targets, and RoIs.
+
+ Args:
+ proposed_boxes: a tensor of shape of [batch_size, N, 4]. N is the number
+ of proposals before groundtruth assignment. The last dimension is the
+ box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
+ format.
+ gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4].
+ The coordinates of gt_boxes are in the pixel coordinates of the scaled
+ image. This tensor might have padding of values -1 indicating the invalid
+ box coordinates.
+ gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
+ tensor might have paddings with values of -1 indicating the invalid
+ classes.
+ num_samples_per_image: a integer represents RoI minibatch size per image.
+ mix_gt_boxes: a bool indicating whether to mix the groundtruth boxes before
+ sampling proposals.
+ fg_fraction: a float represents the target fraction of RoI minibatch that
+ is labeled foreground (i.e., class > 0).
+ fg_iou_thresh: a float represents the IoU overlap threshold for an RoI to be
+ considered foreground (if >= fg_iou_thresh).
+ bg_iou_thresh_hi: a float represents the IoU overlap threshold for an RoI to
+ be considered background (class = 0 if overlap in [LO, HI)).
+ bg_iou_thresh_lo: a float represents the IoU overlap threshold for an RoI to
+ be considered background (class = 0 if overlap in [LO, HI)).
+
+ Returns:
+ sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
+ coordinates of the sampled RoIs, where K is the number of the sampled
+ RoIs, i.e. K = num_samples_per_image.
+ sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
+ box coordinates of the matched groundtruth boxes of the samples RoIs.
+ sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
+ classes of the matched groundtruth boxes of the sampled RoIs.
+ sampled_gt_indices: a tensor of shape of [batch_size, K], storing the
+ indices of the sampled groudntruth boxes in the original `gt_boxes`
+ tensor, i.e. gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
+ """
+
+ with tf.name_scope('sample_proposals'):
+ if mix_gt_boxes:
+ boxes = tf.concat([proposed_boxes, gt_boxes], axis=1)
+ else:
+ boxes = proposed_boxes
+
+ (matched_gt_boxes, matched_gt_classes, matched_gt_indices,
+ matched_iou, _) = box_matching(boxes, gt_boxes, gt_classes)
+
+ positive_match = tf.greater(matched_iou, fg_iou_thresh)
+ negative_match = tf.logical_and(
+ tf.greater_equal(matched_iou, bg_iou_thresh_lo),
+ tf.less(matched_iou, bg_iou_thresh_hi))
+ ignored_match = tf.less(matched_iou, 0.0)
+
+ # re-assign negatively matched boxes to the background class.
+ matched_gt_classes = tf.where(
+ negative_match, tf.zeros_like(matched_gt_classes), matched_gt_classes)
+ matched_gt_indices = tf.where(
+ negative_match, tf.zeros_like(matched_gt_indices), matched_gt_indices)
+
+ sample_candidates = tf.logical_and(
+ tf.logical_or(positive_match, negative_match),
+ tf.logical_not(ignored_match))
+
+ sampler = (
+ balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
+ positive_fraction=fg_fraction, is_static=True))
+
+ batch_size, _ = sample_candidates.get_shape().as_list()
+ sampled_indicators = []
+ for i in range(batch_size):
+ sampled_indicator = sampler.subsample(
+ sample_candidates[i], num_samples_per_image, positive_match[i])
+ sampled_indicators.append(sampled_indicator)
+ sampled_indicators = tf.stack(sampled_indicators)
+ _, sampled_indices = tf.nn.top_k(
+ tf.cast(sampled_indicators, dtype=tf.int32),
+ k=num_samples_per_image,
+ sorted=True)
+
+ sampled_indices_shape = tf.shape(sampled_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(sampled_indices_shape[0]), axis=-1) *
+ tf.ones([1, sampled_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_indices = tf.stack([batch_indices, sampled_indices], axis=-1)
+
+ sampled_rois = tf.gather_nd(boxes, gather_nd_indices)
+ sampled_gt_boxes = tf.gather_nd(matched_gt_boxes, gather_nd_indices)
+ sampled_gt_classes = tf.gather_nd(
+ matched_gt_classes, gather_nd_indices)
+ sampled_gt_indices = tf.gather_nd(
+ matched_gt_indices, gather_nd_indices)
+
+ return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
+ sampled_gt_indices)
+
+
+def sample_and_crop_foreground_masks(candidate_rois,
+ candidate_gt_boxes,
+ candidate_gt_classes,
+ candidate_gt_indices,
+ gt_masks,
+ num_mask_samples_per_image=128,
+ mask_target_size=28):
+ """Samples and creates cropped foreground masks for training.
+
+ Args:
+ candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the
+ number of candidate RoIs to be considered for mask sampling. It includes
+ both positive and negative RoIs. The `num_mask_samples_per_image` positive
+ RoIs will be sampled to create mask training targets.
+ candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the
+ corresponding groundtruth boxes to the `candidate_rois`.
+ candidate_gt_classes: a tensor of shape of [batch_size, N], storing the
+ corresponding groundtruth classes to the `candidate_rois`. 0 in the tensor
+ corresponds to the background class, i.e. negative RoIs.
+ candidate_gt_indices: a tensor of shape [batch_size, N], storing the
+ corresponding groundtruth instance indices to the `candidate_gt_boxes`,
+ i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i] and
+ gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, is the
+ superset of candidate_gt_boxes.
+ gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width]
+ containing all the groundtruth masks which sample masks are drawn from.
+ num_mask_samples_per_image: an integer which specifies the number of masks
+ to sample.
+ mask_target_size: an integer which specifies the final cropped mask size
+ after sampling. The output masks are resized w.r.t the sampled RoIs.
+
+ Returns:
+ foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI
+ that corresponds to the sampled foreground masks, where
+ K = num_mask_samples_per_image.
+ foreground_classes: a tensor of shape of [batch_size, K] storing the classes
+ corresponding to the sampled foreground masks.
+ cropoped_foreground_masks: a tensor of shape of
+ [batch_size, K, mask_target_size, mask_target_size] storing the cropped
+ foreground masks used for training.
+ """
+ with tf.name_scope('sample_and_crop_foreground_masks'):
+ _, fg_instance_indices = tf.nn.top_k(
+ tf.cast(tf.greater(candidate_gt_classes, 0), dtype=tf.int32),
+ k=num_mask_samples_per_image)
+
+ fg_instance_indices_shape = tf.shape(fg_instance_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(fg_instance_indices_shape[0]), axis=-1) *
+ tf.ones([1, fg_instance_indices_shape[-1]], dtype=tf.int32))
+
+ gather_nd_instance_indices = tf.stack(
+ [batch_indices, fg_instance_indices], axis=-1)
+ foreground_rois = tf.gather_nd(
+ candidate_rois, gather_nd_instance_indices)
+ foreground_boxes = tf.gather_nd(
+ candidate_gt_boxes, gather_nd_instance_indices)
+ foreground_classes = tf.gather_nd(
+ candidate_gt_classes, gather_nd_instance_indices)
+ foreground_gt_indices = tf.gather_nd(
+ candidate_gt_indices, gather_nd_instance_indices)
+
+ foreground_gt_indices_shape = tf.shape(foreground_gt_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(foreground_gt_indices_shape[0]), axis=-1) *
+ tf.ones([1, foreground_gt_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_gt_indices = tf.stack(
+ [batch_indices, foreground_gt_indices], axis=-1)
+ foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices)
+
+ cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box(
+ foreground_masks, foreground_boxes, foreground_rois, mask_target_size,
+ sample_offset=0.5)
+
+ return foreground_rois, foreground_classes, cropped_foreground_masks
+
+
+class ROISampler(object):
+ """Samples RoIs and creates training targets."""
+
+ def __init__(self, params):
+ self._num_samples_per_image = params.num_samples_per_image
+ self._fg_fraction = params.fg_fraction
+ self._fg_iou_thresh = params.fg_iou_thresh
+ self._bg_iou_thresh_hi = params.bg_iou_thresh_hi
+ self._bg_iou_thresh_lo = params.bg_iou_thresh_lo
+ self._mix_gt_boxes = params.mix_gt_boxes
+
+ def __call__(self, rois, gt_boxes, gt_classes):
+ """Sample and assign RoIs for training.
+
+ Args:
+ rois: a tensor of shape of [batch_size, N, 4]. N is the number
+ of proposals before groundtruth assignment. The last dimension is the
+ box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
+ format.
+ gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4].
+ The coordinates of gt_boxes are in the pixel coordinates of the scaled
+ image. This tensor might have padding of values -1 indicating the
+ invalid box coordinates.
+ gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
+ tensor might have paddings with values of -1 indicating the invalid
+ classes.
+
+ Returns:
+ sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
+ coordinates of the sampled RoIs, where K is the number of the sampled
+ RoIs, i.e. K = num_samples_per_image.
+ sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
+ box coordinates of the matched groundtruth boxes of the samples RoIs.
+ sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
+ classes of the matched groundtruth boxes of the sampled RoIs.
+ """
+ sampled_rois, sampled_gt_boxes, sampled_gt_classes, sampled_gt_indices = (
+ assign_and_sample_proposals(
+ rois,
+ gt_boxes,
+ gt_classes,
+ num_samples_per_image=self._num_samples_per_image,
+ mix_gt_boxes=self._mix_gt_boxes,
+ fg_fraction=self._fg_fraction,
+ fg_iou_thresh=self._fg_iou_thresh,
+ bg_iou_thresh_hi=self._bg_iou_thresh_hi,
+ bg_iou_thresh_lo=self._bg_iou_thresh_lo))
+ return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
+ sampled_gt_indices)
+
+
+class MaskSampler(object):
+ """Samples and creates mask training targets."""
+
+ def __init__(self, mask_target_size, num_mask_samples_per_image):
+ self._mask_target_size = mask_target_size
+ self._num_mask_samples_per_image = num_mask_samples_per_image
+
+ def __call__(self,
+ candidate_rois,
+ candidate_gt_boxes,
+ candidate_gt_classes,
+ candidate_gt_indices,
+ gt_masks):
+ """Sample and create mask targets for training.
+
+ Args:
+ candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the
+ number of candidate RoIs to be considered for mask sampling. It includes
+ both positive and negative RoIs. The `num_mask_samples_per_image`
+ positive RoIs will be sampled to create mask training targets.
+ candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the
+ corresponding groundtruth boxes to the `candidate_rois`.
+ candidate_gt_classes: a tensor of shape of [batch_size, N], storing the
+ corresponding groundtruth classes to the `candidate_rois`. 0 in the
+ tensor corresponds to the background class, i.e. negative RoIs.
+ candidate_gt_indices: a tensor of shape [batch_size, N], storing the
+ corresponding groundtruth instance indices to the `candidate_gt_boxes`,
+ i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i],
+ where gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N,
+ is the superset of candidate_gt_boxes.
+ gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width]
+ containing all the groundtruth masks which sample masks are drawn from.
+ after sampling. The output masks are resized w.r.t the sampled RoIs.
+
+ Returns:
+ foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI
+ that corresponds to the sampled foreground masks, where
+ K = num_mask_samples_per_image.
+ foreground_classes: a tensor of shape of [batch_size, K] storing the
+ classes corresponding to the sampled foreground masks.
+ cropoped_foreground_masks: a tensor of shape of
+ [batch_size, K, mask_target_size, mask_target_size] storing the
+ cropped foreground masks used for training.
+ """
+ foreground_rois, foreground_classes, cropped_foreground_masks = (
+ sample_and_crop_foreground_masks(
+ candidate_rois,
+ candidate_gt_boxes,
+ candidate_gt_classes,
+ candidate_gt_indices,
+ gt_masks,
+ self._num_mask_samples_per_image,
+ self._mask_target_size))
+ return foreground_rois, foreground_classes, cropped_foreground_masks
diff --git a/models/official/vision/detection/utils/__init__.py b/models/official/vision/detection/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/detection/utils/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/utils/box_utils.py b/models/official/vision/detection/utils/box_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c2ebf5781f44363b090f3e272101d6014f2edd0
--- /dev/null
+++ b/models/official/vision/detection/utils/box_utils.py
@@ -0,0 +1,551 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for bounding box processing."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+EPSILON = 1e-8
+BBOX_XFORM_CLIP = np.log(1000. / 16.)
+
+
+def visualize_images_with_bounding_boxes(images, box_outputs, step,
+ summary_writer):
+ """Records subset of evaluation images with bounding boxes."""
+ image_shape = tf.shape(images[0])
+ image_height = tf.cast(image_shape[0], tf.float32)
+ image_width = tf.cast(image_shape[1], tf.float32)
+ normalized_boxes = normalize_boxes(box_outputs, [image_height, image_width])
+
+ bounding_box_color = tf.constant([[1.0, 1.0, 0.0, 1.0]])
+ image_summary = tf.image.draw_bounding_boxes(images, normalized_boxes,
+ bounding_box_color)
+ with summary_writer.as_default():
+ tf.summary.image('bounding_box_summary', image_summary, step=step)
+ summary_writer.flush()
+
+
+def yxyx_to_xywh(boxes):
+ """Converts boxes from ymin, xmin, ymax, xmax to xmin, ymin, width, height.
+
+ Args:
+ boxes: a numpy array whose last dimension is 4 representing the coordinates
+ of boxes in ymin, xmin, ymax, xmax order.
+
+ Returns:
+ boxes: a numpy array whose shape is the same as `boxes` in new format.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ boxes_ymin = boxes[..., 0]
+ boxes_xmin = boxes[..., 1]
+ boxes_width = boxes[..., 3] - boxes[..., 1]
+ boxes_height = boxes[..., 2] - boxes[..., 0]
+ new_boxes = np.stack([boxes_xmin, boxes_ymin, boxes_width, boxes_height],
+ axis=-1)
+
+ return new_boxes
+
+
+def jitter_boxes(boxes, noise_scale=0.025):
+ """Jitter the box coordinates by some noise distribution.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ noise_scale: a python float which specifies the magnitude of noise. The rule
+ of thumb is to set this between (0, 0.1]. The default value is found to
+ mimic the noisy detections best empirically.
+
+ Returns:
+ jittered_boxes: a tensor whose shape is the same as `boxes` representing
+ the jittered boxes.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('jitter_boxes'):
+ bbox_jitters = tf.random.normal(boxes.get_shape(), stddev=noise_scale)
+ ymin = boxes[..., 0:1]
+ xmin = boxes[..., 1:2]
+ ymax = boxes[..., 2:3]
+ xmax = boxes[..., 3:4]
+ width = xmax - xmin
+ height = ymax - ymin
+ new_center_x = (xmin + xmax) / 2.0 + bbox_jitters[..., 0:1] * width
+ new_center_y = (ymin + ymax) / 2.0 + bbox_jitters[..., 1:2] * height
+ new_width = width * tf.math.exp(bbox_jitters[..., 2:3])
+ new_height = height * tf.math.exp(bbox_jitters[..., 3:4])
+ jittered_boxes = tf.concat([
+ new_center_y - new_height * 0.5, new_center_x - new_width * 0.5,
+ new_center_y + new_height * 0.5, new_center_x + new_width * 0.5
+ ],
+ axis=-1)
+
+ return jittered_boxes
+
+
+def normalize_boxes(boxes, image_shape):
+ """Converts boxes to the normalized coordinates.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates
+ of boxes in ymin, xmin, ymax, xmax order.
+ image_shape: a list of two integers, a two-element vector or a tensor such
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
+ dimension is 2, which represents [height, width].
+
+ Returns:
+ normalized_boxes: a tensor whose shape is the same as `boxes` representing
+ the normalized boxes.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('normalize_boxes'):
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
+ height, width = image_shape
+ else:
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
+ height = image_shape[..., 0:1]
+ width = image_shape[..., 1:2]
+
+ ymin = boxes[..., 0:1] / height
+ xmin = boxes[..., 1:2] / width
+ ymax = boxes[..., 2:3] / height
+ xmax = boxes[..., 3:4] / width
+
+ normalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
+ return normalized_boxes
+
+
+def denormalize_boxes(boxes, image_shape):
+ """Converts boxes normalized by [height, width] to pixel coordinates.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates
+ of boxes in ymin, xmin, ymax, xmax order.
+ image_shape: a list of two integers, a two-element vector or a tensor such
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
+ dimension is 2, which represents [height, width].
+
+ Returns:
+ denormalized_boxes: a tensor whose shape is the same as `boxes` representing
+ the denormalized boxes.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ with tf.name_scope('denormalize_boxes'):
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
+ height, width = image_shape
+ else:
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
+ height, width = tf.split(image_shape, 2, axis=-1)
+
+ ymin, xmin, ymax, xmax = tf.split(boxes, 4, axis=-1)
+ ymin = ymin * height
+ xmin = xmin * width
+ ymax = ymax * height
+ xmax = xmax * width
+
+ denormalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
+ return denormalized_boxes
+
+
+def clip_boxes(boxes, image_shape):
+ """Clips boxes to image boundaries.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates
+ of boxes in ymin, xmin, ymax, xmax order.
+ image_shape: a list of two integers, a two-element vector or a tensor such
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
+ dimension is 2, which represents [height, width].
+
+ Returns:
+ clipped_boxes: a tensor whose shape is the same as `boxes` representing the
+ clipped boxes.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('clip_boxes'):
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
+ height, width = image_shape
+ max_length = [height - 1.0, width - 1.0, height - 1.0, width - 1.0]
+ else:
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
+ height, width = tf.unstack(image_shape, axis=-1)
+ max_length = tf.stack(
+ [height - 1.0, width - 1.0, height - 1.0, width - 1.0], axis=-1)
+
+ clipped_boxes = tf.math.maximum(tf.math.minimum(boxes, max_length), 0.0)
+ return clipped_boxes
+
+
+def compute_outer_boxes(boxes, image_shape, scale=1.0):
+ """Compute outer box encloses an object with a margin.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ image_shape: a list of two integers, a two-element vector or a tensor such
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
+ dimension is 2, which represents [height, width].
+ scale: a float number specifying the scale of output outer boxes to input
+ `boxes`.
+
+ Returns:
+ outer_boxes: a tensor whose shape is the same as `boxes` representing the
+ outer boxes.
+ """
+ if scale < 1.0:
+ raise ValueError(
+ 'scale is {}, but outer box scale must be greater than 1.0.'.format(
+ scale))
+ centers_y = (boxes[..., 0] + boxes[..., 2]) / 2.0
+ centers_x = (boxes[..., 1] + boxes[..., 3]) / 2.0
+ box_height = (boxes[..., 2] - boxes[..., 0]) * scale
+ box_width = (boxes[..., 3] - boxes[..., 1]) * scale
+ outer_boxes = tf.stack([
+ centers_y - box_height / 2.0, centers_x - box_width / 2.0,
+ centers_y + box_height / 2.0, centers_x + box_width / 2.0
+ ],
+ axis=1)
+ outer_boxes = clip_boxes(outer_boxes, image_shape)
+ return outer_boxes
+
+
+def encode_boxes(boxes, anchors, weights=None):
+ """Encode boxes to targets.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates
+ of boxes in ymin, xmin, ymax, xmax order.
+ anchors: a tensor whose shape is the same as, or `broadcastable` to `boxes`,
+ representing the coordinates of anchors in ymin, xmin, ymax, xmax order.
+ weights: None or a list of four float numbers used to scale coordinates.
+
+ Returns:
+ encoded_boxes: a tensor whose shape is the same as `boxes` representing the
+ encoded box targets.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('encode_boxes'):
+ boxes = tf.cast(boxes, dtype=anchors.dtype)
+ ymin = boxes[..., 0:1]
+ xmin = boxes[..., 1:2]
+ ymax = boxes[..., 2:3]
+ xmax = boxes[..., 3:4]
+ box_h = ymax - ymin + 1.0
+ box_w = xmax - xmin + 1.0
+ box_yc = ymin + 0.5 * box_h
+ box_xc = xmin + 0.5 * box_w
+
+ anchor_ymin = anchors[..., 0:1]
+ anchor_xmin = anchors[..., 1:2]
+ anchor_ymax = anchors[..., 2:3]
+ anchor_xmax = anchors[..., 3:4]
+ anchor_h = anchor_ymax - anchor_ymin + 1.0
+ anchor_w = anchor_xmax - anchor_xmin + 1.0
+ anchor_yc = anchor_ymin + 0.5 * anchor_h
+ anchor_xc = anchor_xmin + 0.5 * anchor_w
+
+ encoded_dy = (box_yc - anchor_yc) / anchor_h
+ encoded_dx = (box_xc - anchor_xc) / anchor_w
+ encoded_dh = tf.math.log(box_h / anchor_h)
+ encoded_dw = tf.math.log(box_w / anchor_w)
+ if weights:
+ encoded_dy *= weights[0]
+ encoded_dx *= weights[1]
+ encoded_dh *= weights[2]
+ encoded_dw *= weights[3]
+
+ encoded_boxes = tf.concat(
+ [encoded_dy, encoded_dx, encoded_dh, encoded_dw],
+ axis=-1)
+ return encoded_boxes
+
+
+def decode_boxes(encoded_boxes, anchors, weights=None):
+ """Decode boxes.
+
+ Args:
+ encoded_boxes: a tensor whose last dimension is 4 representing the
+ coordinates of encoded boxes in ymin, xmin, ymax, xmax order.
+ anchors: a tensor whose shape is the same as, or `broadcastable` to `boxes`,
+ representing the coordinates of anchors in ymin, xmin, ymax, xmax order.
+ weights: None or a list of four float numbers used to scale coordinates.
+
+ Returns:
+ encoded_boxes: a tensor whose shape is the same as `boxes` representing the
+ decoded box targets.
+ """
+ if encoded_boxes.shape[-1] != 4:
+ raise ValueError('encoded_boxes.shape[-1] is {:d}, but must be 4.'.format(
+ encoded_boxes.shape[-1]))
+
+ with tf.name_scope('decode_boxes'):
+ encoded_boxes = tf.cast(encoded_boxes, dtype=anchors.dtype)
+ dy = encoded_boxes[..., 0:1]
+ dx = encoded_boxes[..., 1:2]
+ dh = encoded_boxes[..., 2:3]
+ dw = encoded_boxes[..., 3:4]
+ if weights:
+ dy /= weights[0]
+ dx /= weights[1]
+ dh /= weights[2]
+ dw /= weights[3]
+ dh = tf.math.minimum(dh, BBOX_XFORM_CLIP)
+ dw = tf.math.minimum(dw, BBOX_XFORM_CLIP)
+
+ anchor_ymin = anchors[..., 0:1]
+ anchor_xmin = anchors[..., 1:2]
+ anchor_ymax = anchors[..., 2:3]
+ anchor_xmax = anchors[..., 3:4]
+ anchor_h = anchor_ymax - anchor_ymin + 1.0
+ anchor_w = anchor_xmax - anchor_xmin + 1.0
+ anchor_yc = anchor_ymin + 0.5 * anchor_h
+ anchor_xc = anchor_xmin + 0.5 * anchor_w
+
+ decoded_boxes_yc = dy * anchor_h + anchor_yc
+ decoded_boxes_xc = dx * anchor_w + anchor_xc
+ decoded_boxes_h = tf.math.exp(dh) * anchor_h
+ decoded_boxes_w = tf.math.exp(dw) * anchor_w
+
+ decoded_boxes_ymin = decoded_boxes_yc - 0.5 * decoded_boxes_h
+ decoded_boxes_xmin = decoded_boxes_xc - 0.5 * decoded_boxes_w
+ decoded_boxes_ymax = decoded_boxes_ymin + decoded_boxes_h - 1.0
+ decoded_boxes_xmax = decoded_boxes_xmin + decoded_boxes_w - 1.0
+
+ decoded_boxes = tf.concat(
+ [decoded_boxes_ymin, decoded_boxes_xmin,
+ decoded_boxes_ymax, decoded_boxes_xmax],
+ axis=-1)
+ return decoded_boxes
+
+
+def filter_boxes(boxes, scores, image_shape, min_size_threshold):
+ """Filter and remove boxes that are too small or fall outside the image.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ scores: a tensor whose shape is the same as tf.shape(boxes)[:-1]
+ representing the original scores of the boxes.
+ image_shape: a tensor whose shape is the same as, or `broadcastable` to
+ `boxes` except the last dimension, which is 2, representing [height,
+ width] of the scaled image.
+ min_size_threshold: a float representing the minimal box size in each side
+ (w.r.t. the scaled image). Boxes whose sides are smaller than it will be
+ filtered out.
+
+ Returns:
+ filtered_boxes: a tensor whose shape is the same as `boxes` but with
+ the position of the filtered boxes are filled with 0.
+ filtered_scores: a tensor whose shape is the same as 'scores' but with
+ the positinon of the filtered boxes filled with 0.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('filter_boxes'):
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
+ height, width = image_shape
+ else:
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
+ height = image_shape[..., 0]
+ width = image_shape[..., 1]
+
+ ymin = boxes[..., 0]
+ xmin = boxes[..., 1]
+ ymax = boxes[..., 2]
+ xmax = boxes[..., 3]
+
+ h = ymax - ymin + 1.0
+ w = xmax - xmin + 1.0
+ yc = ymin + 0.5 * h
+ xc = xmin + 0.5 * w
+
+ min_size = tf.cast(
+ tf.math.maximum(min_size_threshold, 1.0), dtype=boxes.dtype)
+
+ filtered_size_mask = tf.math.logical_and(
+ tf.math.greater(h, min_size), tf.math.greater(w, min_size))
+ filtered_center_mask = tf.logical_and(
+ tf.math.logical_and(tf.math.greater(yc, 0.0), tf.math.less(yc, height)),
+ tf.math.logical_and(tf.math.greater(xc, 0.0), tf.math.less(xc, width)))
+ filtered_mask = tf.math.logical_and(filtered_size_mask,
+ filtered_center_mask)
+
+ filtered_scores = tf.where(filtered_mask, scores, tf.zeros_like(scores))
+ filtered_boxes = tf.cast(
+ tf.expand_dims(filtered_mask, axis=-1), dtype=boxes.dtype) * boxes
+
+ return filtered_boxes, filtered_scores
+
+
+def filter_boxes_by_scores(boxes, scores, min_score_threshold):
+ """Filter and remove boxes whose scores are smaller than the threshold.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ scores: a tensor whose shape is the same as tf.shape(boxes)[:-1]
+ representing the original scores of the boxes.
+ min_score_threshold: a float representing the minimal box score threshold.
+ Boxes whose score are smaller than it will be filtered out.
+
+ Returns:
+ filtered_boxes: a tensor whose shape is the same as `boxes` but with
+ the position of the filtered boxes are filled with -1.
+ filtered_scores: a tensor whose shape is the same as 'scores' but with
+ the
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('filter_boxes_by_scores'):
+ filtered_mask = tf.math.greater(scores, min_score_threshold)
+ filtered_scores = tf.where(filtered_mask, scores, -tf.ones_like(scores))
+ filtered_boxes = tf.cast(
+ tf.expand_dims(filtered_mask, axis=-1), dtype=boxes.dtype) * boxes
+
+ return filtered_boxes, filtered_scores
+
+
+def top_k_boxes(boxes, scores, k):
+ """Sort and select top k boxes according to the scores.
+
+ Args:
+ boxes: a tensor of shape [batch_size, N, 4] representing the coordiante of
+ the boxes. N is the number of boxes per image.
+ scores: a tensor of shsape [batch_size, N] representing the socre of the
+ boxes.
+ k: an integer or a tensor indicating the top k number.
+
+ Returns:
+ selected_boxes: a tensor of shape [batch_size, k, 4] representing the
+ selected top k box coordinates.
+ selected_scores: a tensor of shape [batch_size, k] representing the selected
+ top k box scores.
+ """
+ with tf.name_scope('top_k_boxes'):
+ selected_scores, top_k_indices = tf.nn.top_k(scores, k=k, sorted=True)
+
+ batch_size, _ = scores.get_shape().as_list()
+ if batch_size == 1:
+ selected_boxes = tf.squeeze(
+ tf.gather(boxes, top_k_indices, axis=1), axis=1)
+ else:
+ top_k_indices_shape = tf.shape(top_k_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(top_k_indices_shape[0]), axis=-1) *
+ tf.ones([1, top_k_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_indices = tf.stack([batch_indices, top_k_indices], axis=-1)
+ selected_boxes = tf.gather_nd(boxes, gather_nd_indices)
+
+ return selected_boxes, selected_scores
+
+
+def bbox_overlap(boxes, gt_boxes):
+ """Calculates the overlap between proposal and ground truth boxes.
+
+ Some `gt_boxes` may have been padded. The returned `iou` tensor for these
+ boxes will be -1.
+
+ Args:
+ boxes: a tensor with a shape of [batch_size, N, 4]. N is the number of
+ proposals before groundtruth assignment (e.g., rpn_post_nms_topn). The
+ last dimension is the pixel coordinates in [ymin, xmin, ymax, xmax] form.
+ gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4]. This
+ tensor might have paddings with a negative value.
+
+ Returns:
+ iou: a tensor with as a shape of [batch_size, N, MAX_NUM_INSTANCES].
+ """
+ with tf.name_scope('bbox_overlap'):
+ bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(
+ value=boxes, num_or_size_splits=4, axis=2)
+ gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(
+ value=gt_boxes, num_or_size_splits=4, axis=2)
+
+ # Calculates the intersection area.
+ i_xmin = tf.math.maximum(bb_x_min, tf.transpose(gt_x_min, [0, 2, 1]))
+ i_xmax = tf.math.minimum(bb_x_max, tf.transpose(gt_x_max, [0, 2, 1]))
+ i_ymin = tf.math.maximum(bb_y_min, tf.transpose(gt_y_min, [0, 2, 1]))
+ i_ymax = tf.math.minimum(bb_y_max, tf.transpose(gt_y_max, [0, 2, 1]))
+ i_area = tf.math.maximum((i_xmax - i_xmin), 0) * tf.math.maximum(
+ (i_ymax - i_ymin), 0)
+
+ # Calculates the union area.
+ bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
+ gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
+ # Adds a small epsilon to avoid divide-by-zero.
+ u_area = bb_area + tf.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8
+
+ # Calculates IoU.
+ iou = i_area / u_area
+
+ # Fills -1 for IoU entries between the padded ground truth boxes.
+ gt_invalid_mask = tf.less(
+ tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
+ padding_mask = tf.logical_or(
+ tf.zeros_like(bb_x_min, dtype=tf.bool),
+ tf.transpose(gt_invalid_mask, [0, 2, 1]))
+ iou = tf.where(padding_mask, -tf.ones_like(iou), iou)
+
+ return iou
+
+
+def get_non_empty_box_indices(boxes):
+ """Get indices for non-empty boxes."""
+ # Selects indices if box height or width is 0.
+ height = boxes[:, 2] - boxes[:, 0]
+ width = boxes[:, 3] - boxes[:, 1]
+ indices = tf.where(tf.logical_and(tf.greater(height, 0),
+ tf.greater(width, 0)))
+ return indices[:, 0]
diff --git a/models/official/vision/detection/utils/class_utils.py b/models/official/vision/detection/utils/class_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cce9cf982bbbce7b90ee44e67ebe65997b7a91da
--- /dev/null
+++ b/models/official/vision/detection/utils/class_utils.py
@@ -0,0 +1,44 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for handling dataset object categories."""
+
+
+def coco_split_class_ids(split_name):
+ """Return the COCO class split ids based on split name and training mode.
+
+ Args:
+ split_name: The name of dataset split.
+
+ Returns:
+ class_ids: a python list of integer.
+ """
+ if split_name == 'all':
+ return []
+
+ elif split_name == 'voc':
+ return [
+ 1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72
+ ]
+
+ elif split_name == 'nonvoc':
+ return [
+ 8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36,
+ 37, 38, 39, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56,
+ 57, 58, 59, 60, 61, 65, 70, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84,
+ 85, 86, 87, 88, 89, 90
+ ]
+
+ else:
+ raise ValueError('Invalid split name {}!!!'.format(split_name))
diff --git a/models/official/vision/detection/utils/dataloader_utils.py b/models/official/vision/detection/utils/dataloader_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..da82203511da50393a352bf75ee56f25c6626c05
--- /dev/null
+++ b/models/official/vision/detection/utils/dataloader_utils.py
@@ -0,0 +1,40 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for dataloader."""
+
+import tensorflow as tf
+
+from official.vision.detection.utils import input_utils
+
+
+def process_source_id(source_id):
+ """Processes source_id to the right format."""
+ if source_id.dtype == tf.string:
+ source_id = tf.cast(tf.strings.to_number(source_id), tf.int64)
+ with tf.control_dependencies([source_id]):
+ source_id = tf.cond(
+ pred=tf.equal(tf.size(input=source_id), 0),
+ true_fn=lambda: tf.cast(tf.constant(-1), tf.int64),
+ false_fn=lambda: tf.identity(source_id))
+ return source_id
+
+
+def pad_groundtruths_to_fixed_size(gt, n):
+ """Pads the first dimension of groundtruths labels to the fixed size."""
+ gt['boxes'] = input_utils.pad_to_fixed_size(gt['boxes'], n, -1)
+ gt['is_crowds'] = input_utils.pad_to_fixed_size(gt['is_crowds'], n, 0)
+ gt['areas'] = input_utils.pad_to_fixed_size(gt['areas'], n, -1)
+ gt['classes'] = input_utils.pad_to_fixed_size(gt['classes'], n, -1)
+ return gt
diff --git a/models/official/vision/detection/utils/input_utils.py b/models/official/vision/detection/utils/input_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6010dc8973f387318c4553d3014ccf495cf01fc6
--- /dev/null
+++ b/models/official/vision/detection/utils/input_utils.py
@@ -0,0 +1,366 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for input processing."""
+
+import math
+import tensorflow as tf
+
+from official.vision.detection.utils import box_utils
+from official.vision.detection.utils.object_detection import preprocessor
+
+
+def pad_to_fixed_size(input_tensor, size, constant_values=0):
+ """Pads data to a fixed length at the first dimension.
+
+ Args:
+ input_tensor: `Tensor` with any dimension.
+ size: `int` number for the first dimension of output Tensor.
+ constant_values: `int` value assigned to the paddings.
+
+ Returns:
+ `Tensor` with the first dimension padded to `size`.
+ """
+ input_shape = input_tensor.get_shape().as_list()
+ padding_shape = []
+
+ # Computes the padding length on the first dimension.
+ padding_length = tf.maximum(0, size - tf.shape(input_tensor)[0])
+ assert_length = tf.Assert(
+ tf.greater_equal(padding_length, 0), [padding_length])
+ with tf.control_dependencies([assert_length]):
+ padding_shape.append(padding_length)
+
+ # Copies shapes of the rest of input shape dimensions.
+ for i in range(1, len(input_shape)):
+ padding_shape.append(tf.shape(input=input_tensor)[i])
+
+ # Pads input tensor to the fixed first dimension.
+ paddings = tf.cast(constant_values * tf.ones(padding_shape),
+ input_tensor.dtype)
+ padded_tensor = tf.concat([input_tensor, paddings], axis=0)
+ output_shape = input_shape
+ output_shape[0] = size
+ padded_tensor.set_shape(output_shape)
+ return padded_tensor
+
+
+def normalize_image(image,
+ offset=(0.485, 0.456, 0.406),
+ scale=(0.229, 0.224, 0.225)):
+ """Normalizes the image to zero mean and unit variance."""
+ image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+ offset = tf.constant(offset)
+ offset = tf.expand_dims(offset, axis=0)
+ offset = tf.expand_dims(offset, axis=0)
+ image -= offset
+
+ scale = tf.constant(scale)
+ scale = tf.expand_dims(scale, axis=0)
+ scale = tf.expand_dims(scale, axis=0)
+ image /= scale
+ return image
+
+
+def compute_padded_size(desired_size, stride):
+ """Compute the padded size given the desired size and the stride.
+
+ The padded size will be the smallest rectangle, such that each dimension is
+ the smallest multiple of the stride which is larger than the desired
+ dimension. For example, if desired_size = (100, 200) and stride = 32,
+ the output padded_size = (128, 224).
+
+ Args:
+ desired_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the target output image size.
+ stride: an integer, the stride of the backbone network.
+
+ Returns:
+ padded_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the padded output image size.
+ """
+ if isinstance(desired_size, list) or isinstance(desired_size, tuple):
+ padded_size = [int(math.ceil(d * 1.0 / stride) * stride)
+ for d in desired_size]
+ else:
+ padded_size = tf.cast(
+ tf.math.ceil(
+ tf.cast(desired_size, dtype=tf.float32) / stride) * stride,
+ tf.int32)
+ return padded_size
+
+
+def resize_and_crop_image(image,
+ desired_size,
+ padded_size,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ seed=1,
+ method=tf.image.ResizeMethod.BILINEAR):
+ """Resizes the input image to output size.
+
+ Resize and pad images given the desired output size of the image and
+ stride size.
+
+ Here are the preprocessing steps.
+ 1. For a given image, keep its aspect ratio and rescale the image to make it
+ the largest rectangle to be bounded by the rectangle specified by the
+ `desired_size`.
+ 2. Pad the rescaled image to the padded_size.
+
+ Args:
+ image: a `Tensor` of shape [height, width, 3] representing an image.
+ desired_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the desired actual output image size.
+ padded_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the padded output image size. Padding will be applied
+ after scaling the image to the desired_size.
+ aug_scale_min: a `float` with range between [0, 1.0] representing minimum
+ random scale applied to desired_size for training scale jittering.
+ aug_scale_max: a `float` with range between [1.0, inf] representing maximum
+ random scale applied to desired_size for training scale jittering.
+ seed: seed for random scale jittering.
+ method: function to resize input image to scaled image.
+
+ Returns:
+ output_image: `Tensor` of shape [height, width, 3] where [height, width]
+ equals to `output_size`.
+ image_info: a 2D `Tensor` that encodes the information of the image and the
+ applied preprocessing. It is in the format of
+ [[original_height, original_width], [desired_height, desired_width],
+ [y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
+ desireed_width] is the actual scaled image size, and [y_scale, x_scale] is
+ the scaling factory, which is the ratio of
+ scaled dimension / original dimension.
+ """
+ with tf.name_scope('resize_and_crop_image'):
+ image_size = tf.cast(tf.shape(input=image)[0:2], tf.float32)
+
+ random_jittering = (aug_scale_min != 1.0 or aug_scale_max != 1.0)
+
+ if random_jittering:
+ random_scale = tf.random.uniform([],
+ aug_scale_min,
+ aug_scale_max,
+ seed=seed)
+ scaled_size = tf.round(random_scale * desired_size)
+ else:
+ scaled_size = desired_size
+
+ scale = tf.minimum(
+ scaled_size[0] / image_size[0], scaled_size[1] / image_size[1])
+ scaled_size = tf.round(image_size * scale)
+
+ # Computes 2D image_scale.
+ image_scale = scaled_size / image_size
+
+ # Selects non-zero random offset (x, y) if scaled image is larger than
+ # desired_size.
+ if random_jittering:
+ max_offset = scaled_size - desired_size
+ max_offset = tf.where(tf.less(max_offset, 0),
+ tf.zeros_like(max_offset),
+ max_offset)
+ offset = max_offset * tf.random.uniform([
+ 2,
+ ], 0, 1, seed=seed)
+ offset = tf.cast(offset, tf.int32)
+ else:
+ offset = tf.zeros((2,), tf.int32)
+
+ scaled_image = tf.image.resize(
+ image, tf.cast(scaled_size, tf.int32), method=method)
+
+ if random_jittering:
+ scaled_image = scaled_image[offset[0]:offset[0] + desired_size[0],
+ offset[1]:offset[1] + desired_size[1], :]
+
+ output_image = tf.image.pad_to_bounding_box(scaled_image, 0, 0,
+ padded_size[0], padded_size[1])
+
+ image_info = tf.stack([
+ image_size,
+ tf.cast(desired_size, dtype=tf.float32),
+ image_scale,
+ tf.cast(offset, tf.float32)])
+ return output_image, image_info
+
+
+def resize_and_crop_image_v2(image,
+ short_side,
+ long_side,
+ padded_size,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ seed=1,
+ method=tf.image.ResizeMethod.BILINEAR):
+ """Resizes the input image to output size (Faster R-CNN style).
+
+ Resize and pad images given the specified short / long side length and the
+ stride size.
+
+ Here are the preprocessing steps.
+ 1. For a given image, keep its aspect ratio and first try to rescale the short
+ side of the original image to `short_side`.
+ 2. If the scaled image after 1 has a long side that exceeds `long_side`, keep
+ the aspect ratio and rescal the long side of the image to `long_side`.
+ 2. Pad the rescaled image to the padded_size.
+
+ Args:
+ image: a `Tensor` of shape [height, width, 3] representing an image.
+ short_side: a scalar `Tensor` or `int` representing the desired short side
+ to be rescaled to.
+ long_side: a scalar `Tensor` or `int` representing the desired long side to
+ be rescaled to.
+ padded_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the padded output image size. Padding will be applied
+ after scaling the image to the desired_size.
+ aug_scale_min: a `float` with range between [0, 1.0] representing minimum
+ random scale applied to desired_size for training scale jittering.
+ aug_scale_max: a `float` with range between [1.0, inf] representing maximum
+ random scale applied to desired_size for training scale jittering.
+ seed: seed for random scale jittering.
+ method: function to resize input image to scaled image.
+
+ Returns:
+ output_image: `Tensor` of shape [height, width, 3] where [height, width]
+ equals to `output_size`.
+ image_info: a 2D `Tensor` that encodes the information of the image and the
+ applied preprocessing. It is in the format of
+ [[original_height, original_width], [desired_height, desired_width],
+ [y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
+ desired_width] is the actual scaled image size, and [y_scale, x_scale] is
+ the scaling factor, which is the ratio of
+ scaled dimension / original dimension.
+ """
+ with tf.name_scope('resize_and_crop_image_v2'):
+ image_size = tf.cast(tf.shape(image)[0:2], tf.float32)
+
+ scale_using_short_side = (
+ short_side / tf.math.minimum(image_size[0], image_size[1]))
+ scale_using_long_side = (
+ long_side / tf.math.maximum(image_size[0], image_size[1]))
+
+ scaled_size = tf.math.round(image_size * scale_using_short_side)
+ scaled_size = tf.where(
+ tf.math.greater(
+ tf.math.maximum(scaled_size[0], scaled_size[1]), long_side),
+ tf.math.round(image_size * scale_using_long_side), scaled_size)
+ desired_size = scaled_size
+
+ random_jittering = (aug_scale_min != 1.0 or aug_scale_max != 1.0)
+
+ if random_jittering:
+ random_scale = tf.random.uniform([],
+ aug_scale_min,
+ aug_scale_max,
+ seed=seed)
+ scaled_size = tf.math.round(random_scale * scaled_size)
+
+ # Computes 2D image_scale.
+ image_scale = scaled_size / image_size
+
+ # Selects non-zero random offset (x, y) if scaled image is larger than
+ # desired_size.
+ if random_jittering:
+ max_offset = scaled_size - desired_size
+ max_offset = tf.where(
+ tf.math.less(max_offset, 0), tf.zeros_like(max_offset), max_offset)
+ offset = max_offset * tf.random.uniform([
+ 2,
+ ], 0, 1, seed=seed)
+ offset = tf.cast(offset, tf.int32)
+ else:
+ offset = tf.zeros((2,), tf.int32)
+
+ scaled_image = tf.image.resize(
+ image, tf.cast(scaled_size, tf.int32), method=method)
+
+ if random_jittering:
+ scaled_image = scaled_image[
+ offset[0]:offset[0] + desired_size[0],
+ offset[1]:offset[1] + desired_size[1], :]
+
+ output_image = tf.image.pad_to_bounding_box(
+ scaled_image, 0, 0, padded_size[0], padded_size[1])
+
+ image_info = tf.stack([
+ image_size,
+ tf.cast(desired_size, dtype=tf.float32),
+ image_scale,
+ tf.cast(offset, tf.float32)])
+ return output_image, image_info
+
+
+def resize_and_crop_boxes(boxes,
+ image_scale,
+ output_size,
+ offset):
+ """Resizes boxes to output size with scale and offset.
+
+ Args:
+ boxes: `Tensor` of shape [N, 4] representing ground truth boxes.
+ image_scale: 2D float `Tensor` representing scale factors that apply to
+ [height, width] of input image.
+ output_size: 2D `Tensor` or `int` representing [height, width] of target
+ output image size.
+ offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
+ boxes.
+
+ Returns:
+ boxes: `Tensor` of shape [N, 4] representing the scaled boxes.
+ """
+ # Adjusts box coordinates based on image_scale and offset.
+ boxes *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
+ boxes -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
+ # Clips the boxes.
+ boxes = box_utils.clip_boxes(boxes, output_size)
+ return boxes
+
+
+def resize_and_crop_masks(masks,
+ image_scale,
+ output_size,
+ offset):
+ """Resizes boxes to output size with scale and offset.
+
+ Args:
+ masks: `Tensor` of shape [N, H, W, 1] representing ground truth masks.
+ image_scale: 2D float `Tensor` representing scale factors that apply to
+ [height, width] of input image.
+ output_size: 2D `Tensor` or `int` representing [height, width] of target
+ output image size.
+ offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
+ boxes.
+
+ Returns:
+ masks: `Tensor` of shape [N, H, W, 1] representing the scaled masks.
+ """
+ mask_size = tf.shape(input=masks)[1:3]
+ scaled_size = tf.cast(image_scale * tf.cast(mask_size, image_scale.dtype),
+ tf.int32)
+ scaled_masks = tf.image.resize(
+ masks, scaled_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
+ offset = tf.cast(offset, tf.int32)
+ scaled_masks = scaled_masks[:, offset[0]:offset[0] + output_size[0],
+ offset[1]:offset[1] + output_size[1], :]
+
+ output_masks = tf.image.pad_to_bounding_box(scaled_masks, 0, 0,
+ output_size[0], output_size[1])
+ return output_masks
+
+
+def random_horizontal_flip(image, boxes=None, masks=None):
+ """Randomly flips input image and bounding boxes."""
+ return preprocessor.random_horizontal_flip(image, boxes, masks)
diff --git a/models/official/vision/detection/utils/mask_utils.py b/models/official/vision/detection/utils/mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..637d0484f4b48213c4b323be6e0c88f9fa19ebcc
--- /dev/null
+++ b/models/official/vision/detection/utils/mask_utils.py
@@ -0,0 +1,192 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for segmentations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+import cv2
+
+
+def paste_instance_masks(masks,
+ detected_boxes,
+ image_height,
+ image_width):
+ """Paste instance masks to generate the image segmentation results.
+
+ Args:
+ masks: a numpy array of shape [N, mask_height, mask_width] representing the
+ instance masks w.r.t. the `detected_boxes`.
+ detected_boxes: a numpy array of shape [N, 4] representing the reference
+ bounding boxes.
+ image_height: an integer representing the height of the image.
+ image_width: an integer representing the width of the image.
+
+ Returns:
+ segms: a numpy array of shape [N, image_height, image_width] representing
+ the instance masks *pasted* on the image canvas.
+ """
+
+ def expand_boxes(boxes, scale):
+ """Expands an array of boxes by a given scale."""
+ # Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/boxes.py#L227 # pylint: disable=line-too-long
+ # The `boxes` in the reference implementation is in [x1, y1, x2, y2] form,
+ # whereas `boxes` here is in [x1, y1, w, h] form
+ w_half = boxes[:, 2] * .5
+ h_half = boxes[:, 3] * .5
+ x_c = boxes[:, 0] + w_half
+ y_c = boxes[:, 1] + h_half
+
+ w_half *= scale
+ h_half *= scale
+
+ boxes_exp = np.zeros(boxes.shape)
+ boxes_exp[:, 0] = x_c - w_half
+ boxes_exp[:, 2] = x_c + w_half
+ boxes_exp[:, 1] = y_c - h_half
+ boxes_exp[:, 3] = y_c + h_half
+
+ return boxes_exp
+
+ # Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/test.py#L812 # pylint: disable=line-too-long
+ # To work around an issue with cv2.resize (it seems to automatically pad
+ # with repeated border values), we manually zero-pad the masks by 1 pixel
+ # prior to resizing back to the original image resolution. This prevents
+ # "top hat" artifacts. We therefore need to expand the reference boxes by an
+ # appropriate factor.
+ _, mask_height, mask_width = masks.shape
+ scale = max((mask_width + 2.0) / mask_width,
+ (mask_height + 2.0) / mask_height)
+
+ ref_boxes = expand_boxes(detected_boxes, scale)
+ ref_boxes = ref_boxes.astype(np.int32)
+ padded_mask = np.zeros((mask_height + 2, mask_width + 2), dtype=np.float32)
+ segms = []
+ for mask_ind, mask in enumerate(masks):
+ im_mask = np.zeros((image_height, image_width), dtype=np.uint8)
+ # Process mask inside bounding boxes.
+ padded_mask[1:-1, 1:-1] = mask[:, :]
+
+ ref_box = ref_boxes[mask_ind, :]
+ w = ref_box[2] - ref_box[0] + 1
+ h = ref_box[3] - ref_box[1] + 1
+ w = np.maximum(w, 1)
+ h = np.maximum(h, 1)
+
+ mask = cv2.resize(padded_mask, (w, h))
+ mask = np.array(mask > 0.5, dtype=np.uint8)
+
+ x_0 = min(max(ref_box[0], 0), image_width)
+ x_1 = min(max(ref_box[2] + 1, 0), image_width)
+ y_0 = min(max(ref_box[1], 0), image_height)
+ y_1 = min(max(ref_box[3] + 1, 0), image_height)
+
+ im_mask[y_0:y_1, x_0:x_1] = mask[
+ (y_0 - ref_box[1]):(y_1 - ref_box[1]),
+ (x_0 - ref_box[0]):(x_1 - ref_box[0])
+ ]
+ segms.append(im_mask)
+
+ segms = np.array(segms)
+ assert masks.shape[0] == segms.shape[0]
+ return segms
+
+
+def paste_instance_masks_v2(masks,
+ detected_boxes,
+ image_height,
+ image_width):
+ """Paste instance masks to generate the image segmentation (v2).
+
+ Args:
+ masks: a numpy array of shape [N, mask_height, mask_width] representing the
+ instance masks w.r.t. the `detected_boxes`.
+ detected_boxes: a numpy array of shape [N, 4] representing the reference
+ bounding boxes.
+ image_height: an integer representing the height of the image.
+ image_width: an integer representing the width of the image.
+
+ Returns:
+ segms: a numpy array of shape [N, image_height, image_width] representing
+ the instance masks *pasted* on the image canvas.
+ """
+ _, mask_height, mask_width = masks.shape
+
+ segms = []
+ for i, mask in enumerate(masks):
+ box = detected_boxes[i, :]
+ xmin = box[0]
+ ymin = box[1]
+ xmax = xmin + box[2]
+ ymax = ymin + box[3]
+
+ # Sample points of the cropped mask w.r.t. the image grid.
+ # Note that these coordinates may fall beyond the image.
+ # Pixel clipping will happen after warping.
+ xmin_int = int(math.floor(xmin))
+ xmax_int = int(math.ceil(xmax))
+ ymin_int = int(math.floor(ymin))
+ ymax_int = int(math.ceil(ymax))
+
+ alpha = box[2] / (1.0 * mask_width)
+ beta = box[3] / (1.0 * mask_height)
+ # pylint: disable=invalid-name
+ # Transformation from mask pixel indices to image coordinate.
+ M_mask_to_image = np.array(
+ [[alpha, 0, xmin],
+ [0, beta, ymin],
+ [0, 0, 1]],
+ dtype=np.float32)
+ # Transformation from image to cropped mask coordinate.
+ M_image_to_crop = np.array(
+ [[1, 0, -xmin_int],
+ [0, 1, -ymin_int],
+ [0, 0, 1]],
+ dtype=np.float32)
+ M = np.dot(M_image_to_crop, M_mask_to_image)
+ # Compensate the half pixel offset that OpenCV has in the
+ # warpPerspective implementation: the top-left pixel is sampled
+ # at (0,0), but we want it to be at (0.5, 0.5).
+ M = np.dot(
+ np.dot(
+ np.array([[1, 0, -0.5],
+ [0, 1, -0.5],
+ [0, 0, 1]], np.float32),
+ M),
+ np.array([[1, 0, 0.5],
+ [0, 1, 0.5],
+ [0, 0, 1]], np.float32))
+ # pylint: enable=invalid-name
+ cropped_mask = cv2.warpPerspective(
+ mask.astype(np.float32), M,
+ (xmax_int - xmin_int, ymax_int - ymin_int))
+ cropped_mask = np.array(cropped_mask > 0.5, dtype=np.uint8)
+
+ img_mask = np.zeros((image_height, image_width))
+ x0 = max(min(xmin_int, image_width), 0)
+ x1 = max(min(xmax_int, image_width), 0)
+ y0 = max(min(ymin_int, image_height), 0)
+ y1 = max(min(ymax_int, image_height), 0)
+ img_mask[y0:y1, x0:x1] = cropped_mask[
+ (y0 - ymin_int):(y1 - ymin_int),
+ (x0 - xmin_int):(x1 - xmin_int)]
+
+ segms.append(img_mask)
+
+ segms = np.array(segms)
+ return segms
+
diff --git a/models/official/vision/detection/utils/object_detection/__init__.py b/models/official/vision/detection/utils/object_detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85c94f4b6bd7567796755895505a320405a40777
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/detection/utils/object_detection/argmax_matcher.py b/models/official/vision/detection/utils/object_detection/argmax_matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f8b051bfb08a72846482c0da9c79d1b98418c38
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/argmax_matcher.py
@@ -0,0 +1,201 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Argmax matcher implementation.
+
+This class takes a similarity matrix and matches columns to rows based on the
+maximum value per column. One can specify matched_thresholds and
+to prevent columns from matching to rows (generally resulting in a negative
+training example) and unmatched_theshold to ignore the match (generally
+resulting in neither a positive or negative training example).
+
+This matcher is used in Fast(er)-RCNN.
+
+Note: matchers are used in TargetAssigners. There is a create_target_assigner
+factory function for popular implementations.
+"""
+import tensorflow as tf
+
+from official.vision.detection.utils.object_detection import matcher
+from official.vision.detection.utils.object_detection import shape_utils
+
+
+class ArgMaxMatcher(matcher.Matcher):
+ """Matcher based on highest value.
+
+ This class computes matches from a similarity matrix. Each column is matched
+ to a single row.
+
+ To support object detection target assignment this class enables setting both
+ matched_threshold (upper threshold) and unmatched_threshold (lower thresholds)
+ defining three categories of similarity which define whether examples are
+ positive, negative, or ignored:
+ (1) similarity >= matched_threshold: Highest similarity. Matched/Positive!
+ (2) matched_threshold > similarity >= unmatched_threshold: Medium similarity.
+ Depending on negatives_lower_than_unmatched, this is either
+ Unmatched/Negative OR Ignore.
+ (3) unmatched_threshold > similarity: Lowest similarity. Depending on flag
+ negatives_lower_than_unmatched, either Unmatched/Negative OR Ignore.
+ For ignored matches this class sets the values in the Match object to -2.
+ """
+
+ def __init__(self,
+ matched_threshold,
+ unmatched_threshold=None,
+ negatives_lower_than_unmatched=True,
+ force_match_for_each_row=False):
+ """Construct ArgMaxMatcher.
+
+ Args:
+ matched_threshold: Threshold for positive matches. Positive if
+ sim >= matched_threshold, where sim is the maximum value of the
+ similarity matrix for a given column. Set to None for no threshold.
+ unmatched_threshold: Threshold for negative matches. Negative if
+ sim < unmatched_threshold. Defaults to matched_threshold
+ when set to None.
+ negatives_lower_than_unmatched: Boolean which defaults to True. If True
+ then negative matches are the ones below the unmatched_threshold,
+ whereas ignored matches are in between the matched and umatched
+ threshold. If False, then negative matches are in between the matched
+ and unmatched threshold, and everything lower than unmatched is ignored.
+ force_match_for_each_row: If True, ensures that each row is matched to
+ at least one column (which is not guaranteed otherwise if the
+ matched_threshold is high). Defaults to False. See
+ argmax_matcher_test.testMatcherForceMatch() for an example.
+
+ Raises:
+ ValueError: if unmatched_threshold is set but matched_threshold is not set
+ or if unmatched_threshold > matched_threshold.
+ """
+ if (matched_threshold is None) and (unmatched_threshold is not None):
+ raise ValueError('Need to also define matched_threshold when'
+ 'unmatched_threshold is defined')
+ self._matched_threshold = matched_threshold
+ if unmatched_threshold is None:
+ self._unmatched_threshold = matched_threshold
+ else:
+ if unmatched_threshold > matched_threshold:
+ raise ValueError('unmatched_threshold needs to be smaller or equal'
+ 'to matched_threshold')
+ self._unmatched_threshold = unmatched_threshold
+ if not negatives_lower_than_unmatched:
+ if self._unmatched_threshold == self._matched_threshold:
+ raise ValueError('When negatives are in between matched and '
+ 'unmatched thresholds, these cannot be of equal '
+ 'value. matched: %s, unmatched: %s',
+ self._matched_threshold, self._unmatched_threshold)
+ self._force_match_for_each_row = force_match_for_each_row
+ self._negatives_lower_than_unmatched = negatives_lower_than_unmatched
+
+ def _match(self, similarity_matrix):
+ """Tries to match each column of the similarity matrix to a row.
+
+ Args:
+ similarity_matrix: tensor of shape [N, M] representing any similarity
+ metric.
+
+ Returns:
+ Match object with corresponding matches for each of M columns.
+ """
+
+ def _match_when_rows_are_empty():
+ """Performs matching when the rows of similarity matrix are empty.
+
+ When the rows are empty, all detections are false positives. So we return
+ a tensor of -1's to indicate that the columns do not match to any rows.
+
+ Returns:
+ matches: int32 tensor indicating the row each column matches to.
+ """
+ similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape(
+ similarity_matrix)
+ return -1 * tf.ones([similarity_matrix_shape[1]], dtype=tf.int32)
+
+ def _match_when_rows_are_non_empty():
+ """Performs matching when the rows of similarity matrix are non empty.
+
+ Returns:
+ matches: int32 tensor indicating the row each column matches to.
+ """
+ # Matches for each column
+ matches = tf.argmax(input=similarity_matrix, axis=0, output_type=tf.int32)
+
+ # Deal with matched and unmatched threshold
+ if self._matched_threshold is not None:
+ # Get logical indices of ignored and unmatched columns as tf.int64
+ matched_vals = tf.reduce_max(input_tensor=similarity_matrix, axis=0)
+ below_unmatched_threshold = tf.greater(self._unmatched_threshold,
+ matched_vals)
+ between_thresholds = tf.logical_and(
+ tf.greater_equal(matched_vals, self._unmatched_threshold),
+ tf.greater(self._matched_threshold, matched_vals))
+
+ if self._negatives_lower_than_unmatched:
+ matches = self._set_values_using_indicator(matches,
+ below_unmatched_threshold,
+ -1)
+ matches = self._set_values_using_indicator(matches,
+ between_thresholds,
+ -2)
+ else:
+ matches = self._set_values_using_indicator(matches,
+ below_unmatched_threshold,
+ -2)
+ matches = self._set_values_using_indicator(matches,
+ between_thresholds,
+ -1)
+
+ if self._force_match_for_each_row:
+ similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape(
+ similarity_matrix)
+ force_match_column_ids = tf.argmax(
+ input=similarity_matrix, axis=1, output_type=tf.int32)
+ force_match_column_indicators = tf.one_hot(
+ force_match_column_ids, depth=similarity_matrix_shape[1])
+ force_match_row_ids = tf.argmax(
+ input=force_match_column_indicators, axis=0, output_type=tf.int32)
+ force_match_column_mask = tf.cast(
+ tf.reduce_max(input_tensor=force_match_column_indicators, axis=0),
+ tf.bool)
+ final_matches = tf.where(force_match_column_mask, force_match_row_ids,
+ matches)
+ return final_matches
+ else:
+ return matches
+
+ if similarity_matrix.shape.is_fully_defined():
+ if similarity_matrix.shape.dims[0].value == 0:
+ return _match_when_rows_are_empty()
+ else:
+ return _match_when_rows_are_non_empty()
+ else:
+ return tf.cond(
+ pred=tf.greater(tf.shape(input=similarity_matrix)[0], 0),
+ true_fn=_match_when_rows_are_non_empty,
+ false_fn=_match_when_rows_are_empty)
+
+ def _set_values_using_indicator(self, x, indicator, val):
+ """Set the indicated fields of x to val.
+
+ Args:
+ x: tensor.
+ indicator: boolean with same shape as x.
+ val: scalar with value to set.
+
+ Returns:
+ modified tensor.
+ """
+ indicator = tf.cast(indicator, x.dtype)
+ return tf.add(tf.multiply(x, 1 - indicator), val * indicator)
diff --git a/models/official/vision/detection/utils/object_detection/balanced_positive_negative_sampler.py b/models/official/vision/detection/utils/object_detection/balanced_positive_negative_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f969182b05a29167649d5c022a667b3f768f0143
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/balanced_positive_negative_sampler.py
@@ -0,0 +1,274 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Class to subsample minibatches by balancing positives and negatives.
+
+Subsamples minibatches based on a pre-specified positive fraction in range
+[0,1]. The class presumes there are many more negatives than positive examples:
+if the desired batch_size cannot be achieved with the pre-specified positive
+fraction, it fills the rest with negative examples. If this is not sufficient
+for obtaining the desired batch_size, it returns fewer examples.
+
+The main function to call is Subsample(self, indicator, labels). For convenience
+one can also call SubsampleWeights(self, weights, labels) which is defined in
+the minibatch_sampler base class.
+
+When is_static is True, it implements a method that guarantees static shapes.
+It also ensures the length of output of the subsample is always batch_size, even
+when number of examples set to True in indicator is less than batch_size.
+
+This is originally implemented in TensorFlow Object Detection API.
+"""
+
+import tensorflow as tf
+
+from official.vision.detection.utils.object_detection import minibatch_sampler
+from official.vision.detection.utils.object_detection import ops
+
+
+class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
+ """Subsamples minibatches to a desired balance of positives and negatives."""
+
+ def __init__(self, positive_fraction=0.5, is_static=False):
+ """Constructs a minibatch sampler.
+
+ Args:
+ positive_fraction: desired fraction of positive examples (scalar in [0,1])
+ in the batch.
+ is_static: If True, uses an implementation with static shape guarantees.
+
+ Raises:
+ ValueError: if positive_fraction < 0, or positive_fraction > 1
+ """
+ if positive_fraction < 0 or positive_fraction > 1:
+ raise ValueError('positive_fraction should be in range [0,1]. '
+ 'Received: %s.' % positive_fraction)
+ self._positive_fraction = positive_fraction
+ self._is_static = is_static
+
+ def _get_num_pos_neg_samples(self, sorted_indices_tensor, sample_size):
+ """Counts the number of positives and negatives numbers to be sampled.
+
+ Args:
+ sorted_indices_tensor: A sorted int32 tensor of shape [N] which contains
+ the signed indices of the examples where the sign is based on the label
+ value. The examples that cannot be sampled are set to 0. It samples
+ atmost sample_size*positive_fraction positive examples and remaining
+ from negative examples.
+ sample_size: Size of subsamples.
+
+ Returns:
+ A tuple containing the number of positive and negative labels in the
+ subsample.
+ """
+ input_length = tf.shape(input=sorted_indices_tensor)[0]
+ valid_positive_index = tf.greater(sorted_indices_tensor,
+ tf.zeros(input_length, tf.int32))
+ num_sampled_pos = tf.reduce_sum(
+ input_tensor=tf.cast(valid_positive_index, tf.int32))
+ max_num_positive_samples = tf.constant(
+ int(sample_size * self._positive_fraction), tf.int32)
+ num_positive_samples = tf.minimum(max_num_positive_samples, num_sampled_pos)
+ num_negative_samples = tf.constant(sample_size,
+ tf.int32) - num_positive_samples
+
+ return num_positive_samples, num_negative_samples
+
+ def _get_values_from_start_and_end(self, input_tensor, num_start_samples,
+ num_end_samples, total_num_samples):
+ """slices num_start_samples and last num_end_samples from input_tensor.
+
+ Args:
+ input_tensor: An int32 tensor of shape [N] to be sliced.
+ num_start_samples: Number of examples to be sliced from the beginning
+ of the input tensor.
+ num_end_samples: Number of examples to be sliced from the end of the
+ input tensor.
+ total_num_samples: Sum of is num_start_samples and num_end_samples. This
+ should be a scalar.
+
+ Returns:
+ A tensor containing the first num_start_samples and last num_end_samples
+ from input_tensor.
+
+ """
+ input_length = tf.shape(input=input_tensor)[0]
+ start_positions = tf.less(tf.range(input_length), num_start_samples)
+ end_positions = tf.greater_equal(
+ tf.range(input_length), input_length - num_end_samples)
+ selected_positions = tf.logical_or(start_positions, end_positions)
+ selected_positions = tf.cast(selected_positions, tf.float32)
+ indexed_positions = tf.multiply(tf.cumsum(selected_positions),
+ selected_positions)
+ one_hot_selector = tf.one_hot(tf.cast(indexed_positions, tf.int32) - 1,
+ total_num_samples,
+ dtype=tf.float32)
+ return tf.cast(tf.tensordot(tf.cast(input_tensor, tf.float32),
+ one_hot_selector, axes=[0, 0]), tf.int32)
+
+ def _static_subsample(self, indicator, batch_size, labels):
+ """Returns subsampled minibatch.
+
+ Args:
+ indicator: boolean tensor of shape [N] whose True entries can be sampled.
+ N should be a complie time constant.
+ batch_size: desired batch size. This scalar cannot be None.
+ labels: boolean tensor of shape [N] denoting positive(=True) and negative
+ (=False) examples. N should be a complie time constant.
+
+ Returns:
+ sampled_idx_indicator: boolean tensor of shape [N], True for entries which
+ are sampled. It ensures the length of output of the subsample is always
+ batch_size, even when number of examples set to True in indicator is
+ less than batch_size.
+
+ Raises:
+ ValueError: if labels and indicator are not 1D boolean tensors.
+ """
+ # Check if indicator and labels have a static size.
+ if not indicator.shape.is_fully_defined():
+ raise ValueError('indicator must be static in shape when is_static is'
+ 'True')
+ if not labels.shape.is_fully_defined():
+ raise ValueError('labels must be static in shape when is_static is'
+ 'True')
+ if not isinstance(batch_size, int):
+ raise ValueError('batch_size has to be an integer when is_static is'
+ 'True.')
+
+ input_length = tf.shape(input=indicator)[0]
+
+ # Set the number of examples set True in indicator to be at least
+ # batch_size.
+ num_true_sampled = tf.reduce_sum(
+ input_tensor=tf.cast(indicator, tf.float32))
+ additional_false_sample = tf.less_equal(
+ tf.cumsum(tf.cast(tf.logical_not(indicator), tf.float32)),
+ batch_size - num_true_sampled)
+ indicator = tf.logical_or(indicator, additional_false_sample)
+
+ # Shuffle indicator and label. Need to store the permutation to restore the
+ # order post sampling.
+ permutation = tf.random.shuffle(tf.range(input_length))
+ indicator = ops.matmul_gather_on_zeroth_axis(
+ tf.cast(indicator, tf.float32), permutation)
+ labels = ops.matmul_gather_on_zeroth_axis(
+ tf.cast(labels, tf.float32), permutation)
+
+ # index (starting from 1) when indicator is True, 0 when False
+ indicator_idx = tf.where(
+ tf.cast(indicator, tf.bool), tf.range(1, input_length + 1),
+ tf.zeros(input_length, tf.int32))
+
+ # Replace -1 for negative, +1 for positive labels
+ signed_label = tf.where(
+ tf.cast(labels, tf.bool), tf.ones(input_length, tf.int32),
+ tf.scalar_mul(-1, tf.ones(input_length, tf.int32)))
+ # negative of index for negative label, positive index for positive label,
+ # 0 when indicator is False.
+ signed_indicator_idx = tf.multiply(indicator_idx, signed_label)
+ sorted_signed_indicator_idx = tf.nn.top_k(
+ signed_indicator_idx, input_length, sorted=True).values
+
+ [num_positive_samples,
+ num_negative_samples] = self._get_num_pos_neg_samples(
+ sorted_signed_indicator_idx, batch_size)
+
+ sampled_idx = self._get_values_from_start_and_end(
+ sorted_signed_indicator_idx, num_positive_samples,
+ num_negative_samples, batch_size)
+
+ # Shift the indices to start from 0 and remove any samples that are set as
+ # False.
+ sampled_idx = tf.abs(sampled_idx) - tf.ones(batch_size, tf.int32)
+ sampled_idx = tf.multiply(
+ tf.cast(tf.greater_equal(sampled_idx, tf.constant(0)), tf.int32),
+ sampled_idx)
+
+ sampled_idx_indicator = tf.cast(
+ tf.reduce_sum(
+ input_tensor=tf.one_hot(sampled_idx, depth=input_length), axis=0),
+ tf.bool)
+
+ # project back the order based on stored permutations
+ reprojections = tf.one_hot(permutation, depth=input_length,
+ dtype=tf.float32)
+ return tf.cast(tf.tensordot(
+ tf.cast(sampled_idx_indicator, tf.float32),
+ reprojections, axes=[0, 0]), tf.bool)
+
+ def subsample(self, indicator, batch_size, labels, scope=None):
+ """Returns subsampled minibatch.
+
+ Args:
+ indicator: boolean tensor of shape [N] whose True entries can be sampled.
+ batch_size: desired batch size. If None, keeps all positive samples and
+ randomly selects negative samples so that the positive sample fraction
+ matches self._positive_fraction. It cannot be None is is_static is True.
+ labels: boolean tensor of shape [N] denoting positive(=True) and negative
+ (=False) examples.
+ scope: name scope.
+
+ Returns:
+ sampled_idx_indicator: boolean tensor of shape [N], True for entries which
+ are sampled.
+
+ Raises:
+ ValueError: if labels and indicator are not 1D boolean tensors.
+ """
+ if len(indicator.get_shape().as_list()) != 1:
+ raise ValueError('indicator must be 1 dimensional, got a tensor of '
+ 'shape %s' % indicator.get_shape())
+ if len(labels.get_shape().as_list()) != 1:
+ raise ValueError('labels must be 1 dimensional, got a tensor of '
+ 'shape %s' % labels.get_shape())
+ if labels.dtype != tf.bool:
+ raise ValueError('labels should be of type bool. Received: %s' %
+ labels.dtype)
+ if indicator.dtype != tf.bool:
+ raise ValueError('indicator should be of type bool. Received: %s' %
+ indicator.dtype)
+ scope = scope or 'BalancedPositiveNegativeSampler'
+ with tf.name_scope(scope):
+ if self._is_static:
+ return self._static_subsample(indicator, batch_size, labels)
+
+ else:
+ # Only sample from indicated samples
+ negative_idx = tf.logical_not(labels)
+ positive_idx = tf.logical_and(labels, indicator)
+ negative_idx = tf.logical_and(negative_idx, indicator)
+
+ # Sample positive and negative samples separately
+ if batch_size is None:
+ max_num_pos = tf.reduce_sum(
+ input_tensor=tf.cast(positive_idx, dtype=tf.int32))
+ else:
+ max_num_pos = int(self._positive_fraction * batch_size)
+ sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos)
+ num_sampled_pos = tf.reduce_sum(
+ input_tensor=tf.cast(sampled_pos_idx, tf.int32))
+ if batch_size is None:
+ negative_positive_ratio = (
+ 1 - self._positive_fraction) / self._positive_fraction
+ max_num_neg = tf.cast(
+ negative_positive_ratio *
+ tf.cast(num_sampled_pos, dtype=tf.float32),
+ dtype=tf.int32)
+ else:
+ max_num_neg = batch_size - num_sampled_pos
+ sampled_neg_idx = self.subsample_indicator(negative_idx, max_num_neg)
+
+ return tf.logical_or(sampled_pos_idx, sampled_neg_idx)
diff --git a/models/official/vision/detection/utils/object_detection/box_coder.py b/models/official/vision/detection/utils/object_detection/box_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f20ac956dfbce1fa69d1b9e6f5b023b704e1ec8a
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/box_coder.py
@@ -0,0 +1,151 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Base box coder.
+
+Box coders convert between coordinate frames, namely image-centric
+(with (0,0) on the top left of image) and anchor-centric (with (0,0) being
+defined by a specific anchor).
+
+Users of a BoxCoder can call two methods:
+ encode: which encodes a box with respect to a given anchor
+ (or rather, a tensor of boxes wrt a corresponding tensor of anchors) and
+ decode: which inverts this encoding with a decode operation.
+In both cases, the arguments are assumed to be in 1-1 correspondence already;
+it is not the job of a BoxCoder to perform matching.
+"""
+from abc import ABCMeta
+from abc import abstractmethod
+from abc import abstractproperty
+
+import tensorflow as tf
+
+
+# Box coder types.
+FASTER_RCNN = 'faster_rcnn'
+KEYPOINT = 'keypoint'
+MEAN_STDDEV = 'mean_stddev'
+SQUARE = 'square'
+
+
+class BoxCoder(object):
+ """Abstract base class for box coder."""
+ __metaclass__ = ABCMeta
+
+ @abstractproperty
+ def code_size(self):
+ """Return the size of each code.
+
+ This number is a constant and should agree with the output of the `encode`
+ op (e.g. if rel_codes is the output of self.encode(...), then it should have
+ shape [N, code_size()]). This abstractproperty should be overridden by
+ implementations.
+
+ Returns:
+ an integer constant
+ """
+ pass
+
+ def encode(self, boxes, anchors):
+ """Encode a box list relative to an anchor collection.
+
+ Args:
+ boxes: BoxList holding N boxes to be encoded
+ anchors: BoxList of N anchors
+
+ Returns:
+ a tensor representing N relative-encoded boxes
+ """
+ with tf.name_scope('Encode'):
+ return self._encode(boxes, anchors)
+
+ def decode(self, rel_codes, anchors):
+ """Decode boxes that are encoded relative to an anchor collection.
+
+ Args:
+ rel_codes: a tensor representing N relative-encoded boxes
+ anchors: BoxList of anchors
+
+ Returns:
+ boxlist: BoxList holding N boxes encoded in the ordinary way (i.e.,
+ with corners y_min, x_min, y_max, x_max)
+ """
+ with tf.name_scope('Decode'):
+ return self._decode(rel_codes, anchors)
+
+ @abstractmethod
+ def _encode(self, boxes, anchors):
+ """Method to be overriden by implementations.
+
+ Args:
+ boxes: BoxList holding N boxes to be encoded
+ anchors: BoxList of N anchors
+
+ Returns:
+ a tensor representing N relative-encoded boxes
+ """
+ pass
+
+ @abstractmethod
+ def _decode(self, rel_codes, anchors):
+ """Method to be overriden by implementations.
+
+ Args:
+ rel_codes: a tensor representing N relative-encoded boxes
+ anchors: BoxList of anchors
+
+ Returns:
+ boxlist: BoxList holding N boxes encoded in the ordinary way (i.e.,
+ with corners y_min, x_min, y_max, x_max)
+ """
+ pass
+
+
+def batch_decode(encoded_boxes, box_coder, anchors):
+ """Decode a batch of encoded boxes.
+
+ This op takes a batch of encoded bounding boxes and transforms
+ them to a batch of bounding boxes specified by their corners in
+ the order of [y_min, x_min, y_max, x_max].
+
+ Args:
+ encoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
+ code_size] representing the location of the objects.
+ box_coder: a BoxCoder object.
+ anchors: a BoxList of anchors used to encode `encoded_boxes`.
+
+ Returns:
+ decoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
+ coder_size] representing the corners of the objects in the order
+ of [y_min, x_min, y_max, x_max].
+
+ Raises:
+ ValueError: if batch sizes of the inputs are inconsistent, or if
+ the number of anchors inferred from encoded_boxes and anchors are
+ inconsistent.
+ """
+ encoded_boxes.get_shape().assert_has_rank(3)
+ if encoded_boxes.get_shape()[1].value != anchors.num_boxes_static():
+ raise ValueError('The number of anchors inferred from encoded_boxes'
+ ' and anchors are inconsistent: shape[1] of encoded_boxes'
+ ' %s should be equal to the number of anchors: %s.' %
+ (encoded_boxes.get_shape()[1].value,
+ anchors.num_boxes_static()))
+
+ decoded_boxes = tf.stack([
+ box_coder.decode(boxes, anchors).get()
+ for boxes in tf.unstack(encoded_boxes)
+ ])
+ return decoded_boxes
diff --git a/models/official/vision/detection/utils/object_detection/box_list.py b/models/official/vision/detection/utils/object_detection/box_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..113fab8c197194f1cd0099d5a177cd9f1fb6e64c
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/box_list.py
@@ -0,0 +1,211 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Bounding Box List definition.
+
+BoxList represents a list of bounding boxes as tensorflow
+tensors, where each bounding box is represented as a row of 4 numbers,
+[y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes
+within a given list correspond to a single image. See also
+box_list_ops.py for common box related operations (such as area, iou, etc).
+
+Optionally, users can add additional related fields (such as weights).
+We assume the following things to be true about fields:
+* they correspond to boxes in the box_list along the 0th dimension
+* they have inferrable rank at graph construction time
+* all dimensions except for possibly the 0th can be inferred
+ (i.e., not None) at graph construction time.
+
+Some other notes:
+ * Following tensorflow conventions, we use height, width ordering,
+ and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering
+ * Tensors are always provided as (flat) [N, 4] tensors.
+"""
+
+import tensorflow as tf
+
+
+class BoxList(object):
+ """Box collection."""
+
+ def __init__(self, boxes):
+ """Constructs box collection.
+
+ Args:
+ boxes: a tensor of shape [N, 4] representing box corners
+
+ Raises:
+ ValueError: if invalid dimensions for bbox data or if bbox data is not in
+ float32 format.
+ """
+ if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
+ raise ValueError('Invalid dimensions for box data.')
+ if boxes.dtype != tf.float32:
+ raise ValueError('Invalid tensor type: should be tf.float32')
+ self.data = {'boxes': boxes}
+
+ def num_boxes(self):
+ """Returns number of boxes held in collection.
+
+ Returns:
+ a tensor representing the number of boxes held in the collection.
+ """
+ return tf.shape(input=self.data['boxes'])[0]
+
+ def num_boxes_static(self):
+ """Returns number of boxes held in collection.
+
+ This number is inferred at graph construction time rather than run-time.
+
+ Returns:
+ Number of boxes held in collection (integer) or None if this is not
+ inferrable at graph construction time.
+ """
+ return self.data['boxes'].get_shape().dims[0].value
+
+ def get_all_fields(self):
+ """Returns all fields."""
+ return self.data.keys()
+
+ def get_extra_fields(self):
+ """Returns all non-box fields (i.e., everything not named 'boxes')."""
+ return [k for k in self.data.keys() if k != 'boxes']
+
+ def add_field(self, field, field_data):
+ """Add field to box list.
+
+ This method can be used to add related box data such as
+ weights/labels, etc.
+
+ Args:
+ field: a string key to access the data via `get`
+ field_data: a tensor containing the data to store in the BoxList
+ """
+ self.data[field] = field_data
+
+ def has_field(self, field):
+ return field in self.data
+
+ def get(self):
+ """Convenience function for accessing box coordinates.
+
+ Returns:
+ a tensor with shape [N, 4] representing box coordinates.
+ """
+ return self.get_field('boxes')
+
+ def set(self, boxes):
+ """Convenience function for setting box coordinates.
+
+ Args:
+ boxes: a tensor of shape [N, 4] representing box corners
+
+ Raises:
+ ValueError: if invalid dimensions for bbox data
+ """
+ if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
+ raise ValueError('Invalid dimensions for box data.')
+ self.data['boxes'] = boxes
+
+ def get_field(self, field):
+ """Accesses a box collection and associated fields.
+
+ This function returns specified field with object; if no field is specified,
+ it returns the box coordinates.
+
+ Args:
+ field: this optional string parameter can be used to specify
+ a related field to be accessed.
+
+ Returns:
+ a tensor representing the box collection or an associated field.
+
+ Raises:
+ ValueError: if invalid field
+ """
+ if not self.has_field(field):
+ raise ValueError('field ' + str(field) + ' does not exist')
+ return self.data[field]
+
+ def set_field(self, field, value):
+ """Sets the value of a field.
+
+ Updates the field of a box_list with a given value.
+
+ Args:
+ field: (string) name of the field to set value.
+ value: the value to assign to the field.
+
+ Raises:
+ ValueError: if the box_list does not have specified field.
+ """
+ if not self.has_field(field):
+ raise ValueError('field %s does not exist' % field)
+ self.data[field] = value
+
+ def get_center_coordinates_and_sizes(self, scope=None):
+ """Computes the center coordinates, height and width of the boxes.
+
+ Args:
+ scope: name scope of the function.
+
+ Returns:
+ a list of 4 1-D tensors [ycenter, xcenter, height, width].
+ """
+ if not scope:
+ scope = 'get_center_coordinates_and_sizes'
+ with tf.name_scope(scope):
+ box_corners = self.get()
+ ymin, xmin, ymax, xmax = tf.unstack(tf.transpose(a=box_corners))
+ width = xmax - xmin
+ height = ymax - ymin
+ ycenter = ymin + height / 2.
+ xcenter = xmin + width / 2.
+ return [ycenter, xcenter, height, width]
+
+ def transpose_coordinates(self, scope=None):
+ """Transpose the coordinate representation in a boxlist.
+
+ Args:
+ scope: name scope of the function.
+ """
+ if not scope:
+ scope = 'transpose_coordinates'
+ with tf.name_scope(scope):
+ y_min, x_min, y_max, x_max = tf.split(
+ value=self.get(), num_or_size_splits=4, axis=1)
+ self.set(tf.concat([x_min, y_min, x_max, y_max], 1))
+
+ def as_tensor_dict(self, fields=None):
+ """Retrieves specified fields as a dictionary of tensors.
+
+ Args:
+ fields: (optional) list of fields to return in the dictionary.
+ If None (default), all fields are returned.
+
+ Returns:
+ tensor_dict: A dictionary of tensors specified by fields.
+
+ Raises:
+ ValueError: if specified field is not contained in boxlist.
+ """
+ tensor_dict = {}
+ if fields is None:
+ fields = self.get_all_fields()
+ for field in fields:
+ if not self.has_field(field):
+ raise ValueError('boxlist must contain all specified fields')
+ tensor_dict[field] = self.get_field(field)
+ return tensor_dict
diff --git a/models/official/vision/detection/utils/object_detection/box_list_ops.py b/models/official/vision/detection/utils/object_detection/box_list_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f1b06e28d588eb05c9ea8596b44d08690481eae
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/box_list_ops.py
@@ -0,0 +1,1079 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Bounding Box List operations.
+
+Example box operations that are supported:
+ * areas: compute bounding box areas
+ * iou: pairwise intersection-over-union scores
+ * sq_dist: pairwise distances between bounding boxes
+
+Whenever box_list_ops functions output a BoxList, the fields of the incoming
+BoxList are retained unless documented otherwise.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import range
+import tensorflow as tf
+
+from official.vision.detection.utils.object_detection import box_list
+from official.vision.detection.utils.object_detection import ops
+
+
+class SortOrder(object):
+ """Enum class for sort order.
+
+ Attributes:
+ ascend: ascend order.
+ descend: descend order.
+ """
+ ascend = 1
+ descend = 2
+
+
+def area(boxlist, scope=None):
+ """Computes area of boxes.
+
+ Args:
+ boxlist: BoxList holding N boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N] representing box areas.
+ """
+ with tf.name_scope(scope, 'Area'):
+ y_min, x_min, y_max, x_max = tf.split(
+ value=boxlist.get(), num_or_size_splits=4, axis=1)
+ return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])
+
+
+def height_width(boxlist, scope=None):
+ """Computes height and width of boxes in boxlist.
+
+ Args:
+ boxlist: BoxList holding N boxes
+ scope: name scope.
+
+ Returns:
+ Height: A tensor with shape [N] representing box heights.
+ Width: A tensor with shape [N] representing box widths.
+ """
+ with tf.name_scope(scope, 'HeightWidth'):
+ y_min, x_min, y_max, x_max = tf.split(
+ value=boxlist.get(), num_or_size_splits=4, axis=1)
+ return tf.squeeze(y_max - y_min, [1]), tf.squeeze(x_max - x_min, [1])
+
+
+def scale(boxlist, y_scale, x_scale, scope=None):
+ """scale box coordinates in x and y dimensions.
+
+ Args:
+ boxlist: BoxList holding N boxes
+ y_scale: (float) scalar tensor
+ x_scale: (float) scalar tensor
+ scope: name scope.
+
+ Returns:
+ boxlist: BoxList holding N boxes
+ """
+ with tf.name_scope(scope, 'Scale'):
+ y_scale = tf.cast(y_scale, tf.float32)
+ x_scale = tf.cast(x_scale, tf.float32)
+ y_min, x_min, y_max, x_max = tf.split(
+ value=boxlist.get(), num_or_size_splits=4, axis=1)
+ y_min = y_scale * y_min
+ y_max = y_scale * y_max
+ x_min = x_scale * x_min
+ x_max = x_scale * x_max
+ scaled_boxlist = box_list.BoxList(
+ tf.concat([y_min, x_min, y_max, x_max], 1))
+ return _copy_extra_fields(scaled_boxlist, boxlist)
+
+
+def clip_to_window(boxlist, window, filter_nonoverlapping=True, scope=None):
+ """Clip bounding boxes to a window.
+
+ This op clips any input bounding boxes (represented by bounding box
+ corners) to a window, optionally filtering out boxes that do not
+ overlap at all with the window.
+
+ Args:
+ boxlist: BoxList holding M_in boxes
+ window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max]
+ window to which the op should clip boxes.
+ filter_nonoverlapping: whether to filter out boxes that do not overlap at
+ all with the window.
+ scope: name scope.
+
+ Returns:
+ a BoxList holding M_out boxes where M_out <= M_in
+ """
+ with tf.name_scope(scope, 'ClipToWindow'):
+ y_min, x_min, y_max, x_max = tf.split(
+ value=boxlist.get(), num_or_size_splits=4, axis=1)
+ win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
+ y_min_clipped = tf.maximum(tf.minimum(y_min, win_y_max), win_y_min)
+ y_max_clipped = tf.maximum(tf.minimum(y_max, win_y_max), win_y_min)
+ x_min_clipped = tf.maximum(tf.minimum(x_min, win_x_max), win_x_min)
+ x_max_clipped = tf.maximum(tf.minimum(x_max, win_x_max), win_x_min)
+ clipped = box_list.BoxList(
+ tf.concat([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped],
+ 1))
+ clipped = _copy_extra_fields(clipped, boxlist)
+ if filter_nonoverlapping:
+ areas = area(clipped)
+ nonzero_area_indices = tf.cast(
+ tf.reshape(tf.where(tf.greater(areas, 0.0)), [-1]), tf.int32)
+ clipped = gather(clipped, nonzero_area_indices)
+ return clipped
+
+
+def prune_outside_window(boxlist, window, scope=None):
+ """Prunes bounding boxes that fall outside a given window.
+
+ This function prunes bounding boxes that even partially fall outside the given
+ window. See also clip_to_window which only prunes bounding boxes that fall
+ completely outside the window, and clips any bounding boxes that partially
+ overflow.
+
+ Args:
+ boxlist: a BoxList holding M_in boxes.
+ window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
+ of the window
+ scope: name scope.
+
+ Returns:
+ pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in
+ valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
+ in the input tensor.
+ """
+ with tf.name_scope(scope, 'PruneOutsideWindow'):
+ y_min, x_min, y_max, x_max = tf.split(
+ value=boxlist.get(), num_or_size_splits=4, axis=1)
+ win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
+ coordinate_violations = tf.concat([
+ tf.less(y_min, win_y_min), tf.less(x_min, win_x_min),
+ tf.greater(y_max, win_y_max), tf.greater(x_max, win_x_max)
+ ], 1)
+ valid_indices = tf.reshape(
+ tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1])
+ return gather(boxlist, valid_indices), valid_indices
+
+
+def prune_completely_outside_window(boxlist, window, scope=None):
+ """Prunes bounding boxes that fall completely outside of the given window.
+
+ The function clip_to_window prunes bounding boxes that fall
+ completely outside the window, but also clips any bounding boxes that
+ partially overflow. This function does not clip partially overflowing boxes.
+
+ Args:
+ boxlist: a BoxList holding M_in boxes.
+ window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
+ of the window
+ scope: name scope.
+
+ Returns:
+ pruned_boxlist: a new BoxList with all bounding boxes partially or fully in
+ the window.
+ valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
+ in the input tensor.
+ """
+ with tf.name_scope(scope, 'PruneCompleteleyOutsideWindow'):
+ y_min, x_min, y_max, x_max = tf.split(
+ value=boxlist.get(), num_or_size_splits=4, axis=1)
+ win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
+ coordinate_violations = tf.concat([
+ tf.greater_equal(y_min, win_y_max), tf.greater_equal(x_min, win_x_max),
+ tf.less_equal(y_max, win_y_min), tf.less_equal(x_max, win_x_min)
+ ], 1)
+ valid_indices = tf.reshape(
+ tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1])
+ return gather(boxlist, valid_indices), valid_indices
+
+
+def intersection(boxlist1, boxlist2, scope=None):
+ """Compute pairwise intersection areas between boxes.
+
+ Args:
+ boxlist1: BoxList holding N boxes
+ boxlist2: BoxList holding M boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N, M] representing pairwise intersections
+ """
+ with tf.name_scope(scope, 'Intersection'):
+ y_min1, x_min1, y_max1, x_max1 = tf.split(
+ value=boxlist1.get(), num_or_size_splits=4, axis=1)
+ y_min2, x_min2, y_max2, x_max2 = tf.split(
+ value=boxlist2.get(), num_or_size_splits=4, axis=1)
+ all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2))
+ all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2))
+ intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin)
+ all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2))
+ all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2))
+ intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin)
+ return intersect_heights * intersect_widths
+
+
+def matched_intersection(boxlist1, boxlist2, scope=None):
+ """Compute intersection areas between corresponding boxes in two boxlists.
+
+ Args:
+ boxlist1: BoxList holding N boxes
+ boxlist2: BoxList holding N boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N] representing pairwise intersections
+ """
+ with tf.name_scope(scope, 'MatchedIntersection'):
+ y_min1, x_min1, y_max1, x_max1 = tf.split(
+ value=boxlist1.get(), num_or_size_splits=4, axis=1)
+ y_min2, x_min2, y_max2, x_max2 = tf.split(
+ value=boxlist2.get(), num_or_size_splits=4, axis=1)
+ min_ymax = tf.minimum(y_max1, y_max2)
+ max_ymin = tf.maximum(y_min1, y_min2)
+ intersect_heights = tf.maximum(0.0, min_ymax - max_ymin)
+ min_xmax = tf.minimum(x_max1, x_max2)
+ max_xmin = tf.maximum(x_min1, x_min2)
+ intersect_widths = tf.maximum(0.0, min_xmax - max_xmin)
+ return tf.reshape(intersect_heights * intersect_widths, [-1])
+
+
+def iou(boxlist1, boxlist2, scope=None):
+ """Computes pairwise intersection-over-union between box collections.
+
+ Args:
+ boxlist1: BoxList holding N boxes
+ boxlist2: BoxList holding M boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N, M] representing pairwise iou scores.
+ """
+ with tf.name_scope(scope, 'IOU'):
+ intersections = intersection(boxlist1, boxlist2)
+ areas1 = area(boxlist1)
+ areas2 = area(boxlist2)
+ unions = (
+ tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections)
+ return tf.where(
+ tf.equal(intersections, 0.0),
+ tf.zeros_like(intersections), tf.truediv(intersections, unions))
+
+
+def matched_iou(boxlist1, boxlist2, scope=None):
+ """Compute intersection-over-union between corresponding boxes in boxlists.
+
+ Args:
+ boxlist1: BoxList holding N boxes
+ boxlist2: BoxList holding N boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N] representing pairwise iou scores.
+ """
+ with tf.name_scope(scope, 'MatchedIOU'):
+ intersections = matched_intersection(boxlist1, boxlist2)
+ areas1 = area(boxlist1)
+ areas2 = area(boxlist2)
+ unions = areas1 + areas2 - intersections
+ return tf.where(
+ tf.equal(intersections, 0.0),
+ tf.zeros_like(intersections), tf.truediv(intersections, unions))
+
+
+def ioa(boxlist1, boxlist2, scope=None):
+ """Computes pairwise intersection-over-area between box collections.
+
+ intersection-over-area (IOA) between two boxes box1 and box2 is defined as
+ their intersection area over box2's area. Note that ioa is not symmetric,
+ that is, ioa(box1, box2) != ioa(box2, box1).
+
+ Args:
+ boxlist1: BoxList holding N boxes
+ boxlist2: BoxList holding M boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N, M] representing pairwise ioa scores.
+ """
+ with tf.name_scope(scope, 'IOA'):
+ intersections = intersection(boxlist1, boxlist2)
+ areas = tf.expand_dims(area(boxlist2), 0)
+ return tf.truediv(intersections, areas)
+
+
+def prune_non_overlapping_boxes(
+ boxlist1, boxlist2, min_overlap=0.0, scope=None):
+ """Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2.
+
+ For each box in boxlist1, we want its IOA to be more than minoverlap with
+ at least one of the boxes in boxlist2. If it does not, we remove it.
+
+ Args:
+ boxlist1: BoxList holding N boxes.
+ boxlist2: BoxList holding M boxes.
+ min_overlap: Minimum required overlap between boxes, to count them as
+ overlapping.
+ scope: name scope.
+
+ Returns:
+ new_boxlist1: A pruned boxlist with size [N', 4].
+ keep_inds: A tensor with shape [N'] indexing kept bounding boxes in the
+ first input BoxList `boxlist1`.
+ """
+ with tf.name_scope(scope, 'PruneNonOverlappingBoxes'):
+ ioa_ = ioa(boxlist2, boxlist1) # [M, N] tensor
+ ioa_ = tf.reduce_max(ioa_, reduction_indices=[0]) # [N] tensor
+ keep_bool = tf.greater_equal(ioa_, tf.constant(min_overlap))
+ keep_inds = tf.squeeze(tf.where(keep_bool), axis=[1])
+ new_boxlist1 = gather(boxlist1, keep_inds)
+ return new_boxlist1, keep_inds
+
+
+def prune_small_boxes(boxlist, min_side, scope=None):
+ """Prunes small boxes in the boxlist which have a side smaller than min_side.
+
+ Args:
+ boxlist: BoxList holding N boxes.
+ min_side: Minimum width AND height of box to survive pruning.
+ scope: name scope.
+
+ Returns:
+ A pruned boxlist.
+ """
+ with tf.name_scope(scope, 'PruneSmallBoxes'):
+ height, width = height_width(boxlist)
+ is_valid = tf.logical_and(tf.greater_equal(width, min_side),
+ tf.greater_equal(height, min_side))
+ return gather(boxlist, tf.reshape(tf.where(is_valid), [-1]))
+
+
+def change_coordinate_frame(boxlist, window, scope=None):
+ """Change coordinate frame of the boxlist to be relative to window's frame.
+
+ Given a window of the form [ymin, xmin, ymax, xmax],
+ changes bounding box coordinates from boxlist to be relative to this window
+ (e.g., the min corner maps to (0,0) and the max corner maps to (1,1)).
+
+ An example use case is data augmentation: where we are given groundtruth
+ boxes (boxlist) and would like to randomly crop the image to some
+ window (window). In this case we need to change the coordinate frame of
+ each groundtruth box to be relative to this new window.
+
+ Args:
+ boxlist: A BoxList object holding N boxes.
+ window: A rank 1 tensor [4].
+ scope: name scope.
+
+ Returns:
+ Returns a BoxList object with N boxes.
+ """
+ with tf.name_scope(scope, 'ChangeCoordinateFrame'):
+ win_height = window[2] - window[0]
+ win_width = window[3] - window[1]
+ boxlist_new = scale(box_list.BoxList(
+ boxlist.get() - [window[0], window[1], window[0], window[1]]),
+ 1.0 / win_height, 1.0 / win_width)
+ boxlist_new = _copy_extra_fields(boxlist_new, boxlist)
+ return boxlist_new
+
+
+def sq_dist(boxlist1, boxlist2, scope=None):
+ """Computes the pairwise squared distances between box corners.
+
+ This op treats each box as if it were a point in a 4d Euclidean space and
+ computes pairwise squared distances.
+
+ Mathematically, we are given two matrices of box coordinates X and Y,
+ where X(i,:) is the i'th row of X, containing the 4 numbers defining the
+ corners of the i'th box in boxlist1. Similarly Y(j,:) corresponds to
+ boxlist2. We compute
+ Z(i,j) = ||X(i,:) - Y(j,:)||^2
+ = ||X(i,:)||^2 + ||Y(j,:)||^2 - 2 X(i,:)' * Y(j,:),
+
+ Args:
+ boxlist1: BoxList holding N boxes
+ boxlist2: BoxList holding M boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N, M] representing pairwise distances
+ """
+ with tf.name_scope(scope, 'SqDist'):
+ sqnorm1 = tf.reduce_sum(tf.square(boxlist1.get()), 1, keep_dims=True)
+ sqnorm2 = tf.reduce_sum(tf.square(boxlist2.get()), 1, keep_dims=True)
+ innerprod = tf.matmul(boxlist1.get(), boxlist2.get(),
+ transpose_a=False, transpose_b=True)
+ return sqnorm1 + tf.transpose(sqnorm2) - 2.0 * innerprod
+
+
+def boolean_mask(boxlist, indicator, fields=None, scope=None,
+ use_static_shapes=False, indicator_sum=None):
+ """Select boxes from BoxList according to indicator and return new BoxList.
+
+ `boolean_mask` returns the subset of boxes that are marked as "True" by the
+ indicator tensor. By default, `boolean_mask` returns boxes corresponding to
+ the input index list, as well as all additional fields stored in the boxlist
+ (indexing into the first dimension). However one can optionally only draw
+ from a subset of fields.
+
+ Args:
+ boxlist: BoxList holding N boxes
+ indicator: a rank-1 boolean tensor
+ fields: (optional) list of fields to also gather from. If None (default),
+ all fields are gathered from. Pass an empty fields list to only gather
+ the box coordinates.
+ scope: name scope.
+ use_static_shapes: Whether to use an implementation with static shape
+ gurantees.
+ indicator_sum: An integer containing the sum of `indicator` vector. Only
+ required if `use_static_shape` is True.
+
+ Returns:
+ subboxlist: a BoxList corresponding to the subset of the input BoxList
+ specified by indicator
+ Raises:
+ ValueError: if `indicator` is not a rank-1 boolean tensor.
+ """
+ with tf.name_scope(scope, 'BooleanMask'):
+ if indicator.shape.ndims != 1:
+ raise ValueError('indicator should have rank 1')
+ if indicator.dtype != tf.bool:
+ raise ValueError('indicator should be a boolean tensor')
+ if use_static_shapes:
+ if not (indicator_sum and isinstance(indicator_sum, int)):
+ raise ValueError('`indicator_sum` must be a of type int')
+ selected_positions = tf.cast(indicator, dtype=tf.float32)
+ indexed_positions = tf.cast(
+ tf.multiply(
+ tf.cumsum(selected_positions), selected_positions),
+ dtype=tf.int32)
+ one_hot_selector = tf.one_hot(
+ indexed_positions - 1, indicator_sum, dtype=tf.float32)
+ sampled_indices = tf.cast(
+ tf.tensordot(
+ tf.cast(tf.range(tf.shape(indicator)[0]), dtype=tf.float32),
+ one_hot_selector,
+ axes=[0, 0]),
+ dtype=tf.int32)
+ return gather(boxlist, sampled_indices, use_static_shapes=True)
+ else:
+ subboxlist = box_list.BoxList(tf.boolean_mask(boxlist.get(), indicator))
+ if fields is None:
+ fields = boxlist.get_extra_fields()
+ for field in fields:
+ if not boxlist.has_field(field):
+ raise ValueError('boxlist must contain all specified fields')
+ subfieldlist = tf.boolean_mask(boxlist.get_field(field), indicator)
+ subboxlist.add_field(field, subfieldlist)
+ return subboxlist
+
+
+def gather(boxlist, indices, fields=None, scope=None, use_static_shapes=False):
+ """Gather boxes from BoxList according to indices and return new BoxList.
+
+ By default, `gather` returns boxes corresponding to the input index list, as
+ well as all additional fields stored in the boxlist (indexing into the
+ first dimension). However one can optionally only gather from a
+ subset of fields.
+
+ Args:
+ boxlist: BoxList holding N boxes
+ indices: a rank-1 tensor of type int32 / int64
+ fields: (optional) list of fields to also gather from. If None (default),
+ all fields are gathered from. Pass an empty fields list to only gather
+ the box coordinates.
+ scope: name scope.
+ use_static_shapes: Whether to use an implementation with static shape
+ gurantees.
+
+ Returns:
+ subboxlist: a BoxList corresponding to the subset of the input BoxList
+ specified by indices
+ Raises:
+ ValueError: if specified field is not contained in boxlist or if the
+ indices are not of type int32
+ """
+ with tf.name_scope(scope, 'Gather'):
+ if len(indices.shape.as_list()) != 1:
+ raise ValueError('indices should have rank 1')
+ if indices.dtype != tf.int32 and indices.dtype != tf.int64:
+ raise ValueError('indices should be an int32 / int64 tensor')
+ gather_op = tf.gather
+ if use_static_shapes:
+ gather_op = ops.matmul_gather_on_zeroth_axis
+ subboxlist = box_list.BoxList(gather_op(boxlist.get(), indices))
+ if fields is None:
+ fields = boxlist.get_extra_fields()
+ fields += ['boxes']
+ for field in fields:
+ if not boxlist.has_field(field):
+ raise ValueError('boxlist must contain all specified fields')
+ subfieldlist = gather_op(boxlist.get_field(field), indices)
+ subboxlist.add_field(field, subfieldlist)
+ return subboxlist
+
+
+def concatenate(boxlists, fields=None, scope=None):
+ """Concatenate list of BoxLists.
+
+ This op concatenates a list of input BoxLists into a larger BoxList. It also
+ handles concatenation of BoxList fields as long as the field tensor shapes
+ are equal except for the first dimension.
+
+ Args:
+ boxlists: list of BoxList objects
+ fields: optional list of fields to also concatenate. By default, all
+ fields from the first BoxList in the list are included in the
+ concatenation.
+ scope: name scope.
+
+ Returns:
+ a BoxList with number of boxes equal to
+ sum([boxlist.num_boxes() for boxlist in BoxList])
+ Raises:
+ ValueError: if boxlists is invalid (i.e., is not a list, is empty, or
+ contains non BoxList objects), or if requested fields are not contained in
+ all boxlists
+ """
+ with tf.name_scope(scope, 'Concatenate'):
+ if not isinstance(boxlists, list):
+ raise ValueError('boxlists should be a list')
+ if not boxlists:
+ raise ValueError('boxlists should have nonzero length')
+ for boxlist in boxlists:
+ if not isinstance(boxlist, box_list.BoxList):
+ raise ValueError('all elements of boxlists should be BoxList objects')
+ concatenated = box_list.BoxList(
+ tf.concat([boxlist.get() for boxlist in boxlists], 0))
+ if fields is None:
+ fields = boxlists[0].get_extra_fields()
+ for field in fields:
+ first_field_shape = boxlists[0].get_field(field).get_shape().as_list()
+ first_field_shape[0] = -1
+ if None in first_field_shape:
+ raise ValueError('field %s must have fully defined shape except for the'
+ ' 0th dimension.' % field)
+ for boxlist in boxlists:
+ if not boxlist.has_field(field):
+ raise ValueError('boxlist must contain all requested fields')
+ field_shape = boxlist.get_field(field).get_shape().as_list()
+ field_shape[0] = -1
+ if field_shape != first_field_shape:
+ raise ValueError('field %s must have same shape for all boxlists '
+ 'except for the 0th dimension.' % field)
+ concatenated_field = tf.concat(
+ [boxlist.get_field(field) for boxlist in boxlists], 0)
+ concatenated.add_field(field, concatenated_field)
+ return concatenated
+
+
+def sort_by_field(boxlist, field, order=SortOrder.descend, scope=None):
+ """Sort boxes and associated fields according to a scalar field.
+
+ A common use case is reordering the boxes according to descending scores.
+
+ Args:
+ boxlist: BoxList holding N boxes.
+ field: A BoxList field for sorting and reordering the BoxList.
+ order: (Optional) descend or ascend. Default is descend.
+ scope: name scope.
+
+ Returns:
+ sorted_boxlist: A sorted BoxList with the field in the specified order.
+
+ Raises:
+ ValueError: if specified field does not exist
+ ValueError: if the order is not either descend or ascend
+ """
+ with tf.name_scope(scope, 'SortByField'):
+ if order != SortOrder.descend and order != SortOrder.ascend:
+ raise ValueError('Invalid sort order')
+
+ field_to_sort = boxlist.get_field(field)
+ if len(field_to_sort.shape.as_list()) != 1:
+ raise ValueError('Field should have rank 1')
+
+ num_boxes = boxlist.num_boxes()
+ num_entries = tf.size(field_to_sort)
+ length_assert = tf.Assert(
+ tf.equal(num_boxes, num_entries),
+ ['Incorrect field size: actual vs expected.', num_entries, num_boxes])
+
+ with tf.control_dependencies([length_assert]):
+ _, sorted_indices = tf.nn.top_k(field_to_sort, num_boxes, sorted=True)
+
+ if order == SortOrder.ascend:
+ sorted_indices = tf.reverse_v2(sorted_indices, [0])
+
+ return gather(boxlist, sorted_indices)
+
+
+def visualize_boxes_in_image(image, boxlist, normalized=False, scope=None):
+ """Overlay bounding box list on image.
+
+ Currently this visualization plots a 1 pixel thick red bounding box on top
+ of the image. Note that tf.image.draw_bounding_boxes essentially is
+ 1 indexed.
+
+ Args:
+ image: an image tensor with shape [height, width, 3]
+ boxlist: a BoxList
+ normalized: (boolean) specify whether corners are to be interpreted
+ as absolute coordinates in image space or normalized with respect to the
+ image size.
+ scope: name scope.
+
+ Returns:
+ image_and_boxes: an image tensor with shape [height, width, 3]
+ """
+ with tf.name_scope(scope, 'VisualizeBoxesInImage'):
+ if not normalized:
+ height, width, _ = tf.unstack(tf.shape(image))
+ boxlist = scale(boxlist,
+ 1.0 / tf.cast(height, tf.float32),
+ 1.0 / tf.cast(width, tf.float32))
+ corners = tf.expand_dims(boxlist.get(), 0)
+ image = tf.expand_dims(image, 0)
+ return tf.squeeze(tf.image.draw_bounding_boxes(image, corners), [0])
+
+
+def filter_field_value_equals(boxlist, field, value, scope=None):
+ """Filter to keep only boxes with field entries equal to the given value.
+
+ Args:
+ boxlist: BoxList holding N boxes.
+ field: field name for filtering.
+ value: scalar value.
+ scope: name scope.
+
+ Returns:
+ a BoxList holding M boxes where M <= N
+
+ Raises:
+ ValueError: if boxlist not a BoxList object or if it does not have
+ the specified field.
+ """
+ with tf.name_scope(scope, 'FilterFieldValueEquals'):
+ if not isinstance(boxlist, box_list.BoxList):
+ raise ValueError('boxlist must be a BoxList')
+ if not boxlist.has_field(field):
+ raise ValueError('boxlist must contain the specified field')
+ filter_field = boxlist.get_field(field)
+ gather_index = tf.reshape(tf.where(tf.equal(filter_field, value)), [-1])
+ return gather(boxlist, gather_index)
+
+
+def filter_greater_than(boxlist, thresh, scope=None):
+ """Filter to keep only boxes with score exceeding a given threshold.
+
+ This op keeps the collection of boxes whose corresponding scores are
+ greater than the input threshold.
+
+ TODO(jonathanhuang): Change function name to filter_scores_greater_than
+
+ Args:
+ boxlist: BoxList holding N boxes. Must contain a 'scores' field
+ representing detection scores.
+ thresh: scalar threshold
+ scope: name scope.
+
+ Returns:
+ a BoxList holding M boxes where M <= N
+
+ Raises:
+ ValueError: if boxlist not a BoxList object or if it does not
+ have a scores field
+ """
+ with tf.name_scope(scope, 'FilterGreaterThan'):
+ if not isinstance(boxlist, box_list.BoxList):
+ raise ValueError('boxlist must be a BoxList')
+ if not boxlist.has_field('scores'):
+ raise ValueError('input boxlist must have \'scores\' field')
+ scores = boxlist.get_field('scores')
+ if len(scores.shape.as_list()) > 2:
+ raise ValueError('Scores should have rank 1 or 2')
+ if len(scores.shape.as_list()) == 2 and scores.shape.as_list()[1] != 1:
+ raise ValueError('Scores should have rank 1 or have shape '
+ 'consistent with [None, 1]')
+ high_score_indices = tf.cast(tf.reshape(
+ tf.where(tf.greater(scores, thresh)),
+ [-1]), tf.int32)
+ return gather(boxlist, high_score_indices)
+
+
+def non_max_suppression(boxlist, thresh, max_output_size, scope=None):
+ """Non maximum suppression.
+
+ This op greedily selects a subset of detection bounding boxes, pruning
+ away boxes that have high IOU (intersection over union) overlap (> thresh)
+ with already selected boxes. Note that this only works for a single class ---
+ to apply NMS to multi-class predictions, use MultiClassNonMaxSuppression.
+
+ Args:
+ boxlist: BoxList holding N boxes. Must contain a 'scores' field
+ representing detection scores.
+ thresh: scalar threshold
+ max_output_size: maximum number of retained boxes
+ scope: name scope.
+
+ Returns:
+ a BoxList holding M boxes where M <= max_output_size
+ Raises:
+ ValueError: if thresh is not in [0, 1]
+ """
+ with tf.name_scope(scope, 'NonMaxSuppression'):
+ if not 0 <= thresh <= 1.0:
+ raise ValueError('thresh must be between 0 and 1')
+ if not isinstance(boxlist, box_list.BoxList):
+ raise ValueError('boxlist must be a BoxList')
+ if not boxlist.has_field('scores'):
+ raise ValueError('input boxlist must have \'scores\' field')
+ selected_indices = tf.image.non_max_suppression(
+ boxlist.get(), boxlist.get_field('scores'),
+ max_output_size, iou_threshold=thresh)
+ return gather(boxlist, selected_indices)
+
+
+def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
+ """Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
+
+ Args:
+ boxlist_to_copy_to: BoxList to which extra fields are copied.
+ boxlist_to_copy_from: BoxList from which fields are copied.
+
+ Returns:
+ boxlist_to_copy_to with extra fields.
+ """
+ for field in boxlist_to_copy_from.get_extra_fields():
+ boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field))
+ return boxlist_to_copy_to
+
+
+def to_normalized_coordinates(boxlist, height, width,
+ check_range=True, scope=None):
+ """Converts absolute box coordinates to normalized coordinates in [0, 1].
+
+ Usually one uses the dynamic shape of the image or conv-layer tensor:
+ boxlist = box_list_ops.to_normalized_coordinates(boxlist,
+ tf.shape(images)[1],
+ tf.shape(images)[2]),
+
+ This function raises an assertion failed error at graph execution time when
+ the maximum coordinate is smaller than 1.01 (which means that coordinates are
+ already normalized). The value 1.01 is to deal with small rounding errors.
+
+ Args:
+ boxlist: BoxList with coordinates in terms of pixel-locations.
+ height: Maximum value for height of absolute box coordinates.
+ width: Maximum value for width of absolute box coordinates.
+ check_range: If True, checks if the coordinates are normalized or not.
+ scope: name scope.
+
+ Returns:
+ boxlist with normalized coordinates in [0, 1].
+ """
+ with tf.name_scope(scope, 'ToNormalizedCoordinates'):
+ height = tf.cast(height, tf.float32)
+ width = tf.cast(width, tf.float32)
+
+ if check_range:
+ max_val = tf.reduce_max(boxlist.get())
+ max_assert = tf.Assert(tf.greater(max_val, 1.01),
+ ['max value is lower than 1.01: ', max_val])
+ with tf.control_dependencies([max_assert]):
+ width = tf.identity(width)
+
+ return scale(boxlist, 1 / height, 1 / width)
+
+
+def to_absolute_coordinates(boxlist,
+ height,
+ width,
+ check_range=True,
+ maximum_normalized_coordinate=1.1,
+ scope=None):
+ """Converts normalized box coordinates to absolute pixel coordinates.
+
+ This function raises an assertion failed error when the maximum box coordinate
+ value is larger than maximum_normalized_coordinate (in which case coordinates
+ are already absolute).
+
+ Args:
+ boxlist: BoxList with coordinates in range [0, 1].
+ height: Maximum value for height of absolute box coordinates.
+ width: Maximum value for width of absolute box coordinates.
+ check_range: If True, checks if the coordinates are normalized or not.
+ maximum_normalized_coordinate: Maximum coordinate value to be considered
+ as normalized, default to 1.1.
+ scope: name scope.
+
+ Returns:
+ boxlist with absolute coordinates in terms of the image size.
+
+ """
+ with tf.name_scope(scope, 'ToAbsoluteCoordinates'):
+ height = tf.cast(height, tf.float32)
+ width = tf.cast(width, tf.float32)
+
+ # Ensure range of input boxes is correct.
+ if check_range:
+ box_maximum = tf.reduce_max(boxlist.get())
+ max_assert = tf.Assert(
+ tf.greater_equal(maximum_normalized_coordinate, box_maximum),
+ ['maximum box coordinate value is larger '
+ 'than %f: ' % maximum_normalized_coordinate, box_maximum])
+ with tf.control_dependencies([max_assert]):
+ width = tf.identity(width)
+
+ return scale(boxlist, height, width)
+
+
+def refine_boxes_multi_class(pool_boxes,
+ num_classes,
+ nms_iou_thresh,
+ nms_max_detections,
+ voting_iou_thresh=0.5):
+ """Refines a pool of boxes using non max suppression and box voting.
+
+ Box refinement is done independently for each class.
+
+ Args:
+ pool_boxes: (BoxList) A collection of boxes to be refined. pool_boxes must
+ have a rank 1 'scores' field and a rank 1 'classes' field.
+ num_classes: (int scalar) Number of classes.
+ nms_iou_thresh: (float scalar) iou threshold for non max suppression (NMS).
+ nms_max_detections: (int scalar) maximum output size for NMS.
+ voting_iou_thresh: (float scalar) iou threshold for box voting.
+
+ Returns:
+ BoxList of refined boxes.
+
+ Raises:
+ ValueError: if
+ a) nms_iou_thresh or voting_iou_thresh is not in [0, 1].
+ b) pool_boxes is not a BoxList.
+ c) pool_boxes does not have a scores and classes field.
+ """
+ if not 0.0 <= nms_iou_thresh <= 1.0:
+ raise ValueError('nms_iou_thresh must be between 0 and 1')
+ if not 0.0 <= voting_iou_thresh <= 1.0:
+ raise ValueError('voting_iou_thresh must be between 0 and 1')
+ if not isinstance(pool_boxes, box_list.BoxList):
+ raise ValueError('pool_boxes must be a BoxList')
+ if not pool_boxes.has_field('scores'):
+ raise ValueError('pool_boxes must have a \'scores\' field')
+ if not pool_boxes.has_field('classes'):
+ raise ValueError('pool_boxes must have a \'classes\' field')
+
+ refined_boxes = []
+ for i in range(num_classes):
+ boxes_class = filter_field_value_equals(pool_boxes, 'classes', i)
+ refined_boxes_class = refine_boxes(boxes_class, nms_iou_thresh,
+ nms_max_detections, voting_iou_thresh)
+ refined_boxes.append(refined_boxes_class)
+ return sort_by_field(concatenate(refined_boxes), 'scores')
+
+
+def refine_boxes(pool_boxes,
+ nms_iou_thresh,
+ nms_max_detections,
+ voting_iou_thresh=0.5):
+ """Refines a pool of boxes using non max suppression and box voting.
+
+ Args:
+ pool_boxes: (BoxList) A collection of boxes to be refined. pool_boxes must
+ have a rank 1 'scores' field.
+ nms_iou_thresh: (float scalar) iou threshold for non max suppression (NMS).
+ nms_max_detections: (int scalar) maximum output size for NMS.
+ voting_iou_thresh: (float scalar) iou threshold for box voting.
+
+ Returns:
+ BoxList of refined boxes.
+
+ Raises:
+ ValueError: if
+ a) nms_iou_thresh or voting_iou_thresh is not in [0, 1].
+ b) pool_boxes is not a BoxList.
+ c) pool_boxes does not have a scores field.
+ """
+ if not 0.0 <= nms_iou_thresh <= 1.0:
+ raise ValueError('nms_iou_thresh must be between 0 and 1')
+ if not 0.0 <= voting_iou_thresh <= 1.0:
+ raise ValueError('voting_iou_thresh must be between 0 and 1')
+ if not isinstance(pool_boxes, box_list.BoxList):
+ raise ValueError('pool_boxes must be a BoxList')
+ if not pool_boxes.has_field('scores'):
+ raise ValueError('pool_boxes must have a \'scores\' field')
+
+ nms_boxes = non_max_suppression(
+ pool_boxes, nms_iou_thresh, nms_max_detections)
+ return box_voting(nms_boxes, pool_boxes, voting_iou_thresh)
+
+
+def box_voting(selected_boxes, pool_boxes, iou_thresh=0.5):
+ """Performs box voting as described in S. Gidaris and N. Komodakis, ICCV 2015.
+
+ Performs box voting as described in 'Object detection via a multi-region &
+ semantic segmentation-aware CNN model', Gidaris and Komodakis, ICCV 2015. For
+ each box 'B' in selected_boxes, we find the set 'S' of boxes in pool_boxes
+ with iou overlap >= iou_thresh. The location of B is set to the weighted
+ average location of boxes in S (scores are used for weighting). And the score
+ of B is set to the average score of boxes in S.
+
+ Args:
+ selected_boxes: BoxList containing a subset of boxes in pool_boxes. These
+ boxes are usually selected from pool_boxes using non max suppression.
+ pool_boxes: BoxList containing a set of (possibly redundant) boxes.
+ iou_thresh: (float scalar) iou threshold for matching boxes in
+ selected_boxes and pool_boxes.
+
+ Returns:
+ BoxList containing averaged locations and scores for each box in
+ selected_boxes.
+
+ Raises:
+ ValueError: if
+ a) selected_boxes or pool_boxes is not a BoxList.
+ b) if iou_thresh is not in [0, 1].
+ c) pool_boxes does not have a scores field.
+ """
+ if not 0.0 <= iou_thresh <= 1.0:
+ raise ValueError('iou_thresh must be between 0 and 1')
+ if not isinstance(selected_boxes, box_list.BoxList):
+ raise ValueError('selected_boxes must be a BoxList')
+ if not isinstance(pool_boxes, box_list.BoxList):
+ raise ValueError('pool_boxes must be a BoxList')
+ if not pool_boxes.has_field('scores'):
+ raise ValueError('pool_boxes must have a \'scores\' field')
+
+ iou_ = iou(selected_boxes, pool_boxes)
+ match_indicator = tf.cast(tf.greater(iou_, iou_thresh), dtype=tf.float32)
+ num_matches = tf.reduce_sum(match_indicator, 1)
+ # TODO(kbanoop): Handle the case where some boxes in selected_boxes do not
+ # match to any boxes in pool_boxes. For such boxes without any matches, we
+ # should return the original boxes without voting.
+ match_assert = tf.Assert(
+ tf.reduce_all(tf.greater(num_matches, 0)),
+ ['Each box in selected_boxes must match with at least one box '
+ 'in pool_boxes.'])
+
+ scores = tf.expand_dims(pool_boxes.get_field('scores'), 1)
+ scores_assert = tf.Assert(
+ tf.reduce_all(tf.greater_equal(scores, 0)),
+ ['Scores must be non negative.'])
+
+ with tf.control_dependencies([scores_assert, match_assert]):
+ sum_scores = tf.matmul(match_indicator, scores)
+ averaged_scores = tf.reshape(sum_scores, [-1]) / num_matches
+
+ box_locations = tf.matmul(match_indicator,
+ pool_boxes.get() * scores) / sum_scores
+ averaged_boxes = box_list.BoxList(box_locations)
+ _copy_extra_fields(averaged_boxes, selected_boxes)
+ averaged_boxes.add_field('scores', averaged_scores)
+ return averaged_boxes
+
+
+def get_minimal_coverage_box(boxlist,
+ default_box=None,
+ scope=None):
+ """Creates a single bounding box which covers all boxes in the boxlist.
+
+ Args:
+ boxlist: A Boxlist.
+ default_box: A [1, 4] float32 tensor. If no boxes are present in `boxlist`,
+ this default box will be returned. If None, will use a default box of
+ [[0., 0., 1., 1.]].
+ scope: Name scope.
+
+ Returns:
+ A [1, 4] float32 tensor with a bounding box that tightly covers all the
+ boxes in the box list. If the boxlist does not contain any boxes, the
+ default box is returned.
+ """
+ with tf.name_scope(scope, 'CreateCoverageBox'):
+ num_boxes = boxlist.num_boxes()
+
+ def coverage_box(bboxes):
+ y_min, x_min, y_max, x_max = tf.split(
+ value=bboxes, num_or_size_splits=4, axis=1)
+ y_min_coverage = tf.reduce_min(y_min, axis=0)
+ x_min_coverage = tf.reduce_min(x_min, axis=0)
+ y_max_coverage = tf.reduce_max(y_max, axis=0)
+ x_max_coverage = tf.reduce_max(x_max, axis=0)
+ return tf.stack(
+ [y_min_coverage, x_min_coverage, y_max_coverage, x_max_coverage],
+ axis=1)
+
+ default_box = default_box or tf.constant([[0., 0., 1., 1.]])
+ return tf.cond(
+ tf.greater_equal(num_boxes, 1),
+ true_fn=lambda: coverage_box(boxlist.get()),
+ false_fn=lambda: default_box)
+
+
+def sample_boxes_by_jittering(boxlist,
+ num_boxes_to_sample,
+ stddev=0.1,
+ scope=None):
+ """Samples num_boxes_to_sample boxes by jittering around boxlist boxes.
+
+ It is possible that this function might generate boxes with size 0. The larger
+ the stddev, this is more probable. For a small stddev of 0.1 this probability
+ is very small.
+
+ Args:
+ boxlist: A boxlist containing N boxes in normalized coordinates.
+ num_boxes_to_sample: A positive integer containing the number of boxes to
+ sample.
+ stddev: Standard deviation. This is used to draw random offsets for the
+ box corners from a normal distribution. The offset is multiplied by the
+ box size so will be larger in terms of pixels for larger boxes.
+ scope: Name scope.
+
+ Returns:
+ sampled_boxlist: A boxlist containing num_boxes_to_sample boxes in
+ normalized coordinates.
+ """
+ with tf.name_scope(scope, 'SampleBoxesByJittering'):
+ num_boxes = boxlist.num_boxes()
+ box_indices = tf.random_uniform(
+ [num_boxes_to_sample],
+ minval=0,
+ maxval=num_boxes,
+ dtype=tf.int32)
+ sampled_boxes = tf.gather(boxlist.get(), box_indices)
+ sampled_boxes_height = sampled_boxes[:, 2] - sampled_boxes[:, 0]
+ sampled_boxes_width = sampled_boxes[:, 3] - sampled_boxes[:, 1]
+ rand_miny_gaussian = tf.random_normal([num_boxes_to_sample], stddev=stddev)
+ rand_minx_gaussian = tf.random_normal([num_boxes_to_sample], stddev=stddev)
+ rand_maxy_gaussian = tf.random_normal([num_boxes_to_sample], stddev=stddev)
+ rand_maxx_gaussian = tf.random_normal([num_boxes_to_sample], stddev=stddev)
+ miny = rand_miny_gaussian * sampled_boxes_height + sampled_boxes[:, 0]
+ minx = rand_minx_gaussian * sampled_boxes_width + sampled_boxes[:, 1]
+ maxy = rand_maxy_gaussian * sampled_boxes_height + sampled_boxes[:, 2]
+ maxx = rand_maxx_gaussian * sampled_boxes_width + sampled_boxes[:, 3]
+ maxy = tf.maximum(miny, maxy)
+ maxx = tf.maximum(minx, maxx)
+ sampled_boxes = tf.stack([miny, minx, maxy, maxx], axis=1)
+ sampled_boxes = tf.maximum(tf.minimum(sampled_boxes, 1.0), 0.0)
+ return box_list.BoxList(sampled_boxes)
diff --git a/models/official/vision/detection/utils/object_detection/faster_rcnn_box_coder.py b/models/official/vision/detection/utils/object_detection/faster_rcnn_box_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..235df4ede474e89687a17413e81e60aa21772e23
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/faster_rcnn_box_coder.py
@@ -0,0 +1,118 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Faster RCNN box coder.
+
+Faster RCNN box coder follows the coding schema described below:
+ ty = (y - ya) / ha
+ tx = (x - xa) / wa
+ th = log(h / ha)
+ tw = log(w / wa)
+ where x, y, w, h denote the box's center coordinates, width and height
+ respectively. Similarly, xa, ya, wa, ha denote the anchor's center
+ coordinates, width and height. tx, ty, tw and th denote the anchor-encoded
+ center, width and height respectively.
+
+ See http://arxiv.org/abs/1506.01497 for details.
+"""
+
+import tensorflow as tf
+
+from official.vision.detection.utils.object_detection import box_coder
+from official.vision.detection.utils.object_detection import box_list
+
+EPSILON = 1e-8
+
+
+class FasterRcnnBoxCoder(box_coder.BoxCoder):
+ """Faster RCNN box coder."""
+
+ def __init__(self, scale_factors=None):
+ """Constructor for FasterRcnnBoxCoder.
+
+ Args:
+ scale_factors: List of 4 positive scalars to scale ty, tx, th and tw.
+ If set to None, does not perform scaling. For Faster RCNN,
+ the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0].
+ """
+ if scale_factors:
+ assert len(scale_factors) == 4
+ for scalar in scale_factors:
+ assert scalar > 0
+ self._scale_factors = scale_factors
+
+ @property
+ def code_size(self):
+ return 4
+
+ def _encode(self, boxes, anchors):
+ """Encode a box collection with respect to anchor collection.
+
+ Args:
+ boxes: BoxList holding N boxes to be encoded.
+ anchors: BoxList of anchors.
+
+ Returns:
+ a tensor representing N anchor-encoded boxes of the format
+ [ty, tx, th, tw].
+ """
+ # Convert anchors to the center coordinate representation.
+ ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
+ ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes()
+ # Avoid NaN in division and log below.
+ ha += EPSILON
+ wa += EPSILON
+ h += EPSILON
+ w += EPSILON
+
+ tx = (xcenter - xcenter_a) / wa
+ ty = (ycenter - ycenter_a) / ha
+ tw = tf.math.log(w / wa)
+ th = tf.math.log(h / ha)
+ # Scales location targets as used in paper for joint training.
+ if self._scale_factors:
+ ty *= self._scale_factors[0]
+ tx *= self._scale_factors[1]
+ th *= self._scale_factors[2]
+ tw *= self._scale_factors[3]
+ return tf.transpose(a=tf.stack([ty, tx, th, tw]))
+
+ def _decode(self, rel_codes, anchors):
+ """Decode relative codes to boxes.
+
+ Args:
+ rel_codes: a tensor representing N anchor-encoded boxes.
+ anchors: BoxList of anchors.
+
+ Returns:
+ boxes: BoxList holding N bounding boxes.
+ """
+ ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
+
+ ty, tx, th, tw = tf.unstack(tf.transpose(a=rel_codes))
+ if self._scale_factors:
+ ty /= self._scale_factors[0]
+ tx /= self._scale_factors[1]
+ th /= self._scale_factors[2]
+ tw /= self._scale_factors[3]
+ w = tf.exp(tw) * wa
+ h = tf.exp(th) * ha
+ ycenter = ty * ha + ycenter_a
+ xcenter = tx * wa + xcenter_a
+ ymin = ycenter - h / 2.
+ xmin = xcenter - w / 2.
+ ymax = ycenter + h / 2.
+ xmax = xcenter + w / 2.
+ return box_list.BoxList(tf.transpose(a=tf.stack([ymin, xmin, ymax, xmax])))
diff --git a/models/official/vision/detection/utils/object_detection/matcher.py b/models/official/vision/detection/utils/object_detection/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a025d5e7118ee20f136c8a31b4c183de11f1e7f
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/matcher.py
@@ -0,0 +1,243 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Matcher interface and Match class.
+
+This module defines the Matcher interface and the Match object. The job of the
+matcher is to match row and column indices based on the similarity matrix and
+other optional parameters. Each column is matched to at most one row. There
+are three possibilities for the matching:
+
+1) match: A column matches a row.
+2) no_match: A column does not match any row.
+3) ignore: A column that is neither 'match' nor no_match.
+
+The ignore case is regularly encountered in object detection: when an anchor has
+a relatively small overlap with a ground-truth box, one neither wants to
+consider this box a positive example (match) nor a negative example (no match).
+
+The Match class is used to store the match results and it provides simple apis
+to query the results.
+"""
+from abc import ABCMeta
+from abc import abstractmethod
+
+import tensorflow as tf
+
+
+class Match(object):
+ """Class to store results from the matcher.
+
+ This class is used to store the results from the matcher. It provides
+ convenient methods to query the matching results.
+ """
+
+ def __init__(self, match_results):
+ """Constructs a Match object.
+
+ Args:
+ match_results: Integer tensor of shape [N] with (1) match_results[i]>=0,
+ meaning that column i is matched with row match_results[i].
+ (2) match_results[i]=-1, meaning that column i is not matched.
+ (3) match_results[i]=-2, meaning that column i is ignored.
+
+ Raises:
+ ValueError: if match_results does not have rank 1 or is not an
+ integer int32 scalar tensor
+ """
+ if match_results.shape.ndims != 1:
+ raise ValueError('match_results should have rank 1')
+ if match_results.dtype != tf.int32:
+ raise ValueError('match_results should be an int32 or int64 scalar '
+ 'tensor')
+ self._match_results = match_results
+
+ @property
+ def match_results(self):
+ """The accessor for match results.
+
+ Returns:
+ the tensor which encodes the match results.
+ """
+ return self._match_results
+
+ def matched_column_indices(self):
+ """Returns column indices that match to some row.
+
+ The indices returned by this op are always sorted in increasing order.
+
+ Returns:
+ column_indices: int32 tensor of shape [K] with column indices.
+ """
+ return self._reshape_and_cast(tf.where(tf.greater(self._match_results, -1)))
+
+ def matched_column_indicator(self):
+ """Returns column indices that are matched.
+
+ Returns:
+ column_indices: int32 tensor of shape [K] with column indices.
+ """
+ return tf.greater_equal(self._match_results, 0)
+
+ def num_matched_columns(self):
+ """Returns number (int32 scalar tensor) of matched columns."""
+ return tf.size(input=self.matched_column_indices())
+
+ def unmatched_column_indices(self):
+ """Returns column indices that do not match any row.
+
+ The indices returned by this op are always sorted in increasing order.
+
+ Returns:
+ column_indices: int32 tensor of shape [K] with column indices.
+ """
+ return self._reshape_and_cast(tf.where(tf.equal(self._match_results, -1)))
+
+ def unmatched_column_indicator(self):
+ """Returns column indices that are unmatched.
+
+ Returns:
+ column_indices: int32 tensor of shape [K] with column indices.
+ """
+ return tf.equal(self._match_results, -1)
+
+ def num_unmatched_columns(self):
+ """Returns number (int32 scalar tensor) of unmatched columns."""
+ return tf.size(input=self.unmatched_column_indices())
+
+ def ignored_column_indices(self):
+ """Returns column indices that are ignored (neither Matched nor Unmatched).
+
+ The indices returned by this op are always sorted in increasing order.
+
+ Returns:
+ column_indices: int32 tensor of shape [K] with column indices.
+ """
+ return self._reshape_and_cast(tf.where(self.ignored_column_indicator()))
+
+ def ignored_column_indicator(self):
+ """Returns boolean column indicator where True means the colum is ignored.
+
+ Returns:
+ column_indicator: boolean vector which is True for all ignored column
+ indices.
+ """
+ return tf.equal(self._match_results, -2)
+
+ def num_ignored_columns(self):
+ """Returns number (int32 scalar tensor) of matched columns."""
+ return tf.size(input=self.ignored_column_indices())
+
+ def unmatched_or_ignored_column_indices(self):
+ """Returns column indices that are unmatched or ignored.
+
+ The indices returned by this op are always sorted in increasing order.
+
+ Returns:
+ column_indices: int32 tensor of shape [K] with column indices.
+ """
+ return self._reshape_and_cast(tf.where(tf.greater(0, self._match_results)))
+
+ def matched_row_indices(self):
+ """Returns row indices that match some column.
+
+ The indices returned by this op are ordered so as to be in correspondence
+ with the output of matched_column_indicator(). For example if
+ self.matched_column_indicator() is [0,2], and self.matched_row_indices() is
+ [7, 3], then we know that column 0 was matched to row 7 and column 2 was
+ matched to row 3.
+
+ Returns:
+ row_indices: int32 tensor of shape [K] with row indices.
+ """
+ return self._reshape_and_cast(
+ tf.gather(self._match_results, self.matched_column_indices()))
+
+ def _reshape_and_cast(self, t):
+ return tf.cast(tf.reshape(t, [-1]), tf.int32)
+
+ def gather_based_on_match(self, input_tensor, unmatched_value,
+ ignored_value):
+ """Gathers elements from `input_tensor` based on match results.
+
+ For columns that are matched to a row, gathered_tensor[col] is set to
+ input_tensor[match_results[col]]. For columns that are unmatched,
+ gathered_tensor[col] is set to unmatched_value. Finally, for columns that
+ are ignored gathered_tensor[col] is set to ignored_value.
+
+ Note that the input_tensor.shape[1:] must match with unmatched_value.shape
+ and ignored_value.shape
+
+ Args:
+ input_tensor: Tensor to gather values from.
+ unmatched_value: Constant tensor value for unmatched columns.
+ ignored_value: Constant tensor value for ignored columns.
+
+ Returns:
+ gathered_tensor: A tensor containing values gathered from input_tensor.
+ The shape of the gathered tensor is [match_results.shape[0]] +
+ input_tensor.shape[1:].
+ """
+ input_tensor = tf.concat([tf.stack([ignored_value, unmatched_value]),
+ input_tensor], axis=0)
+ gather_indices = tf.maximum(self.match_results + 2, 0)
+ gathered_tensor = tf.gather(input_tensor, gather_indices)
+ return gathered_tensor
+
+
+class Matcher(object):
+ """Abstract base class for matcher.
+ """
+ __metaclass__ = ABCMeta
+
+ def match(self, similarity_matrix, scope=None, **params):
+ """Computes matches among row and column indices and returns the result.
+
+ Computes matches among the row and column indices based on the similarity
+ matrix and optional arguments.
+
+ Args:
+ similarity_matrix: Float tensor of shape [N, M] with pairwise similarity
+ where higher value means more similar.
+ scope: Op scope name. Defaults to 'Match' if None.
+ **params: Additional keyword arguments for specific implementations of
+ the Matcher.
+
+ Returns:
+ A Match object with the results of matching.
+ """
+ if not scope:
+ scope = 'Match'
+ with tf.name_scope(scope) as scope:
+ return Match(self._match(similarity_matrix, **params))
+
+ @abstractmethod
+ def _match(self, similarity_matrix, **params):
+ """Method to be overridden by implementations.
+
+ Args:
+ similarity_matrix: Float tensor of shape [N, M] with pairwise similarity
+ where higher value means more similar.
+ **params: Additional keyword arguments for specific implementations of
+ the Matcher.
+
+ Returns:
+ match_results: Integer tensor of shape [M]: match_results[i]>=0 means
+ that column i is matched to row match_results[i], match_results[i]=-1
+ means that the column is not matched. match_results[i]=-2 means that
+ the column is ignored (usually this happens when there is a very weak
+ match which one neither wants as positive nor negative example).
+ """
+ pass
diff --git a/models/official/vision/detection/utils/object_detection/minibatch_sampler.py b/models/official/vision/detection/utils/object_detection/minibatch_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f529ab5976ca56f014788c1263e5887fde0444
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/minibatch_sampler.py
@@ -0,0 +1,93 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Base minibatch sampler module.
+
+The job of the minibatch_sampler is to subsample a minibatch based on some
+criterion.
+
+The main function call is:
+ subsample(indicator, batch_size, **params).
+Indicator is a 1d boolean tensor where True denotes which examples can be
+sampled. It returns a boolean indicator where True denotes an example has been
+sampled..
+
+Subclasses should implement the Subsample function and can make use of the
+@staticmethod SubsampleIndicator.
+
+This is originally implemented in TensorFlow Object Detection API.
+"""
+
+from abc import ABCMeta
+from abc import abstractmethod
+
+import tensorflow as tf
+
+from official.vision.detection.utils.object_detection import ops
+
+
+class MinibatchSampler(object):
+ """Abstract base class for subsampling minibatches."""
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ """Constructs a minibatch sampler."""
+ pass
+
+ @abstractmethod
+ def subsample(self, indicator, batch_size, **params):
+ """Returns subsample of entries in indicator.
+
+ Args:
+ indicator: boolean tensor of shape [N] whose True entries can be sampled.
+ batch_size: desired batch size.
+ **params: additional keyword arguments for specific implementations of
+ the MinibatchSampler.
+
+ Returns:
+ sample_indicator: boolean tensor of shape [N] whose True entries have been
+ sampled. If sum(indicator) >= batch_size, sum(is_sampled) = batch_size
+ """
+ pass
+
+ @staticmethod
+ def subsample_indicator(indicator, num_samples):
+ """Subsample indicator vector.
+
+ Given a boolean indicator vector with M elements set to `True`, the function
+ assigns all but `num_samples` of these previously `True` elements to
+ `False`. If `num_samples` is greater than M, the original indicator vector
+ is returned.
+
+ Args:
+ indicator: a 1-dimensional boolean tensor indicating which elements
+ are allowed to be sampled and which are not.
+ num_samples: int32 scalar tensor
+
+ Returns:
+ a boolean tensor with the same shape as input (indicator) tensor
+ """
+ indices = tf.where(indicator)
+ indices = tf.random.shuffle(indices)
+ indices = tf.reshape(indices, [-1])
+
+ num_samples = tf.minimum(tf.size(input=indices), num_samples)
+ selected_indices = tf.slice(indices, [0], tf.reshape(num_samples, [1]))
+
+ selected_indicator = ops.indices_to_dense_vector(
+ selected_indices,
+ tf.shape(input=indicator)[0])
+
+ return tf.equal(selected_indicator, 1)
diff --git a/models/official/vision/detection/utils/object_detection/ops.py b/models/official/vision/detection/utils/object_detection/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbfc1ae9353604986ad3f1f06a4f8e2e72bb5ca0
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/ops.py
@@ -0,0 +1,82 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A module for helper tensorflow ops.
+
+This is originally implemented in TensorFlow Object Detection API.
+"""
+
+import tensorflow as tf
+
+from official.vision.detection.utils.object_detection import shape_utils
+
+
+def indices_to_dense_vector(indices,
+ size,
+ indices_value=1.,
+ default_value=0,
+ dtype=tf.float32):
+ """Creates dense vector with indices set to specific value and rest to zeros.
+
+ This function exists because it is unclear if it is safe to use
+ tf.sparse_to_dense(indices, [size], 1, validate_indices=False)
+ with indices which are not ordered.
+ This function accepts a dynamic size (e.g. tf.shape(tensor)[0])
+
+ Args:
+ indices: 1d Tensor with integer indices which are to be set to
+ indices_values.
+ size: scalar with size (integer) of output Tensor.
+ indices_value: values of elements specified by indices in the output vector
+ default_value: values of other elements in the output vector.
+ dtype: data type.
+
+ Returns:
+ dense 1D Tensor of shape [size] with indices set to indices_values and the
+ rest set to default_value.
+ """
+ size = tf.cast(size, dtype=tf.int32)
+ zeros = tf.ones([size], dtype=dtype) * default_value
+ values = tf.ones_like(indices, dtype=dtype) * indices_value
+
+ return tf.dynamic_stitch(
+ [tf.range(size), tf.cast(indices, dtype=tf.int32)], [zeros, values])
+
+
+def matmul_gather_on_zeroth_axis(params, indices, scope=None):
+ """Matrix multiplication based implementation of tf.gather on zeroth axis.
+
+ TODO(rathodv, jonathanhuang): enable sparse matmul option.
+
+ Args:
+ params: A float32 Tensor. The tensor from which to gather values.
+ Must be at least rank 1.
+ indices: A Tensor. Must be one of the following types: int32, int64.
+ Must be in range [0, params.shape[0])
+ scope: A name for the operation (optional).
+
+ Returns:
+ A Tensor. Has the same type as params. Values from params gathered
+ from indices given by indices, with shape indices.shape + params.shape[1:].
+ """
+ scope = scope or 'MatMulGather'
+ with tf.name_scope(scope):
+ params_shape = shape_utils.combined_static_and_dynamic_shape(params)
+ indices_shape = shape_utils.combined_static_and_dynamic_shape(indices)
+ params2d = tf.reshape(params, [params_shape[0], -1])
+ indicator_matrix = tf.one_hot(indices, params_shape[0])
+ gathered_result_flattened = tf.matmul(indicator_matrix, params2d)
+ return tf.reshape(gathered_result_flattened,
+ tf.stack(indices_shape + params_shape[1:]))
diff --git a/models/official/vision/detection/utils/object_detection/preprocessor.py b/models/official/vision/detection/utils/object_detection/preprocessor.py
new file mode 100644
index 0000000000000000000000000000000000000000..55da5d2dfafda816be7dcb2d334a3a0711e0b699
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/preprocessor.py
@@ -0,0 +1,525 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Preprocess images and bounding boxes for detection.
+
+We perform two sets of operations in preprocessing stage:
+(a) operations that are applied to both training and testing data,
+(b) operations that are applied only to training data for the purpose of
+ data augmentation.
+
+A preprocessing function receives a set of inputs,
+e.g. an image and bounding boxes,
+performs an operation on them, and returns them.
+Some examples are: randomly cropping the image, randomly mirroring the image,
+ randomly changing the brightness, contrast, hue and
+ randomly jittering the bounding boxes.
+
+The image is a rank 4 tensor: [1, height, width, channels] with
+dtype=tf.float32. The groundtruth_boxes is a rank 2 tensor: [N, 4] where
+in each row there is a box with [ymin xmin ymax xmax].
+Boxes are in normalized coordinates meaning
+their coordinate values range in [0, 1]
+
+Important Note: In tensor_dict, images is a rank 4 tensor, but preprocessing
+functions receive a rank 3 tensor for processing the image. Thus, inside the
+preprocess function we squeeze the image to become a rank 3 tensor and then
+we pass it to the functions. At the end of the preprocess we expand the image
+back to rank 4.
+"""
+
+import tensorflow as tf
+
+import numpy as np
+
+from official.vision.detection.utils.object_detection import box_list
+
+
+def _flip_boxes_left_right(boxes):
+ """Left-right flip the boxes.
+
+ Args:
+ boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
+ Boxes are in normalized form meaning their coordinates vary
+ between [0, 1].
+ Each row is in the form of [ymin, xmin, ymax, xmax].
+
+ Returns:
+ Flipped boxes.
+ """
+ ymin, xmin, ymax, xmax = tf.split(value=boxes, num_or_size_splits=4, axis=1)
+ flipped_xmin = tf.subtract(1.0, xmax)
+ flipped_xmax = tf.subtract(1.0, xmin)
+ flipped_boxes = tf.concat([ymin, flipped_xmin, ymax, flipped_xmax], 1)
+ return flipped_boxes
+
+
+def _flip_masks_left_right(masks):
+ """Left-right flip masks.
+
+ Args:
+ masks: rank 3 float32 tensor with shape
+ [num_instances, height, width] representing instance masks.
+
+ Returns:
+ flipped masks: rank 3 float32 tensor with shape
+ [num_instances, height, width] representing instance masks.
+ """
+ return masks[:, :, ::-1]
+
+
+def keypoint_flip_horizontal(keypoints, flip_point, flip_permutation,
+ scope=None):
+ """Flips the keypoints horizontally around the flip_point.
+
+ This operation flips the x coordinate for each keypoint around the flip_point
+ and also permutes the keypoints in a manner specified by flip_permutation.
+
+ Args:
+ keypoints: a tensor of shape [num_instances, num_keypoints, 2]
+ flip_point: (float) scalar tensor representing the x coordinate to flip the
+ keypoints around.
+ flip_permutation: rank 1 int32 tensor containing the keypoint flip
+ permutation. This specifies the mapping from original keypoint indices
+ to the flipped keypoint indices. This is used primarily for keypoints
+ that are not reflection invariant. E.g. Suppose there are 3 keypoints
+ representing ['head', 'right_eye', 'left_eye'], then a logical choice for
+ flip_permutation might be [0, 2, 1] since we want to swap the 'left_eye'
+ and 'right_eye' after a horizontal flip.
+ scope: name scope.
+
+ Returns:
+ new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
+ """
+ if not scope:
+ scope = 'FlipHorizontal'
+ with tf.name_scope(scope):
+ keypoints = tf.transpose(a=keypoints, perm=[1, 0, 2])
+ keypoints = tf.gather(keypoints, flip_permutation)
+ v, u = tf.split(value=keypoints, num_or_size_splits=2, axis=2)
+ u = flip_point * 2.0 - u
+ new_keypoints = tf.concat([v, u], 2)
+ new_keypoints = tf.transpose(a=new_keypoints, perm=[1, 0, 2])
+ return new_keypoints
+
+
+def keypoint_change_coordinate_frame(keypoints, window, scope=None):
+ """Changes coordinate frame of the keypoints to be relative to window's frame.
+
+ Given a window of the form [y_min, x_min, y_max, x_max], changes keypoint
+ coordinates from keypoints of shape [num_instances, num_keypoints, 2]
+ to be relative to this window.
+
+ An example use case is data augmentation: where we are given groundtruth
+ keypoints and would like to randomly crop the image to some window. In this
+ case we need to change the coordinate frame of each groundtruth keypoint to be
+ relative to this new window.
+
+ Args:
+ keypoints: a tensor of shape [num_instances, num_keypoints, 2]
+ window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max]
+ window we should change the coordinate frame to.
+ scope: name scope.
+
+ Returns:
+ new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
+ """
+ if not scope:
+ scope = 'ChangeCoordinateFrame'
+ with tf.name_scope(scope):
+ win_height = window[2] - window[0]
+ win_width = window[3] - window[1]
+ new_keypoints = box_list_ops.scale(keypoints - [window[0], window[1]],
+ 1.0 / win_height, 1.0 / win_width)
+ return new_keypoints
+
+
+def keypoint_prune_outside_window(keypoints, window, scope=None):
+ """Prunes keypoints that fall outside a given window.
+
+ This function replaces keypoints that fall outside the given window with nan.
+ See also clip_to_window which clips any keypoints that fall outside the given
+ window.
+
+ Args:
+ keypoints: a tensor of shape [num_instances, num_keypoints, 2]
+ window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max]
+ window outside of which the op should prune the keypoints.
+ scope: name scope.
+
+ Returns:
+ new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
+ """
+ if not scope:
+ scope = 'PruneOutsideWindow'
+ with tf.name_scope(scope):
+ y, x = tf.split(value=keypoints, num_or_size_splits=2, axis=2)
+ win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
+
+ valid_indices = tf.logical_and(
+ tf.logical_and(y >= win_y_min, y <= win_y_max),
+ tf.logical_and(x >= win_x_min, x <= win_x_max))
+
+ new_y = tf.where(valid_indices, y, np.nan * tf.ones_like(y))
+ new_x = tf.where(valid_indices, x, np.nan * tf.ones_like(x))
+ new_keypoints = tf.concat([new_y, new_x], 2)
+
+ return new_keypoints
+
+
+def random_horizontal_flip(image,
+ boxes=None,
+ masks=None,
+ keypoints=None,
+ keypoint_flip_permutation=None,
+ seed=None):
+ """Randomly flips the image and detections horizontally.
+
+ The probability of flipping the image is 50%.
+
+ Args:
+ image: rank 3 float32 tensor with shape [height, width, channels].
+ boxes: (optional) rank 2 float32 tensor with shape [N, 4]
+ containing the bounding boxes.
+ Boxes are in normalized form meaning their coordinates vary
+ between [0, 1].
+ Each row is in the form of [ymin, xmin, ymax, xmax].
+ masks: (optional) rank 3 float32 tensor with shape
+ [num_instances, height, width] containing instance masks. The masks
+ are of the same height, width as the input `image`.
+ keypoints: (optional) rank 3 float32 tensor with shape
+ [num_instances, num_keypoints, 2]. The keypoints are in y-x
+ normalized coordinates.
+ keypoint_flip_permutation: rank 1 int32 tensor containing the keypoint flip
+ permutation.
+ seed: random seed
+
+ Returns:
+ image: image which is the same shape as input image.
+
+ If boxes, masks, keypoints, and keypoint_flip_permutation are not None,
+ the function also returns the following tensors.
+
+ boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
+ Boxes are in normalized form meaning their coordinates vary
+ between [0, 1].
+ masks: rank 3 float32 tensor with shape [num_instances, height, width]
+ containing instance masks.
+ keypoints: rank 3 float32 tensor with shape
+ [num_instances, num_keypoints, 2]
+
+ Raises:
+ ValueError: if keypoints are provided but keypoint_flip_permutation is not.
+ """
+
+ def _flip_image(image):
+ # flip image
+ image_flipped = tf.image.flip_left_right(image)
+ return image_flipped
+
+ if keypoints is not None and keypoint_flip_permutation is None:
+ raise ValueError(
+ 'keypoints are provided but keypoints_flip_permutation is not provided')
+
+ with tf.name_scope('RandomHorizontalFlip'):
+ result = []
+ # random variable defining whether to do flip or not
+ do_a_flip_random = tf.greater(tf.random.uniform([], seed=seed), 0.5)
+
+ # flip image
+ image = tf.cond(
+ pred=do_a_flip_random,
+ true_fn=lambda: _flip_image(image),
+ false_fn=lambda: image)
+ result.append(image)
+
+ # flip boxes
+ if boxes is not None:
+ boxes = tf.cond(
+ pred=do_a_flip_random,
+ true_fn=lambda: _flip_boxes_left_right(boxes),
+ false_fn=lambda: boxes)
+ result.append(boxes)
+
+ # flip masks
+ if masks is not None:
+ masks = tf.cond(
+ pred=do_a_flip_random,
+ true_fn=lambda: _flip_masks_left_right(masks),
+ false_fn=lambda: masks)
+ result.append(masks)
+
+ # flip keypoints
+ if keypoints is not None and keypoint_flip_permutation is not None:
+ permutation = keypoint_flip_permutation
+ keypoints = tf.cond(
+ pred=do_a_flip_random,
+ true_fn=lambda: keypoint_flip_horizontal(keypoints, 0.5, permutation),
+ false_fn=lambda: keypoints)
+ result.append(keypoints)
+
+ return tuple(result)
+
+
+def _compute_new_static_size(image, min_dimension, max_dimension):
+ """Compute new static shape for resize_to_range method."""
+ image_shape = image.get_shape().as_list()
+ orig_height = image_shape[0]
+ orig_width = image_shape[1]
+ num_channels = image_shape[2]
+ orig_min_dim = min(orig_height, orig_width)
+ # Calculates the larger of the possible sizes
+ large_scale_factor = min_dimension / float(orig_min_dim)
+ # Scaling orig_(height|width) by large_scale_factor will make the smaller
+ # dimension equal to min_dimension, save for floating point rounding errors.
+ # For reasonably-sized images, taking the nearest integer will reliably
+ # eliminate this error.
+ large_height = int(round(orig_height * large_scale_factor))
+ large_width = int(round(orig_width * large_scale_factor))
+ large_size = [large_height, large_width]
+ if max_dimension:
+ # Calculates the smaller of the possible sizes, use that if the larger
+ # is too big.
+ orig_max_dim = max(orig_height, orig_width)
+ small_scale_factor = max_dimension / float(orig_max_dim)
+ # Scaling orig_(height|width) by small_scale_factor will make the larger
+ # dimension equal to max_dimension, save for floating point rounding
+ # errors. For reasonably-sized images, taking the nearest integer will
+ # reliably eliminate this error.
+ small_height = int(round(orig_height * small_scale_factor))
+ small_width = int(round(orig_width * small_scale_factor))
+ small_size = [small_height, small_width]
+ new_size = large_size
+ if max(large_size) > max_dimension:
+ new_size = small_size
+ else:
+ new_size = large_size
+ return tf.constant(new_size + [num_channels])
+
+
+def _compute_new_dynamic_size(image, min_dimension, max_dimension):
+ """Compute new dynamic shape for resize_to_range method."""
+ image_shape = tf.shape(input=image)
+ orig_height = tf.cast(image_shape[0], dtype=tf.float32)
+ orig_width = tf.cast(image_shape[1], dtype=tf.float32)
+ num_channels = image_shape[2]
+ orig_min_dim = tf.minimum(orig_height, orig_width)
+ # Calculates the larger of the possible sizes
+ min_dimension = tf.constant(min_dimension, dtype=tf.float32)
+ large_scale_factor = min_dimension / orig_min_dim
+ # Scaling orig_(height|width) by large_scale_factor will make the smaller
+ # dimension equal to min_dimension, save for floating point rounding errors.
+ # For reasonably-sized images, taking the nearest integer will reliably
+ # eliminate this error.
+ large_height = tf.cast(
+ tf.round(orig_height * large_scale_factor), dtype=tf.int32)
+ large_width = tf.cast(
+ tf.round(orig_width * large_scale_factor), dtype=tf.int32)
+ large_size = tf.stack([large_height, large_width])
+ if max_dimension:
+ # Calculates the smaller of the possible sizes, use that if the larger
+ # is too big.
+ orig_max_dim = tf.maximum(orig_height, orig_width)
+ max_dimension = tf.constant(max_dimension, dtype=tf.float32)
+ small_scale_factor = max_dimension / orig_max_dim
+ # Scaling orig_(height|width) by small_scale_factor will make the larger
+ # dimension equal to max_dimension, save for floating point rounding
+ # errors. For reasonably-sized images, taking the nearest integer will
+ # reliably eliminate this error.
+ small_height = tf.cast(
+ tf.round(orig_height * small_scale_factor), dtype=tf.int32)
+ small_width = tf.cast(
+ tf.round(orig_width * small_scale_factor), dtype=tf.int32)
+ small_size = tf.stack([small_height, small_width])
+ new_size = tf.cond(
+ pred=tf.cast(tf.reduce_max(input_tensor=large_size), dtype=tf.float32) >
+ max_dimension,
+ true_fn=lambda: small_size,
+ false_fn=lambda: large_size)
+ else:
+ new_size = large_size
+ return tf.stack(tf.unstack(new_size) + [num_channels])
+
+
+def resize_to_range(image,
+ masks=None,
+ min_dimension=None,
+ max_dimension=None,
+ method=tf.image.ResizeMethod.BILINEAR,
+ align_corners=False,
+ pad_to_max_dimension=False):
+ """Resizes an image so its dimensions are within the provided value.
+
+ The output size can be described by two cases:
+ 1. If the image can be rescaled so its minimum dimension is equal to the
+ provided value without the other dimension exceeding max_dimension,
+ then do so.
+ 2. Otherwise, resize so the largest dimension is equal to max_dimension.
+
+ Args:
+ image: A 3D tensor of shape [height, width, channels]
+ masks: (optional) rank 3 float32 tensor with shape
+ [num_instances, height, width] containing instance masks.
+ min_dimension: (optional) (scalar) desired size of the smaller image
+ dimension.
+ max_dimension: (optional) (scalar) maximum allowed size
+ of the larger image dimension.
+ method: (optional) interpolation method used in resizing. Defaults to
+ BILINEAR.
+ align_corners: bool. If true, exactly align all 4 corners of the input
+ and output. Defaults to False.
+ pad_to_max_dimension: Whether to resize the image and pad it with zeros
+ so the resulting image is of the spatial size
+ [max_dimension, max_dimension]. If masks are included they are padded
+ similarly.
+
+ Returns:
+ Note that the position of the resized_image_shape changes based on whether
+ masks are present.
+ resized_image: A 3D tensor of shape [new_height, new_width, channels],
+ where the image has been resized (with bilinear interpolation) so that
+ min(new_height, new_width) == min_dimension or
+ max(new_height, new_width) == max_dimension.
+ resized_masks: If masks is not None, also outputs masks. A 3D tensor of
+ shape [num_instances, new_height, new_width].
+ resized_image_shape: A 1D tensor of shape [3] containing shape of the
+ resized image.
+
+ Raises:
+ ValueError: if the image is not a 3D tensor.
+ """
+ if len(image.get_shape()) != 3:
+ raise ValueError('Image should be 3D tensor')
+
+ with tf.name_scope('ResizeToRange'):
+ if image.get_shape().is_fully_defined():
+ new_size = _compute_new_static_size(image, min_dimension, max_dimension)
+ else:
+ new_size = _compute_new_dynamic_size(image, min_dimension, max_dimension)
+ new_image = tf.image.resize(image, new_size[:-1], method=method)
+
+ if pad_to_max_dimension:
+ new_image = tf.image.pad_to_bounding_box(
+ new_image, 0, 0, max_dimension, max_dimension)
+
+ result = [new_image]
+ if masks is not None:
+ new_masks = tf.expand_dims(masks, 3)
+ new_masks = tf.image.resize(
+ new_masks,
+ new_size[:-1],
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
+ new_masks = tf.squeeze(new_masks, 3)
+ if pad_to_max_dimension:
+ new_masks = tf.image.pad_to_bounding_box(
+ new_masks, 0, 0, max_dimension, max_dimension)
+ result.append(new_masks)
+
+ result.append(new_size)
+ return result
+
+
+def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
+ """Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
+
+ Args:
+ boxlist_to_copy_to: BoxList to which extra fields are copied.
+ boxlist_to_copy_from: BoxList from which fields are copied.
+
+ Returns:
+ boxlist_to_copy_to with extra fields.
+ """
+ for field in boxlist_to_copy_from.get_extra_fields():
+ boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field))
+ return boxlist_to_copy_to
+
+
+def box_list_scale(boxlist, y_scale, x_scale, scope=None):
+ """scale box coordinates in x and y dimensions.
+
+ Args:
+ boxlist: BoxList holding N boxes
+ y_scale: (float) scalar tensor
+ x_scale: (float) scalar tensor
+ scope: name scope.
+
+ Returns:
+ boxlist: BoxList holding N boxes
+ """
+ if not scope:
+ scope = 'Scale'
+ with tf.name_scope(scope):
+ y_scale = tf.cast(y_scale, tf.float32)
+ x_scale = tf.cast(x_scale, tf.float32)
+ y_min, x_min, y_max, x_max = tf.split(
+ value=boxlist.get(), num_or_size_splits=4, axis=1)
+ y_min = y_scale * y_min
+ y_max = y_scale * y_max
+ x_min = x_scale * x_min
+ x_max = x_scale * x_max
+ scaled_boxlist = box_list.BoxList(
+ tf.concat([y_min, x_min, y_max, x_max], 1))
+ return _copy_extra_fields(scaled_boxlist, boxlist)
+
+
+def keypoint_scale(keypoints, y_scale, x_scale, scope=None):
+ """Scales keypoint coordinates in x and y dimensions.
+
+ Args:
+ keypoints: a tensor of shape [num_instances, num_keypoints, 2]
+ y_scale: (float) scalar tensor
+ x_scale: (float) scalar tensor
+ scope: name scope.
+
+ Returns:
+ new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
+ """
+ if not scope:
+ scope = 'Scale'
+ with tf.name_scope(scope):
+ y_scale = tf.cast(y_scale, tf.float32)
+ x_scale = tf.cast(x_scale, tf.float32)
+ new_keypoints = keypoints * [[[y_scale, x_scale]]]
+ return new_keypoints
+
+
+def scale_boxes_to_pixel_coordinates(image, boxes, keypoints=None):
+ """Scales boxes from normalized to pixel coordinates.
+
+ Args:
+ image: A 3D float32 tensor of shape [height, width, channels].
+ boxes: A 2D float32 tensor of shape [num_boxes, 4] containing the bounding
+ boxes in normalized coordinates. Each row is of the form
+ [ymin, xmin, ymax, xmax].
+ keypoints: (optional) rank 3 float32 tensor with shape
+ [num_instances, num_keypoints, 2]. The keypoints are in y-x normalized
+ coordinates.
+
+ Returns:
+ image: unchanged input image.
+ scaled_boxes: a 2D float32 tensor of shape [num_boxes, 4] containing the
+ bounding boxes in pixel coordinates.
+ scaled_keypoints: a 3D float32 tensor with shape
+ [num_instances, num_keypoints, 2] containing the keypoints in pixel
+ coordinates.
+ """
+ boxlist = box_list.BoxList(boxes)
+ image_height = tf.shape(input=image)[0]
+ image_width = tf.shape(input=image)[1]
+ scaled_boxes = box_list_scale(boxlist, image_height, image_width).get()
+ result = [image, scaled_boxes]
+ if keypoints is not None:
+ scaled_keypoints = keypoint_scale(keypoints, image_height, image_width)
+ result.append(scaled_keypoints)
+ return tuple(result)
diff --git a/models/official/vision/detection/utils/object_detection/region_similarity_calculator.py b/models/official/vision/detection/utils/object_detection/region_similarity_calculator.py
new file mode 100644
index 0000000000000000000000000000000000000000..0af2ce495ad53c9df0f8d2eb79f7431b02ab430e
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/region_similarity_calculator.py
@@ -0,0 +1,143 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Region Similarity Calculators for BoxLists.
+
+Region Similarity Calculators compare a pairwise measure of similarity
+between the boxes in two BoxLists.
+"""
+from abc import ABCMeta
+from abc import abstractmethod
+
+import tensorflow as tf
+
+
+def area(boxlist, scope=None):
+ """Computes area of boxes.
+
+ Args:
+ boxlist: BoxList holding N boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N] representing box areas.
+ """
+ if not scope:
+ scope = 'Area'
+ with tf.name_scope(scope):
+ y_min, x_min, y_max, x_max = tf.split(
+ value=boxlist.get(), num_or_size_splits=4, axis=1)
+ return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])
+
+
+def intersection(boxlist1, boxlist2, scope=None):
+ """Compute pairwise intersection areas between boxes.
+
+ Args:
+ boxlist1: BoxList holding N boxes
+ boxlist2: BoxList holding M boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N, M] representing pairwise intersections
+ """
+ if not scope:
+ scope = 'Intersection'
+ with tf.name_scope(scope):
+ y_min1, x_min1, y_max1, x_max1 = tf.split(
+ value=boxlist1.get(), num_or_size_splits=4, axis=1)
+ y_min2, x_min2, y_max2, x_max2 = tf.split(
+ value=boxlist2.get(), num_or_size_splits=4, axis=1)
+ all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(a=y_max2))
+ all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(a=y_min2))
+ intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin)
+ all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(a=x_max2))
+ all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(a=x_min2))
+ intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin)
+ return intersect_heights * intersect_widths
+
+
+def iou(boxlist1, boxlist2, scope=None):
+ """Computes pairwise intersection-over-union between box collections.
+
+ Args:
+ boxlist1: BoxList holding N boxes
+ boxlist2: BoxList holding M boxes
+ scope: name scope.
+
+ Returns:
+ a tensor with shape [N, M] representing pairwise iou scores.
+ """
+ if not scope:
+ scope = 'IOU'
+ with tf.name_scope(scope):
+ intersections = intersection(boxlist1, boxlist2)
+ areas1 = area(boxlist1)
+ areas2 = area(boxlist2)
+ unions = (
+ tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections)
+ return tf.where(
+ tf.equal(intersections, 0.0), tf.zeros_like(intersections),
+ tf.truediv(intersections, unions))
+
+
+class RegionSimilarityCalculator(object):
+ """Abstract base class for region similarity calculator."""
+ __metaclass__ = ABCMeta
+
+ def compare(self, boxlist1, boxlist2, scope=None):
+ """Computes matrix of pairwise similarity between BoxLists.
+
+ This op (to be overriden) computes a measure of pairwise similarity between
+ the boxes in the given BoxLists. Higher values indicate more similarity.
+
+ Note that this method simply measures similarity and does not explicitly
+ perform a matching.
+
+ Args:
+ boxlist1: BoxList holding N boxes.
+ boxlist2: BoxList holding M boxes.
+ scope: Op scope name. Defaults to 'Compare' if None.
+
+ Returns:
+ a (float32) tensor of shape [N, M] with pairwise similarity score.
+ """
+ if not scope:
+ scope = 'Compare'
+ with tf.name_scope(scope) as scope:
+ return self._compare(boxlist1, boxlist2)
+
+ @abstractmethod
+ def _compare(self, boxlist1, boxlist2):
+ pass
+
+
+class IouSimilarity(RegionSimilarityCalculator):
+ """Class to compute similarity based on Intersection over Union (IOU) metric.
+
+ This class computes pairwise similarity between two BoxLists based on IOU.
+ """
+
+ def _compare(self, boxlist1, boxlist2):
+ """Compute pairwise IOU similarity between the two BoxLists.
+
+ Args:
+ boxlist1: BoxList holding N boxes.
+ boxlist2: BoxList holding M boxes.
+
+ Returns:
+ A tensor with shape [N, M] representing pairwise iou scores.
+ """
+ return iou(boxlist1, boxlist2)
diff --git a/models/official/vision/detection/utils/object_detection/shape_utils.py b/models/official/vision/detection/utils/object_detection/shape_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30b62b7acc15b7f9f98b6c27b1a22efaf2998a8
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/shape_utils.py
@@ -0,0 +1,112 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utils used to manipulate tensor shapes."""
+
+import tensorflow as tf
+
+
+def assert_shape_equal(shape_a, shape_b):
+ """Asserts that shape_a and shape_b are equal.
+
+ If the shapes are static, raises a ValueError when the shapes
+ mismatch.
+
+ If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
+ mismatch.
+
+ Args:
+ shape_a: a list containing shape of the first tensor.
+ shape_b: a list containing shape of the second tensor.
+
+ Returns:
+ Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
+ when the shapes are dynamic.
+
+ Raises:
+ ValueError: When shapes are both static and unequal.
+ """
+ if (all(isinstance(dim, int) for dim in shape_a) and
+ all(isinstance(dim, int) for dim in shape_b)):
+ if shape_a != shape_b:
+ raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
+ else: return tf.no_op()
+ else:
+ return tf.assert_equal(shape_a, shape_b)
+
+
+def combined_static_and_dynamic_shape(tensor):
+ """Returns a list containing static and dynamic values for the dimensions.
+
+ Returns a list of static and dynamic values for shape dimensions. This is
+ useful to preserve static shapes when available in reshape operation.
+
+ Args:
+ tensor: A tensor of any type.
+
+ Returns:
+ A list of size tensor.shape.ndims containing integers or a scalar tensor.
+ """
+ static_tensor_shape = tensor.shape.as_list()
+ dynamic_tensor_shape = tf.shape(input=tensor)
+ combined_shape = []
+ for index, dim in enumerate(static_tensor_shape):
+ if dim is not None:
+ combined_shape.append(dim)
+ else:
+ combined_shape.append(dynamic_tensor_shape[index])
+ return combined_shape
+
+
+def pad_or_clip_nd(tensor, output_shape):
+ """Pad or Clip given tensor to the output shape.
+
+ Args:
+ tensor: Input tensor to pad or clip.
+ output_shape: A list of integers / scalar tensors (or None for dynamic dim)
+ representing the size to pad or clip each dimension of the input tensor.
+
+ Returns:
+ Input tensor padded and clipped to the output shape.
+ """
+ tensor_shape = tf.shape(input=tensor)
+ clip_size = [
+ tf.where(tensor_shape[i] - shape > 0, shape, -1)
+ if shape is not None else -1 for i, shape in enumerate(output_shape)
+ ]
+ clipped_tensor = tf.slice(
+ tensor,
+ begin=tf.zeros(len(clip_size), dtype=tf.int32),
+ size=clip_size)
+
+ # Pad tensor if the shape of clipped tensor is smaller than the expected
+ # shape.
+ clipped_tensor_shape = tf.shape(input=clipped_tensor)
+ trailing_paddings = [
+ shape - clipped_tensor_shape[i] if shape is not None else 0
+ for i, shape in enumerate(output_shape)
+ ]
+ paddings = tf.stack(
+ [
+ tf.zeros(len(trailing_paddings), dtype=tf.int32),
+ trailing_paddings
+ ],
+ axis=1)
+ padded_tensor = tf.pad(tensor=clipped_tensor, paddings=paddings)
+ output_static_shape = [
+ dim if not isinstance(dim, tf.Tensor) else None for dim in output_shape
+ ]
+ padded_tensor.set_shape(output_static_shape)
+ return padded_tensor
diff --git a/models/official/vision/detection/utils/object_detection/target_assigner.py b/models/official/vision/detection/utils/object_detection/target_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..c04448efb052b45da65366b26e7d773b62015773
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/target_assigner.py
@@ -0,0 +1,314 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Base target assigner module.
+
+The job of a TargetAssigner is, for a given set of anchors (bounding boxes) and
+groundtruth detections (bounding boxes), to assign classification and regression
+targets to each anchor as well as weights to each anchor (specifying, e.g.,
+which anchors should not contribute to training loss).
+
+It assigns classification/regression targets by performing the following steps:
+1) Computing pairwise similarity between anchors and groundtruth boxes using a
+ provided RegionSimilarity Calculator
+2) Computing a matching based on the similarity matrix using a provided Matcher
+3) Assigning regression targets based on the matching and a provided BoxCoder
+4) Assigning classification targets based on the matching and groundtruth labels
+
+Note that TargetAssigners only operate on detections from a single
+image at a time, so any logic for applying a TargetAssigner to multiple
+images must be handled externally.
+"""
+import tensorflow as tf
+
+from official.vision.detection.utils.object_detection import box_list
+from official.vision.detection.utils.object_detection import shape_utils
+
+
+KEYPOINTS_FIELD_NAME = 'keypoints'
+
+
+class TargetAssigner(object):
+ """Target assigner to compute classification and regression targets."""
+
+ def __init__(self, similarity_calc, matcher, box_coder,
+ negative_class_weight=1.0, unmatched_cls_target=None):
+ """Construct Object Detection Target Assigner.
+
+ Args:
+ similarity_calc: a RegionSimilarityCalculator
+ matcher: Matcher used to match groundtruth to anchors.
+ box_coder: BoxCoder used to encode matching groundtruth boxes with
+ respect to anchors.
+ negative_class_weight: classification weight to be associated to negative
+ anchors (default: 1.0). The weight must be in [0., 1.].
+ unmatched_cls_target: a float32 tensor with shape [d_1, d_2, ..., d_k]
+ which is consistent with the classification target for each
+ anchor (and can be empty for scalar targets). This shape must thus be
+ compatible with the groundtruth labels that are passed to the "assign"
+ function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
+ If set to None, unmatched_cls_target is set to be [0] for each anchor.
+
+ Raises:
+ ValueError: if similarity_calc is not a RegionSimilarityCalculator or
+ if matcher is not a Matcher or if box_coder is not a BoxCoder
+ """
+ self._similarity_calc = similarity_calc
+ self._matcher = matcher
+ self._box_coder = box_coder
+ self._negative_class_weight = negative_class_weight
+ if unmatched_cls_target is None:
+ self._unmatched_cls_target = tf.constant([0], tf.float32)
+ else:
+ self._unmatched_cls_target = unmatched_cls_target
+
+ @property
+ def box_coder(self):
+ return self._box_coder
+
+ def assign(self, anchors, groundtruth_boxes, groundtruth_labels=None,
+ groundtruth_weights=None, **params):
+ """Assign classification and regression targets to each anchor.
+
+ For a given set of anchors and groundtruth detections, match anchors
+ to groundtruth_boxes and assign classification and regression targets to
+ each anchor as well as weights based on the resulting match (specifying,
+ e.g., which anchors should not contribute to training loss).
+
+ Anchors that are not matched to anything are given a classification target
+ of self._unmatched_cls_target which can be specified via the constructor.
+
+ Args:
+ anchors: a BoxList representing N anchors
+ groundtruth_boxes: a BoxList representing M groundtruth boxes
+ groundtruth_labels: a tensor of shape [M, d_1, ... d_k]
+ with labels for each of the ground_truth boxes. The subshape
+ [d_1, ... d_k] can be empty (corresponding to scalar inputs). When set
+ to None, groundtruth_labels assumes a binary problem where all
+ ground_truth boxes get a positive label (of 1).
+ groundtruth_weights: a float tensor of shape [M] indicating the weight to
+ assign to all anchors match to a particular groundtruth box. The weights
+ must be in [0., 1.]. If None, all weights are set to 1.
+ **params: Additional keyword arguments for specific implementations of
+ the Matcher.
+
+ Returns:
+ cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
+ where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels
+ which has shape [num_gt_boxes, d_1, d_2, ... d_k].
+ cls_weights: a float32 tensor with shape [num_anchors]
+ reg_targets: a float32 tensor with shape [num_anchors, box_code_dimension]
+ reg_weights: a float32 tensor with shape [num_anchors]
+ match: a matcher.Match object encoding the match between anchors and
+ groundtruth boxes, with rows corresponding to groundtruth boxes
+ and columns corresponding to anchors.
+
+ Raises:
+ ValueError: if anchors or groundtruth_boxes are not of type
+ box_list.BoxList
+ """
+ if not isinstance(anchors, box_list.BoxList):
+ raise ValueError('anchors must be an BoxList')
+ if not isinstance(groundtruth_boxes, box_list.BoxList):
+ raise ValueError('groundtruth_boxes must be an BoxList')
+
+ if groundtruth_labels is None:
+ groundtruth_labels = tf.ones(tf.expand_dims(groundtruth_boxes.num_boxes(),
+ 0))
+ groundtruth_labels = tf.expand_dims(groundtruth_labels, -1)
+ unmatched_shape_assert = shape_utils.assert_shape_equal(
+ shape_utils.combined_static_and_dynamic_shape(groundtruth_labels)[1:],
+ shape_utils.combined_static_and_dynamic_shape(
+ self._unmatched_cls_target))
+ labels_and_box_shapes_assert = shape_utils.assert_shape_equal(
+ shape_utils.combined_static_and_dynamic_shape(
+ groundtruth_labels)[:1],
+ shape_utils.combined_static_and_dynamic_shape(
+ groundtruth_boxes.get())[:1])
+
+ if groundtruth_weights is None:
+ num_gt_boxes = groundtruth_boxes.num_boxes_static()
+ if not num_gt_boxes:
+ num_gt_boxes = groundtruth_boxes.num_boxes()
+ groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
+ with tf.control_dependencies(
+ [unmatched_shape_assert, labels_and_box_shapes_assert]):
+ match_quality_matrix = self._similarity_calc.compare(groundtruth_boxes,
+ anchors)
+ match = self._matcher.match(match_quality_matrix, **params)
+ reg_targets = self._create_regression_targets(anchors,
+ groundtruth_boxes,
+ match)
+ cls_targets = self._create_classification_targets(groundtruth_labels,
+ match)
+ reg_weights = self._create_regression_weights(match, groundtruth_weights)
+ cls_weights = self._create_classification_weights(match,
+ groundtruth_weights)
+
+ num_anchors = anchors.num_boxes_static()
+ if num_anchors is not None:
+ reg_targets = self._reset_target_shape(reg_targets, num_anchors)
+ cls_targets = self._reset_target_shape(cls_targets, num_anchors)
+ reg_weights = self._reset_target_shape(reg_weights, num_anchors)
+ cls_weights = self._reset_target_shape(cls_weights, num_anchors)
+
+ return cls_targets, cls_weights, reg_targets, reg_weights, match
+
+ def _reset_target_shape(self, target, num_anchors):
+ """Sets the static shape of the target.
+
+ Args:
+ target: the target tensor. Its first dimension will be overwritten.
+ num_anchors: the number of anchors, which is used to override the target's
+ first dimension.
+
+ Returns:
+ A tensor with the shape info filled in.
+ """
+ target_shape = target.get_shape().as_list()
+ target_shape[0] = num_anchors
+ target.set_shape(target_shape)
+ return target
+
+ def _create_regression_targets(self, anchors, groundtruth_boxes, match):
+ """Returns a regression target for each anchor.
+
+ Args:
+ anchors: a BoxList representing N anchors
+ groundtruth_boxes: a BoxList representing M groundtruth_boxes
+ match: a matcher.Match object
+
+ Returns:
+ reg_targets: a float32 tensor with shape [N, box_code_dimension]
+ """
+ matched_gt_boxes = match.gather_based_on_match(
+ groundtruth_boxes.get(),
+ unmatched_value=tf.zeros(4),
+ ignored_value=tf.zeros(4))
+ matched_gt_boxlist = box_list.BoxList(matched_gt_boxes)
+ if groundtruth_boxes.has_field(KEYPOINTS_FIELD_NAME):
+ groundtruth_keypoints = groundtruth_boxes.get_field(KEYPOINTS_FIELD_NAME)
+ matched_keypoints = match.gather_based_on_match(
+ groundtruth_keypoints,
+ unmatched_value=tf.zeros(groundtruth_keypoints.get_shape()[1:]),
+ ignored_value=tf.zeros(groundtruth_keypoints.get_shape()[1:]))
+ matched_gt_boxlist.add_field(KEYPOINTS_FIELD_NAME, matched_keypoints)
+ matched_reg_targets = self._box_coder.encode(matched_gt_boxlist, anchors)
+ match_results_shape = shape_utils.combined_static_and_dynamic_shape(
+ match.match_results)
+
+ # Zero out the unmatched and ignored regression targets.
+ unmatched_ignored_reg_targets = tf.tile(
+ self._default_regression_target(), [match_results_shape[0], 1])
+ matched_anchors_mask = match.matched_column_indicator()
+ # To broadcast matched_anchors_mask to the same shape as
+ # matched_reg_targets.
+ matched_anchors_mask = tf.tile(
+ tf.expand_dims(matched_anchors_mask, 1),
+ [1, tf.shape(matched_reg_targets)[1]])
+ reg_targets = tf.where(matched_anchors_mask, matched_reg_targets,
+ unmatched_ignored_reg_targets)
+ return reg_targets
+
+ def _default_regression_target(self):
+ """Returns the default target for anchors to regress to.
+
+ Default regression targets are set to zero (though in
+ this implementation what these targets are set to should
+ not matter as the regression weight of any box set to
+ regress to the default target is zero).
+
+ Returns:
+ default_target: a float32 tensor with shape [1, box_code_dimension]
+ """
+ return tf.constant([self._box_coder.code_size*[0]], tf.float32)
+
+ def _create_classification_targets(self, groundtruth_labels, match):
+ """Create classification targets for each anchor.
+
+ Assign a classification target of for each anchor to the matching
+ groundtruth label that is provided by match. Anchors that are not matched
+ to anything are given the target self._unmatched_cls_target
+
+ Args:
+ groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k]
+ with labels for each of the ground_truth boxes. The subshape
+ [d_1, ... d_k] can be empty (corresponding to scalar labels).
+ match: a matcher.Match object that provides a matching between anchors
+ and groundtruth boxes.
+
+ Returns:
+ a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the
+ subshape [d_1, ..., d_k] is compatible with groundtruth_labels which has
+ shape [num_gt_boxes, d_1, d_2, ... d_k].
+ """
+ return match.gather_based_on_match(
+ groundtruth_labels,
+ unmatched_value=self._unmatched_cls_target,
+ ignored_value=self._unmatched_cls_target)
+
+ def _create_regression_weights(self, match, groundtruth_weights):
+ """Set regression weight for each anchor.
+
+ Only positive anchors are set to contribute to the regression loss, so this
+ method returns a weight of 1 for every positive anchor and 0 for every
+ negative anchor.
+
+ Args:
+ match: a matcher.Match object that provides a matching between anchors
+ and groundtruth boxes.
+ groundtruth_weights: a float tensor of shape [M] indicating the weight to
+ assign to all anchors match to a particular groundtruth box.
+
+ Returns:
+ a float32 tensor with shape [num_anchors] representing regression weights.
+ """
+ return match.gather_based_on_match(
+ groundtruth_weights, ignored_value=0., unmatched_value=0.)
+
+ def _create_classification_weights(self,
+ match,
+ groundtruth_weights):
+ """Create classification weights for each anchor.
+
+ Positive (matched) anchors are associated with a weight of
+ positive_class_weight and negative (unmatched) anchors are associated with
+ a weight of negative_class_weight. When anchors are ignored, weights are set
+ to zero. By default, both positive/negative weights are set to 1.0,
+ but they can be adjusted to handle class imbalance (which is almost always
+ the case in object detection).
+
+ Args:
+ match: a matcher.Match object that provides a matching between anchors
+ and groundtruth boxes.
+ groundtruth_weights: a float tensor of shape [M] indicating the weight to
+ assign to all anchors match to a particular groundtruth box.
+
+ Returns:
+ a float32 tensor with shape [num_anchors] representing classification
+ weights.
+ """
+ return match.gather_based_on_match(
+ groundtruth_weights,
+ ignored_value=0.,
+ unmatched_value=self._negative_class_weight)
+
+ def get_box_coder(self):
+ """Get BoxCoder of this TargetAssigner.
+
+ Returns:
+ BoxCoder object.
+ """
+ return self._box_coder
diff --git a/models/official/vision/detection/utils/object_detection/visualization_utils.py b/models/official/vision/detection/utils/object_detection/visualization_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db4af8089df673cd5c57c4a020b5d7e8f03846c9
--- /dev/null
+++ b/models/official/vision/detection/utils/object_detection/visualization_utils.py
@@ -0,0 +1,733 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A set of functions that are used for visualization.
+
+These functions often receive an image, perform some visualization on the image.
+The functions do not return a value, instead they modify the image itself.
+
+"""
+import collections
+import functools
+from absl import logging
+# Set headless-friendly backend.
+import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements
+import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
+import numpy as np
+import PIL.Image as Image
+import PIL.ImageColor as ImageColor
+import PIL.ImageDraw as ImageDraw
+import PIL.ImageFont as ImageFont
+import six
+import tensorflow as tf
+
+from official.vision.detection.utils import box_utils
+from official.vision.detection.utils.object_detection import shape_utils
+
+
+_TITLE_LEFT_MARGIN = 10
+_TITLE_TOP_MARGIN = 10
+STANDARD_COLORS = [
+ 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
+ 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
+ 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
+ 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
+ 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
+ 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
+ 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
+ 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
+ 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
+ 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
+ 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
+ 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
+ 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
+ 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
+ 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
+ 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
+ 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
+ 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
+ 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
+ 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
+ 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
+ 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
+ 'WhiteSmoke', 'Yellow', 'YellowGreen'
+]
+
+
+def save_image_array_as_png(image, output_path):
+ """Saves an image (represented as a numpy array) to PNG.
+
+ Args:
+ image: a numpy array with shape [height, width, 3].
+ output_path: path to which image should be written.
+ """
+ image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
+ with tf.io.gfile.GFile(output_path, 'w') as fid:
+ image_pil.save(fid, 'PNG')
+
+
+def encode_image_array_as_png_str(image):
+ """Encodes a numpy array into a PNG string.
+
+ Args:
+ image: a numpy array with shape [height, width, 3].
+
+ Returns:
+ PNG encoded image string.
+ """
+ image_pil = Image.fromarray(np.uint8(image))
+ output = six.BytesIO()
+ image_pil.save(output, format='PNG')
+ png_string = output.getvalue()
+ output.close()
+ return png_string
+
+
+def visualize_images_with_bounding_boxes(images, box_outputs, step,
+ summary_writer):
+ """Records subset of evaluation images with bounding boxes."""
+ if not isinstance(images, list):
+ logging.warning('visualize_images_with_bounding_boxes expects list of '
+ 'images but received type: %s and value: %s',
+ type(images), images)
+ return
+
+ image_shape = tf.shape(images[0])
+ image_height = tf.cast(image_shape[0], tf.float32)
+ image_width = tf.cast(image_shape[1], tf.float32)
+ normalized_boxes = box_utils.normalize_boxes(box_outputs,
+ [image_height, image_width])
+
+ bounding_box_color = tf.constant([[1.0, 1.0, 0.0, 1.0]])
+ image_summary = tf.image.draw_bounding_boxes(
+ tf.cast(images, tf.float32), normalized_boxes, bounding_box_color)
+ with summary_writer.as_default():
+ tf.summary.image('bounding_box_summary', image_summary, step=step)
+ summary_writer.flush()
+
+
+def draw_bounding_box_on_image_array(image,
+ ymin,
+ xmin,
+ ymax,
+ xmax,
+ color='red',
+ thickness=4,
+ display_str_list=(),
+ use_normalized_coordinates=True):
+ """Adds a bounding box to an image (numpy array).
+
+ Bounding box coordinates can be specified in either absolute (pixel) or
+ normalized coordinates by setting the use_normalized_coordinates argument.
+
+ Args:
+ image: a numpy array with shape [height, width, 3].
+ ymin: ymin of bounding box.
+ xmin: xmin of bounding box.
+ ymax: ymax of bounding box.
+ xmax: xmax of bounding box.
+ color: color to draw bounding box. Default is red.
+ thickness: line thickness. Default value is 4.
+ display_str_list: list of strings to display in box
+ (each to be shown on its own line).
+ use_normalized_coordinates: If True (default), treat coordinates
+ ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
+ coordinates as absolute.
+ """
+ image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
+ draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
+ thickness, display_str_list,
+ use_normalized_coordinates)
+ np.copyto(image, np.array(image_pil))
+
+
+def draw_bounding_box_on_image(image,
+ ymin,
+ xmin,
+ ymax,
+ xmax,
+ color='red',
+ thickness=4,
+ display_str_list=(),
+ use_normalized_coordinates=True):
+ """Adds a bounding box to an image.
+
+ Bounding box coordinates can be specified in either absolute (pixel) or
+ normalized coordinates by setting the use_normalized_coordinates argument.
+
+ Each string in display_str_list is displayed on a separate line above the
+ bounding box in black text on a rectangle filled with the input 'color'.
+ If the top of the bounding box extends to the edge of the image, the strings
+ are displayed below the bounding box.
+
+ Args:
+ image: a PIL.Image object.
+ ymin: ymin of bounding box.
+ xmin: xmin of bounding box.
+ ymax: ymax of bounding box.
+ xmax: xmax of bounding box.
+ color: color to draw bounding box. Default is red.
+ thickness: line thickness. Default value is 4.
+ display_str_list: list of strings to display in box
+ (each to be shown on its own line).
+ use_normalized_coordinates: If True (default), treat coordinates
+ ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
+ coordinates as absolute.
+ """
+ draw = ImageDraw.Draw(image)
+ im_width, im_height = image.size
+ if use_normalized_coordinates:
+ (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
+ ymin * im_height, ymax * im_height)
+ else:
+ (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
+ draw.line([(left, top), (left, bottom), (right, bottom),
+ (right, top), (left, top)], width=thickness, fill=color)
+ try:
+ font = ImageFont.truetype('arial.ttf', 24)
+ except IOError:
+ font = ImageFont.load_default()
+
+ # If the total height of the display strings added to the top of the bounding
+ # box exceeds the top of the image, stack the strings below the bounding box
+ # instead of above.
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
+ # Each display_str has a top and bottom margin of 0.05x.
+ total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
+
+ if top > total_display_str_height:
+ text_bottom = top
+ else:
+ text_bottom = bottom + total_display_str_height
+ # Reverse list and print from bottom to top.
+ for display_str in display_str_list[::-1]:
+ text_width, text_height = font.getsize(display_str)
+ margin = np.ceil(0.05 * text_height)
+ draw.rectangle(
+ [(left, text_bottom - text_height - 2 * margin), (left + text_width,
+ text_bottom)],
+ fill=color)
+ draw.text(
+ (left + margin, text_bottom - text_height - margin),
+ display_str,
+ fill='black',
+ font=font)
+ text_bottom -= text_height - 2 * margin
+
+
+def draw_bounding_boxes_on_image_array(image,
+ boxes,
+ color='red',
+ thickness=4,
+ display_str_list_list=()):
+ """Draws bounding boxes on image (numpy array).
+
+ Args:
+ image: a numpy array object.
+ boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
+ The coordinates are in normalized format between [0, 1].
+ color: color to draw bounding box. Default is red.
+ thickness: line thickness. Default value is 4.
+ display_str_list_list: list of list of strings.
+ a list of strings for each bounding box.
+ The reason to pass a list of strings for a
+ bounding box is that it might contain
+ multiple labels.
+
+ Raises:
+ ValueError: if boxes is not a [N, 4] array
+ """
+ image_pil = Image.fromarray(image)
+ draw_bounding_boxes_on_image(image_pil, boxes, color, thickness,
+ display_str_list_list)
+ np.copyto(image, np.array(image_pil))
+
+
+def draw_bounding_boxes_on_image(image,
+ boxes,
+ color='red',
+ thickness=4,
+ display_str_list_list=()):
+ """Draws bounding boxes on image.
+
+ Args:
+ image: a PIL.Image object.
+ boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
+ The coordinates are in normalized format between [0, 1].
+ color: color to draw bounding box. Default is red.
+ thickness: line thickness. Default value is 4.
+ display_str_list_list: list of list of strings.
+ a list of strings for each bounding box.
+ The reason to pass a list of strings for a
+ bounding box is that it might contain
+ multiple labels.
+
+ Raises:
+ ValueError: if boxes is not a [N, 4] array
+ """
+ boxes_shape = boxes.shape
+ if not boxes_shape:
+ return
+ if len(boxes_shape) != 2 or boxes_shape[1] != 4:
+ raise ValueError('Input must be of size [N, 4]')
+ for i in range(boxes_shape[0]):
+ display_str_list = ()
+ if display_str_list_list:
+ display_str_list = display_str_list_list[i]
+ draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2],
+ boxes[i, 3], color, thickness, display_str_list)
+
+
+def _visualize_boxes(image, boxes, classes, scores, category_index, **kwargs):
+ return visualize_boxes_and_labels_on_image_array(
+ image, boxes, classes, scores, category_index=category_index, **kwargs)
+
+
+def _visualize_boxes_and_masks(image, boxes, classes, scores, masks,
+ category_index, **kwargs):
+ return visualize_boxes_and_labels_on_image_array(
+ image,
+ boxes,
+ classes,
+ scores,
+ category_index=category_index,
+ instance_masks=masks,
+ **kwargs)
+
+
+def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints,
+ category_index, **kwargs):
+ return visualize_boxes_and_labels_on_image_array(
+ image,
+ boxes,
+ classes,
+ scores,
+ category_index=category_index,
+ keypoints=keypoints,
+ **kwargs)
+
+
+def _visualize_boxes_and_masks_and_keypoints(
+ image, boxes, classes, scores, masks, keypoints, category_index, **kwargs):
+ return visualize_boxes_and_labels_on_image_array(
+ image,
+ boxes,
+ classes,
+ scores,
+ category_index=category_index,
+ instance_masks=masks,
+ keypoints=keypoints,
+ **kwargs)
+
+
+def _resize_original_image(image, image_shape):
+ image = tf.expand_dims(image, 0)
+ image = tf.image.resize(
+ image, image_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
+ return tf.cast(tf.squeeze(image, 0), tf.uint8)
+
+
+def draw_bounding_boxes_on_image_tensors(images,
+ boxes,
+ classes,
+ scores,
+ category_index,
+ original_image_spatial_shape=None,
+ true_image_shape=None,
+ instance_masks=None,
+ keypoints=None,
+ max_boxes_to_draw=20,
+ min_score_thresh=0.2,
+ use_normalized_coordinates=True):
+ """Draws bounding boxes, masks, and keypoints on batch of image tensors.
+
+ Args:
+ images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional
+ channels will be ignored. If C = 1, then we convert the images to RGB
+ images.
+ boxes: [N, max_detections, 4] float32 tensor of detection boxes.
+ classes: [N, max_detections] int tensor of detection classes. Note that
+ classes are 1-indexed.
+ scores: [N, max_detections] float32 tensor of detection scores.
+ category_index: a dict that maps integer ids to category dicts. e.g.
+ {1: {1: 'dog'}, 2: {2: 'cat'}, ...}
+ original_image_spatial_shape: [N, 2] tensor containing the spatial size of
+ the original image.
+ true_image_shape: [N, 3] tensor containing the spatial size of unpadded
+ original_image.
+ instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with
+ instance masks.
+ keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2]
+ with keypoints.
+ max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20.
+ min_score_thresh: Minimum score threshold for visualization. Default 0.2.
+ use_normalized_coordinates: Whether to assume boxes and kepoints are in
+ normalized coordinates (as opposed to absolute coordiantes).
+ Default is True.
+
+ Returns:
+ 4D image tensor of type uint8, with boxes drawn on top.
+ """
+ # Additional channels are being ignored.
+ if images.shape[3] > 3:
+ images = images[:, :, :, 0:3]
+ elif images.shape[3] == 1:
+ images = tf.image.grayscale_to_rgb(images)
+ visualization_keyword_args = {
+ 'use_normalized_coordinates': use_normalized_coordinates,
+ 'max_boxes_to_draw': max_boxes_to_draw,
+ 'min_score_thresh': min_score_thresh,
+ 'agnostic_mode': False,
+ 'line_thickness': 4
+ }
+ if true_image_shape is None:
+ true_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 3])
+ else:
+ true_shapes = true_image_shape
+ if original_image_spatial_shape is None:
+ original_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 2])
+ else:
+ original_shapes = original_image_spatial_shape
+
+ if instance_masks is not None and keypoints is None:
+ visualize_boxes_fn = functools.partial(
+ _visualize_boxes_and_masks,
+ category_index=category_index,
+ **visualization_keyword_args)
+ elems = [
+ true_shapes, original_shapes, images, boxes, classes, scores,
+ instance_masks
+ ]
+ elif instance_masks is None and keypoints is not None:
+ visualize_boxes_fn = functools.partial(
+ _visualize_boxes_and_keypoints,
+ category_index=category_index,
+ **visualization_keyword_args)
+ elems = [
+ true_shapes, original_shapes, images, boxes, classes, scores, keypoints
+ ]
+ elif instance_masks is not None and keypoints is not None:
+ visualize_boxes_fn = functools.partial(
+ _visualize_boxes_and_masks_and_keypoints,
+ category_index=category_index,
+ **visualization_keyword_args)
+ elems = [
+ true_shapes, original_shapes, images, boxes, classes, scores,
+ instance_masks, keypoints
+ ]
+ else:
+ visualize_boxes_fn = functools.partial(
+ _visualize_boxes,
+ category_index=category_index,
+ **visualization_keyword_args)
+ elems = [
+ true_shapes, original_shapes, images, boxes, classes, scores
+ ]
+
+ def draw_boxes(image_and_detections):
+ """Draws boxes on image."""
+ true_shape = image_and_detections[0]
+ original_shape = image_and_detections[1]
+ if true_image_shape is not None:
+ image = shape_utils.pad_or_clip_nd(
+ image_and_detections[2], [true_shape[0], true_shape[1], 3])
+ if original_image_spatial_shape is not None:
+ image_and_detections[2] = _resize_original_image(image, original_shape)
+
+ image_with_boxes = tf.compat.v1.py_func(visualize_boxes_fn,
+ image_and_detections[2:], tf.uint8)
+ return image_with_boxes
+
+ images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False)
+ return images
+
+
+def draw_keypoints_on_image_array(image,
+ keypoints,
+ color='red',
+ radius=2,
+ use_normalized_coordinates=True):
+ """Draws keypoints on an image (numpy array).
+
+ Args:
+ image: a numpy array with shape [height, width, 3].
+ keypoints: a numpy array with shape [num_keypoints, 2].
+ color: color to draw the keypoints with. Default is red.
+ radius: keypoint radius. Default value is 2.
+ use_normalized_coordinates: if True (default), treat keypoint values as
+ relative to the image. Otherwise treat them as absolute.
+ """
+ image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
+ draw_keypoints_on_image(image_pil, keypoints, color, radius,
+ use_normalized_coordinates)
+ np.copyto(image, np.array(image_pil))
+
+
+def draw_keypoints_on_image(image,
+ keypoints,
+ color='red',
+ radius=2,
+ use_normalized_coordinates=True):
+ """Draws keypoints on an image.
+
+ Args:
+ image: a PIL.Image object.
+ keypoints: a numpy array with shape [num_keypoints, 2].
+ color: color to draw the keypoints with. Default is red.
+ radius: keypoint radius. Default value is 2.
+ use_normalized_coordinates: if True (default), treat keypoint values as
+ relative to the image. Otherwise treat them as absolute.
+ """
+ draw = ImageDraw.Draw(image)
+ im_width, im_height = image.size
+ keypoints_x = [k[1] for k in keypoints]
+ keypoints_y = [k[0] for k in keypoints]
+ if use_normalized_coordinates:
+ keypoints_x = tuple([im_width * x for x in keypoints_x])
+ keypoints_y = tuple([im_height * y for y in keypoints_y])
+ for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y):
+ draw.ellipse([(keypoint_x - radius, keypoint_y - radius),
+ (keypoint_x + radius, keypoint_y + radius)],
+ outline=color, fill=color)
+
+
+def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
+ """Draws mask on an image.
+
+ Args:
+ image: uint8 numpy array with shape (img_height, img_height, 3)
+ mask: a uint8 numpy array of shape (img_height, img_height) with
+ values between either 0 or 1.
+ color: color to draw the keypoints with. Default is red.
+ alpha: transparency value between 0 and 1. (default: 0.4)
+
+ Raises:
+ ValueError: On incorrect data type for image or masks.
+ """
+ if image.dtype != np.uint8:
+ raise ValueError('`image` not of type np.uint8')
+ if mask.dtype != np.uint8:
+ raise ValueError('`mask` not of type np.uint8')
+ if np.any(np.logical_and(mask != 1, mask != 0)):
+ raise ValueError('`mask` elements should be in [0, 1]')
+ if image.shape[:2] != mask.shape:
+ raise ValueError('The image has spatial dimensions %s but the mask has '
+ 'dimensions %s' % (image.shape[:2], mask.shape))
+ rgb = ImageColor.getrgb(color)
+ pil_image = Image.fromarray(image)
+
+ solid_color = np.expand_dims(
+ np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
+ pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA')
+ pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L')
+ pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
+ np.copyto(image, np.array(pil_image.convert('RGB')))
+
+
+def visualize_boxes_and_labels_on_image_array(
+ image,
+ boxes,
+ classes,
+ scores,
+ category_index,
+ instance_masks=None,
+ instance_boundaries=None,
+ keypoints=None,
+ use_normalized_coordinates=False,
+ max_boxes_to_draw=20,
+ min_score_thresh=.5,
+ agnostic_mode=False,
+ line_thickness=4,
+ groundtruth_box_visualization_color='black',
+ skip_scores=False,
+ skip_labels=False):
+ """Overlay labeled boxes on an image with formatted scores and label names.
+
+ This function groups boxes that correspond to the same location
+ and creates a display string for each detection and overlays these
+ on the image. Note that this function modifies the image in place, and returns
+ that same image.
+
+ Args:
+ image: uint8 numpy array with shape (img_height, img_width, 3)
+ boxes: a numpy array of shape [N, 4]
+ classes: a numpy array of shape [N]. Note that class indices are 1-based,
+ and match the keys in the label map.
+ scores: a numpy array of shape [N] or None. If scores=None, then
+ this function assumes that the boxes to be plotted are groundtruth
+ boxes and plot all boxes as black with no classes or scores.
+ category_index: a dict containing category dictionaries (each holding
+ category index `id` and category name `name`) keyed by category indices.
+ instance_masks: a numpy array of shape [N, image_height, image_width] with
+ values ranging between 0 and 1, can be None.
+ instance_boundaries: a numpy array of shape [N, image_height, image_width]
+ with values ranging between 0 and 1, can be None.
+ keypoints: a numpy array of shape [N, num_keypoints, 2], can
+ be None
+ use_normalized_coordinates: whether boxes is to be interpreted as
+ normalized coordinates or not.
+ max_boxes_to_draw: maximum number of boxes to visualize. If None, draw
+ all boxes.
+ min_score_thresh: minimum score threshold for a box to be visualized
+ agnostic_mode: boolean (default: False) controlling whether to evaluate in
+ class-agnostic mode or not. This mode will display scores but ignore
+ classes.
+ line_thickness: integer (default: 4) controlling line width of the boxes.
+ groundtruth_box_visualization_color: box color for visualizing groundtruth
+ boxes
+ skip_scores: whether to skip score when drawing a single detection
+ skip_labels: whether to skip label when drawing a single detection
+
+ Returns:
+ uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes.
+ """
+ # Create a display string (and color) for every box location, group any boxes
+ # that correspond to the same location.
+ box_to_display_str_map = collections.defaultdict(list)
+ box_to_color_map = collections.defaultdict(str)
+ box_to_instance_masks_map = {}
+ box_to_instance_boundaries_map = {}
+ box_to_keypoints_map = collections.defaultdict(list)
+ if not max_boxes_to_draw:
+ max_boxes_to_draw = boxes.shape[0]
+ for i in range(min(max_boxes_to_draw, boxes.shape[0])):
+ if scores is None or scores[i] > min_score_thresh:
+ box = tuple(boxes[i].tolist())
+ if instance_masks is not None:
+ box_to_instance_masks_map[box] = instance_masks[i]
+ if instance_boundaries is not None:
+ box_to_instance_boundaries_map[box] = instance_boundaries[i]
+ if keypoints is not None:
+ box_to_keypoints_map[box].extend(keypoints[i])
+ if scores is None:
+ box_to_color_map[box] = groundtruth_box_visualization_color
+ else:
+ display_str = ''
+ if not skip_labels:
+ if not agnostic_mode:
+ if classes[i] in category_index.keys():
+ class_name = category_index[classes[i]]['name']
+ else:
+ class_name = 'N/A'
+ display_str = str(class_name)
+ if not skip_scores:
+ if not display_str:
+ display_str = '{}%'.format(int(100*scores[i]))
+ else:
+ display_str = '{}: {}%'.format(display_str, int(100*scores[i]))
+ box_to_display_str_map[box].append(display_str)
+ if agnostic_mode:
+ box_to_color_map[box] = 'DarkOrange'
+ else:
+ box_to_color_map[box] = STANDARD_COLORS[
+ classes[i] % len(STANDARD_COLORS)]
+
+ # Draw all boxes onto image.
+ for box, color in box_to_color_map.items():
+ ymin, xmin, ymax, xmax = box
+ if instance_masks is not None:
+ draw_mask_on_image_array(
+ image,
+ box_to_instance_masks_map[box],
+ color=color
+ )
+ if instance_boundaries is not None:
+ draw_mask_on_image_array(
+ image,
+ box_to_instance_boundaries_map[box],
+ color='red',
+ alpha=1.0
+ )
+ draw_bounding_box_on_image_array(
+ image,
+ ymin,
+ xmin,
+ ymax,
+ xmax,
+ color=color,
+ thickness=line_thickness,
+ display_str_list=box_to_display_str_map[box],
+ use_normalized_coordinates=use_normalized_coordinates)
+ if keypoints is not None:
+ draw_keypoints_on_image_array(
+ image,
+ box_to_keypoints_map[box],
+ color=color,
+ radius=line_thickness / 2,
+ use_normalized_coordinates=use_normalized_coordinates)
+
+ return image
+
+
+def add_cdf_image_summary(values, name):
+ """Adds a tf.summary.image for a CDF plot of the values.
+
+ Normalizes `values` such that they sum to 1, plots the cumulative distribution
+ function and creates a tf image summary.
+
+ Args:
+ values: a 1-D float32 tensor containing the values.
+ name: name for the image summary.
+ """
+ def cdf_plot(values):
+ """Numpy function to plot CDF."""
+ normalized_values = values / np.sum(values)
+ sorted_values = np.sort(normalized_values)
+ cumulative_values = np.cumsum(sorted_values)
+ fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32)
+ / cumulative_values.size)
+ fig = plt.figure(frameon=False)
+ ax = fig.add_subplot('111')
+ ax.plot(fraction_of_examples, cumulative_values)
+ ax.set_ylabel('cumulative normalized values')
+ ax.set_xlabel('fraction of examples')
+ fig.canvas.draw()
+ width, height = fig.get_size_inches() * fig.get_dpi()
+ image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape(
+ 1, int(height), int(width), 3)
+ return image
+
+ cdf_plot = tf.compat.v1.py_func(cdf_plot, [values], tf.uint8)
+ tf.compat.v1.summary.image(name, cdf_plot)
+
+
+def add_hist_image_summary(values, bins, name):
+ """Adds a tf.summary.image for a histogram plot of the values.
+
+ Plots the histogram of values and creates a tf image summary.
+
+ Args:
+ values: a 1-D float32 tensor containing the values.
+ bins: bin edges which will be directly passed to np.histogram.
+ name: name for the image summary.
+ """
+
+ def hist_plot(values, bins):
+ """Numpy function to plot hist."""
+ fig = plt.figure(frameon=False)
+ ax = fig.add_subplot('111')
+ y, x = np.histogram(values, bins=bins)
+ ax.plot(x[:-1], y)
+ ax.set_ylabel('count')
+ ax.set_xlabel('value')
+ fig.canvas.draw()
+ width, height = fig.get_size_inches() * fig.get_dpi()
+ image = np.fromstring(
+ fig.canvas.tostring_rgb(), dtype='uint8').reshape(
+ 1, int(height), int(width), 3)
+ return image
+
+ hist_plot = tf.compat.v1.py_func(hist_plot, [values, bins], tf.uint8)
+ tf.compat.v1.summary.image(name, hist_plot)
diff --git a/models/official/vision/image_classification/README.md b/models/official/vision/image_classification/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..eb061d5b5f3284255bdb484cfbbb20bb3e157268
--- /dev/null
+++ b/models/official/vision/image_classification/README.md
@@ -0,0 +1,182 @@
+# Image Classification
+
+This folder contains TF 2.0 model examples for image classification:
+
+* [MNIST](#mnist)
+* [Classifier Trainer](#classifier-trainer), a framework that uses the Keras
+compile/fit methods for image classification models, including:
+ * ResNet
+ * EfficientNet[^1]
+
+[^1]: Currently a work in progress. We cannot match "AutoAugment (AA)" in [the original version](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet).
+For more information about other types of models, please refer to this
+[README file](../../README.md).
+
+## Before you begin
+Please make sure that you have the latest version of TensorFlow
+installed and
+[add the models folder to your Python path](/official/#running-the-models).
+
+### ImageNet preparation
+
+#### Using TFDS
+`classifier_trainer.py` supports ImageNet with
+[TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/overview).
+
+Please see the following [example snippet](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/scripts/download_and_prepare.py)
+for more information on how to use TFDS to download and prepare datasets, and
+specifically the [TFDS ImageNet readme](https://github.com/tensorflow/datasets/blob/master/docs/catalog/imagenet2012.md)
+for manual download instructions.
+
+#### Legacy TFRecords
+Download the ImageNet dataset and convert it to TFRecord format.
+The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
+and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
+provide a few options.
+
+Note that the legacy ResNet runners, e.g. [resnet/resnet_ctl_imagenet_main.py](resnet/resnet_ctl_imagenet_main.py)
+require TFRecords whereas `classifier_trainer.py` can use both by setting the
+builder to 'records' or 'tfds' in the configurations.
+
+### Running on Cloud TPUs
+
+Note: These models will **not** work with TPUs on Colab.
+
+You can train image classification models on Cloud TPUs using
+[tf.distribute.experimental.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy?version=nightly).
+If you are not familiar with Cloud TPUs, it is strongly recommended that you go
+through the
+[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
+create a TPU and GCE VM.
+
+### Running on multiple GPU hosts
+
+You can also train these models on multiple hosts, each with GPUs, using
+[tf.distribute.experimental.MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy).
+
+The easiest way to run multi-host benchmarks is to set the
+[`TF_CONFIG`](https://www.tensorflow.org/guide/distributed_training#TF_CONFIG)
+appropriately at each host. e.g., to run using `MultiWorkerMirroredStrategy` on
+2 hosts, the `cluster` in `TF_CONFIG` should have 2 `host:port` entries, and
+host `i` should have the `task` in `TF_CONFIG` set to `{"type": "worker",
+"index": i}`. `MultiWorkerMirroredStrategy` will automatically use all the
+available GPUs at each host.
+
+## MNIST
+
+To download the data and run the MNIST sample model locally for the first time,
+run one of the following command:
+
+```bash
+python3 mnist_main.py \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --train_epochs=10 \
+ --distribution_strategy=one_device \
+ --num_gpus=$NUM_GPUS \
+ --download
+```
+
+To train the model on a Cloud TPU, run the following command:
+
+```bash
+python3 mnist_main.py \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --train_epochs=10 \
+ --distribution_strategy=tpu \
+ --download
+```
+
+Note: the `--download` flag is only required the first time you run the model.
+
+
+## Classifier Trainer
+The classifier trainer is a unified framework for running image classification
+models using Keras's compile/fit methods. Experiments should be provided in the
+form of YAML files, some examples are included within the configs/examples
+folder. Please see [configs/examples](./configs/examples) for more example
+configurations.
+
+The provided configuration files use a per replica batch size and is scaled
+by the number of devices. For instance, if `batch size` = 64, then for 1 GPU
+the global batch size would be 64 * 1 = 64. For 8 GPUs, the global batch size
+would be 64 * 8 = 512. Similarly, for a v3-8 TPU, the global batch size would
+be 64 * 8 = 512, and for a v3-32, the global batch size is 64 * 32 = 2048.
+
+### ResNet50
+
+#### On GPU:
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=resnet \
+ --dataset=imagenet \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/resnet/imagenet/gpu.yaml \
+ --params_override='runtime.num_gpus=$NUM_GPUS'
+```
+
+To train on multiple hosts, each with GPUs attached using
+[MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
+please update `runtime` section in gpu.yaml
+(or override using `--params_override`) with:
+
+```YAML
+# gpu.yaml
+runtime:
+ distribution_strategy: 'multi_worker_mirrored'
+ worker_hosts: '$HOST1:port,$HOST2:port'
+ num_gpus: $NUM_GPUS
+ task_index: 0
+```
+By having `task_index: 0` on the first host and `task_index: 1` on the second
+and so on. `$HOST1` and `$HOST2` are the IP addresses of the hosts, and `port`
+can be chosen any free port on the hosts. Only the first host will write
+TensorBoard Summaries and save checkpoints.
+
+#### On TPU:
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=resnet \
+ --dataset=imagenet \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/resnet/imagenet/tpu.yaml
+```
+
+### EfficientNet
+**Note: EfficientNet development is a work in progress.**
+#### On GPU:
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=efficientnet \
+ --dataset=imagenet \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml \
+ --params_override='runtime.num_gpus=$NUM_GPUS'
+```
+
+
+#### On TPU:
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=efficientnet \
+ --dataset=imagenet \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
+```
+
+Note that the number of GPU devices can be overridden in the command line using
+`--params_overrides`. The TPU does not need this override as the device is fixed
+by providing the TPU address or name with the `--tpu` flag.
+
diff --git a/models/official/vision/image_classification/__init__.py b/models/official/vision/image_classification/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/vision/image_classification/augment.py b/models/official/vision/image_classification/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ef23a229c80bcc1fec92d431996688dc34eaad
--- /dev/null
+++ b/models/official/vision/image_classification/augment.py
@@ -0,0 +1,999 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AutoAugment and RandAugment policies for enhanced image preprocessing.
+
+AutoAugment Reference: https://arxiv.org/abs/1805.09501
+RandAugment Reference: https://arxiv.org/abs/1909.13719
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import math
+import tensorflow as tf
+from typing import Any, Dict, List, Optional, Text, Tuple
+
+from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
+
+# This signifies the max integer that the controller RNN could predict for the
+# augmentation scheme.
+_MAX_LEVEL = 10.
+
+
+def to_4d(image: tf.Tensor) -> tf.Tensor:
+ """Converts an input Tensor to 4 dimensions.
+
+ 4D image => [N, H, W, C] or [N, C, H, W]
+ 3D image => [1, H, W, C] or [1, C, H, W]
+ 2D image => [1, H, W, 1]
+
+ Args:
+ image: The 2/3/4D input tensor.
+
+ Returns:
+ A 4D image tensor.
+
+ Raises:
+ `TypeError` if `image` is not a 2/3/4D tensor.
+
+ """
+ shape = tf.shape(image)
+ original_rank = tf.rank(image)
+ left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
+ right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
+ new_shape = tf.concat(
+ [
+ tf.ones(shape=left_pad, dtype=tf.int32),
+ shape,
+ tf.ones(shape=right_pad, dtype=tf.int32),
+ ],
+ axis=0,
+ )
+ return tf.reshape(image, new_shape)
+
+
+def from_4d(image: tf.Tensor, ndims: tf.Tensor) -> tf.Tensor:
+ """Converts a 4D image back to `ndims` rank."""
+ shape = tf.shape(image)
+ begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32)
+ end = 4 - tf.cast(tf.equal(ndims, 2), dtype=tf.int32)
+ new_shape = shape[begin:end]
+ return tf.reshape(image, new_shape)
+
+
+def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
+ """Converts translations to a projective transform.
+
+ The translation matrix looks like this:
+ [[1 0 -dx]
+ [0 1 -dy]
+ [0 0 1]]
+
+ Args:
+ translations: The 2-element list representing [dx, dy], or a matrix of
+ 2-element lists representing [dx dy] to translate for each image. The
+ shape must be static.
+
+ Returns:
+ The transformation matrix of shape (num_images, 8).
+
+ Raises:
+ `TypeError` if
+ - the shape of `translations` is not known or
+ - the shape of `translations` is not rank 1 or 2.
+
+ """
+ translations = tf.convert_to_tensor(translations, dtype=tf.float32)
+ if translations.get_shape().ndims is None:
+ raise TypeError('translations rank must be statically known')
+ elif len(translations.get_shape()) == 1:
+ translations = translations[None]
+ elif len(translations.get_shape()) != 2:
+ raise TypeError('translations should have rank 1 or 2.')
+ num_translations = tf.shape(translations)[0]
+
+ return tf.concat(
+ values=[
+ tf.ones((num_translations, 1), tf.dtypes.float32),
+ tf.zeros((num_translations, 1), tf.dtypes.float32),
+ -translations[:, 0, None],
+ tf.zeros((num_translations, 1), tf.dtypes.float32),
+ tf.ones((num_translations, 1), tf.dtypes.float32),
+ -translations[:, 1, None],
+ tf.zeros((num_translations, 2), tf.dtypes.float32),
+ ],
+ axis=1,
+ )
+
+
+def _convert_angles_to_transform(
+ angles: tf.Tensor,
+ image_width: tf.Tensor,
+ image_height: tf.Tensor) -> tf.Tensor:
+ """Converts an angle or angles to a projective transform.
+
+ Args:
+ angles: A scalar to rotate all images, or a vector to rotate a batch of
+ images. This must be a scalar.
+ image_width: The width of the image(s) to be transformed.
+ image_height: The height of the image(s) to be transformed.
+
+ Returns:
+ A tensor of shape (num_images, 8).
+
+ Raises:
+ `TypeError` if `angles` is not rank 0 or 1.
+
+ """
+ angles = tf.convert_to_tensor(angles, dtype=tf.float32)
+ if len(angles.get_shape()) == 0: # pylint:disable=g-explicit-length-test
+ angles = angles[None]
+ elif len(angles.get_shape()) != 1:
+ raise TypeError('Angles should have a rank 0 or 1.')
+ x_offset = ((image_width - 1) -
+ (tf.math.cos(angles) * (image_width - 1) - tf.math.sin(angles) *
+ (image_height - 1))) / 2.0
+ y_offset = ((image_height - 1) -
+ (tf.math.sin(angles) * (image_width - 1) + tf.math.cos(angles) *
+ (image_height - 1))) / 2.0
+ num_angles = tf.shape(angles)[0]
+ return tf.concat(
+ values=[
+ tf.math.cos(angles)[:, None],
+ -tf.math.sin(angles)[:, None],
+ x_offset[:, None],
+ tf.math.sin(angles)[:, None],
+ tf.math.cos(angles)[:, None],
+ y_offset[:, None],
+ tf.zeros((num_angles, 2), tf.dtypes.float32),
+ ],
+ axis=1,
+ )
+
+
+def transform(image: tf.Tensor, transforms) -> tf.Tensor:
+ """Prepares input data for `image_ops.transform`."""
+ original_ndims = tf.rank(image)
+ transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
+ if transforms.shape.rank == 1:
+ transforms = transforms[None]
+ image = to_4d(image)
+ image = image_ops.transform(
+ images=image,
+ transforms=transforms,
+ interpolation='nearest')
+ return from_4d(image, original_ndims)
+
+
+def translate(image: tf.Tensor, translations) -> tf.Tensor:
+ """Translates image(s) by provided vectors.
+
+ Args:
+ image: An image Tensor of type uint8.
+ translations: A vector or matrix representing [dx dy].
+
+ Returns:
+ The translated version of the image.
+
+ """
+ transforms = _convert_translation_to_transform(translations)
+ return transform(image, transforms=transforms)
+
+
+def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
+ """Rotates the image by degrees either clockwise or counterclockwise.
+
+ Args:
+ image: An image Tensor of type uint8.
+ degrees: Float, a scalar angle in degrees to rotate all images by. If
+ degrees is positive the image will be rotated clockwise otherwise it will
+ be rotated counterclockwise.
+
+ Returns:
+ The rotated version of image.
+
+ """
+ # Convert from degrees to radians.
+ degrees_to_radians = math.pi / 180.0
+ radians = tf.cast(degrees * degrees_to_radians, tf.float32)
+
+ original_ndims = tf.rank(image)
+ image = to_4d(image)
+
+ image_height = tf.cast(tf.shape(image)[1], tf.float32)
+ image_width = tf.cast(tf.shape(image)[2], tf.float32)
+ transforms = _convert_angles_to_transform(angles=radians,
+ image_width=image_width,
+ image_height=image_height)
+ # In practice, we should randomize the rotation degrees by flipping
+ # it negatively half the time, but that's done on 'degrees' outside
+ # of the function.
+ image = transform(image, transforms=transforms)
+ return from_4d(image, original_ndims)
+
+
+def blend(image1: tf.Tensor, image2: tf.Tensor, factor: float) -> tf.Tensor:
+ """Blend image1 and image2 using 'factor'.
+
+ Factor can be above 0.0. A value of 0.0 means only image1 is used.
+ A value of 1.0 means only image2 is used. A value between 0.0 and
+ 1.0 means we linearly interpolate the pixel values between the two
+ images. A value greater than 1.0 "extrapolates" the difference
+ between the two pixel values, and we clip the results to values
+ between 0 and 255.
+
+ Args:
+ image1: An image Tensor of type uint8.
+ image2: An image Tensor of type uint8.
+ factor: A floating point value above 0.0.
+
+ Returns:
+ A blended image Tensor of type uint8.
+ """
+ if factor == 0.0:
+ return tf.convert_to_tensor(image1)
+ if factor == 1.0:
+ return tf.convert_to_tensor(image2)
+
+ image1 = tf.cast(image1, tf.float32)
+ image2 = tf.cast(image2, tf.float32)
+
+ difference = image2 - image1
+ scaled = factor * difference
+
+ # Do addition in float.
+ temp = tf.cast(image1, tf.float32) + scaled
+
+ # Interpolate
+ if factor > 0.0 and factor < 1.0:
+ # Interpolation means we always stay within 0 and 255.
+ return tf.cast(temp, tf.uint8)
+
+ # Extrapolate:
+ #
+ # We need to clip and then cast.
+ return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)
+
+
+def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
+ """Apply cutout (https://arxiv.org/abs/1708.04552) to image.
+
+ This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
+ a random location within `img`. The pixel values filled in will be of the
+ value `replace`. The located where the mask will be applied is randomly
+ chosen uniformly over the whole image.
+
+ Args:
+ image: An image Tensor of type uint8.
+ pad_size: Specifies how big the zero mask that will be generated is that
+ is applied to the image. The mask will be of size
+ (2*pad_size x 2*pad_size).
+ replace: What pixel value to fill in the image in the area that has
+ the cutout mask applied to it.
+
+ Returns:
+ An image Tensor that is of type uint8.
+ """
+ image_height = tf.shape(image)[0]
+ image_width = tf.shape(image)[1]
+
+ # Sample the center location in the image where the zero mask will be applied.
+ cutout_center_height = tf.random.uniform(
+ shape=[], minval=0, maxval=image_height,
+ dtype=tf.int32)
+
+ cutout_center_width = tf.random.uniform(
+ shape=[], minval=0, maxval=image_width,
+ dtype=tf.int32)
+
+ lower_pad = tf.maximum(0, cutout_center_height - pad_size)
+ upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
+ left_pad = tf.maximum(0, cutout_center_width - pad_size)
+ right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
+
+ cutout_shape = [image_height - (lower_pad + upper_pad),
+ image_width - (left_pad + right_pad)]
+ padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
+ mask = tf.pad(
+ tf.zeros(cutout_shape, dtype=image.dtype),
+ padding_dims, constant_values=1)
+ mask = tf.expand_dims(mask, -1)
+ mask = tf.tile(mask, [1, 1, 3])
+ image = tf.where(
+ tf.equal(mask, 0),
+ tf.ones_like(image, dtype=image.dtype) * replace,
+ image)
+ return image
+
+
+def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
+ # For each pixel in the image, select the pixel
+ # if the value is less than the threshold.
+ # Otherwise, subtract 255 from the pixel.
+ return tf.where(image < threshold, image, 255 - image)
+
+
+def solarize_add(image: tf.Tensor,
+ addition: int = 0,
+ threshold: int = 128) -> tf.Tensor:
+ # For each pixel in the image less than threshold
+ # we add 'addition' amount to it and then clip the
+ # pixel value to be between 0 and 255. The value
+ # of 'addition' is between -128 and 128.
+ added_image = tf.cast(image, tf.int64) + addition
+ added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
+ return tf.where(image < threshold, added_image, image)
+
+
+def color(image: tf.Tensor, factor: float) -> tf.Tensor:
+ """Equivalent of PIL Color."""
+ degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
+ return blend(degenerate, image, factor)
+
+
+def contrast(image: tf.Tensor, factor: float) -> tf.Tensor:
+ """Equivalent of PIL Contrast."""
+ degenerate = tf.image.rgb_to_grayscale(image)
+ # Cast before calling tf.histogram.
+ degenerate = tf.cast(degenerate, tf.int32)
+
+ # Compute the grayscale histogram, then compute the mean pixel value,
+ # and create a constant image size of that value. Use that as the
+ # blending degenerate target of the original image.
+ hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
+ mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
+ degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
+ degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
+ degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
+ return blend(degenerate, image, factor)
+
+
+def brightness(image: tf.Tensor, factor: float) -> tf.Tensor:
+ """Equivalent of PIL Brightness."""
+ degenerate = tf.zeros_like(image)
+ return blend(degenerate, image, factor)
+
+
+def posterize(image: tf.Tensor, bits: int) -> tf.Tensor:
+ """Equivalent of PIL Posterize."""
+ shift = 8 - bits
+ return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
+
+
+def wrapped_rotate(image: tf.Tensor, degrees: float, replace: int) -> tf.Tensor:
+ """Applies rotation with wrap/unwrap."""
+ image = rotate(wrap(image), degrees=degrees)
+ return unwrap(image, replace)
+
+
+def translate_x(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
+ """Equivalent of PIL Translate in X dimension."""
+ image = translate(wrap(image), [-pixels, 0])
+ return unwrap(image, replace)
+
+
+def translate_y(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
+ """Equivalent of PIL Translate in Y dimension."""
+ image = translate(wrap(image), [0, -pixels])
+ return unwrap(image, replace)
+
+
+def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
+ """Equivalent of PIL Shearing in X dimension."""
+ # Shear parallel to x axis is a projective transform
+ # with a matrix form of:
+ # [1 level
+ # 0 1].
+ image = transform(image=wrap(image),
+ transforms=[1., level, 0., 0., 1., 0., 0., 0.])
+ return unwrap(image, replace)
+
+
+def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
+ """Equivalent of PIL Shearing in Y dimension."""
+ # Shear parallel to y axis is a projective transform
+ # with a matrix form of:
+ # [1 0
+ # level 1].
+ image = transform(image=wrap(image),
+ transforms=[1., 0., 0., level, 1., 0., 0., 0.])
+ return unwrap(image, replace)
+
+
+def autocontrast(image: tf.Tensor) -> tf.Tensor:
+ """Implements Autocontrast function from PIL using TF ops.
+
+ Args:
+ image: A 3D uint8 tensor.
+
+ Returns:
+ The image after it has had autocontrast applied to it and will be of type
+ uint8.
+ """
+
+ def scale_channel(image: tf.Tensor) -> tf.Tensor:
+ """Scale the 2D image using the autocontrast rule."""
+ # A possibly cheaper version can be done using cumsum/unique_with_counts
+ # over the histogram values, rather than iterating over the entire image.
+ # to compute mins and maxes.
+ lo = tf.cast(tf.reduce_min(image), tf.float32)
+ hi = tf.cast(tf.reduce_max(image), tf.float32)
+
+ # Scale the image, making the lowest value 0 and the highest value 255.
+ def scale_values(im):
+ scale = 255.0 / (hi - lo)
+ offset = -lo * scale
+ im = tf.cast(im, tf.float32) * scale + offset
+ im = tf.clip_by_value(im, 0.0, 255.0)
+ return tf.cast(im, tf.uint8)
+
+ result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
+ return result
+
+ # Assumes RGB for now. Scales each channel independently
+ # and then stacks the result.
+ s1 = scale_channel(image[:, :, 0])
+ s2 = scale_channel(image[:, :, 1])
+ s3 = scale_channel(image[:, :, 2])
+ image = tf.stack([s1, s2, s3], 2)
+ return image
+
+
+def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
+ """Implements Sharpness function from PIL using TF ops."""
+ orig_image = image
+ image = tf.cast(image, tf.float32)
+ # Make image 4D for conv operation.
+ image = tf.expand_dims(image, 0)
+ # SMOOTH PIL Kernel.
+ kernel = tf.constant(
+ [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32,
+ shape=[3, 3, 1, 1]) / 13.
+ # Tile across channel dimension.
+ kernel = tf.tile(kernel, [1, 1, 3, 1])
+ strides = [1, 1, 1, 1]
+ degenerate = tf.nn.depthwise_conv2d(
+ image, kernel, strides, padding='VALID', dilations=[1, 1])
+ degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
+ degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
+
+ # For the borders of the resulting image, fill in the values of the
+ # original image.
+ mask = tf.ones_like(degenerate)
+ padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
+ padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
+ result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
+
+ # Blend the final result.
+ return blend(result, orig_image, factor)
+
+
+def equalize(image: tf.Tensor) -> tf.Tensor:
+ """Implements Equalize function from PIL using TF ops."""
+ def scale_channel(im, c):
+ """Scale the data in the channel to implement equalize."""
+ im = tf.cast(im[:, :, c], tf.int32)
+ # Compute the histogram of the image channel.
+ histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)
+
+ # For the purposes of computing the step, filter out the nonzeros.
+ nonzero = tf.where(tf.not_equal(histo, 0))
+ nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
+ step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
+
+ def build_lut(histo, step):
+ # Compute the cumulative sum, shifting by step // 2
+ # and then normalization by step.
+ lut = (tf.cumsum(histo) + (step // 2)) // step
+ # Shift lut, prepending with 0.
+ lut = tf.concat([[0], lut[:-1]], 0)
+ # Clip the counts to be in range. This is done
+ # in the C code for image.point.
+ return tf.clip_by_value(lut, 0, 255)
+
+ # If step is zero, return the original image. Otherwise, build
+ # lut from the full histogram and step and then index from it.
+ result = tf.cond(tf.equal(step, 0),
+ lambda: im,
+ lambda: tf.gather(build_lut(histo, step), im))
+
+ return tf.cast(result, tf.uint8)
+
+ # Assumes RGB for now. Scales each channel independently
+ # and then stacks the result.
+ s1 = scale_channel(image, 0)
+ s2 = scale_channel(image, 1)
+ s3 = scale_channel(image, 2)
+ image = tf.stack([s1, s2, s3], 2)
+ return image
+
+
+def invert(image: tf.Tensor) -> tf.Tensor:
+ """Inverts the image pixels."""
+ image = tf.convert_to_tensor(image)
+ return 255 - image
+
+
+def wrap(image: tf.Tensor) -> tf.Tensor:
+ """Returns 'image' with an extra channel set to all 1s."""
+ shape = tf.shape(image)
+ extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype)
+ extended = tf.concat([image, extended_channel], axis=2)
+ return extended
+
+
+def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
+ """Unwraps an image produced by wrap.
+
+ Where there is a 0 in the last channel for every spatial position,
+ the rest of the three channels in that spatial dimension are grayed
+ (set to 128). Operations like translate and shear on a wrapped
+ Tensor will leave 0s in empty locations. Some transformations look
+ at the intensity of values to do preprocessing, and we want these
+ empty pixels to assume the 'average' value, rather than pure black.
+
+
+ Args:
+ image: A 3D Image Tensor with 4 channels.
+ replace: A one or three value 1D tensor to fill empty pixels.
+
+ Returns:
+ image: A 3D image Tensor with 3 channels.
+ """
+ image_shape = tf.shape(image)
+ # Flatten the spatial dimensions.
+ flattened_image = tf.reshape(image, [-1, image_shape[2]])
+
+ # Find all pixels where the last channel is zero.
+ alpha_channel = tf.expand_dims(flattened_image[:, 3], axis=-1)
+
+ replace = tf.concat([replace, tf.ones([1], image.dtype)], 0)
+
+ # Where they are zero, fill them in with 'replace'.
+ flattened_image = tf.where(
+ tf.equal(alpha_channel, 0),
+ tf.ones_like(flattened_image, dtype=image.dtype) * replace,
+ flattened_image)
+
+ image = tf.reshape(flattened_image, image_shape)
+ image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
+ return image
+
+
+def _randomly_negate_tensor(tensor):
+ """With 50% prob turn the tensor negative."""
+ should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
+ final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
+ return final_tensor
+
+
+def _rotate_level_to_arg(level: float):
+ level = (level/_MAX_LEVEL) * 30.
+ level = _randomly_negate_tensor(level)
+ return (level,)
+
+
+def _shrink_level_to_arg(level: float):
+ """Converts level to ratio by which we shrink the image content."""
+ if level == 0:
+ return (1.0,) # if level is zero, do not shrink the image
+ # Maximum shrinking ratio is 2.9.
+ level = 2. / (_MAX_LEVEL / level) + 0.9
+ return (level,)
+
+
+def _enhance_level_to_arg(level: float):
+ return ((level/_MAX_LEVEL) * 1.8 + 0.1,)
+
+
+def _shear_level_to_arg(level: float):
+ level = (level/_MAX_LEVEL) * 0.3
+ # Flip level to negative with 50% chance.
+ level = _randomly_negate_tensor(level)
+ return (level,)
+
+
+def _translate_level_to_arg(level: float, translate_const: float):
+ level = (level/_MAX_LEVEL) * float(translate_const)
+ # Flip level to negative with 50% chance.
+ level = _randomly_negate_tensor(level)
+ return (level,)
+
+
+def _mult_to_arg(level: float, multiplier: float = 1.):
+ return (int((level / _MAX_LEVEL) * multiplier),)
+
+
+def _apply_func_with_prob(func: Any,
+ image: tf.Tensor,
+ args: Any,
+ prob: float):
+ """Apply `func` to image w/ `args` as input with probability `prob`."""
+ assert isinstance(args, tuple)
+
+ # Apply the function with probability `prob`.
+ should_apply_op = tf.cast(
+ tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
+ augmented_image = tf.cond(
+ should_apply_op,
+ lambda: func(image, *args),
+ lambda: image)
+ return augmented_image
+
+
+def select_and_apply_random_policy(policies: Any, image: tf.Tensor):
+ """Select a random policy from `policies` and apply it to `image`."""
+ policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32)
+ # Note that using tf.case instead of tf.conds would result in significantly
+ # larger graphs and would even break export for some larger policies.
+ for (i, policy) in enumerate(policies):
+ image = tf.cond(
+ tf.equal(i, policy_to_select),
+ lambda selected_policy=policy: selected_policy(image),
+ lambda: image)
+ return image
+
+
+NAME_TO_FUNC = {
+ 'AutoContrast': autocontrast,
+ 'Equalize': equalize,
+ 'Invert': invert,
+ 'Rotate': wrapped_rotate,
+ 'Posterize': posterize,
+ 'Solarize': solarize,
+ 'SolarizeAdd': solarize_add,
+ 'Color': color,
+ 'Contrast': contrast,
+ 'Brightness': brightness,
+ 'Sharpness': sharpness,
+ 'ShearX': shear_x,
+ 'ShearY': shear_y,
+ 'TranslateX': translate_x,
+ 'TranslateY': translate_y,
+ 'Cutout': cutout,
+}
+
+# Functions that have a 'replace' parameter
+REPLACE_FUNCS = frozenset({
+ 'Rotate',
+ 'TranslateX',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateY',
+ 'Cutout',
+})
+
+
+def level_to_arg(cutout_const: float, translate_const: float):
+ """Creates a dict mapping image operation names to their arguments."""
+
+ no_arg = lambda level: ()
+ posterize_arg = lambda level: _mult_to_arg(level, 4)
+ solarize_arg = lambda level: _mult_to_arg(level, 256)
+ solarize_add_arg = lambda level: _mult_to_arg(level, 110)
+ cutout_arg = lambda level: _mult_to_arg(level, cutout_const)
+ translate_arg = lambda level: _translate_level_to_arg(level, translate_const)
+
+ args = {
+ 'AutoContrast': no_arg,
+ 'Equalize': no_arg,
+ 'Invert': no_arg,
+ 'Rotate': _rotate_level_to_arg,
+ 'Posterize': posterize_arg,
+ 'Solarize': solarize_arg,
+ 'SolarizeAdd': solarize_add_arg,
+ 'Color': _enhance_level_to_arg,
+ 'Contrast': _enhance_level_to_arg,
+ 'Brightness': _enhance_level_to_arg,
+ 'Sharpness': _enhance_level_to_arg,
+ 'ShearX': _shear_level_to_arg,
+ 'ShearY': _shear_level_to_arg,
+ 'Cutout': cutout_arg,
+ 'TranslateX': translate_arg,
+ 'TranslateY': translate_arg,
+ }
+ return args
+
+
+def _parse_policy_info(name: Text,
+ prob: float,
+ level: float,
+ replace_value: List[int],
+ cutout_const: float,
+ translate_const: float) -> Tuple[Any, float, Any]:
+ """Return the function that corresponds to `name` and update `level` param."""
+ func = NAME_TO_FUNC[name]
+ args = level_to_arg(cutout_const, translate_const)[name](level)
+
+ if name in REPLACE_FUNCS:
+ # Add in replace arg if it is required for the function that is called.
+ args = tuple(list(args) + [replace_value])
+
+ return func, prob, args
+
+
+class ImageAugment(object):
+ """Image augmentation class for applying image distortions."""
+
+ def distort(self, image: tf.Tensor) -> tf.Tensor:
+ """Given an image tensor, returns a distorted image with the same shape.
+
+ Args:
+ image: `Tensor` of shape [height, width, 3] representing an image.
+
+ Returns:
+ The augmented version of `image`.
+ """
+ raise NotImplementedError()
+
+
+class AutoAugment(ImageAugment):
+ """Applies the AutoAugment policy to images.
+
+ AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
+ """
+
+ def __init__(self,
+ augmentation_name: Text = 'v0',
+ policies: Optional[Dict[Text, Any]] = None,
+ cutout_const: float = 100,
+ translate_const: float = 250):
+ """Applies the AutoAugment policy to images.
+
+ Args:
+ augmentation_name: The name of the AutoAugment policy to use. The
+ available options are `v0` and `test`. `v0` is the policy used for all
+ of the results in the paper and was found to achieve the best results on
+ the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
+ on the COCO dataset that have slight variation in what operations were
+ used during the search procedure along with how many operations are
+ applied in parallel to a single image (2 vs 3).
+ policies: list of lists of tuples in the form `(func, prob, level)`,
+ `func` is a string name of the augmentation function, `prob` is the
+ probability of applying the `func` operation, `level` is the input
+ argument for `func`.
+ cutout_const: multiplier for applying cutout.
+ translate_const: multiplier for applying translation.
+ """
+ super(AutoAugment, self).__init__()
+
+ if policies is None:
+ self.available_policies = {
+ 'v0': self.policy_v0(),
+ 'test': self.policy_test(),
+ 'simple': self.policy_simple(),
+ }
+
+ if augmentation_name not in self.available_policies:
+ raise ValueError(
+ 'Invalid augmentation_name: {}'.format(augmentation_name))
+
+ self.augmentation_name = augmentation_name
+ self.policies = self.available_policies[augmentation_name]
+ self.cutout_const = float(cutout_const)
+ self.translate_const = float(translate_const)
+
+ def distort(self, image: tf.Tensor) -> tf.Tensor:
+ """Applies the AutoAugment policy to `image`.
+
+ AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
+
+ Args:
+ image: `Tensor` of shape [height, width, 3] representing an image.
+
+ Returns:
+ A version of image that now has data augmentation applied to it based on
+ the `policies` pass into the function.
+ """
+ input_image_type = image.dtype
+
+ if input_image_type != tf.uint8:
+ image = tf.clip_by_value(image, 0.0, 255.0)
+ image = tf.cast(image, dtype=tf.uint8)
+
+ replace_value = [128] * 3
+
+ # func is the string name of the augmentation function, prob is the
+ # probability of applying the operation and level is the parameter
+ # associated with the tf op.
+
+ # tf_policies are functions that take in an image and return an augmented
+ # image.
+ tf_policies = []
+ for policy in self.policies:
+ tf_policy = []
+ # Link string name to the correct python function and make sure the
+ # correct argument is passed into that function.
+ for policy_info in policy:
+ policy_info = list(policy_info) + [
+ replace_value, self.cutout_const, self.translate_const
+ ]
+ tf_policy.append(_parse_policy_info(*policy_info))
+ # Now build the tf policy that will apply the augmentation procedue
+ # on image.
+ def make_final_policy(tf_policy_):
+
+ def final_policy(image_):
+ for func, prob, args in tf_policy_:
+ image_ = _apply_func_with_prob(func, image_, args, prob)
+ return image_
+
+ return final_policy
+
+ tf_policies.append(make_final_policy(tf_policy))
+
+ image = select_and_apply_random_policy(tf_policies, image)
+ image = tf.cast(image, dtype=input_image_type)
+ return image
+
+ @staticmethod
+ def policy_v0():
+ """Autoaugment policy that was used in AutoAugment Paper.
+
+ Each tuple is an augmentation operation of the form
+ (operation, probability, magnitude). Each element in policy is a
+ sub-policy that will be applied sequentially on the image.
+
+ Returns:
+ the policy.
+ """
+
+ # TODO(dankondratyuk): tensorflow_addons defines custom ops, which
+ # for some reason are not included when building/linking
+ # This results in the error, "Op type not registered
+ # 'Addons>ImageProjectiveTransformV2' in binary" when running on borg TPUs
+ policy = [
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+ [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+ ]
+ return policy
+
+ @staticmethod
+ def policy_simple():
+ """Same as `policy_v0`, except with custom ops removed."""
+
+ policy = [
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ ]
+ return policy
+
+ @staticmethod
+ def policy_test():
+ """Autoaugment test policy for debugging."""
+ policy = [
+ [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)],
+ ]
+ return policy
+
+
+class RandAugment(ImageAugment):
+ """Applies the RandAugment policy to images.
+
+ RandAugment is from the paper https://arxiv.org/abs/1909.13719,
+ """
+
+ def __init__(self,
+ num_layers: int = 2,
+ magnitude: float = 10.,
+ cutout_const: float = 40.,
+ translate_const: float = 100.):
+ """Applies the RandAugment policy to images.
+
+ Args:
+ num_layers: Integer, the number of augmentation transformations to apply
+ sequentially to an image. Represented as (N) in the paper. Usually best
+ values will be in the range [1, 3].
+ magnitude: Integer, shared magnitude across all augmentation operations.
+ Represented as (M) in the paper. Usually best values are in the range
+ [5, 10].
+ cutout_const: multiplier for applying cutout.
+ translate_const: multiplier for applying translation.
+ """
+ super(RandAugment, self).__init__()
+
+ self.num_layers = num_layers
+ self.magnitude = float(magnitude)
+ self.cutout_const = float(cutout_const)
+ self.translate_const = float(translate_const)
+ self.available_ops = [
+ 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize',
+ 'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
+ 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'
+ ]
+
+ def distort(self, image: tf.Tensor) -> tf.Tensor:
+ """Applies the RandAugment policy to `image`.
+
+ Args:
+ image: `Tensor` of shape [height, width, 3] representing an image.
+
+ Returns:
+ The augmented version of `image`.
+ """
+ input_image_type = image.dtype
+
+ if input_image_type != tf.uint8:
+ image = tf.clip_by_value(image, 0.0, 255.0)
+ image = tf.cast(image, dtype=tf.uint8)
+
+ replace_value = [128] * 3
+ min_prob, max_prob = 0.2, 0.8
+
+ for _ in range(self.num_layers):
+ op_to_select = tf.random.uniform(
+ [], maxval=len(self.available_ops) + 1, dtype=tf.int32)
+
+ branch_fns = []
+ for (i, op_name) in enumerate(self.available_ops):
+ prob = tf.random.uniform([],
+ minval=min_prob,
+ maxval=max_prob,
+ dtype=tf.float32)
+ func, _, args = _parse_policy_info(op_name,
+ prob,
+ self.magnitude,
+ replace_value,
+ self.cutout_const,
+ self.translate_const)
+ branch_fns.append((
+ i,
+ # pylint:disable=g-long-lambda
+ lambda selected_func=func, selected_args=args: selected_func(
+ image, *selected_args)))
+ # pylint:enable=g-long-lambda
+
+ image = tf.switch_case(branch_index=op_to_select,
+ branch_fns=branch_fns,
+ default=lambda: tf.identity(image))
+
+ image = tf.cast(image, dtype=input_image_type)
+ return image
diff --git a/models/official/vision/image_classification/augment_test.py b/models/official/vision/image_classification/augment_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..76bdb2b7b9db4fc109f39674c68ae0c1169f3f12
--- /dev/null
+++ b/models/official/vision/image_classification/augment_test.py
@@ -0,0 +1,143 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for autoaugment."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+import tensorflow as tf
+
+from official.vision.image_classification import augment
+
+
+def get_dtype_test_cases():
+ return [
+ ('uint8', tf.uint8),
+ ('int32', tf.int32),
+ ('float16', tf.float16),
+ ('float32', tf.float32),
+ ]
+
+
+@parameterized.named_parameters(get_dtype_test_cases())
+class TransformsTest(parameterized.TestCase, tf.test.TestCase):
+ """Basic tests for fundamental transformations."""
+
+ def test_to_from_4d(self, dtype):
+ for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]:
+ original_ndims = len(shape)
+ image = tf.zeros(shape, dtype=dtype)
+ image_4d = augment.to_4d(image)
+ self.assertEqual(4, tf.rank(image_4d))
+ self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims))
+
+ def test_transform(self, dtype):
+ image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
+ self.assertAllEqual(augment.transform(image, transforms=[1]*8),
+ [[4, 4], [4, 4]])
+
+ def test_translate(self, dtype):
+ image = tf.constant(
+ [[1, 0, 1, 0],
+ [0, 1, 0, 1],
+ [1, 0, 1, 0],
+ [0, 1, 0, 1]],
+ dtype=dtype)
+ translations = [-1, -1]
+ translated = augment.translate(image=image,
+ translations=translations)
+ expected = [
+ [1, 0, 1, 1],
+ [0, 1, 0, 0],
+ [1, 0, 1, 1],
+ [1, 0, 1, 1]]
+ self.assertAllEqual(translated, expected)
+
+ def test_translate_shapes(self, dtype):
+ translation = [0, 0]
+ for shape in [(3, 3), (5, 5), (224, 224, 3)]:
+ image = tf.zeros(shape, dtype=dtype)
+ self.assertAllEqual(image, augment.translate(image, translation))
+
+ def test_translate_invalid_translation(self, dtype):
+ image = tf.zeros((1, 1), dtype=dtype)
+ invalid_translation = [[[1, 1]]]
+ with self.assertRaisesRegex(TypeError, 'rank 1 or 2'):
+ _ = augment.translate(image, invalid_translation)
+
+ def test_rotate(self, dtype):
+ image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
+ rotation = 90.
+ transformed = augment.rotate(image=image, degrees=rotation)
+ expected = [[2, 5, 8],
+ [1, 4, 7],
+ [0, 3, 6]]
+ self.assertAllEqual(transformed, expected)
+
+ def test_rotate_shapes(self, dtype):
+ degrees = 0.
+ for shape in [(3, 3), (5, 5), (224, 224, 3)]:
+ image = tf.zeros(shape, dtype=dtype)
+ self.assertAllEqual(image, augment.rotate(image, degrees))
+
+
+class AutoaugmentTest(tf.test.TestCase):
+
+ def test_autoaugment(self):
+ """Smoke test to be sure there are no syntax errors."""
+ image = tf.zeros((224, 224, 3), dtype=tf.uint8)
+
+ augmenter = augment.AutoAugment()
+ aug_image = augmenter.distort(image)
+
+ self.assertEqual((224, 224, 3), aug_image.shape)
+
+ def test_randaug(self):
+ """Smoke test to be sure there are no syntax errors."""
+ image = tf.zeros((224, 224, 3), dtype=tf.uint8)
+
+ augmenter = augment.RandAugment()
+ aug_image = augmenter.distort(image)
+
+ self.assertEqual((224, 224, 3), aug_image.shape)
+
+ def test_all_policy_ops(self):
+ """Smoke test to be sure all augmentation functions can execute."""
+
+ prob = 1
+ magnitude = 10
+ replace_value = [128] * 3
+ cutout_const = 100
+ translate_const = 250
+
+ image = tf.ones((224, 224, 3), dtype=tf.uint8)
+
+ for op_name in augment.NAME_TO_FUNC:
+ func, _, args = augment._parse_policy_info(op_name,
+ prob,
+ magnitude,
+ replace_value,
+ cutout_const,
+ translate_const)
+ image = func(image, *args)
+
+ self.assertEqual((224, 224, 3), image.shape)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/vision/image_classification/callbacks.py b/models/official/vision/image_classification/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..985d0c60cc0b866e10ad350986c004e4ea4ac161
--- /dev/null
+++ b/models/official/vision/image_classification/callbacks.py
@@ -0,0 +1,258 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common modules for callbacks."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+from typing import Any, List, MutableMapping, Text
+from absl import logging
+import tensorflow as tf
+
+from official.utils.misc import keras_utils
+from official.vision.image_classification import optimizer_factory
+
+
+def get_callbacks(model_checkpoint: bool = True,
+ include_tensorboard: bool = True,
+ time_history: bool = True,
+ track_lr: bool = True,
+ write_model_weights: bool = True,
+ apply_moving_average: bool = False,
+ initial_step: int = 0,
+ batch_size: int = 0,
+ log_steps: int = 0,
+ model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
+ """Get all callbacks."""
+ model_dir = model_dir or ''
+ callbacks = []
+ if model_checkpoint:
+ ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
+ callbacks.append(tf.keras.callbacks.ModelCheckpoint(
+ ckpt_full_path, save_weights_only=True, verbose=1))
+ if include_tensorboard:
+ callbacks.append(
+ CustomTensorBoard(
+ log_dir=model_dir,
+ track_lr=track_lr,
+ initial_step=initial_step,
+ write_images=write_model_weights))
+ if time_history:
+ callbacks.append(
+ keras_utils.TimeHistory(
+ batch_size,
+ log_steps,
+ logdir=model_dir if include_tensorboard else None))
+ if apply_moving_average:
+ # Save moving average model to a different file so that
+ # we can resume training from a checkpoint
+ ckpt_full_path = os.path.join(
+ model_dir, 'average', 'model.ckpt-{epoch:04d}')
+ callbacks.append(AverageModelCheckpoint(
+ update_weights=False,
+ filepath=ckpt_full_path,
+ save_weights_only=True,
+ verbose=1))
+ callbacks.append(MovingAverageCallback())
+ return callbacks
+
+
+def get_scalar_from_tensor(t: tf.Tensor) -> int:
+ """Utility function to convert a Tensor to a scalar."""
+ t = tf.keras.backend.get_value(t)
+ if callable(t):
+ return t()
+ else:
+ return t
+
+
+class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
+ """A customized TensorBoard callback that tracks additional datapoints.
+
+ Metrics tracked:
+ - Global learning rate
+
+ Attributes:
+ log_dir: the path of the directory where to save the log files to be parsed
+ by TensorBoard.
+ track_lr: `bool`, whether or not to track the global learning rate.
+ initial_step: the initial step, used for preemption recovery.
+ **kwargs: Additional arguments for backwards compatibility. Possible key is
+ `period`.
+ """
+
+ # TODO(b/146499062): track params, flops, log lr, l2 loss,
+ # classification loss
+
+ def __init__(self,
+ log_dir: str,
+ track_lr: bool = False,
+ initial_step: int = 0,
+ **kwargs):
+ super(CustomTensorBoard, self).__init__(log_dir=log_dir, **kwargs)
+ self.step = initial_step
+ self._track_lr = track_lr
+
+ def on_batch_begin(self,
+ epoch: int,
+ logs: MutableMapping[str, Any] = None) -> None:
+ self.step += 1
+ if logs is None:
+ logs = {}
+ logs.update(self._calculate_metrics())
+ super(CustomTensorBoard, self).on_batch_begin(epoch, logs)
+
+ def on_epoch_begin(self,
+ epoch: int,
+ logs: MutableMapping[str, Any] = None) -> None:
+ if logs is None:
+ logs = {}
+ metrics = self._calculate_metrics()
+ logs.update(metrics)
+ for k, v in metrics.items():
+ logging.info('Current %s: %f', k, v)
+ super(CustomTensorBoard, self).on_epoch_begin(epoch, logs)
+
+ def on_epoch_end(self,
+ epoch: int,
+ logs: MutableMapping[str, Any] = None) -> None:
+ if logs is None:
+ logs = {}
+ metrics = self._calculate_metrics()
+ logs.update(metrics)
+ super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
+
+ def _calculate_metrics(self) -> MutableMapping[str, Any]:
+ logs = {}
+ # TODO(b/149030439): disable LR reporting.
+ # if self._track_lr:
+ # logs['learning_rate'] = self._calculate_lr()
+ return logs
+
+ def _calculate_lr(self) -> int:
+ """Calculates the learning rate given the current step."""
+ return get_scalar_from_tensor(
+ self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32)) # pylint:disable=protected-access
+
+ def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
+ """Get the base optimizer used by the current model."""
+
+ optimizer = self.model.optimizer
+
+ # The optimizer might be wrapped by another class, so unwrap it
+ while hasattr(optimizer, '_optimizer'):
+ optimizer = optimizer._optimizer # pylint:disable=protected-access
+
+ return optimizer
+
+
+class MovingAverageCallback(tf.keras.callbacks.Callback):
+ """A Callback to be used with a `MovingAverage` optimizer.
+
+ Applies moving average weights to the model during validation time to test
+ and predict on the averaged weights rather than the current model weights.
+ Once training is complete, the model weights will be overwritten with the
+ averaged weights (by default).
+
+ Attributes:
+ overwrite_weights_on_train_end: Whether to overwrite the current model
+ weights with the averaged weights from the moving average optimizer.
+ **kwargs: Any additional callback arguments.
+ """
+
+ def __init__(self,
+ overwrite_weights_on_train_end: bool = False,
+ **kwargs):
+ super(MovingAverageCallback, self).__init__(**kwargs)
+ self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
+
+ def set_model(self, model: tf.keras.Model):
+ super(MovingAverageCallback, self).set_model(model)
+ assert isinstance(self.model.optimizer,
+ optimizer_factory.MovingAverage)
+ self.model.optimizer.shadow_copy(self.model)
+
+ def on_test_begin(self, logs: MutableMapping[Text, Any] = None):
+ self.model.optimizer.swap_weights()
+
+ def on_test_end(self, logs: MutableMapping[Text, Any] = None):
+ self.model.optimizer.swap_weights()
+
+ def on_train_end(self, logs: MutableMapping[Text, Any] = None):
+ if self.overwrite_weights_on_train_end:
+ self.model.optimizer.assign_average_vars(self.model.variables)
+
+
+class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
+ """Saves and, optionally, assigns the averaged weights.
+
+ Taken from tfa.callbacks.AverageModelCheckpoint.
+
+ Attributes:
+ update_weights: If True, assign the moving average weights
+ to the model, and save them. If False, keep the old
+ non-averaged weights, but the saved model uses the
+ average weights.
+ See `tf.keras.callbacks.ModelCheckpoint` for the other args.
+ """
+
+ def __init__(
+ self,
+ update_weights: bool,
+ filepath: str,
+ monitor: str = 'val_loss',
+ verbose: int = 0,
+ save_best_only: bool = False,
+ save_weights_only: bool = False,
+ mode: str = 'auto',
+ save_freq: str = 'epoch',
+ **kwargs):
+ self.update_weights = update_weights
+ super().__init__(
+ filepath,
+ monitor,
+ verbose,
+ save_best_only,
+ save_weights_only,
+ mode,
+ save_freq,
+ **kwargs)
+
+ def set_model(self, model):
+ if not isinstance(model.optimizer, optimizer_factory.MovingAverage):
+ raise TypeError(
+ 'AverageModelCheckpoint is only used when training'
+ 'with MovingAverage')
+ return super().set_model(model)
+
+ def _save_model(self, epoch, logs):
+ assert isinstance(self.model.optimizer, optimizer_factory.MovingAverage)
+
+ if self.update_weights:
+ self.model.optimizer.assign_average_vars(self.model.variables)
+ return super()._save_model(epoch, logs)
+ else:
+ # Note: `model.get_weights()` gives us the weights (non-ref)
+ # whereas `model.variables` returns references to the variables.
+ non_avg_weights = self.model.get_weights()
+ self.model.optimizer.assign_average_vars(self.model.variables)
+ # result is currently None, since `super._save_model` doesn't
+ # return anything, but this may change in the future.
+ result = super()._save_model(epoch, logs)
+ self.model.set_weights(non_avg_weights)
+ return result
diff --git a/models/official/vision/image_classification/classifier_trainer.py b/models/official/vision/image_classification/classifier_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e5ea468c9a0895658c28e89b8537e0056148fa0
--- /dev/null
+++ b/models/official/vision/image_classification/classifier_trainer.py
@@ -0,0 +1,456 @@
+# Lint as: python3
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Runs an Image Classification model."""
+
+import os
+import pprint
+from typing import Any, Tuple, Text, Optional, Mapping
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+
+from official.modeling import hyperparams
+from official.modeling import performance
+from official.utils import hyperparams_flags
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+from official.vision.image_classification import callbacks as custom_callbacks
+from official.vision.image_classification import dataset_factory
+from official.vision.image_classification import optimizer_factory
+from official.vision.image_classification.configs import base_configs
+from official.vision.image_classification.configs import configs
+from official.vision.image_classification.efficientnet import efficientnet_model
+from official.vision.image_classification.resnet import common
+from official.vision.image_classification.resnet import resnet_model
+
+
+def get_models() -> Mapping[str, tf.keras.Model]:
+ """Returns the mapping from model type name to Keras model."""
+ return {
+ 'efficientnet': efficientnet_model.EfficientNet.from_name,
+ 'resnet': resnet_model.resnet50,
+ }
+
+
+def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
+ """Returns the mapping from dtype string representations to TF dtypes."""
+ return {
+ 'float32': tf.float32,
+ 'bfloat16': tf.bfloat16,
+ 'float16': tf.float16,
+ 'fp32': tf.float32,
+ 'bf16': tf.bfloat16,
+ }
+
+
+def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
+ """Get a dict of available metrics to track."""
+ if one_hot:
+ return {
+ # (name, metric_fn)
+ 'acc': tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
+ 'accuracy': tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
+ 'top_1': tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
+ 'top_5': tf.keras.metrics.TopKCategoricalAccuracy(
+ k=5,
+ name='top_5_accuracy'),
+ }
+ else:
+ return {
+ # (name, metric_fn)
+ 'acc': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
+ 'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
+ 'top_1': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
+ 'top_5': tf.keras.metrics.SparseTopKCategoricalAccuracy(
+ k=5,
+ name='top_5_accuracy'),
+ }
+
+
+def get_image_size_from_model(
+ params: base_configs.ExperimentConfig) -> Optional[int]:
+ """If the given model has a preferred image size, return it."""
+ if params.model_name == 'efficientnet':
+ efficientnet_name = params.model.model_params.model_name
+ if efficientnet_name in efficientnet_model.MODEL_CONFIGS:
+ return efficientnet_model.MODEL_CONFIGS[efficientnet_name].resolution
+ return None
+
+
+def _get_dataset_builders(params: base_configs.ExperimentConfig,
+ strategy: tf.distribute.Strategy,
+ one_hot: bool
+ ) -> Tuple[Any, Any]:
+ """Create and return train and validation dataset builders."""
+ if one_hot:
+ logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
+ else:
+ logging.warning('label_smoothing not applied, so datasets will not be one '
+ 'hot encoded.')
+
+ num_devices = strategy.num_replicas_in_sync if strategy else 1
+
+ image_size = get_image_size_from_model(params)
+
+ dataset_configs = [
+ params.train_dataset, params.validation_dataset
+ ]
+ builders = []
+
+ for config in dataset_configs:
+ if config is not None and config.has_data:
+ builder = dataset_factory.DatasetBuilder(
+ config,
+ image_size=image_size or config.image_size,
+ num_devices=num_devices,
+ one_hot=one_hot)
+ else:
+ builder = None
+ builders.append(builder)
+
+ return builders
+
+
+def get_loss_scale(params: base_configs.ExperimentConfig,
+ fp16_default: float = 128.) -> float:
+ """Returns the loss scale for initializations."""
+ loss_scale = params.runtime.loss_scale
+ if loss_scale == 'dynamic':
+ return loss_scale
+ elif loss_scale is not None:
+ return float(loss_scale)
+ elif (params.train_dataset.dtype == 'float32' or
+ params.train_dataset.dtype == 'bfloat16'):
+ return 1.
+ else:
+ assert params.train_dataset.dtype == 'float16'
+ return fp16_default
+
+
+def _get_params_from_flags(flags_obj: flags.FlagValues):
+ """Get ParamsDict from flags."""
+ model = flags_obj.model_type.lower()
+ dataset = flags_obj.dataset.lower()
+ params = configs.get_config(model=model, dataset=dataset)
+
+ flags_overrides = {
+ 'model_dir': flags_obj.model_dir,
+ 'mode': flags_obj.mode,
+ 'model': {
+ 'name': model,
+ },
+ 'runtime': {
+ 'run_eagerly': flags_obj.run_eagerly,
+ 'tpu': flags_obj.tpu,
+ },
+ 'train_dataset': {
+ 'data_dir': flags_obj.data_dir,
+ },
+ 'validation_dataset': {
+ 'data_dir': flags_obj.data_dir,
+ },
+ 'train': {
+ 'time_history': {
+ 'log_steps': flags_obj.log_steps,
+ },
+ },
+ }
+
+ overriding_configs = (flags_obj.config_file,
+ flags_obj.params_override,
+ flags_overrides)
+
+ pp = pprint.PrettyPrinter()
+
+ logging.info('Base params: %s', pp.pformat(params.as_dict()))
+
+ for param in overriding_configs:
+ logging.info('Overriding params: %s', param)
+ params = hyperparams.override_params_dict(params, param, is_strict=True)
+
+ params.validate()
+ params.lock()
+
+ logging.info('Final model parameters: %s', pp.pformat(params.as_dict()))
+ return params
+
+
+def resume_from_checkpoint(model: tf.keras.Model,
+ model_dir: str,
+ train_steps: int) -> int:
+ """Resumes from the latest checkpoint, if possible.
+
+ Loads the model weights and optimizer settings from a checkpoint.
+ This function should be used in case of preemption recovery.
+
+ Args:
+ model: The model whose weights should be restored.
+ model_dir: The directory where model weights were saved.
+ train_steps: The number of steps to train.
+
+ Returns:
+ The epoch of the latest checkpoint, or 0 if not restoring.
+
+ """
+ logging.info('Load from checkpoint is enabled.')
+ latest_checkpoint = tf.train.latest_checkpoint(model_dir)
+ logging.info('latest_checkpoint: %s', latest_checkpoint)
+ if not latest_checkpoint:
+ logging.info('No checkpoint detected.')
+ return 0
+
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint)
+ model.load_weights(latest_checkpoint)
+ initial_epoch = model.optimizer.iterations // train_steps
+ logging.info('Completed loading from checkpoint.')
+ logging.info('Resuming from epoch %d', initial_epoch)
+ return int(initial_epoch)
+
+
+def initialize(params: base_configs.ExperimentConfig,
+ dataset_builder: dataset_factory.DatasetBuilder):
+ """Initializes backend related initializations."""
+ keras_utils.set_session_config(
+ enable_xla=params.runtime.enable_xla)
+ performance.set_mixed_precision_policy(dataset_builder.dtype,
+ get_loss_scale(params))
+ if tf.config.list_physical_devices('GPU'):
+ data_format = 'channels_first'
+ else:
+ data_format = 'channels_last'
+ tf.keras.backend.set_image_data_format(data_format)
+ if params.runtime.run_eagerly:
+ # Enable eager execution to allow step-by-step debugging
+ tf.config.experimental_run_functions_eagerly(True)
+ if tf.config.list_physical_devices('GPU'):
+ if params.runtime.gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=params.runtime.per_gpu_thread_count,
+ gpu_thread_mode=params.runtime.gpu_thread_mode,
+ num_gpus=params.runtime.num_gpus,
+ datasets_num_private_threads=params.runtime.dataset_num_private_threads) # pylint:disable=line-too-long
+ if params.runtime.batchnorm_spatial_persistent:
+ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
+
+
+def define_classifier_flags():
+ """Defines common flags for image classification."""
+ hyperparams_flags.initialize_common_flags()
+ flags.DEFINE_string(
+ 'data_dir',
+ default=None,
+ help='The location of the input data.')
+ flags.DEFINE_string(
+ 'mode',
+ default=None,
+ help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
+ flags.DEFINE_bool(
+ 'run_eagerly',
+ default=None,
+ help='Use eager execution and disable autograph for debugging.')
+ flags.DEFINE_string(
+ 'model_type',
+ default=None,
+ help='The type of the model, e.g. EfficientNet, etc.')
+ flags.DEFINE_string(
+ 'dataset',
+ default=None,
+ help='The name of the dataset, e.g. ImageNet, etc.')
+ flags.DEFINE_integer(
+ 'log_steps',
+ default=100,
+ help='The interval of steps between logging of batch level stats.')
+
+
+def serialize_config(params: base_configs.ExperimentConfig,
+ model_dir: str):
+ """Serializes and saves the experiment config."""
+ params_save_path = os.path.join(model_dir, 'params.yaml')
+ logging.info('Saving experiment configuration to %s', params_save_path)
+ tf.io.gfile.makedirs(model_dir)
+ hyperparams.save_params_dict_to_yaml(params, params_save_path)
+
+
+def train_and_eval(
+ params: base_configs.ExperimentConfig,
+ strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
+ """Runs the train and eval path using compile/fit."""
+ logging.info('Running train and eval.')
+
+ distribution_utils.configure_cluster(
+ params.runtime.worker_hosts,
+ params.runtime.task_index)
+
+ # Note: for TPUs, strategy and scope should be created before the dataset
+ strategy = strategy_override or distribution_utils.get_distribution_strategy(
+ distribution_strategy=params.runtime.distribution_strategy,
+ all_reduce_alg=params.runtime.all_reduce_alg,
+ num_gpus=params.runtime.num_gpus,
+ tpu_address=params.runtime.tpu)
+
+ strategy_scope = distribution_utils.get_strategy_scope(strategy)
+
+ logging.info('Detected %d devices.',
+ strategy.num_replicas_in_sync if strategy else 1)
+
+ label_smoothing = params.model.loss.label_smoothing
+ one_hot = label_smoothing and label_smoothing > 0
+
+ builders = _get_dataset_builders(params, strategy, one_hot)
+ datasets = [builder.build(strategy)
+ if builder else None for builder in builders]
+
+ # Unpack datasets and builders based on train/val/test splits
+ train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
+ train_dataset, validation_dataset = datasets
+
+ train_epochs = params.train.epochs
+ train_steps = params.train.steps or train_builder.num_steps
+ validation_steps = params.evaluation.steps or validation_builder.num_steps
+
+ initialize(params, train_builder)
+
+ logging.info('Global batch size: %d', train_builder.global_batch_size)
+
+ with strategy_scope:
+ model_params = params.model.model_params.as_dict()
+ model = get_models()[params.model.name](**model_params)
+ learning_rate = optimizer_factory.build_learning_rate(
+ params=params.model.learning_rate,
+ batch_size=train_builder.global_batch_size,
+ train_epochs=train_epochs,
+ train_steps=train_steps)
+ optimizer = optimizer_factory.build_optimizer(
+ optimizer_name=params.model.optimizer.name,
+ base_learning_rate=learning_rate,
+ params=params.model.optimizer.as_dict())
+
+ metrics_map = _get_metrics(one_hot)
+ metrics = [metrics_map[metric] for metric in params.train.metrics]
+ steps_per_loop = train_steps if params.train.set_epoch_loop else 1
+
+ if one_hot:
+ loss_obj = tf.keras.losses.CategoricalCrossentropy(
+ label_smoothing=params.model.loss.label_smoothing)
+ else:
+ loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
+ model.compile(optimizer=optimizer,
+ loss=loss_obj,
+ metrics=metrics,
+ experimental_steps_per_execution=steps_per_loop)
+
+ initial_epoch = 0
+ if params.train.resume_checkpoint:
+ initial_epoch = resume_from_checkpoint(model=model,
+ model_dir=params.model_dir,
+ train_steps=train_steps)
+
+ callbacks = custom_callbacks.get_callbacks(
+ model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
+ include_tensorboard=params.train.callbacks.enable_tensorboard,
+ time_history=params.train.callbacks.enable_time_history,
+ track_lr=params.train.tensorboard.track_lr,
+ write_model_weights=params.train.tensorboard.write_model_weights,
+ initial_step=initial_epoch * train_steps,
+ batch_size=train_builder.global_batch_size,
+ log_steps=params.train.time_history.log_steps,
+ model_dir=params.model_dir)
+
+ serialize_config(params=params, model_dir=params.model_dir)
+
+ if params.evaluation.skip_eval:
+ validation_kwargs = {}
+ else:
+ validation_kwargs = {
+ 'validation_data': validation_dataset,
+ 'validation_steps': validation_steps,
+ 'validation_freq': params.evaluation.epochs_between_evals,
+ }
+
+ history = model.fit(
+ train_dataset,
+ epochs=train_epochs,
+ steps_per_epoch=train_steps,
+ initial_epoch=initial_epoch,
+ callbacks=callbacks,
+ verbose=2,
+ **validation_kwargs)
+
+ validation_output = None
+ if not params.evaluation.skip_eval:
+ validation_output = model.evaluate(
+ validation_dataset, steps=validation_steps, verbose=2)
+
+ # TODO(dankondratyuk): eval and save final test accuracy
+ stats = common.build_stats(history,
+ validation_output,
+ callbacks)
+ return stats
+
+
+def export(params: base_configs.ExperimentConfig):
+ """Runs the model export functionality."""
+ logging.info('Exporting model.')
+ model_params = params.model.model_params.as_dict()
+ model = get_models()[params.model.name](**model_params)
+ checkpoint = params.export.checkpoint
+ if checkpoint is None:
+ logging.info('No export checkpoint was provided. Using the latest '
+ 'checkpoint from model_dir.')
+ checkpoint = tf.train.latest_checkpoint(params.model_dir)
+
+ model.load_weights(checkpoint)
+ model.save(params.export.destination)
+
+
+def run(flags_obj: flags.FlagValues,
+ strategy_override: tf.distribute.Strategy = None) -> Mapping[str, Any]:
+ """Runs Image Classification model using native Keras APIs.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+ strategy_override: A `tf.distribute.Strategy` object to use for model.
+
+ Returns:
+ Dictionary of training/eval stats
+ """
+ params = _get_params_from_flags(flags_obj)
+ if params.mode == 'train_and_eval':
+ return train_and_eval(params, strategy_override)
+ elif params.mode == 'export_only':
+ export(params)
+ else:
+ raise ValueError('{} is not a valid mode.'.format(params.mode))
+
+
+def main(_):
+ stats = run(flags.FLAGS)
+ if stats:
+ logging.info('Run stats:\n%s', stats)
+
+
+if __name__ == '__main__':
+ logging.set_verbosity(logging.INFO)
+ define_classifier_flags()
+ flags.mark_flag_as_required('data_dir')
+ flags.mark_flag_as_required('mode')
+ flags.mark_flag_as_required('model_type')
+ flags.mark_flag_as_required('dataset')
+
+ app.run(main)
diff --git a/models/official/vision/image_classification/classifier_trainer_test.py b/models/official/vision/image_classification/classifier_trainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..244425feef76bf89d4de939cb8a1914a6f0f47c6
--- /dev/null
+++ b/models/official/vision/image_classification/classifier_trainer_test.py
@@ -0,0 +1,387 @@
+# Lint as: python3
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for the classifier trainer models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import functools
+import json
+
+import os
+import sys
+
+from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple
+
+from absl import flags
+from absl.testing import parameterized
+import tensorflow as tf
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.utils.flags import core as flags_core
+from official.vision.image_classification import classifier_trainer
+from official.vision.image_classification import dataset_factory
+from official.vision.image_classification import test_utils
+from official.vision.image_classification.configs import base_configs
+
+classifier_trainer.define_classifier_flags()
+
+
+def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
+ """Returns the combinations of end-to-end tests to run."""
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ ],
+ model=[
+ 'efficientnet',
+ 'resnet',
+ ],
+ mode='eager',
+ dataset=[
+ 'imagenet',
+ ],
+ )
+
+
+def get_params_override(params_override: Mapping[str, Any]) -> str:
+ """Converts params_override dict to string command."""
+ return '--params_override=' + json.dumps(params_override)
+
+
+def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
+ """Returns a basic parameter configuration for testing."""
+ return {
+ 'train_dataset': {
+ 'builder': 'synthetic',
+ 'use_per_replica_batch_size': True,
+ 'batch_size': 1,
+ 'image_size': 224,
+ 'dtype': dtype,
+ },
+ 'validation_dataset': {
+ 'builder': 'synthetic',
+ 'batch_size': 1,
+ 'use_per_replica_batch_size': True,
+ 'image_size': 224,
+ 'dtype': dtype,
+ },
+ 'train': {
+ 'steps': 1,
+ 'epochs': 1,
+ 'callbacks': {
+ 'enable_checkpoint_and_export': True,
+ 'enable_tensorboard': False,
+ },
+ },
+ 'evaluation': {
+ 'steps': 1,
+ },
+ }
+
+
+def get_trivial_model(num_classes: int) -> tf.keras.Model:
+ """Creates and compiles trivial model for ImageNet dataset."""
+ model = test_utils.trivial_model(num_classes=num_classes)
+ lr = 0.01
+ optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
+ loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
+ model.compile(optimizer=optimizer,
+ loss=loss_obj,
+ run_eagerly=True)
+ return model
+
+
+def get_trivial_data() -> tf.data.Dataset:
+ """Gets trivial data in the ImageNet size."""
+ def generate_data(_) -> tf.data.Dataset:
+ image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32)
+ label = tf.zeros([1], dtype=tf.int32)
+ return image, label
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(generate_data,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.prefetch(buffer_size=1).batch(1)
+ return dataset
+
+
+def run_end_to_end(main: Callable[[Any], None],
+ extra_flags: Optional[Iterable[str]] = None,
+ model_dir: Optional[str] = None):
+ """Runs the classifier trainer end-to-end."""
+ extra_flags = [] if extra_flags is None else extra_flags
+ args = [sys.argv[0], '--model_dir', model_dir] + extra_flags
+ flags_core.parse_flags(argv=args)
+ main(flags.FLAGS)
+
+
+class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
+ """Unit tests for Keras models."""
+ _tempdir = None
+
+ @classmethod
+ def setUpClass(cls): # pylint: disable=invalid-name
+ super(ClassifierTest, cls).setUpClass()
+
+ def tearDown(self):
+ super(ClassifierTest, self).tearDown()
+ tf.io.gfile.rmtree(self.get_temp_dir())
+
+ @combinations.generate(distribution_strategy_combinations())
+ def test_end_to_end_train_and_eval(self, distribution, model, dataset):
+ """Test train_and_eval and export for Keras classifier models."""
+ # Some parameters are not defined as flags (e.g. cannot run
+ # classifier_train.py --batch_size=...) by design, so use
+ # "--params_override=..." instead
+ model_dir = self.get_temp_dir()
+ base_flags = [
+ '--data_dir=not_used',
+ '--model_type=' + model,
+ '--dataset=' + dataset,
+ ]
+ train_and_eval_flags = base_flags + [
+ get_params_override(basic_params_override()),
+ '--mode=train_and_eval',
+ ]
+
+ run = functools.partial(classifier_trainer.run,
+ strategy_override=distribution)
+ run_end_to_end(main=run,
+ extra_flags=train_and_eval_flags,
+ model_dir=model_dir)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ model=[
+ 'efficientnet',
+ 'resnet',
+ ],
+ mode='eager',
+ dataset='imagenet',
+ dtype='float16',
+ ))
+ def test_gpu_train(self, distribution, model, dataset, dtype):
+ """Test train_and_eval and export for Keras classifier models."""
+ # Some parameters are not defined as flags (e.g. cannot run
+ # classifier_train.py --batch_size=...) by design, so use
+ # "--params_override=..." instead
+ model_dir = self.get_temp_dir()
+ base_flags = [
+ '--data_dir=not_used',
+ '--model_type=' + model,
+ '--dataset=' + dataset,
+ ]
+ train_and_eval_flags = base_flags + [
+ get_params_override(basic_params_override(dtype)),
+ '--mode=train_and_eval',
+ ]
+
+ export_params = basic_params_override()
+ export_path = os.path.join(model_dir, 'export')
+ export_params['export'] = {}
+ export_params['export']['destination'] = export_path
+ export_flags = base_flags + [
+ '--mode=export_only',
+ get_params_override(export_params)
+ ]
+
+ run = functools.partial(classifier_trainer.run,
+ strategy_override=distribution)
+ run_end_to_end(main=run,
+ extra_flags=train_and_eval_flags,
+ model_dir=model_dir)
+ run_end_to_end(main=run,
+ extra_flags=export_flags,
+ model_dir=model_dir)
+ self.assertTrue(os.path.exists(export_path))
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.tpu_strategy,
+ ],
+ model=[
+ 'efficientnet',
+ 'resnet',
+ ],
+ mode='eager',
+ dataset='imagenet',
+ dtype='bfloat16',
+ ))
+ def test_tpu_train(self, distribution, model, dataset, dtype):
+ """Test train_and_eval and export for Keras classifier models."""
+ # Some parameters are not defined as flags (e.g. cannot run
+ # classifier_train.py --batch_size=...) by design, so use
+ # "--params_override=..." instead
+ model_dir = self.get_temp_dir()
+ base_flags = [
+ '--data_dir=not_used',
+ '--model_type=' + model,
+ '--dataset=' + dataset,
+ ]
+ train_and_eval_flags = base_flags + [
+ get_params_override(basic_params_override(dtype)),
+ '--mode=train_and_eval',
+ ]
+
+ run = functools.partial(classifier_trainer.run,
+ strategy_override=distribution)
+ run_end_to_end(main=run,
+ extra_flags=train_and_eval_flags,
+ model_dir=model_dir)
+
+ @combinations.generate(distribution_strategy_combinations())
+ def test_end_to_end_invalid_mode(self, distribution, model, dataset):
+ """Test the Keras EfficientNet model with `strategy`."""
+ model_dir = self.get_temp_dir()
+ extra_flags = [
+ '--data_dir=not_used',
+ '--mode=invalid_mode',
+ '--model_type=' + model,
+ '--dataset=' + dataset,
+ get_params_override(basic_params_override()),
+ ]
+
+ run = functools.partial(classifier_trainer.run,
+ strategy_override=distribution)
+ with self.assertRaises(ValueError):
+ run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir)
+
+
+class UtilTests(parameterized.TestCase, tf.test.TestCase):
+ """Tests for individual utility functions within classifier_trainer.py."""
+
+ @parameterized.named_parameters(
+ ('efficientnet-b0', 'efficientnet', 'efficientnet-b0', 224),
+ ('efficientnet-b1', 'efficientnet', 'efficientnet-b1', 240),
+ ('efficientnet-b2', 'efficientnet', 'efficientnet-b2', 260),
+ ('efficientnet-b3', 'efficientnet', 'efficientnet-b3', 300),
+ ('efficientnet-b4', 'efficientnet', 'efficientnet-b4', 380),
+ ('efficientnet-b5', 'efficientnet', 'efficientnet-b5', 456),
+ ('efficientnet-b6', 'efficientnet', 'efficientnet-b6', 528),
+ ('efficientnet-b7', 'efficientnet', 'efficientnet-b7', 600),
+ ('resnet', 'resnet', '', None),
+ )
+ def test_get_model_size(self, model, model_name, expected):
+ config = base_configs.ExperimentConfig(
+ model_name=model,
+ model=base_configs.ModelConfig(
+ model_params={
+ 'model_name': model_name,
+ },
+ )
+ )
+ size = classifier_trainer.get_image_size_from_model(config)
+ self.assertEqual(size, expected)
+
+ @parameterized.named_parameters(
+ ('dynamic', 'dynamic', None, 'dynamic'),
+ ('scalar', 128., None, 128.),
+ ('float32', None, 'float32', 1),
+ ('float16', None, 'float16', 128),
+ )
+ def test_get_loss_scale(self, loss_scale, dtype, expected):
+ config = base_configs.ExperimentConfig(
+ runtime=base_configs.RuntimeConfig(
+ loss_scale=loss_scale),
+ train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
+ ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
+ self.assertEqual(ls, expected)
+
+ @parameterized.named_parameters(
+ ('float16', 'float16'),
+ ('bfloat16', 'bfloat16')
+ )
+ def test_initialize(self, dtype):
+ config = base_configs.ExperimentConfig(
+ runtime=base_configs.RuntimeConfig(
+ run_eagerly=False,
+ enable_xla=False,
+ per_gpu_thread_count=1,
+ gpu_thread_mode='gpu_private',
+ num_gpus=1,
+ dataset_num_private_threads=1,
+ ),
+ train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
+ model=base_configs.ModelConfig(),
+ )
+
+ class EmptyClass:
+ pass
+ fake_ds_builder = EmptyClass()
+ fake_ds_builder.dtype = dtype
+ fake_ds_builder.config = EmptyClass()
+ classifier_trainer.initialize(config, fake_ds_builder)
+
+ def test_resume_from_checkpoint(self):
+ """Tests functionality for resuming from checkpoint."""
+ # Set the keras policy
+ policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
+ tf.keras.mixed_precision.experimental.set_policy(policy)
+
+ # Get the model, datasets, and compile it.
+ model = get_trivial_model(10)
+
+ # Create the checkpoint
+ model_dir = self.get_temp_dir()
+ train_epochs = 1
+ train_steps = 10
+ ds = get_trivial_data()
+ callbacks = [
+ tf.keras.callbacks.ModelCheckpoint(
+ os.path.join(model_dir, 'model.ckpt-{epoch:04d}'),
+ save_weights_only=True)
+ ]
+ model.fit(
+ ds,
+ callbacks=callbacks,
+ epochs=train_epochs,
+ steps_per_epoch=train_steps)
+
+ # Test load from checkpoint
+ clean_model = get_trivial_model(10)
+ weights_before_load = copy.deepcopy(clean_model.get_weights())
+ initial_epoch = classifier_trainer.resume_from_checkpoint(
+ model=clean_model,
+ model_dir=model_dir,
+ train_steps=train_steps)
+ self.assertEqual(initial_epoch, 1)
+ self.assertNotAllClose(weights_before_load, clean_model.get_weights())
+
+ tf.io.gfile.rmtree(model_dir)
+
+ def test_serialize_config(self):
+ """Tests functionality for serializing data."""
+ config = base_configs.ExperimentConfig()
+ model_dir = self.get_temp_dir()
+ classifier_trainer.serialize_config(params=config, model_dir=model_dir)
+ saved_params_path = os.path.join(model_dir, 'params.yaml')
+ self.assertTrue(os.path.exists(saved_params_path))
+ tf.io.gfile.rmtree(model_dir)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/vision/image_classification/configs/__init__.py b/models/official/vision/image_classification/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..931c2ef11db4a949e6c2e95bca44e36bac1241e9
--- /dev/null
+++ b/models/official/vision/image_classification/configs/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/official/vision/image_classification/configs/base_configs.py b/models/official/vision/image_classification/configs/base_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..11fcb5305660ec71153ebfc12631f455a3464115
--- /dev/null
+++ b/models/official/vision/image_classification/configs/base_configs.py
@@ -0,0 +1,231 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Definitions for high level configuration groups.."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from typing import Any, List, Mapping, Optional
+
+import dataclasses
+
+from official.modeling import hyperparams
+from official.modeling.hyperparams import config_definitions
+
+CallbacksConfig = config_definitions.CallbacksConfig
+TensorboardConfig = config_definitions.TensorboardConfig
+RuntimeConfig = config_definitions.RuntimeConfig
+
+
+@dataclasses.dataclass
+class ExportConfig(hyperparams.Config):
+ """Configuration for exports.
+
+ Attributes:
+ checkpoint: the path to the checkpoint to export.
+ destination: the path to where the checkpoint should be exported.
+ """
+ checkpoint: str = None
+ destination: str = None
+
+
+@dataclasses.dataclass
+class MetricsConfig(hyperparams.Config):
+ """Configuration for Metrics.
+
+ Attributes:
+ accuracy: Whether or not to track accuracy as a Callback. Defaults to None.
+ top_5: Whether or not to track top_5_accuracy as a Callback. Defaults to
+ None.
+ """
+ accuracy: bool = None
+ top_5: bool = None
+
+
+@dataclasses.dataclass
+class TimeHistoryConfig(hyperparams.Config):
+ """Configuration for the TimeHistory callback.
+
+ Attributes:
+ log_steps: Interval of steps between logging of batch level stats.
+ """
+ log_steps: int = None
+
+
+@dataclasses.dataclass
+class TrainConfig(hyperparams.Config):
+ """Configuration for training.
+
+ Attributes:
+ resume_checkpoint: Whether or not to enable load checkpoint loading.
+ Defaults to None.
+ epochs: The number of training epochs to run. Defaults to None.
+ steps: The number of steps to run per epoch. If None, then this will be
+ inferred based on the number of images and batch size. Defaults to None.
+ callbacks: An instance of CallbacksConfig.
+ metrics: An instance of MetricsConfig.
+ tensorboard: An instance of TensorboardConfig.
+ set_epoch_loop: Whether or not to set `experimental_steps_per_execution` to
+ equal the number of training steps in `model.compile`. This reduces the
+ number of callbacks run per epoch which significantly improves end-to-end
+ TPU training time.
+ """
+ resume_checkpoint: bool = None
+ epochs: int = None
+ steps: int = None
+ callbacks: CallbacksConfig = CallbacksConfig()
+ metrics: MetricsConfig = None
+ tensorboard: TensorboardConfig = TensorboardConfig()
+ time_history: TimeHistoryConfig = TimeHistoryConfig()
+ set_epoch_loop: bool = False
+
+
+@dataclasses.dataclass
+class EvalConfig(hyperparams.Config):
+ """Configuration for evaluation.
+
+ Attributes:
+ epochs_between_evals: The number of train epochs to run between evaluations.
+ Defaults to None.
+ steps: The number of eval steps to run during evaluation. If None, this will
+ be inferred based on the number of images and batch size. Defaults to
+ None.
+ skip_eval: Whether or not to skip evaluation.
+ """
+ epochs_between_evals: int = None
+ steps: int = None
+ skip_eval: bool = False
+
+
+@dataclasses.dataclass
+class LossConfig(hyperparams.Config):
+ """Configuration for Loss.
+
+ Attributes:
+ name: The name of the loss. Defaults to None.
+ label_smoothing: Whether or not to apply label smoothing to the loss. This
+ only applies to 'categorical_cross_entropy'.
+ """
+ name: str = None
+ label_smoothing: float = None
+
+
+@dataclasses.dataclass
+class OptimizerConfig(hyperparams.Config):
+ """Configuration for Optimizers.
+
+ Attributes:
+ name: The name of the optimizer. Defaults to None.
+ decay: Decay or rho, discounting factor for gradient. Defaults to None.
+ epsilon: Small value used to avoid 0 denominator. Defaults to None.
+ momentum: Plain momentum constant. Defaults to None.
+ nesterov: Whether or not to apply Nesterov momentum. Defaults to None.
+ moving_average_decay: The amount of decay to apply. If 0 or None, then
+ exponential moving average is not used. Defaults to None.
+ lookahead: Whether or not to apply the lookahead optimizer. Defaults to
+ None.
+ beta_1: The exponential decay rate for the 1st moment estimates. Used in the
+ Adam optimizers. Defaults to None.
+ beta_2: The exponential decay rate for the 2nd moment estimates. Used in the
+ Adam optimizers. Defaults to None.
+ epsilon: Small value used to avoid 0 denominator. Defaults to 1e-7.
+ """
+ name: str = None
+ decay: float = None
+ epsilon: float = None
+ momentum: float = None
+ nesterov: bool = None
+ moving_average_decay: Optional[float] = None
+ lookahead: Optional[bool] = None
+ beta_1: float = None
+ beta_2: float = None
+ epsilon: float = None
+
+
+@dataclasses.dataclass
+class LearningRateConfig(hyperparams.Config):
+ """Configuration for learning rates.
+
+ Attributes:
+ name: The name of the learning rate. Defaults to None.
+ initial_lr: The initial learning rate. Defaults to None.
+ decay_epochs: The number of decay epochs. Defaults to None.
+ decay_rate: The rate of decay. Defaults to None.
+ warmup_epochs: The number of warmup epochs. Defaults to None.
+ batch_lr_multiplier: The multiplier to apply to the base learning rate, if
+ necessary. Defaults to None.
+ examples_per_epoch: the number of examples in a single epoch. Defaults to
+ None.
+ boundaries: boundaries used in piecewise constant decay with warmup.
+ multipliers: multipliers used in piecewise constant decay with warmup.
+ scale_by_batch_size: Scale the learning rate by a fraction of the batch
+ size. Set to 0 for no scaling (default).
+ staircase: Apply exponential decay at discrete values instead of continuous.
+ """
+ name: str = None
+ initial_lr: float = None
+ decay_epochs: float = None
+ decay_rate: float = None
+ warmup_epochs: int = None
+ examples_per_epoch: int = None
+ boundaries: List[int] = None
+ multipliers: List[float] = None
+ scale_by_batch_size: float = 0.
+ staircase: bool = None
+
+
+@dataclasses.dataclass
+class ModelConfig(hyperparams.Config):
+ """Configuration for Models.
+
+ Attributes:
+ name: The name of the model. Defaults to None.
+ model_params: The parameters used to create the model. Defaults to None.
+ num_classes: The number of classes in the model. Defaults to None.
+ loss: A `LossConfig` instance. Defaults to None.
+ optimizer: An `OptimizerConfig` instance. Defaults to None.
+ """
+ name: str = None
+ model_params: hyperparams.Config = None
+ num_classes: int = None
+ loss: LossConfig = None
+ optimizer: OptimizerConfig = None
+
+
+@dataclasses.dataclass
+class ExperimentConfig(hyperparams.Config):
+ """Base configuration for an image classification experiment.
+
+ Attributes:
+ model_dir: The directory to use when running an experiment.
+ mode: e.g. 'train_and_eval', 'export'
+ runtime: A `RuntimeConfig` instance.
+ train: A `TrainConfig` instance.
+ evaluation: An `EvalConfig` instance.
+ model: A `ModelConfig` instance.
+ export: An `ExportConfig` instance.
+ """
+ model_dir: str = None
+ model_name: str = None
+ mode: str = None
+ runtime: RuntimeConfig = None
+ train_dataset: Any = None
+ validation_dataset: Any = None
+ train: TrainConfig = None
+ evaluation: EvalConfig = None
+ model: ModelConfig = None
+ export: ExportConfig = None
diff --git a/models/official/vision/image_classification/configs/configs.py b/models/official/vision/image_classification/configs/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a79a1cd9b563a554614b9d4f2f0b93acf016791
--- /dev/null
+++ b/models/official/vision/image_classification/configs/configs.py
@@ -0,0 +1,118 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Configuration utils for image classification experiments."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import dataclasses
+
+from official.vision.image_classification import dataset_factory
+from official.vision.image_classification.configs import base_configs
+from official.vision.image_classification.efficientnet import efficientnet_config
+from official.vision.image_classification.resnet import resnet_config
+
+
+@dataclasses.dataclass
+class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
+ """Base configuration to train efficientnet-b0 on ImageNet.
+
+ Attributes:
+ export: An `ExportConfig` instance
+ runtime: A `RuntimeConfig` instance.
+ dataset: A `DatasetConfig` instance.
+ train: A `TrainConfig` instance.
+ evaluation: An `EvalConfig` instance.
+ model: A `ModelConfig` instance.
+
+ """
+ export: base_configs.ExportConfig = base_configs.ExportConfig()
+ runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
+ train_dataset: dataset_factory.DatasetConfig = \
+ dataset_factory.ImageNetConfig(split='train')
+ validation_dataset: dataset_factory.DatasetConfig = \
+ dataset_factory.ImageNetConfig(split='validation')
+ train: base_configs.TrainConfig = base_configs.TrainConfig(
+ resume_checkpoint=True,
+ epochs=500,
+ steps=None,
+ callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
+ enable_tensorboard=True),
+ metrics=['accuracy', 'top_5'],
+ time_history=base_configs.TimeHistoryConfig(log_steps=100),
+ tensorboard=base_configs.TensorboardConfig(track_lr=True,
+ write_model_weights=False),
+ set_epoch_loop=False)
+ evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
+ epochs_between_evals=1,
+ steps=None)
+ model: base_configs.ModelConfig = \
+ efficientnet_config.EfficientNetModelConfig()
+
+
+@dataclasses.dataclass
+class ResNetImagenetConfig(base_configs.ExperimentConfig):
+ """Base configuration to train resnet-50 on ImageNet."""
+ export: base_configs.ExportConfig = base_configs.ExportConfig()
+ runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
+ train_dataset: dataset_factory.DatasetConfig = \
+ dataset_factory.ImageNetConfig(split='train',
+ one_hot=False,
+ mean_subtract=True,
+ standardize=True)
+ validation_dataset: dataset_factory.DatasetConfig = \
+ dataset_factory.ImageNetConfig(split='validation',
+ one_hot=False,
+ mean_subtract=True,
+ standardize=True)
+ train: base_configs.TrainConfig = base_configs.TrainConfig(
+ resume_checkpoint=True,
+ epochs=90,
+ steps=None,
+ callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
+ enable_tensorboard=True),
+ metrics=['accuracy', 'top_5'],
+ time_history=base_configs.TimeHistoryConfig(log_steps=100),
+ tensorboard=base_configs.TensorboardConfig(track_lr=True,
+ write_model_weights=False),
+ set_epoch_loop=False)
+ evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
+ epochs_between_evals=1,
+ steps=None)
+ model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
+
+
+def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
+ """Given model and dataset names, return the ExperimentConfig."""
+ dataset_model_config_map = {
+ 'imagenet': {
+ 'efficientnet': EfficientNetImageNetConfig(),
+ 'resnet': ResNetImagenetConfig(),
+ }
+ }
+ try:
+ return dataset_model_config_map[dataset][model]
+ except KeyError:
+ if dataset not in dataset_model_config_map:
+ raise KeyError('Invalid dataset received. Received: {}. Supported '
+ 'datasets include: {}'.format(
+ dataset,
+ ', '.join(dataset_model_config_map.keys())))
+ raise KeyError('Invalid model received. Received: {}. Supported models for'
+ '{} include: {}'.format(
+ model,
+ dataset,
+ ', '.join(dataset_model_config_map[dataset].keys())))
diff --git a/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml b/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6f40ffb1e3020a231832a120d9938bf77e9cc74b
--- /dev/null
+++ b/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
@@ -0,0 +1,52 @@
+# Training configuration for EfficientNet-b0 trained on ImageNet on GPUs.
+# Takes ~32 minutes per epoch for 8 V100s.
+# Reaches ~76.1% within 350 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 32
+ use_per_replica_batch_size: True
+ dtype: 'float32'
+ augmenter:
+ name: 'autoaugment'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 32
+ use_per_replica_batch_size: True
+ dtype: 'float32'
+model:
+ model_params:
+ model_name: 'efficientnet-b0'
+ overrides:
+ num_classes: 1000
+ batch_norm: 'default'
+ dtype: 'float32'
+ activation: 'swish'
+ optimizer:
+ name: 'rmsprop'
+ momentum: 0.9
+ decay: 0.9
+ moving_average_decay: 0.0
+ lookahead: false
+ learning_rate:
+ name: 'exponential'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: True
+ epochs: 500
+evaluation:
+ epochs_between_evals: 1
diff --git a/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml b/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c5be7e9ba32fc7e8f3999df8e7446405dd2d4173
--- /dev/null
+++ b/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
@@ -0,0 +1,52 @@
+# Training configuration for EfficientNet-b0 trained on ImageNet on TPUs.
+# Takes ~2 minutes, 50 seconds per epoch for v3-32.
+# Reaches ~76.1% within 350 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'tpu'
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 128
+ use_per_replica_batch_size: True
+ dtype: 'bfloat16'
+ augmenter:
+ name: 'autoaugment'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 128
+ use_per_replica_batch_size: True
+ dtype: 'bfloat16'
+model:
+ model_params:
+ model_name: 'efficientnet-b0'
+ overrides:
+ num_classes: 1000
+ batch_norm: 'tpu'
+ dtype: 'bfloat16'
+ activation: 'swish'
+ optimizer:
+ name: 'rmsprop'
+ momentum: 0.9
+ decay: 0.9
+ moving_average_decay: 0.0
+ lookahead: false
+ learning_rate:
+ name: 'exponential'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: True
+ epochs: 500
+ set_epoch_loop: True
+evaluation:
+ epochs_between_evals: 1
diff --git a/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml b/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f3dce01a46c64c4d92e97091628daeadaceb21d
--- /dev/null
+++ b/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
@@ -0,0 +1,47 @@
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 32
+ use_per_replica_batch_size: True
+ dtype: 'float32'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 32
+ use_per_replica_batch_size: True
+ dtype: 'float32'
+model:
+ model_params:
+ model_name: 'efficientnet-b1'
+ overrides:
+ num_classes: 1000
+ batch_norm: 'default'
+ dtype: 'float32'
+ activation: 'swish'
+ optimizer:
+ name: 'rmsprop'
+ momentum: 0.9
+ decay: 0.9
+ moving_average_decay: 0.0
+ lookahead: false
+ learning_rate:
+ name: 'exponential'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: True
+ epochs: 500
+evaluation:
+ epochs_between_evals: 1
diff --git a/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml b/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0bb6a9fe6f0b417f92686178d4bc79a44c5a4aa7
--- /dev/null
+++ b/models/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
@@ -0,0 +1,51 @@
+# Training configuration for EfficientNet-b1 trained on ImageNet on TPUs.
+# Takes ~3 minutes, 15 seconds per epoch for v3-32.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'tpu'
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 128
+ use_per_replica_batch_size: True
+ dtype: 'bfloat16'
+ augmenter:
+ name: 'autoaugment'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 128
+ use_per_replica_batch_size: True
+ dtype: 'bfloat16'
+model:
+ model_params:
+ model_name: 'efficientnet-b1'
+ overrides:
+ num_classes: 1000
+ batch_norm: 'tpu'
+ dtype: 'bfloat16'
+ activation: 'swish'
+ optimizer:
+ name: 'rmsprop'
+ momentum: 0.9
+ decay: 0.9
+ moving_average_decay: 0.0
+ lookahead: false
+ learning_rate:
+ name: 'exponential'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: True
+ epochs: 500
+ set_epoch_loop: True
+evaluation:
+ epochs_between_evals: 1
diff --git a/models/official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml b/models/official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..56844b81db70fbd5e8291a4c1c2eb60e3c488088
--- /dev/null
+++ b/models/official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml
@@ -0,0 +1,51 @@
+# Training configuration for ResNet trained on ImageNet on GPUs.
+# Reaches > 76.1% within 90 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+ batchnorm_spatial_persistent: True
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'tfds'
+ split: 'train'
+ image_size: 224
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 256
+ use_per_replica_batch_size: True
+ dtype: 'float16'
+ mean_subtract: True
+ standardize: True
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'tfds'
+ split: 'validation'
+ image_size: 224
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 256
+ use_per_replica_batch_size: True
+ dtype: 'float16'
+ mean_subtract: True
+ standardize: True
+model:
+ name: 'resnet'
+ model_params:
+ rescale_inputs: False
+ optimizer:
+ name: 'momentum'
+ momentum: 0.9
+ decay: 0.9
+ epsilon: 0.001
+ learning_rate:
+ name: 'piecewise_constant_with_warmup'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: True
+ epochs: 90
+evaluation:
+ epochs_between_evals: 1
diff --git a/models/official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml b/models/official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ae975c16251ac0a23877bf8f6804cdea6b2baadf
--- /dev/null
+++ b/models/official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml
@@ -0,0 +1,57 @@
+# Training configuration for ResNet trained on ImageNet on TPUs.
+# Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
+# Reaches > 76.1% within 90 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'tpu'
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'tfds'
+ split: 'train'
+ one_hot: False
+ image_size: 224
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 128
+ use_per_replica_batch_size: True
+ mean_subtract: False
+ standardize: False
+ dtype: 'bfloat16'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'tfds'
+ split: 'validation'
+ one_hot: False
+ image_size: 224
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 128
+ use_per_replica_batch_size: True
+ mean_subtract: False
+ standardize: False
+ dtype: 'bfloat16'
+model:
+ name: 'resnet'
+ model_params:
+ rescale_inputs: True
+ optimizer:
+ name: 'momentum'
+ momentum: 0.9
+ decay: 0.9
+ epsilon: 0.001
+ moving_average_decay: 0.
+ lookahead: False
+ learning_rate:
+ name: 'piecewise_constant_with_warmup'
+ loss:
+ label_smoothing: 0.1
+train:
+ callbacks:
+ enable_checkpoint_and_export: True
+ resume_checkpoint: True
+ epochs: 90
+ set_epoch_loop: True
+evaluation:
+ epochs_between_evals: 1
diff --git a/models/official/vision/image_classification/dataset_factory.py b/models/official/vision/image_classification/dataset_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9dad1268a7bed86f622f80ca28f4d485a0fab31
--- /dev/null
+++ b/models/official/vision/image_classification/dataset_factory.py
@@ -0,0 +1,536 @@
+# Lint as: python3
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Dataset utilities for vision tasks using TFDS and tf.data.Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+from typing import Any, List, Optional, Tuple, Mapping, Union
+from absl import logging
+from dataclasses import dataclass
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+from official.modeling.hyperparams import base_config
+from official.vision.image_classification import augment
+from official.vision.image_classification import preprocessing
+
+
+AUGMENTERS = {
+ 'autoaugment': augment.AutoAugment,
+ 'randaugment': augment.RandAugment,
+}
+
+
+@dataclass
+class AugmentConfig(base_config.Config):
+ """Configuration for image augmenters.
+
+ Attributes:
+ name: The name of the image augmentation to use. Possible options are
+ None (default), 'autoaugment', or 'randaugment'.
+ params: Any paramaters used to initialize the augmenter.
+ """
+ name: Optional[str] = None
+ params: Optional[Mapping[str, Any]] = None
+
+ def build(self) -> augment.ImageAugment:
+ """Build the augmenter using this config."""
+ params = self.params or {}
+ augmenter = AUGMENTERS.get(self.name, None)
+ return augmenter(**params) if augmenter is not None else None
+
+
+@dataclass
+class DatasetConfig(base_config.Config):
+ """The base configuration for building datasets.
+
+ Attributes:
+ name: The name of the Dataset. Usually should correspond to a TFDS dataset.
+ data_dir: The path where the dataset files are stored, if available.
+ filenames: Optional list of strings representing the TFRecord names.
+ builder: The builder type used to load the dataset. Value should be one of
+ 'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
+ (generate dummy synthetic data without reading from files).
+ split: The split of the dataset. Usually 'train', 'validation', or 'test'.
+ image_size: The size of the image in the dataset. This assumes that
+ `width` == `height`. Set to 'infer' to infer the image size from TFDS
+ info. This requires `name` to be a registered dataset in TFDS.
+ num_classes: The number of classes given by the dataset. Set to 'infer'
+ to infer the image size from TFDS info. This requires `name` to be a
+ registered dataset in TFDS.
+ num_channels: The number of channels given by the dataset. Set to 'infer'
+ to infer the image size from TFDS info. This requires `name` to be a
+ registered dataset in TFDS.
+ num_examples: The number of examples given by the dataset. Set to 'infer'
+ to infer the image size from TFDS info. This requires `name` to be a
+ registered dataset in TFDS.
+ batch_size: The base batch size for the dataset.
+ use_per_replica_batch_size: Whether to scale the batch size based on
+ available resources. If set to `True`, the dataset builder will return
+ batch_size multiplied by `num_devices`, the number of device replicas
+ (e.g., the number of GPUs or TPU cores). This setting should be `True` if
+ the strategy argument is passed to `build()` and `num_devices > 1`.
+ num_devices: The number of replica devices to use. This should be set by
+ `strategy.num_replicas_in_sync` when using a distribution strategy.
+ dtype: The desired dtype of the dataset. This will be set during
+ preprocessing.
+ one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
+ label smoothing.
+ augmenter: The augmenter config to use. No augmentation is used by default.
+ download: Whether to download data using TFDS.
+ shuffle_buffer_size: The buffer size used for shuffling training data.
+ file_shuffle_buffer_size: The buffer size used for shuffling raw training
+ files.
+ skip_decoding: Whether to skip image decoding when loading from TFDS.
+ cache: whether to cache to dataset examples. Can be used to avoid re-reading
+ from disk on the second epoch. Requires significant memory overhead.
+ tf_data_service: The URI of a tf.data service to offload preprocessing onto
+ during training. The URI should be in the format "protocol://address",
+ e.g. "grpc://tf-data-service:5050".
+ mean_subtract: whether or not to apply mean subtraction to the dataset.
+ standardize: whether or not to apply standardization to the dataset.
+ """
+ name: Optional[str] = None
+ data_dir: Optional[str] = None
+ filenames: Optional[List[str]] = None
+ builder: str = 'tfds'
+ split: str = 'train'
+ image_size: Union[int, str] = 'infer'
+ num_classes: Union[int, str] = 'infer'
+ num_channels: Union[int, str] = 'infer'
+ num_examples: Union[int, str] = 'infer'
+ batch_size: int = 128
+ use_per_replica_batch_size: bool = True
+ num_devices: int = 1
+ dtype: str = 'float32'
+ one_hot: bool = True
+ augmenter: AugmentConfig = AugmentConfig()
+ download: bool = False
+ shuffle_buffer_size: int = 10000
+ file_shuffle_buffer_size: int = 1024
+ skip_decoding: bool = True
+ cache: bool = False
+ tf_data_service: Optional[str] = None
+ mean_subtract: bool = False
+ standardize: bool = False
+
+ @property
+ def has_data(self):
+ """Whether this dataset is has any data associated with it."""
+ return self.name or self.data_dir or self.filenames
+
+
+@dataclass
+class ImageNetConfig(DatasetConfig):
+ """The base ImageNet dataset config."""
+ name: str = 'imagenet2012'
+ # Note: for large datasets like ImageNet, using records is faster than tfds
+ builder: str = 'records'
+ image_size: int = 224
+ batch_size: int = 128
+
+
+@dataclass
+class Cifar10Config(DatasetConfig):
+ """The base CIFAR-10 dataset config."""
+ name: str = 'cifar10'
+ image_size: int = 224
+ batch_size: int = 128
+ download: bool = True
+ cache: bool = True
+
+
+class DatasetBuilder:
+ """An object for building datasets.
+
+ Allows building various pipelines fetching examples, preprocessing, etc.
+ Maintains additional state information calculated from the dataset, i.e.,
+ training set split, batch size, and number of steps (batches).
+ """
+
+ def __init__(self, config: DatasetConfig, **overrides: Any):
+ """Initialize the builder from the config."""
+ self.config = config.replace(**overrides)
+ self.builder_info = None
+
+ if self.config.augmenter is not None:
+ logging.info('Using augmentation: %s', self.config.augmenter.name)
+ self.augmenter = self.config.augmenter.build()
+ else:
+ self.augmenter = None
+
+ @property
+ def is_training(self) -> bool:
+ """Whether this is the training set."""
+ return self.config.split == 'train'
+
+ @property
+ def batch_size(self) -> int:
+ """The batch size, multiplied by the number of replicas (if configured)."""
+ if self.config.use_per_replica_batch_size:
+ return self.config.batch_size * self.config.num_devices
+ else:
+ return self.config.batch_size
+
+ @property
+ def global_batch_size(self):
+ """The global batch size across all replicas."""
+ return self.batch_size
+
+ @property
+ def local_batch_size(self):
+ """The base unscaled batch size."""
+ if self.config.use_per_replica_batch_size:
+ return self.config.batch_size
+ else:
+ return self.config.batch_size // self.config.num_devices
+
+ @property
+ def num_steps(self) -> int:
+ """The number of steps (batches) to exhaust this dataset."""
+ # Always divide by the global batch size to get the correct # of steps
+ return self.num_examples // self.global_batch_size
+
+ @property
+ def dtype(self) -> tf.dtypes.DType:
+ """Converts the config's dtype string to a tf dtype.
+
+ Returns:
+ A mapping from string representation of a dtype to the `tf.dtypes.DType`.
+
+ Raises:
+ ValueError if the config's dtype is not supported.
+
+ """
+ dtype_map = {
+ 'float32': tf.float32,
+ 'bfloat16': tf.bfloat16,
+ 'float16': tf.float16,
+ 'fp32': tf.float32,
+ 'bf16': tf.bfloat16,
+ }
+ try:
+ return dtype_map[self.config.dtype]
+ except:
+ raise ValueError('Invalid DType provided. Supported types: {}'.format(
+ dtype_map.keys()))
+
+ @property
+ def image_size(self) -> int:
+ """The size of each image (can be inferred from the dataset)."""
+
+ if self.config.image_size == 'infer':
+ return self.info.features['image'].shape[0]
+ else:
+ return int(self.config.image_size)
+
+ @property
+ def num_channels(self) -> int:
+ """The number of image channels (can be inferred from the dataset)."""
+ if self.config.num_channels == 'infer':
+ return self.info.features['image'].shape[-1]
+ else:
+ return int(self.config.num_channels)
+
+ @property
+ def num_examples(self) -> int:
+ """The number of examples (can be inferred from the dataset)."""
+ if self.config.num_examples == 'infer':
+ return self.info.splits[self.config.split].num_examples
+ else:
+ return int(self.config.num_examples)
+
+ @property
+ def num_classes(self) -> int:
+ """The number of classes (can be inferred from the dataset)."""
+ if self.config.num_classes == 'infer':
+ return self.info.features['label'].num_classes
+ else:
+ return int(self.config.num_classes)
+
+ @property
+ def info(self) -> tfds.core.DatasetInfo:
+ """The TFDS dataset info, if available."""
+ if self.builder_info is None:
+ self.builder_info = tfds.builder(self.config.name).info
+ return self.builder_info
+
+ def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
+ """Construct a dataset end-to-end and return it using an optional strategy.
+
+ Args:
+ strategy: a strategy that, if passed, will distribute the dataset
+ according to that strategy. If passed and `num_devices > 1`,
+ `use_per_replica_batch_size` must be set to `True`.
+
+ Returns:
+ A TensorFlow dataset outputting batched images and labels.
+ """
+ if strategy:
+ if strategy.num_replicas_in_sync != self.config.num_devices:
+ logging.warn('Passed a strategy with %d devices, but expected'
+ '%d devices.',
+ strategy.num_replicas_in_sync,
+ self.config.num_devices)
+ dataset = strategy.experimental_distribute_datasets_from_function(
+ self._build)
+ else:
+ dataset = self._build()
+
+ return dataset
+
+ def _build(self, input_context: tf.distribute.InputContext = None
+ ) -> tf.data.Dataset:
+ """Construct a dataset end-to-end and return it.
+
+ Args:
+ input_context: An optional context provided by `tf.distribute` for
+ cross-replica training.
+
+ Returns:
+ A TensorFlow dataset outputting batched images and labels.
+ """
+ builders = {
+ 'tfds': self.load_tfds,
+ 'records': self.load_records,
+ 'synthetic': self.load_synthetic,
+ }
+
+ builder = builders.get(self.config.builder, None)
+
+ if builder is None:
+ raise ValueError('Unknown builder type {}'.format(self.config.builder))
+
+ self.input_context = input_context
+ dataset = builder()
+ dataset = self.pipeline(dataset)
+
+ return dataset
+
+ def load_tfds(self) -> tf.data.Dataset:
+ """Return a dataset loading files from TFDS."""
+
+ logging.info('Using TFDS to load data.')
+
+ builder = tfds.builder(self.config.name,
+ data_dir=self.config.data_dir)
+
+ if self.config.download:
+ builder.download_and_prepare()
+
+ decoders = {}
+
+ if self.config.skip_decoding:
+ decoders['image'] = tfds.decode.SkipDecoding()
+
+ read_config = tfds.ReadConfig(
+ interleave_cycle_length=10,
+ interleave_block_length=1,
+ input_context=self.input_context)
+
+ dataset = builder.as_dataset(
+ split=self.config.split,
+ as_supervised=True,
+ shuffle_files=True,
+ decoders=decoders,
+ read_config=read_config)
+
+ return dataset
+
+ def load_records(self) -> tf.data.Dataset:
+ """Return a dataset loading files with TFRecords."""
+ logging.info('Using TFRecords to load data.')
+ if self.config.filenames is None:
+ if self.config.data_dir is None:
+ raise ValueError('Dataset must specify a path for the data files.')
+
+ file_pattern = os.path.join(self.config.data_dir,
+ '{}*'.format(self.config.split))
+ dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
+ else:
+ dataset = tf.data.Dataset.from_tensor_slices(self.config.filenames)
+
+ return dataset
+
+ def load_synthetic(self) -> tf.data.Dataset:
+ """Return a dataset generating dummy synthetic data."""
+ logging.info('Generating a synthetic dataset.')
+
+ def generate_data(_):
+ image = tf.zeros([self.image_size, self.image_size, self.num_channels],
+ dtype=self.dtype)
+ label = tf.zeros([1], dtype=tf.int32)
+ return image, label
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(generate_data,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
+ """Build a pipeline fetching, shuffling, and preprocessing the dataset.
+
+ Args:
+ dataset: A `tf.data.Dataset` that loads raw files.
+
+ Returns:
+ A TensorFlow dataset outputting batched images and labels.
+ """
+ if (self.config.builder != 'tfds' and self.input_context
+ and self.input_context.num_input_pipelines > 1):
+ dataset = dataset.shard(self.input_context.num_input_pipelines,
+ self.input_context.input_pipeline_id)
+ logging.info('Sharding the dataset: input_pipeline_id=%d '
+ 'num_input_pipelines=%d',
+ self.input_context.num_input_pipelines,
+ self.input_context.input_pipeline_id)
+
+ if self.is_training and self.config.builder == 'records':
+ # Shuffle the input files.
+ dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)
+
+ if self.is_training and not self.config.cache:
+ dataset = dataset.repeat()
+
+ if self.config.builder == 'records':
+ # Read the data from disk in parallel
+ dataset = dataset.interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=10,
+ block_length=1,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if self.config.cache:
+ dataset = dataset.cache()
+
+ if self.is_training:
+ dataset = dataset.shuffle(self.config.shuffle_buffer_size)
+ dataset = dataset.repeat()
+
+ # Parse, pre-process, and batch the data in parallel
+ if self.config.builder == 'records':
+ preprocess = self.parse_record
+ else:
+ preprocess = self.preprocess
+ dataset = dataset.map(preprocess,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if self.input_context and self.config.num_devices > 1:
+ if not self.config.use_per_replica_batch_size:
+ raise ValueError(
+ 'The builder does not support a global batch size with more than '
+ 'one replica. Got {} replicas. Please set a '
+ '`per_replica_batch_size` and enable '
+ '`use_per_replica_batch_size=True`.'.format(
+ self.config.num_devices))
+
+ # The batch size of the dataset will be multiplied by the number of
+ # replicas automatically when strategy.distribute_datasets_from_function
+ # is called, so we use local batch size here.
+ dataset = dataset.batch(self.local_batch_size,
+ drop_remainder=self.is_training)
+ else:
+ dataset = dataset.batch(self.global_batch_size,
+ drop_remainder=self.is_training)
+
+ # Prefetch overlaps in-feed with training
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+
+ if self.config.tf_data_service:
+ if not hasattr(tf.data.experimental, 'service'):
+ raise ValueError('The tf_data_service flag requires Tensorflow version '
+ '>= 2.3.0, but the version is {}'.format(
+ tf.__version__))
+ dataset = dataset.apply(
+ tf.data.experimental.service.distribute(
+ processing_mode='parallel_epochs',
+ service=self.config.tf_data_service,
+ job_name='resnet_train'))
+ dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+
+ return dataset
+
+ def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
+ """Parse an ImageNet record from a serialized string Tensor."""
+ keys_to_features = {
+ 'image/encoded':
+ tf.io.FixedLenFeature((), tf.string, ''),
+ 'image/format':
+ tf.io.FixedLenFeature((), tf.string, 'jpeg'),
+ 'image/class/label':
+ tf.io.FixedLenFeature([], tf.int64, -1),
+ 'image/class/text':
+ tf.io.FixedLenFeature([], tf.string, ''),
+ 'image/object/bbox/xmin':
+ tf.io.VarLenFeature(dtype=tf.float32),
+ 'image/object/bbox/ymin':
+ tf.io.VarLenFeature(dtype=tf.float32),
+ 'image/object/bbox/xmax':
+ tf.io.VarLenFeature(dtype=tf.float32),
+ 'image/object/bbox/ymax':
+ tf.io.VarLenFeature(dtype=tf.float32),
+ 'image/object/class/label':
+ tf.io.VarLenFeature(dtype=tf.int64),
+ }
+
+ parsed = tf.io.parse_single_example(record, keys_to_features)
+
+ label = tf.reshape(parsed['image/class/label'], shape=[1])
+
+ # Subtract one so that labels are in [0, 1000)
+ label -= 1
+
+ image_bytes = tf.reshape(parsed['image/encoded'], shape=[])
+ image, label = self.preprocess(image_bytes, label)
+
+ return image, label
+
+ def preprocess(self, image: tf.Tensor, label: tf.Tensor
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
+ """Apply image preprocessing and augmentation to the image and label."""
+ if self.is_training:
+ image = preprocessing.preprocess_for_train(
+ image,
+ image_size=self.image_size,
+ mean_subtract=self.config.mean_subtract,
+ standardize=self.config.standardize,
+ dtype=self.dtype,
+ augmenter=self.augmenter)
+ else:
+ image = preprocessing.preprocess_for_eval(
+ image,
+ image_size=self.image_size,
+ num_channels=self.num_channels,
+ mean_subtract=self.config.mean_subtract,
+ standardize=self.config.standardize,
+ dtype=self.dtype)
+
+ label = tf.cast(label, tf.int32)
+ if self.config.one_hot:
+ label = tf.one_hot(label, self.num_classes)
+ label = tf.reshape(label, [self.num_classes])
+
+ return image, label
+
+ @classmethod
+ def from_params(cls, *args, **kwargs):
+ """Construct a dataset builder from a default config and any overrides."""
+ config = DatasetConfig.from_args(*args, **kwargs)
+ return cls(config)
diff --git a/models/official/vision/image_classification/efficientnet/__init__.py b/models/official/vision/image_classification/efficientnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/vision/image_classification/efficientnet/common_modules.py b/models/official/vision/image_classification/efficientnet/common_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c9c2097d2398ec78cae5e1265478f804860f944
--- /dev/null
+++ b/models/official/vision/image_classification/efficientnet/common_modules.py
@@ -0,0 +1,117 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common modeling utilities."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+import tensorflow.compat.v1 as tf1
+from typing import Text, Optional
+
+from tensorflow.python.tpu import tpu_function
+
+
+@tf.keras.utils.register_keras_serializable(package='Vision')
+class TpuBatchNormalization(tf.keras.layers.BatchNormalization):
+ """Cross replica batch normalization."""
+
+ def __init__(self, fused: Optional[bool] = False, **kwargs):
+ if fused in (True, None):
+ raise ValueError('TpuBatchNormalization does not support fused=True.')
+ super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs)
+
+ def _cross_replica_average(self, t: tf.Tensor, num_shards_per_group: int):
+ """Calculates the average value of input tensor across TPU replicas."""
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ group_assignment = None
+ if num_shards_per_group > 1:
+ if num_shards % num_shards_per_group != 0:
+ raise ValueError(
+ 'num_shards: %d mod shards_per_group: %d, should be 0' %
+ (num_shards, num_shards_per_group))
+ num_groups = num_shards // num_shards_per_group
+ group_assignment = [[
+ x for x in range(num_shards) if x // num_shards_per_group == y
+ ] for y in range(num_groups)]
+ return tf1.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
+ num_shards_per_group, t.dtype)
+
+ def _moments(self, inputs: tf.Tensor, reduction_axes: int, keep_dims: int):
+ """Compute the mean and variance: it overrides the original _moments."""
+ shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
+ inputs, reduction_axes, keep_dims=keep_dims)
+
+ num_shards = tpu_function.get_tpu_context().number_of_shards or 1
+ if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices.
+ num_shards_per_group = 1
+ else:
+ num_shards_per_group = max(8, num_shards // 8)
+ if num_shards_per_group > 1:
+ # Compute variance using: Var[X]= E[X^2] - E[X]^2.
+ shard_square_of_mean = tf.math.square(shard_mean)
+ shard_mean_of_square = shard_variance + shard_square_of_mean
+ group_mean = self._cross_replica_average(shard_mean, num_shards_per_group)
+ group_mean_of_square = self._cross_replica_average(
+ shard_mean_of_square, num_shards_per_group)
+ group_variance = group_mean_of_square - tf.math.square(group_mean)
+ return (group_mean, group_variance)
+ else:
+ return (shard_mean, shard_variance)
+
+
+def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization:
+ """A helper to create a batch normalization getter.
+
+ Args:
+ batch_norm_type: The type of batch normalization layer implementation. `tpu`
+ will use `TpuBatchNormalization`.
+
+ Returns:
+ An instance of `tf.keras.layers.BatchNormalization`.
+ """
+ if batch_norm_type == 'tpu':
+ return TpuBatchNormalization
+
+ return tf.keras.layers.BatchNormalization
+
+
+def count_params(model, trainable_only=True):
+ """Returns the count of all model parameters, or just trainable ones."""
+ if not trainable_only:
+ return model.count_params()
+ else:
+ return int(np.sum([tf.keras.backend.count_params(p)
+ for p in model.trainable_weights]))
+
+
+def load_weights(model: tf.keras.Model,
+ model_weights_path: Text,
+ weights_format: Text = 'saved_model'):
+ """Load model weights from the given file path.
+
+ Args:
+ model: the model to load weights into
+ model_weights_path: the path of the model weights
+ weights_format: the model weights format. One of 'saved_model', 'h5',
+ or 'checkpoint'.
+ """
+ if weights_format == 'saved_model':
+ loaded_model = tf.keras.models.load_model(model_weights_path)
+ model.set_weights(loaded_model.get_weights())
+ else:
+ model.load_weights(model_weights_path)
diff --git a/models/official/vision/image_classification/efficientnet/efficientnet_config.py b/models/official/vision/image_classification/efficientnet/efficientnet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a758cc63c944463ebf184eaeae26cebd5935031a
--- /dev/null
+++ b/models/official/vision/image_classification/efficientnet/efficientnet_config.py
@@ -0,0 +1,78 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Configuration definitions for EfficientNet losses, learning rates, and optimizers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from typing import Any, Mapping
+
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+from official.vision.image_classification.configs import base_configs
+
+
+@dataclasses.dataclass
+class EfficientNetModelConfig(base_configs.ModelConfig):
+ """Configuration for the EfficientNet model.
+
+ This configuration will default to settings used for training efficientnet-b0
+ on a v3-8 TPU on ImageNet.
+
+ Attributes:
+ name: The name of the model. Defaults to 'EfficientNet'.
+ num_classes: The number of classes in the model.
+ model_params: A dictionary that represents the parameters of the
+ EfficientNet model. These will be passed in to the "from_name" function.
+ loss: The configuration for loss. Defaults to a categorical cross entropy
+ implementation.
+ optimizer: The configuration for optimizations. Defaults to an RMSProp
+ configuration.
+ learning_rate: The configuration for learning rate. Defaults to an
+ exponential configuration.
+ """
+ name: str = 'EfficientNet'
+ num_classes: int = 1000
+ model_params: base_config.Config = dataclasses.field(
+ default_factory=lambda: {
+ 'model_name': 'efficientnet-b0',
+ 'model_weights_path': '',
+ 'weights_format': 'saved_model',
+ 'overrides': {
+ 'batch_norm': 'default',
+ 'rescale_input': True,
+ 'num_classes': 1000,
+ 'activation': 'swish',
+ 'dtype': 'float32',
+ }
+ })
+ loss: base_configs.LossConfig = base_configs.LossConfig(
+ name='categorical_crossentropy', label_smoothing=0.1)
+ optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
+ name='rmsprop',
+ decay=0.9,
+ epsilon=0.001,
+ momentum=0.9,
+ moving_average_decay=None)
+ learning_rate: base_configs.LearningRateConfig = base_configs.LearningRateConfig( # pylint: disable=line-too-long
+ name='exponential',
+ initial_lr=0.008,
+ decay_epochs=2.4,
+ decay_rate=0.97,
+ warmup_epochs=5,
+ scale_by_batch_size=1. / 128.,
+ staircase=True)
diff --git a/models/official/vision/image_classification/efficientnet/efficientnet_model.py b/models/official/vision/image_classification/efficientnet/efficientnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab81fc25d1200557c99f77424d34c74cf8774d84
--- /dev/null
+++ b/models/official/vision/image_classification/efficientnet/efficientnet_model.py
@@ -0,0 +1,505 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains definitions for EfficientNet model.
+
+[1] Mingxing Tan, Quoc V. Le
+ EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
+ ICML'19, https://arxiv.org/abs/1905.11946
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os
+from typing import Any, Dict, Optional, Text, Tuple
+
+from absl import logging
+from dataclasses import dataclass
+import tensorflow as tf
+
+from official.modeling import tf_utils
+from official.modeling.hyperparams import base_config
+from official.vision.image_classification import preprocessing
+from official.vision.image_classification.efficientnet import common_modules
+
+
+@dataclass
+class BlockConfig(base_config.Config):
+ """Config for a single MB Conv Block."""
+ input_filters: int = 0
+ output_filters: int = 0
+ kernel_size: int = 3
+ num_repeat: int = 1
+ expand_ratio: int = 1
+ strides: Tuple[int, int] = (1, 1)
+ se_ratio: Optional[float] = None
+ id_skip: bool = True
+ fused_conv: bool = False
+ conv_type: str = 'depthwise'
+
+
+@dataclass
+class ModelConfig(base_config.Config):
+ """Default Config for Efficientnet-B0."""
+ width_coefficient: float = 1.0
+ depth_coefficient: float = 1.0
+ resolution: int = 224
+ dropout_rate: float = 0.2
+ blocks: Tuple[BlockConfig, ...] = (
+ # (input_filters, output_filters, kernel_size, num_repeat,
+ # expand_ratio, strides, se_ratio)
+ # pylint: disable=bad-whitespace
+ BlockConfig.from_args(32, 16, 3, 1, 1, (1, 1), 0.25),
+ BlockConfig.from_args(16, 24, 3, 2, 6, (2, 2), 0.25),
+ BlockConfig.from_args(24, 40, 5, 2, 6, (2, 2), 0.25),
+ BlockConfig.from_args(40, 80, 3, 3, 6, (2, 2), 0.25),
+ BlockConfig.from_args(80, 112, 5, 3, 6, (1, 1), 0.25),
+ BlockConfig.from_args(112, 192, 5, 4, 6, (2, 2), 0.25),
+ BlockConfig.from_args(192, 320, 3, 1, 6, (1, 1), 0.25),
+ # pylint: enable=bad-whitespace
+ )
+ stem_base_filters: int = 32
+ top_base_filters: int = 1280
+ activation: str = 'simple_swish'
+ batch_norm: str = 'default'
+ bn_momentum: float = 0.99
+ bn_epsilon: float = 1e-3
+ # While the original implementation used a weight decay of 1e-5,
+ # tf.nn.l2_loss divides it by 2, so we halve this to compensate in Keras
+ weight_decay: float = 5e-6
+ drop_connect_rate: float = 0.2
+ depth_divisor: int = 8
+ min_depth: Optional[int] = None
+ use_se: bool = True
+ input_channels: int = 3
+ num_classes: int = 1000
+ model_name: str = 'efficientnet'
+ rescale_input: bool = True
+ data_format: str = 'channels_last'
+ dtype: str = 'float32'
+
+
+MODEL_CONFIGS = {
+ # (width, depth, resolution, dropout)
+ 'efficientnet-b0': ModelConfig.from_args(1.0, 1.0, 224, 0.2),
+ 'efficientnet-b1': ModelConfig.from_args(1.0, 1.1, 240, 0.2),
+ 'efficientnet-b2': ModelConfig.from_args(1.1, 1.2, 260, 0.3),
+ 'efficientnet-b3': ModelConfig.from_args(1.2, 1.4, 300, 0.3),
+ 'efficientnet-b4': ModelConfig.from_args(1.4, 1.8, 380, 0.4),
+ 'efficientnet-b5': ModelConfig.from_args(1.6, 2.2, 456, 0.4),
+ 'efficientnet-b6': ModelConfig.from_args(1.8, 2.6, 528, 0.5),
+ 'efficientnet-b7': ModelConfig.from_args(2.0, 3.1, 600, 0.5),
+ 'efficientnet-b8': ModelConfig.from_args(2.2, 3.6, 672, 0.5),
+ 'efficientnet-l2': ModelConfig.from_args(4.3, 5.3, 800, 0.5),
+}
+
+CONV_KERNEL_INITIALIZER = {
+ 'class_name': 'VarianceScaling',
+ 'config': {
+ 'scale': 2.0,
+ 'mode': 'fan_out',
+ # Note: this is a truncated normal distribution
+ 'distribution': 'normal'
+ }
+}
+
+DENSE_KERNEL_INITIALIZER = {
+ 'class_name': 'VarianceScaling',
+ 'config': {
+ 'scale': 1 / 3.0,
+ 'mode': 'fan_out',
+ 'distribution': 'uniform'
+ }
+}
+
+
+def round_filters(filters: int,
+ config: ModelConfig) -> int:
+ """Round number of filters based on width coefficient."""
+ width_coefficient = config.width_coefficient
+ min_depth = config.min_depth
+ divisor = config.depth_divisor
+ orig_filters = filters
+
+ if not width_coefficient:
+ return filters
+
+ filters *= width_coefficient
+ min_depth = min_depth or divisor
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ logging.info('round_filter input=%s output=%s', orig_filters, new_filters)
+ return int(new_filters)
+
+
+def round_repeats(repeats: int, depth_coefficient: float) -> int:
+ """Round number of repeats based on depth coefficient."""
+ return int(math.ceil(depth_coefficient * repeats))
+
+
+def conv2d_block(inputs: tf.Tensor,
+ conv_filters: Optional[int],
+ config: ModelConfig,
+ kernel_size: Any = (1, 1),
+ strides: Any = (1, 1),
+ use_batch_norm: bool = True,
+ use_bias: bool = False,
+ activation: Any = None,
+ depthwise: bool = False,
+ name: Text = None):
+ """A conv2d followed by batch norm and an activation."""
+ batch_norm = common_modules.get_batch_norm(config.batch_norm)
+ bn_momentum = config.bn_momentum
+ bn_epsilon = config.bn_epsilon
+ data_format = tf.keras.backend.image_data_format()
+ weight_decay = config.weight_decay
+
+ name = name or ''
+
+ # Collect args based on what kind of conv2d block is desired
+ init_kwargs = {
+ 'kernel_size': kernel_size,
+ 'strides': strides,
+ 'use_bias': use_bias,
+ 'padding': 'same',
+ 'name': name + '_conv2d',
+ 'kernel_regularizer': tf.keras.regularizers.l2(weight_decay),
+ 'bias_regularizer': tf.keras.regularizers.l2(weight_decay),
+ }
+
+ if depthwise:
+ conv2d = tf.keras.layers.DepthwiseConv2D
+ init_kwargs.update({'depthwise_initializer': CONV_KERNEL_INITIALIZER})
+ else:
+ conv2d = tf.keras.layers.Conv2D
+ init_kwargs.update({'filters': conv_filters,
+ 'kernel_initializer': CONV_KERNEL_INITIALIZER})
+
+ x = conv2d(**init_kwargs)(inputs)
+
+ if use_batch_norm:
+ bn_axis = 1 if data_format == 'channels_first' else -1
+ x = batch_norm(axis=bn_axis,
+ momentum=bn_momentum,
+ epsilon=bn_epsilon,
+ name=name + '_bn')(x)
+
+ if activation is not None:
+ x = tf.keras.layers.Activation(activation,
+ name=name + '_activation')(x)
+ return x
+
+
+def mb_conv_block(inputs: tf.Tensor,
+ block: BlockConfig,
+ config: ModelConfig,
+ prefix: Text = None):
+ """Mobile Inverted Residual Bottleneck.
+
+ Args:
+ inputs: the Keras input to the block
+ block: BlockConfig, arguments to create a Block
+ config: ModelConfig, a set of model parameters
+ prefix: prefix for naming all layers
+
+ Returns:
+ the output of the block
+ """
+ use_se = config.use_se
+ activation = tf_utils.get_activation(config.activation)
+ drop_connect_rate = config.drop_connect_rate
+ data_format = tf.keras.backend.image_data_format()
+ use_depthwise = block.conv_type != 'no_depthwise'
+ prefix = prefix or ''
+
+ filters = block.input_filters * block.expand_ratio
+
+ x = inputs
+
+ if block.fused_conv:
+ # If we use fused mbconv, skip expansion and use regular conv.
+ x = conv2d_block(x,
+ filters,
+ config,
+ kernel_size=block.kernel_size,
+ strides=block.strides,
+ activation=activation,
+ name=prefix + 'fused')
+ else:
+ if block.expand_ratio != 1:
+ # Expansion phase
+ kernel_size = (1, 1) if use_depthwise else (3, 3)
+ x = conv2d_block(x,
+ filters,
+ config,
+ kernel_size=kernel_size,
+ activation=activation,
+ name=prefix + 'expand')
+
+ # Depthwise Convolution
+ if use_depthwise:
+ x = conv2d_block(x,
+ conv_filters=None,
+ config=config,
+ kernel_size=block.kernel_size,
+ strides=block.strides,
+ activation=activation,
+ depthwise=True,
+ name=prefix + 'depthwise')
+
+ # Squeeze and Excitation phase
+ if use_se:
+ assert block.se_ratio is not None
+ assert 0 < block.se_ratio <= 1
+ num_reduced_filters = max(1, int(
+ block.input_filters * block.se_ratio
+ ))
+
+ if data_format == 'channels_first':
+ se_shape = (filters, 1, 1)
+ else:
+ se_shape = (1, 1, filters)
+
+ se = tf.keras.layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)
+ se = tf.keras.layers.Reshape(se_shape, name=prefix + 'se_reshape')(se)
+
+ se = conv2d_block(se,
+ num_reduced_filters,
+ config,
+ use_bias=True,
+ use_batch_norm=False,
+ activation=activation,
+ name=prefix + 'se_reduce')
+ se = conv2d_block(se,
+ filters,
+ config,
+ use_bias=True,
+ use_batch_norm=False,
+ activation='sigmoid',
+ name=prefix + 'se_expand')
+ x = tf.keras.layers.multiply([x, se], name=prefix + 'se_excite')
+
+ # Output phase
+ x = conv2d_block(x,
+ block.output_filters,
+ config,
+ activation=None,
+ name=prefix + 'project')
+
+ # Add identity so that quantization-aware training can insert quantization
+ # ops correctly.
+ x = tf.keras.layers.Activation(tf_utils.get_activation('identity'),
+ name=prefix + 'id')(x)
+
+ if (block.id_skip
+ and all(s == 1 for s in block.strides)
+ and block.input_filters == block.output_filters):
+ if drop_connect_rate and drop_connect_rate > 0:
+ # Apply dropconnect
+ # The only difference between dropout and dropconnect in TF is scaling by
+ # drop_connect_rate during training. See:
+ # https://github.com/keras-team/keras/pull/9898#issuecomment-380577612
+ x = tf.keras.layers.Dropout(drop_connect_rate,
+ noise_shape=(None, 1, 1, 1),
+ name=prefix + 'drop')(x)
+
+ x = tf.keras.layers.add([x, inputs], name=prefix + 'add')
+
+ return x
+
+
+def efficientnet(image_input: tf.keras.layers.Input,
+ config: ModelConfig):
+ """Creates an EfficientNet graph given the model parameters.
+
+ This function is wrapped by the `EfficientNet` class to make a tf.keras.Model.
+
+ Args:
+ image_input: the input batch of images
+ config: the model config
+
+ Returns:
+ the output of efficientnet
+ """
+ depth_coefficient = config.depth_coefficient
+ blocks = config.blocks
+ stem_base_filters = config.stem_base_filters
+ top_base_filters = config.top_base_filters
+ activation = tf_utils.get_activation(config.activation)
+ dropout_rate = config.dropout_rate
+ drop_connect_rate = config.drop_connect_rate
+ num_classes = config.num_classes
+ input_channels = config.input_channels
+ rescale_input = config.rescale_input
+ data_format = tf.keras.backend.image_data_format()
+ dtype = config.dtype
+ weight_decay = config.weight_decay
+
+ x = image_input
+ if data_format == 'channels_first':
+ # Happens on GPU/TPU if available.
+ x = tf.keras.layers.Permute((3, 1, 2))(x)
+ if rescale_input:
+ x = preprocessing.normalize_images(x,
+ num_channels=input_channels,
+ dtype=dtype,
+ data_format=data_format)
+
+ # Build stem
+ x = conv2d_block(x,
+ round_filters(stem_base_filters, config),
+ config,
+ kernel_size=[3, 3],
+ strides=[2, 2],
+ activation=activation,
+ name='stem')
+
+ # Build blocks
+ num_blocks_total = sum(
+ round_repeats(block.num_repeat, depth_coefficient) for block in blocks)
+ block_num = 0
+
+ for stack_idx, block in enumerate(blocks):
+ assert block.num_repeat > 0
+ # Update block input and output filters based on depth multiplier
+ block = block.replace(
+ input_filters=round_filters(block.input_filters, config),
+ output_filters=round_filters(block.output_filters, config),
+ num_repeat=round_repeats(block.num_repeat, depth_coefficient))
+
+ # The first block needs to take care of stride and filter size increase
+ drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
+ config = config.replace(drop_connect_rate=drop_rate)
+ block_prefix = 'stack_{}/block_0/'.format(stack_idx)
+ x = mb_conv_block(x, block, config, block_prefix)
+ block_num += 1
+ if block.num_repeat > 1:
+ block = block.replace(
+ input_filters=block.output_filters,
+ strides=[1, 1]
+ )
+
+ for block_idx in range(block.num_repeat - 1):
+ drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
+ config = config.replace(drop_connect_rate=drop_rate)
+ block_prefix = 'stack_{}/block_{}/'.format(stack_idx, block_idx + 1)
+ x = mb_conv_block(x, block, config, prefix=block_prefix)
+ block_num += 1
+
+ # Build top
+ x = conv2d_block(x,
+ round_filters(top_base_filters, config),
+ config,
+ activation=activation,
+ name='top')
+
+ # Build classifier
+ x = tf.keras.layers.GlobalAveragePooling2D(name='top_pool')(x)
+ if dropout_rate and dropout_rate > 0:
+ x = tf.keras.layers.Dropout(dropout_rate, name='top_dropout')(x)
+ x = tf.keras.layers.Dense(
+ num_classes,
+ kernel_initializer=DENSE_KERNEL_INITIALIZER,
+ kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
+ bias_regularizer=tf.keras.regularizers.l2(weight_decay),
+ name='logits')(x)
+ x = tf.keras.layers.Activation('softmax', name='probs')(x)
+
+ return x
+
+
+@tf.keras.utils.register_keras_serializable(package='Vision')
+class EfficientNet(tf.keras.Model):
+ """Wrapper class for an EfficientNet Keras model.
+
+ Contains helper methods to build, manage, and save metadata about the model.
+ """
+
+ def __init__(self,
+ config: ModelConfig = None,
+ overrides: Dict[Text, Any] = None):
+ """Create an EfficientNet model.
+
+ Args:
+ config: (optional) the main model parameters to create the model
+ overrides: (optional) a dict containing keys that can override
+ config
+ """
+ overrides = overrides or {}
+ config = config or ModelConfig()
+
+ self.config = config.replace(**overrides)
+
+ input_channels = self.config.input_channels
+ model_name = self.config.model_name
+ input_shape = (None, None, input_channels) # Should handle any size image
+ image_input = tf.keras.layers.Input(shape=input_shape)
+
+ output = efficientnet(image_input, self.config)
+
+ # Cast to float32 in case we have a different model dtype
+ output = tf.cast(output, tf.float32)
+
+ logging.info('Building model %s with params %s',
+ model_name,
+ self.config)
+
+ super(EfficientNet, self).__init__(
+ inputs=image_input, outputs=output, name=model_name)
+
+ @classmethod
+ def from_name(cls,
+ model_name: Text,
+ model_weights_path: Text = None,
+ weights_format: Text = 'saved_model',
+ overrides: Dict[Text, Any] = None):
+ """Construct an EfficientNet model from a predefined model name.
+
+ E.g., `EfficientNet.from_name('efficientnet-b0')`.
+
+ Args:
+ model_name: the predefined model name
+ model_weights_path: the path to the weights (h5 file or saved model dir)
+ weights_format: the model weights format. One of 'saved_model', 'h5',
+ or 'checkpoint'.
+ overrides: (optional) a dict containing keys that can override config
+
+ Returns:
+ A constructed EfficientNet instance.
+ """
+ model_configs = dict(MODEL_CONFIGS)
+ overrides = dict(overrides) if overrides else {}
+
+ # One can define their own custom models if necessary
+ model_configs.update(overrides.pop('model_config', {}))
+
+ if model_name not in model_configs:
+ raise ValueError('Unknown model name {}'.format(model_name))
+
+ config = model_configs[model_name]
+
+ model = cls(config=config, overrides=overrides)
+
+ if model_weights_path:
+ common_modules.load_weights(model,
+ model_weights_path,
+ weights_format=weights_format)
+
+ return model
diff --git a/models/official/vision/image_classification/efficientnet/tfhub_export.py b/models/official/vision/image_classification/efficientnet/tfhub_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..3be8608a5cfc25442f5f936b4052f90b89c6cfce
--- /dev/null
+++ b/models/official/vision/image_classification/efficientnet/tfhub_export.py
@@ -0,0 +1,69 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A script to export TF-Hub SavedModel."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+
+import tensorflow as tf
+
+from official.vision.image_classification.efficientnet import efficientnet_model
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("model_name", None,
+ "EfficientNet model name.")
+flags.DEFINE_string("model_path", None,
+ "File path to TF model checkpoint.")
+flags.DEFINE_string("export_path", None,
+ "TF-Hub SavedModel destination path to export.")
+
+
+def export_tfhub(model_path, hub_destination, model_name):
+ """Restores a tf.keras.Model and saves for TF-Hub."""
+ model_configs = dict(efficientnet_model.MODEL_CONFIGS)
+ config = model_configs[model_name]
+
+ image_input = tf.keras.layers.Input(
+ shape=(None, None, 3), name="image_input", dtype=tf.float32)
+ x = image_input * 255.0
+ ouputs = efficientnet_model.efficientnet(x, config)
+ hub_model = tf.keras.Model(image_input, ouputs)
+ ckpt = tf.train.Checkpoint(model=hub_model)
+ ckpt.restore(model_path).assert_existing_objects_matched()
+ hub_model.save(
+ os.path.join(hub_destination, "classification"), include_optimizer=False)
+
+ feature_vector_output = hub_model.get_layer(name="top_pool").get_output_at(0)
+ hub_model2 = tf.keras.Model(image_input, feature_vector_output)
+ hub_model2.save(
+ os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError("Too many command-line arguments.")
+
+ export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name)
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/vision/image_classification/learning_rate.py b/models/official/vision/image_classification/learning_rate.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c78b04bc6297a08a8bc7823dccc00f464e05ad4
--- /dev/null
+++ b/models/official/vision/image_classification/learning_rate.py
@@ -0,0 +1,164 @@
+# Lint as: python3
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Learning rate utilities for vision tasks."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from typing import Any, List, Mapping
+
+import numpy as np
+import tensorflow as tf
+
+BASE_LEARNING_RATE = 0.1
+
+
+class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """A wrapper for LearningRateSchedule that includes warmup steps."""
+
+ def __init__(
+ self,
+ lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule,
+ warmup_steps: int):
+ """Add warmup decay to a learning rate schedule.
+
+ Args:
+ lr_schedule: base learning rate scheduler
+ warmup_steps: number of warmup steps
+
+ """
+ super(WarmupDecaySchedule, self).__init__()
+ self._lr_schedule = lr_schedule
+ self._warmup_steps = warmup_steps
+
+ def __call__(self, step: int):
+ lr = self._lr_schedule(step)
+ if self._warmup_steps:
+ initial_learning_rate = tf.convert_to_tensor(
+ self._lr_schedule.initial_learning_rate, name="initial_learning_rate")
+ dtype = initial_learning_rate.dtype
+ global_step_recomp = tf.cast(step, dtype)
+ warmup_steps = tf.cast(self._warmup_steps, dtype)
+ warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps
+ lr = tf.cond(global_step_recomp < warmup_steps,
+ lambda: warmup_lr,
+ lambda: lr)
+ return lr
+
+ def get_config(self) -> Mapping[str, Any]:
+ config = self._lr_schedule.get_config()
+ config.update({
+ "warmup_steps": self._warmup_steps,
+ })
+ return config
+
+
+# TODO(b/149030439) - refactor this with
+# tf.keras.optimizers.schedules.PiecewiseConstantDecay + WarmupDecaySchedule.
+class PiecewiseConstantDecayWithWarmup(
+ tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Piecewise constant decay with warmup schedule."""
+
+ def __init__(self,
+ batch_size: int,
+ epoch_size: int,
+ warmup_epochs: int,
+ boundaries: List[int],
+ multipliers: List[float]):
+ """Piecewise constant decay with warmup.
+
+ Args:
+ batch_size: The training batch size used in the experiment.
+ epoch_size: The size of an epoch, or the number of examples in an epoch.
+ warmup_epochs: The number of warmup epochs to apply.
+ boundaries: The list of floats with strictly increasing entries.
+ multipliers: The list of multipliers/learning rates to use for the
+ piecewise portion. The length must be 1 less than that of boundaries.
+
+ """
+ super(PiecewiseConstantDecayWithWarmup, self).__init__()
+ if len(boundaries) != len(multipliers) - 1:
+ raise ValueError("The length of boundaries must be 1 less than the "
+ "length of multipliers")
+
+ base_lr_batch_size = 256
+ steps_per_epoch = epoch_size // batch_size
+
+ self._rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
+ self._step_boundaries = [float(steps_per_epoch) * x for x in boundaries]
+ self._lr_values = [self._rescaled_lr * m for m in multipliers]
+ self._warmup_steps = warmup_epochs * steps_per_epoch
+
+ def __call__(self, step: int):
+ """Compute learning rate at given step."""
+ def warmup_lr():
+ return self._rescaled_lr * (
+ step / tf.cast(self._warmup_steps, tf.float32))
+ def piecewise_lr():
+ return tf.compat.v1.train.piecewise_constant(
+ tf.cast(step, tf.float32), self._step_boundaries, self._lr_values)
+ return tf.cond(step < self._warmup_steps, warmup_lr, piecewise_lr)
+
+ def get_config(self) -> Mapping[str, Any]:
+ return {
+ "rescaled_lr": self._rescaled_lr,
+ "step_boundaries": self._step_boundaries,
+ "lr_values": self._lr_values,
+ "warmup_steps": self._warmup_steps,
+ }
+
+
+class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Class to generate learning rate tensor."""
+
+ def __init__(self, batch_size: int, total_steps: int, warmup_steps: int):
+ """Creates the consine learning rate tensor with linear warmup.
+
+ Args:
+ batch_size: The training batch size used in the experiment.
+ total_steps: Total training steps.
+ warmup_steps: Steps for the warm up period.
+ """
+ super(CosineDecayWithWarmup, self).__init__()
+ base_lr_batch_size = 256
+ self._total_steps = total_steps
+ self._init_learning_rate = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
+ self._warmup_steps = warmup_steps
+
+ def __call__(self, global_step: int):
+ global_step = tf.cast(global_step, dtype=tf.float32)
+ warmup_steps = self._warmup_steps
+ init_lr = self._init_learning_rate
+ total_steps = self._total_steps
+
+ linear_warmup = global_step / warmup_steps * init_lr
+
+ cosine_learning_rate = init_lr * (tf.cos(np.pi *
+ (global_step - warmup_steps) /
+ (total_steps - warmup_steps)) +
+ 1.0) / 2.0
+
+ learning_rate = tf.where(global_step < warmup_steps, linear_warmup,
+ cosine_learning_rate)
+ return learning_rate
+
+ def get_config(self):
+ return {
+ "total_steps": self._total_steps,
+ "warmup_learning_rate": self._warmup_learning_rate,
+ "warmup_steps": self._warmup_steps,
+ "init_learning_rate": self._init_learning_rate,
+ }
diff --git a/models/official/vision/image_classification/learning_rate_test.py b/models/official/vision/image_classification/learning_rate_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..272d2935fd7f1e6a7f1810e9247c4ef505021fde
--- /dev/null
+++ b/models/official/vision/image_classification/learning_rate_test.py
@@ -0,0 +1,99 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for learning_rate."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.vision.image_classification import learning_rate
+
+
+class LearningRateTests(tf.test.TestCase):
+
+ def test_warmup_decay(self):
+ """Basic computational test for warmup decay."""
+ initial_lr = 0.01
+ decay_steps = 100
+ decay_rate = 0.01
+ warmup_steps = 10
+
+ base_lr = tf.keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate=initial_lr,
+ decay_steps=decay_steps,
+ decay_rate=decay_rate)
+ lr = learning_rate.WarmupDecaySchedule(
+ lr_schedule=base_lr,
+ warmup_steps=warmup_steps)
+
+ for step in range(warmup_steps - 1):
+ config = lr.get_config()
+ self.assertEqual(config['warmup_steps'], warmup_steps)
+ self.assertAllClose(self.evaluate(lr(step)),
+ step / warmup_steps * initial_lr)
+
+ def test_piecewise_constant_decay_with_warmup(self):
+ """Basic computational test for piecewise constant decay with warmup."""
+ boundaries = [1, 2, 3]
+ warmup_epochs = boundaries[0]
+ learning_rate_multipliers = [1.0, 0.1, 0.001]
+ expected_keys = [
+ 'rescaled_lr', 'step_boundaries', 'lr_values', 'warmup_steps',
+ ]
+
+ expected_lrs = [0.0, 0.1, 0.1]
+
+ lr = learning_rate.PiecewiseConstantDecayWithWarmup(
+ batch_size=256,
+ epoch_size=256,
+ warmup_epochs=warmup_epochs,
+ boundaries=boundaries[1:],
+ multipliers=learning_rate_multipliers)
+
+ step = 0
+
+ config = lr.get_config()
+ self.assertAllInSet(list(config.keys()), expected_keys)
+
+ for boundary, expected_lr in zip(boundaries, expected_lrs):
+ for _ in range(step, boundary):
+ self.assertAllClose(self.evaluate(lr(step)), expected_lr)
+ step += 1
+
+ def test_piecewise_constant_decay_invalid_boundaries(self):
+ with self.assertRaisesRegex(ValueError,
+ 'The length of boundaries must be 1 less '):
+ learning_rate.PiecewiseConstantDecayWithWarmup(
+ batch_size=256,
+ epoch_size=256,
+ warmup_epochs=1,
+ boundaries=[1, 2],
+ multipliers=[1, 2])
+
+ def test_cosine_decay_with_warmup(self):
+ """Basic computational test for cosine decay with warmup."""
+ expected_lrs = [0.0, 0.1, 0.05, 0.0]
+
+ lr = learning_rate.CosineDecayWithWarmup(
+ batch_size=256, total_steps=3, warmup_steps=1)
+
+ for step in [0, 1, 2, 3]:
+ self.assertAllClose(lr(step), expected_lrs[step])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/vision/image_classification/mnist_main.py b/models/official/vision/image_classification/mnist_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..1470c02d05b431e95de3c5807b68678a96d2b520
--- /dev/null
+++ b/models/official/vision/image_classification/mnist_main.py
@@ -0,0 +1,171 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Runs a simple model on the MNIST dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import model_helpers
+from official.vision.image_classification.resnet import common
+
+FLAGS = flags.FLAGS
+
+
+def build_model():
+ """Constructs the ML model used to predict handwritten digits."""
+
+ image = tf.keras.layers.Input(shape=(28, 28, 1))
+
+ y = tf.keras.layers.Conv2D(filters=32,
+ kernel_size=5,
+ padding='same',
+ activation='relu')(image)
+ y = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
+ strides=(2, 2),
+ padding='same')(y)
+ y = tf.keras.layers.Conv2D(filters=32,
+ kernel_size=5,
+ padding='same',
+ activation='relu')(y)
+ y = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
+ strides=(2, 2),
+ padding='same')(y)
+ y = tf.keras.layers.Flatten()(y)
+ y = tf.keras.layers.Dense(1024, activation='relu')(y)
+ y = tf.keras.layers.Dropout(0.4)(y)
+
+ probs = tf.keras.layers.Dense(10, activation='softmax')(y)
+
+ model = tf.keras.models.Model(image, probs, name='mnist')
+
+ return model
+
+
+@tfds.decode.make_decoder(output_dtype=tf.float32)
+def decode_image(example, feature):
+ """Convert image to float32 and normalize from [0, 255] to [0.0, 1.0]."""
+ return tf.cast(feature.decode_example(example), dtype=tf.float32) / 255
+
+
+def run(flags_obj, datasets_override=None, strategy_override=None):
+ """Run MNIST model training and eval loop using native Keras APIs.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+ datasets_override: A pair of `tf.data.Dataset` objects to train the model,
+ representing the train and test sets.
+ strategy_override: A `tf.distribute.Strategy` object to use for model.
+
+ Returns:
+ Dictionary of training and eval stats.
+ """
+ strategy = strategy_override or distribution_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=flags_obj.num_gpus,
+ tpu_address=flags_obj.tpu)
+
+ strategy_scope = distribution_utils.get_strategy_scope(strategy)
+
+ mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
+ if flags_obj.download:
+ mnist.download_and_prepare()
+
+ mnist_train, mnist_test = datasets_override or mnist.as_dataset(
+ split=['train', 'test'],
+ decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter
+ as_supervised=True)
+ train_input_dataset = mnist_train.cache().repeat().shuffle(
+ buffer_size=50000).batch(flags_obj.batch_size)
+ eval_input_dataset = mnist_test.cache().repeat().batch(flags_obj.batch_size)
+
+ with strategy_scope:
+ lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
+ 0.05, decay_steps=100000, decay_rate=0.96)
+ optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
+
+ model = build_model()
+ model.compile(
+ optimizer=optimizer,
+ loss='sparse_categorical_crossentropy',
+ metrics=['sparse_categorical_accuracy'])
+
+ num_train_examples = mnist.info.splits['train'].num_examples
+ train_steps = num_train_examples // flags_obj.batch_size
+ train_epochs = flags_obj.train_epochs
+
+ ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
+ callbacks = [
+ tf.keras.callbacks.ModelCheckpoint(
+ ckpt_full_path, save_weights_only=True),
+ tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir),
+ ]
+
+ num_eval_examples = mnist.info.splits['test'].num_examples
+ num_eval_steps = num_eval_examples // flags_obj.batch_size
+
+ history = model.fit(
+ train_input_dataset,
+ epochs=train_epochs,
+ steps_per_epoch=train_steps,
+ callbacks=callbacks,
+ validation_steps=num_eval_steps,
+ validation_data=eval_input_dataset,
+ validation_freq=flags_obj.epochs_between_evals)
+
+ export_path = os.path.join(flags_obj.model_dir, 'saved_model')
+ model.save(export_path, include_optimizer=False)
+
+ eval_output = model.evaluate(
+ eval_input_dataset, steps=num_eval_steps, verbose=2)
+
+ stats = common.build_stats(history, eval_output, callbacks)
+ return stats
+
+
+def define_mnist_flags():
+ """Define command line flags for MNIST model."""
+ flags_core.define_base(
+ clean=True,
+ num_gpu=True,
+ train_epochs=True,
+ epochs_between_evals=True,
+ distribution_strategy=True)
+ flags_core.define_device()
+ flags_core.define_distribution()
+ flags.DEFINE_bool('download', False,
+ 'Whether to download data to `--data_dir`.')
+ FLAGS.set_default('batch_size', 1024)
+
+
+def main(_):
+ model_helpers.apply_clean(FLAGS)
+ stats = run(flags.FLAGS)
+ logging.info('Run stats:\n%s', stats)
+
+
+if __name__ == '__main__':
+ logging.set_verbosity(logging.INFO)
+ define_mnist_flags()
+ app.run(main)
diff --git a/models/official/vision/image_classification/mnist_test.py b/models/official/vision/image_classification/mnist_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c05efcfe5d68fbbb3c181c19b59444db1abe5702
--- /dev/null
+++ b/models/official/vision/image_classification/mnist_test.py
@@ -0,0 +1,87 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test the Keras MNIST model on GPU."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from absl.testing import parameterized
+import tensorflow as tf
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.utils.testing import integration
+from official.vision.image_classification import mnist_main
+
+
+def eager_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ mode="eager",
+ )
+
+
+class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
+ """Unit tests for sample Keras MNIST model."""
+ _tempdir = None
+
+ @classmethod
+ def setUpClass(cls): # pylint: disable=invalid-name
+ super(KerasMnistTest, cls).setUpClass()
+ mnist_main.define_mnist_flags()
+
+ def tearDown(self):
+ super(KerasMnistTest, self).tearDown()
+ tf.io.gfile.rmtree(self.get_temp_dir())
+
+ @combinations.generate(eager_strategy_combinations())
+ def test_end_to_end(self, distribution):
+ """Test Keras MNIST model with `strategy`."""
+
+ extra_flags = [
+ "-train_epochs", "1",
+ # Let TFDS find the metadata folder automatically
+ "--data_dir="
+ ]
+
+ dummy_data = (
+ tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32),
+ tf.range(10),
+ )
+ datasets = (
+ tf.data.Dataset.from_tensor_slices(dummy_data),
+ tf.data.Dataset.from_tensor_slices(dummy_data),
+ )
+
+ run = functools.partial(mnist_main.run,
+ datasets_override=datasets,
+ strategy_override=distribution)
+
+ integration.run_synthetic(
+ main=run,
+ synth=False,
+ tmp_root=self.get_temp_dir(),
+ extra_flags=extra_flags)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/official/vision/image_classification/optimizer_factory.py b/models/official/vision/image_classification/optimizer_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..d15aa79e0db61e36074c7227e1eca73df163ffa0
--- /dev/null
+++ b/models/official/vision/image_classification/optimizer_factory.py
@@ -0,0 +1,391 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Optimizer factory for vision tasks."""
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+from absl import logging
+import tensorflow as tf
+import tensorflow_addons as tfa
+
+from typing import Any, Dict, Text, List
+from official.vision.image_classification import learning_rate
+from official.vision.image_classification.configs import base_configs
+
+# pylint: disable=protected-access
+
+
+class MovingAverage(tf.keras.optimizers.Optimizer):
+ """Optimizer that computes a moving average of the variables.
+
+ Empirically it has been found that using the moving average of the trained
+ parameters of a deep network is better than using its trained parameters
+ directly. This optimizer allows you to compute this moving average and swap
+ the variables at save time so that any code outside of the training loop
+ will use by default the average values instead of the original ones.
+
+ Example of usage for training:
+ ```python
+ opt = tf.keras.optimizers.SGD(learning_rate)
+ opt = MovingAverage(opt)
+
+ opt.shadow_copy(model)
+ ```
+
+ At test time, swap the shadow variables to evaluate on the averaged weights:
+ ```python
+ opt.swap_weights()
+ # Test eval the model here
+ opt.swap_weights()
+ ```
+ """
+
+ def __init__(self,
+ optimizer: tf.keras.optimizers.Optimizer,
+ average_decay: float = 0.99,
+ start_step: int = 0,
+ dynamic_decay: bool = True,
+ name: Text = 'moving_average',
+ **kwargs):
+ """Construct a new MovingAverage optimizer.
+
+ Args:
+ optimizer: `tf.keras.optimizers.Optimizer` that will be
+ used to compute and apply gradients.
+ average_decay: float. Decay to use to maintain the moving averages
+ of trained variables.
+ start_step: int. What step to start the moving average.
+ dynamic_decay: bool. Whether to change the decay based on the number
+ of optimizer updates. Decay will start at 0.1 and gradually increase
+ up to `average_decay` after each optimizer update. This behavior is
+ similar to `tf.train.ExponentialMovingAverage` in TF 1.x.
+ name: Optional name for the operations created when applying
+ gradients. Defaults to "moving_average".
+ **kwargs: keyword arguments. Allowed to be {`clipnorm`,
+ `clipvalue`, `lr`, `decay`}.
+ """
+ super(MovingAverage, self).__init__(name, **kwargs)
+ self._optimizer = optimizer
+ self._average_decay = average_decay
+ self._start_step = tf.constant(start_step, tf.float32)
+ self._dynamic_decay = dynamic_decay
+
+ def shadow_copy(self, model: tf.keras.Model):
+ """Creates shadow variables for the given model weights."""
+ for var in model.weights:
+ self.add_slot(var, 'average', initializer='zeros')
+ self._average_weights = [
+ self.get_slot(var, 'average') for var in model.weights
+ ]
+ self._model_weights = model.weights
+
+ @property
+ def has_shadow_copy(self):
+ """Whether this optimizer has created shadow variables."""
+ return self._model_weights is not None
+
+ def _create_slots(self, var_list):
+ self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
+
+ def apply_gradients(self, grads_and_vars, name: Text = None):
+ result = self._optimizer.apply_gradients(grads_and_vars, name)
+ self.update_average(self._optimizer.iterations)
+ return result
+
+ @tf.function
+ def update_average(self, step: tf.Tensor):
+ step = tf.cast(step, tf.float32)
+ if step < self._start_step:
+ decay = tf.constant(0., tf.float32)
+ elif self._dynamic_decay:
+ decay = step - self._start_step
+ decay = tf.minimum(self._average_decay, (1. + decay) / (10. + decay))
+ else:
+ decay = self._average_decay
+
+ def _apply_moving(v_moving, v_normal):
+ diff = v_moving - v_normal
+ v_moving.assign_sub(tf.cast(1. - decay, v_moving.dtype) * diff)
+ return v_moving
+
+ def _update(strategy, v_moving_and_v_normal):
+ for v_moving, v_normal in v_moving_and_v_normal:
+ strategy.extended.update(v_moving, _apply_moving, args=(v_normal,))
+
+ ctx = tf.distribute.get_replica_context()
+ return ctx.merge_call(_update, args=(zip(self._average_weights,
+ self._model_weights),))
+
+ def swap_weights(self):
+ """Swap the average and moving weights.
+
+ This is a convenience method to allow one to evaluate the averaged weights
+ at test time. Loads the weights stored in `self._average` into the model,
+ keeping a copy of the original model weights. Swapping twice will return
+ the original weights.
+ """
+ if tf.distribute.in_cross_replica_context():
+ strategy = tf.distribute.get_strategy()
+ strategy.run(self._swap_weights, args=())
+ else:
+ raise ValueError('Swapping weights must occur under a '
+ 'tf.distribute.Strategy')
+
+ @tf.function
+ def _swap_weights(self):
+ def fn_0(a, b):
+ a.assign_add(b)
+ return a
+ def fn_1(b, a):
+ b.assign(a - b)
+ return b
+ def fn_2(a, b):
+ a.assign_sub(b)
+ return a
+
+ def swap(strategy, a_and_b):
+ """Swap `a` and `b` and mirror to all devices."""
+ for a, b in a_and_b:
+ strategy.extended.update(a, fn_0, args=(b,)) # a = a + b
+ strategy.extended.update(b, fn_1, args=(a,)) # b = a - b
+ strategy.extended.update(a, fn_2, args=(b,)) # a = a - b
+
+ ctx = tf.distribute.get_replica_context()
+ return ctx.merge_call(
+ swap, args=(zip(self._average_weights, self._model_weights),))
+
+ def assign_average_vars(self, var_list: List[tf.Variable]):
+ """Assign variables in var_list with their respective averages.
+
+ Args:
+ var_list: List of model variables to be assigned to their average.
+ Returns:
+ assign_op: The op corresponding to the assignment operation of
+ variables to their average.
+ """
+ assign_op = tf.group([
+ var.assign(self.get_slot(var, 'average')) for var in var_list
+ if var.trainable
+ ])
+ return assign_op
+
+ def _create_hypers(self):
+ self._optimizer._create_hypers() # pylint: disable=protected-access
+
+ def _prepare(self, var_list):
+ return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access
+
+ @property
+ def iterations(self):
+ return self._optimizer.iterations
+
+ @iterations.setter
+ def iterations(self, variable):
+ self._optimizer.iterations = variable
+
+ @property
+ def weights(self):
+ # return self._weights + self._optimizer.weights
+ return self._optimizer.weights
+
+ @property
+ def lr(self):
+ return self._optimizer._get_hyper('learning_rate')
+
+ @lr.setter
+ def lr(self, lr):
+ self._optimizer._set_hyper('learning_rate', lr)
+
+ @property
+ def learning_rate(self):
+ return self._optimizer._get_hyper('learning_rate')
+
+ @learning_rate.setter
+ def learning_rate(self, learning_rate): # pylint: disable=redefined-outer-name
+ self._optimizer._set_hyper('learning_rate', learning_rate)
+
+ def _resource_apply_dense(self, grad, var):
+ return self._optimizer._resource_apply_dense(grad, var)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ return self._optimizer._resource_apply_sparse(grad, var, indices)
+
+ def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
+ return self._optimizer._resource_apply_sparse_duplicate_indices(
+ grad, var, indices)
+
+ def get_config(self):
+ config = {
+ 'optimizer': tf.keras.optimizers.serialize(self._optimizer),
+ 'average_decay': self._average_decay,
+ 'start_step': self._start_step,
+ 'dynamic_decay': self._dynamic_decay,
+ }
+ base_config = super(MovingAverage, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ optimizer = tf.keras.optimizers.deserialize(
+ config.pop('optimizer'),
+ custom_objects=custom_objects,
+ )
+ return cls(optimizer, **config)
+
+
+def build_optimizer(
+ optimizer_name: Text,
+ base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
+ params: Dict[Text, Any]):
+ """Build the optimizer based on name.
+
+ Args:
+ optimizer_name: String representation of the optimizer name. Examples:
+ sgd, momentum, rmsprop.
+ base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule`
+ base learning rate.
+ params: String -> Any dictionary representing the optimizer params.
+ This should contain optimizer specific parameters such as
+ `base_learning_rate`, `decay`, etc.
+
+ Returns:
+ A tf.keras.Optimizer.
+
+ Raises:
+ ValueError if the provided optimizer_name is not supported.
+
+ """
+ optimizer_name = optimizer_name.lower()
+ logging.info('Building %s optimizer with params %s', optimizer_name, params)
+
+ if optimizer_name == 'sgd':
+ logging.info('Using SGD optimizer')
+ nesterov = params.get('nesterov', False)
+ optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate,
+ nesterov=nesterov)
+ elif optimizer_name == 'momentum':
+ logging.info('Using momentum optimizer')
+ nesterov = params.get('nesterov', False)
+ optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate,
+ momentum=params['momentum'],
+ nesterov=nesterov)
+ elif optimizer_name == 'rmsprop':
+ logging.info('Using RMSProp')
+ rho = params.get('decay', None) or params.get('rho', 0.9)
+ momentum = params.get('momentum', 0.9)
+ epsilon = params.get('epsilon', 1e-07)
+ optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon)
+ elif optimizer_name == 'adam':
+ logging.info('Using Adam')
+ beta_1 = params.get('beta_1', 0.9)
+ beta_2 = params.get('beta_2', 0.999)
+ epsilon = params.get('epsilon', 1e-07)
+ optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon)
+ elif optimizer_name == 'adamw':
+ logging.info('Using AdamW')
+ weight_decay = params.get('weight_decay', 0.01)
+ beta_1 = params.get('beta_1', 0.9)
+ beta_2 = params.get('beta_2', 0.999)
+ epsilon = params.get('epsilon', 1e-07)
+ optimizer = tfa.optimizers.AdamW(weight_decay=weight_decay,
+ learning_rate=base_learning_rate,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon)
+ else:
+ raise ValueError('Unknown optimizer %s' % optimizer_name)
+
+ if params.get('lookahead', None):
+ logging.info('Using lookahead optimizer.')
+ optimizer = tfa.optimizers.Lookahead(optimizer)
+
+ # Moving average should be applied last, as it's applied at test time
+ moving_average_decay = params.get('moving_average_decay', 0.)
+ if moving_average_decay is not None and moving_average_decay > 0.:
+ logging.info('Including moving average decay.')
+ optimizer = MovingAverage(
+ optimizer,
+ average_decay=moving_average_decay)
+ return optimizer
+
+
+def build_learning_rate(params: base_configs.LearningRateConfig,
+ batch_size: int = None,
+ train_epochs: int = None,
+ train_steps: int = None):
+ """Build the learning rate given the provided configuration."""
+ decay_type = params.name
+ base_lr = params.initial_lr
+ decay_rate = params.decay_rate
+ if params.decay_epochs is not None:
+ decay_steps = params.decay_epochs * train_steps
+ else:
+ decay_steps = 0
+ if params.warmup_epochs is not None:
+ warmup_steps = params.warmup_epochs * train_steps
+ else:
+ warmup_steps = 0
+
+ lr_multiplier = params.scale_by_batch_size
+
+ if lr_multiplier and lr_multiplier > 0:
+ # Scale the learning rate based on the batch size and a multiplier
+ base_lr *= lr_multiplier * batch_size
+ logging.info('Scaling the learning rate based on the batch size '
+ 'multiplier. New base_lr: %f', base_lr)
+
+ if decay_type == 'exponential':
+ logging.info('Using exponential learning rate with: '
+ 'initial_learning_rate: %f, decay_steps: %d, '
+ 'decay_rate: %f', base_lr, decay_steps, decay_rate)
+ lr = tf.keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate=base_lr,
+ decay_steps=decay_steps,
+ decay_rate=decay_rate,
+ staircase=params.staircase)
+ elif decay_type == 'piecewise_constant_with_warmup':
+ logging.info('Using Piecewise constant decay with warmup. '
+ 'Parameters: batch_size: %d, epoch_size: %d, '
+ 'warmup_epochs: %d, boundaries: %s, multipliers: %s',
+ batch_size, params.examples_per_epoch,
+ params.warmup_epochs, params.boundaries,
+ params.multipliers)
+ lr = learning_rate.PiecewiseConstantDecayWithWarmup(
+ batch_size=batch_size,
+ epoch_size=params.examples_per_epoch,
+ warmup_epochs=params.warmup_epochs,
+ boundaries=params.boundaries,
+ multipliers=params.multipliers)
+ elif decay_type == 'cosine_with_warmup':
+ lr = learning_rate.CosineDecayWithWarmup(
+ batch_size=batch_size,
+ total_steps=train_epochs * train_steps,
+ warmup_steps=warmup_steps)
+ if warmup_steps > 0:
+ if decay_type not in [
+ 'piecewise_constant_with_warmup', 'cosine_with_warmup'
+ ]:
+ logging.info('Applying %d warmup steps to the learning rate',
+ warmup_steps)
+ lr = learning_rate.WarmupDecaySchedule(lr, warmup_steps)
+ return lr
diff --git a/models/official/vision/image_classification/optimizer_factory_test.py b/models/official/vision/image_classification/optimizer_factory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c952618c126b4ee18b4a7f0ee87a91cff873a109
--- /dev/null
+++ b/models/official/vision/image_classification/optimizer_factory_test.py
@@ -0,0 +1,117 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for optimizer_factory."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+
+from absl.testing import parameterized
+from official.vision.image_classification import optimizer_factory
+from official.vision.image_classification.configs import base_configs
+
+
+class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('sgd', 'sgd', 0., False),
+ ('momentum', 'momentum', 0., False),
+ ('rmsprop', 'rmsprop', 0., False),
+ ('adam', 'adam', 0., False),
+ ('adamw', 'adamw', 0., False),
+ ('momentum_lookahead', 'momentum', 0., True),
+ ('sgd_ema', 'sgd', 0.999, False),
+ ('momentum_ema', 'momentum', 0.999, False),
+ ('rmsprop_ema', 'rmsprop', 0.999, False))
+ def test_optimizer(self, optimizer_name, moving_average_decay, lookahead):
+ """Smoke test to be sure no syntax errors."""
+ params = {
+ 'learning_rate': 0.001,
+ 'rho': 0.09,
+ 'momentum': 0.,
+ 'epsilon': 1e-07,
+ 'moving_average_decay': moving_average_decay,
+ 'lookahead': lookahead,
+ }
+ optimizer = optimizer_factory.build_optimizer(
+ optimizer_name=optimizer_name,
+ base_learning_rate=params['learning_rate'],
+ params=params)
+ self.assertTrue(issubclass(type(optimizer), tf.keras.optimizers.Optimizer))
+
+ def test_unknown_optimizer(self):
+ with self.assertRaises(ValueError):
+ optimizer_factory.build_optimizer(
+ optimizer_name='this_optimizer_does_not_exist',
+ base_learning_rate=None,
+ params=None)
+
+ def test_learning_rate_without_decay_or_warmups(self):
+ params = base_configs.LearningRateConfig(
+ name='exponential',
+ initial_lr=0.01,
+ decay_rate=0.01,
+ decay_epochs=None,
+ warmup_epochs=None,
+ scale_by_batch_size=0.01,
+ examples_per_epoch=1,
+ boundaries=[0],
+ multipliers=[0, 1])
+ batch_size = 1
+ train_steps = 1
+
+ lr = optimizer_factory.build_learning_rate(
+ params=params,
+ batch_size=batch_size,
+ train_steps=train_steps)
+ self.assertTrue(
+ issubclass(
+ type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
+
+ @parameterized.named_parameters(
+ ('exponential', 'exponential'),
+ ('piecewise_constant_with_warmup', 'piecewise_constant_with_warmup'),
+ ('cosine_with_warmup', 'cosine_with_warmup'))
+ def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
+ """Basic smoke test for syntax."""
+ params = base_configs.LearningRateConfig(
+ name=lr_decay_type,
+ initial_lr=0.01,
+ decay_rate=0.01,
+ decay_epochs=1,
+ warmup_epochs=1,
+ scale_by_batch_size=0.01,
+ examples_per_epoch=1,
+ boundaries=[0],
+ multipliers=[0, 1])
+ batch_size = 1
+ train_epochs = 1
+ train_steps = 1
+
+ lr = optimizer_factory.build_learning_rate(
+ params=params,
+ batch_size=batch_size,
+ train_epochs=train_epochs,
+ train_steps=train_steps)
+ self.assertTrue(
+ issubclass(
+ type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/official/vision/image_classification/preprocessing.py b/models/official/vision/image_classification/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f2019189d4e5f9c269a67276531b4344ede7e32
--- /dev/null
+++ b/models/official/vision/image_classification/preprocessing.py
@@ -0,0 +1,391 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Preprocessing functions for images."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import tensorflow as tf
+from typing import List, Optional, Text, Tuple
+
+from official.vision.image_classification import augment
+
+
+# Calculated from the ImageNet training set
+MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
+STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
+
+IMAGE_SIZE = 224
+CROP_PADDING = 32
+
+
+def mean_image_subtraction(
+ image_bytes: tf.Tensor,
+ means: Tuple[float, ...],
+ num_channels: int = 3,
+ dtype: tf.dtypes.DType = tf.float32,
+) -> tf.Tensor:
+ """Subtracts the given means from each image channel.
+
+ For example:
+ means = [123.68, 116.779, 103.939]
+ image_bytes = mean_image_subtraction(image_bytes, means)
+
+ Note that the rank of `image` must be known.
+
+ Args:
+ image_bytes: a tensor of size [height, width, C].
+ means: a C-vector of values to subtract from each channel.
+ num_channels: number of color channels in the image that will be distorted.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+
+ Returns:
+ the centered image.
+
+ Raises:
+ ValueError: If the rank of `image` is unknown, if `image` has a rank other
+ than three or if the number of channels in `image` doesn't match the
+ number of values in `means`.
+ """
+ if image_bytes.get_shape().ndims != 3:
+ raise ValueError('Input must be of size [height, width, C>0]')
+
+ if len(means) != num_channels:
+ raise ValueError('len(means) must match the number of channels')
+
+ # We have a 1-D tensor of means; convert to 3-D.
+ # Note(b/130245863): we explicitly call `broadcast` instead of simply
+ # expanding dimensions for better performance.
+ means = tf.broadcast_to(means, tf.shape(image_bytes))
+ if dtype is not None:
+ means = tf.cast(means, dtype=dtype)
+
+ return image_bytes - means
+
+
+def standardize_image(
+ image_bytes: tf.Tensor,
+ stddev: Tuple[float, ...],
+ num_channels: int = 3,
+ dtype: tf.dtypes.DType = tf.float32,
+) -> tf.Tensor:
+ """Divides the given stddev from each image channel.
+
+ For example:
+ stddev = [123.68, 116.779, 103.939]
+ image_bytes = standardize_image(image_bytes, stddev)
+
+ Note that the rank of `image` must be known.
+
+ Args:
+ image_bytes: a tensor of size [height, width, C].
+ stddev: a C-vector of values to divide from each channel.
+ num_channels: number of color channels in the image that will be distorted.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+
+ Returns:
+ the centered image.
+
+ Raises:
+ ValueError: If the rank of `image` is unknown, if `image` has a rank other
+ than three or if the number of channels in `image` doesn't match the
+ number of values in `stddev`.
+ """
+ if image_bytes.get_shape().ndims != 3:
+ raise ValueError('Input must be of size [height, width, C>0]')
+
+ if len(stddev) != num_channels:
+ raise ValueError('len(stddev) must match the number of channels')
+
+ # We have a 1-D tensor of stddev; convert to 3-D.
+ # Note(b/130245863): we explicitly call `broadcast` instead of simply
+ # expanding dimensions for better performance.
+ stddev = tf.broadcast_to(stddev, tf.shape(image_bytes))
+ if dtype is not None:
+ stddev = tf.cast(stddev, dtype=dtype)
+
+ return image_bytes / stddev
+
+
+def normalize_images(features: tf.Tensor,
+ mean_rgb: Tuple[float, ...] = MEAN_RGB,
+ stddev_rgb: Tuple[float, ...] = STDDEV_RGB,
+ num_channels: int = 3,
+ dtype: tf.dtypes.DType = tf.float32,
+ data_format: Text = 'channels_last') -> tf.Tensor:
+ """Normalizes the input image channels with the given mean and stddev.
+
+ Args:
+ features: `Tensor` representing decoded images in float format.
+ mean_rgb: the mean of the channels to subtract.
+ stddev_rgb: the stddev of the channels to divide.
+ num_channels: the number of channels in the input image tensor.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+ data_format: the format of the input image tensor
+ ['channels_first', 'channels_last'].
+
+ Returns:
+ A normalized image `Tensor`.
+ """
+ # TODO(allencwang) - figure out how to use mean_image_subtraction and
+ # standardize_image on batches of images and replace the following.
+ if data_format == 'channels_first':
+ stats_shape = [num_channels, 1, 1]
+ else:
+ stats_shape = [1, 1, num_channels]
+
+ if dtype is not None:
+ features = tf.image.convert_image_dtype(features, dtype=dtype)
+
+ if mean_rgb is not None:
+ mean_rgb = tf.constant(mean_rgb,
+ shape=stats_shape,
+ dtype=features.dtype)
+ mean_rgb = tf.broadcast_to(mean_rgb, tf.shape(features))
+ features = features - mean_rgb
+
+ if stddev_rgb is not None:
+ stddev_rgb = tf.constant(stddev_rgb,
+ shape=stats_shape,
+ dtype=features.dtype)
+ stddev_rgb = tf.broadcast_to(stddev_rgb, tf.shape(features))
+ features = features / stddev_rgb
+
+ return features
+
+
+def decode_and_center_crop(image_bytes: tf.Tensor,
+ image_size: int = IMAGE_SIZE,
+ crop_padding: int = CROP_PADDING) -> tf.Tensor:
+ """Crops to center of image with padding then scales image_size.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ image_size: image height/width dimension.
+ crop_padding: the padding size to use when centering the crop.
+
+ Returns:
+ A decoded and cropped image `Tensor`.
+ """
+ decoded = image_bytes.dtype != tf.string
+ shape = (tf.shape(image_bytes) if decoded
+ else tf.image.extract_jpeg_shape(image_bytes))
+ image_height = shape[0]
+ image_width = shape[1]
+
+ padded_center_crop_size = tf.cast(
+ ((image_size / (image_size + crop_padding)) *
+ tf.cast(tf.minimum(image_height, image_width), tf.float32)),
+ tf.int32)
+
+ offset_height = ((image_height - padded_center_crop_size) + 1) // 2
+ offset_width = ((image_width - padded_center_crop_size) + 1) // 2
+ crop_window = tf.stack([offset_height, offset_width,
+ padded_center_crop_size, padded_center_crop_size])
+ if decoded:
+ image = tf.image.crop_to_bounding_box(
+ image_bytes,
+ offset_height=offset_height,
+ offset_width=offset_width,
+ target_height=padded_center_crop_size,
+ target_width=padded_center_crop_size)
+ else:
+ image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
+
+ image = resize_image(image_bytes=image,
+ height=image_size,
+ width=image_size)
+
+ return image
+
+
+def decode_crop_and_flip(image_bytes: tf.Tensor) -> tf.Tensor:
+ """Crops an image to a random part of the image, then randomly flips.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+
+ Returns:
+ A decoded and cropped image `Tensor`.
+
+ """
+ decoded = image_bytes.dtype != tf.string
+ bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
+ shape = (tf.shape(image_bytes) if decoded
+ else tf.image.extract_jpeg_shape(image_bytes))
+ sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
+ shape,
+ bounding_boxes=bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=[0.75, 1.33],
+ area_range=[0.05, 1.0],
+ max_attempts=100,
+ use_image_if_no_bounding_boxes=True)
+ bbox_begin, bbox_size, _ = sample_distorted_bounding_box
+
+ # Reassemble the bounding box in the format the crop op requires.
+ offset_height, offset_width, _ = tf.unstack(bbox_begin)
+ target_height, target_width, _ = tf.unstack(bbox_size)
+ crop_window = tf.stack([offset_height, offset_width,
+ target_height, target_width])
+ if decoded:
+ cropped = tf.image.crop_to_bounding_box(
+ image_bytes,
+ offset_height=offset_height,
+ offset_width=offset_width,
+ target_height=target_height,
+ target_width=target_width)
+ else:
+ cropped = tf.image.decode_and_crop_jpeg(image_bytes,
+ crop_window,
+ channels=3)
+
+ # Flip to add a little more random distortion in.
+ cropped = tf.image.random_flip_left_right(cropped)
+ return cropped
+
+
+def resize_image(image_bytes: tf.Tensor,
+ height: int = IMAGE_SIZE,
+ width: int = IMAGE_SIZE) -> tf.Tensor:
+ """Resizes an image to a given height and width.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ height: image height dimension.
+ width: image width dimension.
+
+ Returns:
+ A tensor containing the resized image.
+
+ """
+ return tf.compat.v1.image.resize(
+ image_bytes, [height, width], method=tf.image.ResizeMethod.BILINEAR,
+ align_corners=False)
+
+
+def preprocess_for_eval(
+ image_bytes: tf.Tensor,
+ image_size: int = IMAGE_SIZE,
+ num_channels: int = 3,
+ mean_subtract: bool = False,
+ standardize: bool = False,
+ dtype: tf.dtypes.DType = tf.float32
+) -> tf.Tensor:
+ """Preprocesses the given image for evaluation.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ image_size: image height/width dimension.
+ num_channels: number of image input channels.
+ mean_subtract: whether or not to apply mean subtraction.
+ standardize: whether or not to apply standardization.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+
+ Returns:
+ A preprocessed and normalized image `Tensor`.
+ """
+ images = decode_and_center_crop(image_bytes, image_size)
+ images = tf.reshape(images, [image_size, image_size, num_channels])
+
+ if mean_subtract:
+ images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB)
+ if standardize:
+ images = standardize_image(image_bytes=images, stddev=STDDEV_RGB)
+ if dtype is not None:
+ images = tf.image.convert_image_dtype(images, dtype=dtype)
+
+ return images
+
+
+def load_eval_image(filename: Text, image_size: int = IMAGE_SIZE) -> tf.Tensor:
+ """Reads an image from the filesystem and applies image preprocessing.
+
+ Args:
+ filename: a filename path of an image.
+ image_size: image height/width dimension.
+
+ Returns:
+ A preprocessed and normalized image `Tensor`.
+ """
+ image_bytes = tf.io.read_file(filename)
+ image = preprocess_for_eval(image_bytes, image_size)
+
+ return image
+
+
+def build_eval_dataset(filenames: List[Text],
+ labels: List[int] = None,
+ image_size: int = IMAGE_SIZE,
+ batch_size: int = 1) -> tf.Tensor:
+ """Builds a tf.data.Dataset from a list of filenames and labels.
+
+ Args:
+ filenames: a list of filename paths of images.
+ labels: a list of labels corresponding to each image.
+ image_size: image height/width dimension.
+ batch_size: the batch size used by the dataset
+
+ Returns:
+ A preprocessed and normalized image `Tensor`.
+ """
+ if labels is None:
+ labels = [0] * len(filenames)
+
+ filenames = tf.constant(filenames)
+ labels = tf.constant(labels)
+ dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
+
+ dataset = dataset.map(
+ lambda filename, label: (load_eval_image(filename, image_size), label))
+ dataset = dataset.batch(batch_size)
+
+ return dataset
+
+
+def preprocess_for_train(image_bytes: tf.Tensor,
+ image_size: int = IMAGE_SIZE,
+ augmenter: Optional[augment.ImageAugment] = None,
+ mean_subtract: bool = False,
+ standardize: bool = False,
+ dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
+ """Preprocesses the given image for training.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of
+ arbitrary size of dtype tf.uint8.
+ image_size: image height/width dimension.
+ augmenter: the image augmenter to apply.
+ mean_subtract: whether or not to apply mean subtraction.
+ standardize: whether or not to apply standardization.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+
+ Returns:
+ A preprocessed and normalized image `Tensor`.
+ """
+ images = decode_crop_and_flip(image_bytes=image_bytes)
+ images = resize_image(images, height=image_size, width=image_size)
+ if mean_subtract:
+ images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB)
+ if standardize:
+ images = standardize_image(image_bytes=images, stddev=STDDEV_RGB)
+ if augmenter is not None:
+ images = augmenter.distort(images)
+ if dtype is not None:
+ images = tf.image.convert_image_dtype(images, dtype)
+
+ return images
diff --git a/models/official/vision/image_classification/resnet/README.md b/models/official/vision/image_classification/resnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5064523fbdcd4222c2159bdc1c09b7156800bf54
--- /dev/null
+++ b/models/official/vision/image_classification/resnet/README.md
@@ -0,0 +1,125 @@
+This folder contains a
+[custom training loop (CTL)](#resnet-custom-training-loop) implementation for
+ResNet50.
+
+## Before you begin
+Please refer to the [README](../README.md) in the parent directory for
+information on setup and preparing the data.
+
+## ResNet (custom training loop)
+
+Similar to the [estimator implementation](../../../r1/resnet), the Keras
+implementation has code for the ImageNet dataset. The ImageNet
+version uses a ResNet50 model implemented in
+[`resnet_model.py`](./resnet_model.py).
+
+
+### Pretrained Models
+
+* [ResNet50 Checkpoints](https://storage.googleapis.com/cloud-tpu-checkpoints/resnet/resnet50.tar.gz)
+
+* ResNet50 TFHub: [feature vector](https://tfhub.dev/tensorflow/resnet_50/feature_vector/1)
+and [classification](https://tfhub.dev/tensorflow/resnet_50/classification/1)
+
+Again, if you did not download the data to the default directory, specify the
+location with the `--data_dir` flag:
+
+```bash
+python3 resnet_ctl_imagenet_main.py --data_dir=/path/to/imagenet
+```
+
+There are more flag options you can specify. Here are some examples:
+
+- `--use_synthetic_data`: when set to true, synthetic data, rather than real
+data, are used;
+- `--batch_size`: the batch size used for the model;
+- `--model_dir`: the directory to save the model checkpoint;
+- `--train_epochs`: number of epoches to run for training the model;
+- `--train_steps`: number of steps to run for training the model. We now only
+support a number that is smaller than the number of batches in an epoch.
+- `--skip_eval`: when set to true, evaluation as well as validation during
+training is skipped
+
+For example, this is a typical command line to run with ImageNet data with
+batch size 128 per GPU:
+
+```bash
+python3 -m resnet_ctl_imagenet_main.py \
+ --model_dir=/tmp/model_dir/something \
+ --num_gpus=2 \
+ --batch_size=128 \
+ --train_epochs=90 \
+ --train_steps=10 \
+ --use_synthetic_data=false
+```
+
+See [`common.py`](common.py) for full list of options.
+
+### Using multiple GPUs
+
+You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
+You can read more about them in this
+[guide](https://www.tensorflow.org/guide/distribute_strategy).
+
+In this example, we have made it easier to use is with just a command line flag
+`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
+and 0 otherwise.
+
+- --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
+- --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
+- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
+distributed training across the GPUs.
+
+If you wish to run without `tf.distribute.Strategy`, you can do so by setting
+`--distribution_strategy=off`.
+
+### Running on multiple GPU hosts
+
+You can also train these models on multiple hosts, each with GPUs, using
+`tf.distribute.Strategy`.
+
+The easiest way to run multi-host benchmarks is to set the
+[`TF_CONFIG`](https://www.tensorflow.org/guide/distributed_training#TF_CONFIG)
+appropriately at each host. e.g., to run using `MultiWorkerMirroredStrategy` on
+2 hosts, the `cluster` in `TF_CONFIG` should have 2 `host:port` entries, and
+host `i` should have the `task` in `TF_CONFIG` set to `{"type": "worker",
+"index": i}`. `MultiWorkerMirroredStrategy` will automatically use all the
+available GPUs at each host.
+
+### Running on Cloud TPUs
+
+Note: This model will **not** work with TPUs on Colab.
+
+You can train the ResNet CTL model on Cloud TPUs using
+`tf.distribute.TPUStrategy`. If you are not familiar with Cloud TPUs, it is
+strongly recommended that you go through the
+[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
+create a TPU and GCE VM.
+
+To run ResNet model on a TPU, you must set `--distribution_strategy=tpu` and
+`--tpu=$TPU_NAME`, where `$TPU_NAME` the name of your TPU in the Cloud Console.
+From a GCE VM, you can run the following command to train ResNet for one epoch
+on a v2-8 or v3-8 TPU by setting `TRAIN_EPOCHS` to 1:
+
+```bash
+python3 resnet_ctl_imagenet_main.py \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --batch_size=1024 \
+ --steps_per_loop=500 \
+ --train_epochs=$TRAIN_EPOCHS \
+ --use_synthetic_data=false \
+ --dtype=fp32 \
+ --enable_eager=true \
+ --enable_tensorboard=true \
+ --distribution_strategy=tpu \
+ --log_steps=50 \
+ --single_l2_loss_op=true \
+ --use_tf_function=true
+```
+
+To train the ResNet to convergence, run it for 90 epochs by setting
+`TRAIN_EPOCHS` to 90.
+
+Note: `$MODEL_DIR` and `$DATA_DIR` must be GCS paths.
diff --git a/models/official/vision/image_classification/resnet/__init__.py b/models/official/vision/image_classification/resnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/official/vision/image_classification/resnet/common.py b/models/official/vision/image_classification/resnet/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9a64aa4064978863332a8024f4e46d64b9baaef
--- /dev/null
+++ b/models/official/vision/image_classification/resnet/common.py
@@ -0,0 +1,387 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common util functions and classes used by both keras cifar and imagenet."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
+import tensorflow_model_optimization as tfmot
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
+TRAIN_TOP_1 = 'training_accuracy_top_1'
+LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
+ (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
+]
+
+
+class PiecewiseConstantDecayWithWarmup(
+ tf.keras.optimizers.schedules.LearningRateSchedule):
+ """Piecewise constant decay with warmup schedule."""
+
+ def __init__(self, batch_size, epoch_size, warmup_epochs, boundaries,
+ multipliers, compute_lr_on_cpu=True, name=None):
+ super(PiecewiseConstantDecayWithWarmup, self).__init__()
+ if len(boundaries) != len(multipliers) - 1:
+ raise ValueError('The length of boundaries must be 1 less than the '
+ 'length of multipliers')
+
+ base_lr_batch_size = 256
+ steps_per_epoch = epoch_size // batch_size
+
+ self.rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
+ self.step_boundaries = [float(steps_per_epoch) * x for x in boundaries]
+ self.lr_values = [self.rescaled_lr * m for m in multipliers]
+ self.warmup_steps = warmup_epochs * steps_per_epoch
+ self.compute_lr_on_cpu = compute_lr_on_cpu
+ self.name = name
+
+ self.learning_rate_ops_cache = {}
+
+ def __call__(self, step):
+ if tf.executing_eagerly():
+ return self._get_learning_rate(step)
+
+ # In an eager function or graph, the current implementation of optimizer
+ # repeatedly call and thus create ops for the learning rate schedule. To
+ # avoid this, we cache the ops if not executing eagerly.
+ graph = tf.compat.v1.get_default_graph()
+ if graph not in self.learning_rate_ops_cache:
+ if self.compute_lr_on_cpu:
+ with tf.device('/device:CPU:0'):
+ self.learning_rate_ops_cache[graph] = self._get_learning_rate(step)
+ else:
+ self.learning_rate_ops_cache[graph] = self._get_learning_rate(step)
+ return self.learning_rate_ops_cache[graph]
+
+ def _get_learning_rate(self, step):
+ """Compute learning rate at given step."""
+ with tf.name_scope('PiecewiseConstantDecayWithWarmup'):
+ def warmup_lr(step):
+ return self.rescaled_lr * (
+ tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
+ def piecewise_lr(step):
+ return tf.compat.v1.train.piecewise_constant(
+ step, self.step_boundaries, self.lr_values)
+ return tf.cond(step < self.warmup_steps,
+ lambda: warmup_lr(step),
+ lambda: piecewise_lr(step))
+
+ def get_config(self):
+ return {
+ 'rescaled_lr': self.rescaled_lr,
+ 'step_boundaries': self.step_boundaries,
+ 'lr_values': self.lr_values,
+ 'warmup_steps': self.warmup_steps,
+ 'compute_lr_on_cpu': self.compute_lr_on_cpu,
+ 'name': self.name
+ }
+
+
+def get_optimizer(learning_rate=0.1):
+ """Returns optimizer to use."""
+ # The learning_rate is overwritten at the beginning of each step by callback.
+ return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9)
+
+
+def get_callbacks(
+ pruning_method=None,
+ enable_checkpoint_and_export=False,
+ model_dir=None):
+ """Returns common callbacks."""
+ time_callback = keras_utils.TimeHistory(
+ FLAGS.batch_size,
+ FLAGS.log_steps,
+ logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
+ callbacks = [time_callback]
+
+ if FLAGS.enable_tensorboard:
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(
+ log_dir=FLAGS.model_dir,
+ profile_batch=FLAGS.profile_steps)
+ callbacks.append(tensorboard_callback)
+
+ is_pruning_enabled = pruning_method is not None
+ if is_pruning_enabled:
+ callbacks.append(tfmot.sparsity.keras.UpdatePruningStep())
+ if model_dir is not None:
+ callbacks.append(tfmot.sparsity.keras.PruningSummaries(
+ log_dir=model_dir, profile_batch=0))
+
+ if enable_checkpoint_and_export:
+ if model_dir is not None:
+ ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
+ callbacks.append(
+ tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
+ save_weights_only=True))
+ return callbacks
+
+
+def build_stats(history, eval_output, callbacks):
+ """Normalizes and returns dictionary of stats.
+
+ Args:
+ history: Results of the training step. Supports both categorical_accuracy
+ and sparse_categorical_accuracy.
+ eval_output: Output of the eval step. Assumes first value is eval_loss and
+ second value is accuracy_top_1.
+ callbacks: a list of callbacks which might include a time history callback
+ used during keras.fit.
+
+ Returns:
+ Dictionary of normalized results.
+ """
+ stats = {}
+ if eval_output:
+ stats['accuracy_top_1'] = float(eval_output[1])
+ stats['eval_loss'] = float(eval_output[0])
+ if history and history.history:
+ train_hist = history.history
+ # Gets final loss from training.
+ stats['loss'] = float(train_hist['loss'][-1])
+ # Gets top_1 training accuracy.
+ if 'categorical_accuracy' in train_hist:
+ stats[TRAIN_TOP_1] = float(train_hist['categorical_accuracy'][-1])
+ elif 'sparse_categorical_accuracy' in train_hist:
+ stats[TRAIN_TOP_1] = float(train_hist['sparse_categorical_accuracy'][-1])
+ elif 'accuracy' in train_hist:
+ stats[TRAIN_TOP_1] = float(train_hist['accuracy'][-1])
+
+ if not callbacks:
+ return stats
+
+ # Look for the time history callback which was used during keras.fit
+ for callback in callbacks:
+ if isinstance(callback, keras_utils.TimeHistory):
+ timestamp_log = callback.timestamp_log
+ stats['step_timestamp_log'] = timestamp_log
+ stats['train_finish_time'] = callback.train_finish_time
+ if callback.epoch_runtime_log:
+ stats['avg_exp_per_second'] = callback.average_examples_per_second
+
+ return stats
+
+
+def define_keras_flags(
+ dynamic_loss_scale=True,
+ model=False,
+ optimizer=False,
+ pretrained_filepath=False):
+ """Define flags for Keras models."""
+ flags_core.define_base(clean=True, num_gpu=True, run_eagerly=True,
+ train_epochs=True, epochs_between_evals=True,
+ distribution_strategy=True)
+ flags_core.define_performance(num_parallel_calls=False,
+ synthetic_data=True,
+ dtype=True,
+ all_reduce_alg=True,
+ num_packs=True,
+ tf_gpu_thread_mode=True,
+ datasets_num_private_threads=True,
+ dynamic_loss_scale=dynamic_loss_scale,
+ loss_scale=True,
+ fp16_implementation=True,
+ tf_data_experimental_slack=True,
+ enable_xla=True,
+ training_dataset_cache=True)
+ flags_core.define_image()
+ flags_core.define_benchmark()
+ flags_core.define_distribution()
+ flags.adopt_module_key_flags(flags_core)
+
+ flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
+ flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
+ # TODO(b/135607288): Remove this flag once we understand the root cause of
+ # slowdown when setting the learning phase in Keras backend.
+ flags.DEFINE_boolean(
+ name='set_learning_phase_to_train', default=True,
+ help='If skip eval, also set Keras learning phase to 1 (training).')
+ flags.DEFINE_boolean(
+ name='explicit_gpu_placement', default=False,
+ help='If not using distribution strategy, explicitly set device scope '
+ 'for the Keras training loop.')
+ flags.DEFINE_boolean(name='use_trivial_model', default=False,
+ help='Whether to use a trivial Keras model.')
+ flags.DEFINE_boolean(name='report_accuracy_metrics', default=True,
+ help='Report metrics during training and evaluation.')
+ flags.DEFINE_boolean(name='use_tensor_lr', default=True,
+ help='Use learning rate tensor instead of a callback.')
+ flags.DEFINE_boolean(
+ name='enable_tensorboard', default=False,
+ help='Whether to enable Tensorboard callback.')
+ flags.DEFINE_string(
+ name='profile_steps', default=None,
+ help='Save profiling data to model dir at given range of global steps. The '
+ 'value must be a comma separated pair of positive integers, specifying '
+ 'the first and last step to profile. For example, "--profile_steps=2,4" '
+ 'triggers the profiler to process 3 steps, starting from the 2nd step. '
+ 'Note that profiler has a non-trivial performance overhead, and the '
+ 'output file can be gigantic if profiling many steps.')
+ flags.DEFINE_integer(
+ name='train_steps', default=None,
+ help='The number of steps to run for training. If it is larger than '
+ '# batches per epoch, then use # batches per epoch. This flag will be '
+ 'ignored if train_epochs is set to be larger than 1. ')
+ flags.DEFINE_boolean(
+ name='batchnorm_spatial_persistent', default=True,
+ help='Enable the spacial persistent mode for CuDNN batch norm kernel.')
+ flags.DEFINE_boolean(
+ name='enable_get_next_as_optional', default=False,
+ help='Enable get_next_as_optional behavior in DistributedIterator.')
+ flags.DEFINE_boolean(
+ name='enable_checkpoint_and_export', default=False,
+ help='Whether to enable a checkpoint callback and export the savedmodel.')
+ flags.DEFINE_string(
+ name='tpu', default='', help='TPU address to connect to.')
+ flags.DEFINE_integer(
+ name='steps_per_loop',
+ default=500,
+ help='Number of steps per training loop. Only training step happens '
+ 'inside the loop. Callbacks will not be called inside. Will be capped at '
+ 'steps per epoch.')
+ flags.DEFINE_boolean(
+ name='use_tf_while_loop',
+ default=True,
+ help='Whether to build a tf.while_loop inside the training loop on the '
+ 'host. Setting it to True is critical to have peak performance on '
+ 'TPU.')
+
+ if model:
+ flags.DEFINE_string('model', 'resnet50_v1.5',
+ 'Name of model preset. (mobilenet, resnet50_v1.5)')
+ if optimizer:
+ flags.DEFINE_string('optimizer', 'resnet50_default',
+ 'Name of optimizer preset. '
+ '(mobilenet_default, resnet50_default)')
+ # TODO(kimjaehong): Replace as general hyper-params not only for mobilenet.
+ flags.DEFINE_float('initial_learning_rate_per_sample', 0.00007,
+ 'Initial value of learning rate per sample for '
+ 'mobilenet_default.')
+ flags.DEFINE_float('lr_decay_factor', 0.94,
+ 'Learning rate decay factor for mobilenet_default.')
+ flags.DEFINE_float('num_epochs_per_decay', 2.5,
+ 'Number of epochs per decay for mobilenet_default.')
+ if pretrained_filepath:
+ flags.DEFINE_string('pretrained_filepath', '',
+ 'Pretrained file path.')
+
+
+def get_synth_data(height, width, num_channels, num_classes, dtype):
+ """Creates a set of synthetic random data.
+
+ Args:
+ height: Integer height that will be used to create a fake image tensor.
+ width: Integer width that will be used to create a fake image tensor.
+ num_channels: Integer depth that will be used to create a fake image tensor.
+ num_classes: Number of classes that should be represented in the fake labels
+ tensor
+ dtype: Data type for features/images.
+
+ Returns:
+ A tuple of tensors representing the inputs and labels.
+
+ """
+ # Synthetic input should be within [0, 255].
+ inputs = tf.random.truncated_normal([height, width, num_channels],
+ dtype=dtype,
+ mean=127,
+ stddev=60,
+ name='synthetic_inputs')
+ labels = tf.random.uniform([1],
+ minval=0,
+ maxval=num_classes - 1,
+ dtype=tf.int32,
+ name='synthetic_labels')
+ return inputs, labels
+
+
+def define_pruning_flags():
+ """Define flags for pruning methods."""
+ flags.DEFINE_string('pruning_method', None,
+ 'Pruning method.'
+ 'None (no pruning) or polynomial_decay.')
+ flags.DEFINE_float('pruning_initial_sparsity', 0.0,
+ 'Initial sparsity for pruning.')
+ flags.DEFINE_float('pruning_final_sparsity', 0.5,
+ 'Final sparsity for pruning.')
+ flags.DEFINE_integer('pruning_begin_step', 0,
+ 'Begin step for pruning.')
+ flags.DEFINE_integer('pruning_end_step', 100000,
+ 'End step for pruning.')
+ flags.DEFINE_integer('pruning_frequency', 100,
+ 'Frequency for pruning.')
+
+
+def get_synth_input_fn(height, width, num_channels, num_classes,
+ dtype=tf.float32, drop_remainder=True):
+ """Returns an input function that returns a dataset with random data.
+
+ This input_fn returns a data set that iterates over a set of random data and
+ bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
+ copy is still included. This used to find the upper throughput bound when
+ tuning the full input pipeline.
+
+ Args:
+ height: Integer height that will be used to create a fake image tensor.
+ width: Integer width that will be used to create a fake image tensor.
+ num_channels: Integer depth that will be used to create a fake image tensor.
+ num_classes: Number of classes that should be represented in the fake labels
+ tensor
+ dtype: Data type for features/images.
+ drop_remainder: A boolean indicates whether to drop the remainder of the
+ batches. If True, the batch dimension will be static.
+
+ Returns:
+ An input_fn that can be used in place of a real one to return a dataset
+ that can be used for iteration.
+ """
+ # pylint: disable=unused-argument
+ def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
+ """Returns dataset filled with random data."""
+ inputs, labels = get_synth_data(height=height,
+ width=width,
+ num_channels=num_channels,
+ num_classes=num_classes,
+ dtype=dtype)
+ # Cast to float32 for Keras model.
+ labels = tf.cast(labels, dtype=tf.float32)
+ data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
+
+ # `drop_remainder` will make dataset produce outputs with known shapes.
+ data = data.batch(batch_size, drop_remainder=drop_remainder)
+ data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+ return data
+
+ return input_fn
+
+
+def set_cudnn_batchnorm_mode():
+ """Set CuDNN batchnorm mode for better performance.
+
+ Note: Spatial Persistent mode may lead to accuracy losses for certain
+ models.
+ """
+ if FLAGS.batchnorm_spatial_persistent:
+ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
+ else:
+ os.environ.pop('TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT', None)
diff --git a/models/official/vision/image_classification/resnet/imagenet_preprocessing.py b/models/official/vision/image_classification/resnet/imagenet_preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1490c22d8d769f32a6f6a1c6d29455519e8743a
--- /dev/null
+++ b/models/official/vision/image_classification/resnet/imagenet_preprocessing.py
@@ -0,0 +1,561 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides utilities to preprocess images.
+
+Training images are sampled using the provided bounding boxes, and subsequently
+cropped to the sampled bounding box. Images are additionally flipped randomly,
+then resized to the target output size (without aspect-ratio preservation).
+
+Images used during evaluation are resized (with aspect-ratio preservation) and
+centrally cropped.
+
+All images undergo mean color subtraction.
+
+Note that these steps are colloquially referred to as "ResNet preprocessing,"
+and they differ from "VGG preprocessing," which does not use bounding boxes
+and instead does an aspect-preserving resize followed by random crop during
+training. (These both differ from "Inception preprocessing," which introduces
+color distortion steps.)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from absl import logging
+import tensorflow as tf
+
+DEFAULT_IMAGE_SIZE = 224
+NUM_CHANNELS = 3
+NUM_CLASSES = 1001
+
+NUM_IMAGES = {
+ 'train': 1281167,
+ 'validation': 50000,
+}
+
+_NUM_TRAIN_FILES = 1024
+_SHUFFLE_BUFFER = 10000
+
+_R_MEAN = 123.68
+_G_MEAN = 116.78
+_B_MEAN = 103.94
+CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]
+
+# The lower bound for the smallest side of the image for aspect-preserving
+# resizing. For example, if an image is 500 x 1000, it will be resized to
+# _RESIZE_MIN x (_RESIZE_MIN * 2).
+_RESIZE_MIN = 256
+
+
+def process_record_dataset(dataset,
+ is_training,
+ batch_size,
+ shuffle_buffer,
+ parse_record_fn,
+ dtype=tf.float32,
+ datasets_num_private_threads=None,
+ drop_remainder=False,
+ tf_data_experimental_slack=False):
+ """Given a Dataset with raw records, return an iterator over the records.
+
+ Args:
+ dataset: A Dataset representing raw records
+ is_training: A boolean denoting whether the input is for training.
+ batch_size: The number of samples per batch.
+ shuffle_buffer: The buffer size to use when shuffling records. A larger
+ value results in better randomness, but smaller values reduce startup
+ time and use less memory.
+ parse_record_fn: A function that takes a raw record and returns the
+ corresponding (image, label) pair.
+ dtype: Data type to use for images/features.
+ datasets_num_private_threads: Number of threads for a private
+ threadpool created for all datasets computation.
+ drop_remainder: A boolean indicates whether to drop the remainder of the
+ batches. If True, the batch dimension will be static.
+ tf_data_experimental_slack: Whether to enable tf.data's
+ `experimental_slack` option.
+
+ Returns:
+ Dataset of (image, label) pairs ready for iteration.
+ """
+ # Defines a specific size thread pool for tf.data operations.
+ if datasets_num_private_threads:
+ options = tf.data.Options()
+ options.experimental_threading.private_threadpool_size = (
+ datasets_num_private_threads)
+ dataset = dataset.with_options(options)
+ logging.info(
+ 'datasets_num_private_threads: %s', datasets_num_private_threads)
+
+ if is_training:
+ # Shuffles records before repeating to respect epoch boundaries.
+ dataset = dataset.shuffle(buffer_size=shuffle_buffer)
+ # Repeats the dataset for the number of epochs to train.
+ dataset = dataset.repeat()
+
+ # Parses the raw records into images and labels.
+ dataset = dataset.map(
+ lambda value: parse_record_fn(value, is_training, dtype),
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
+
+ # Operations between the final prefetch and the get_next call to the iterator
+ # will happen synchronously during run time. We prefetch here again to
+ # background all of the above processing work and keep it out of the
+ # critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
+ # allows DistributionStrategies to adjust how many batches to fetch based
+ # on how many devices are present.
+ dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+
+ options = tf.data.Options()
+ options.experimental_slack = tf_data_experimental_slack
+ dataset = dataset.with_options(options)
+
+ return dataset
+
+
+def get_filenames(is_training, data_dir):
+ """Return filenames for dataset."""
+ if is_training:
+ return [
+ os.path.join(data_dir, 'train-%05d-of-01024' % i)
+ for i in range(_NUM_TRAIN_FILES)]
+ else:
+ return [
+ os.path.join(data_dir, 'validation-%05d-of-00128' % i)
+ for i in range(128)]
+
+
+def parse_example_proto(example_serialized):
+ """Parses an Example proto containing a training example of an image.
+
+ The output of the build_image_data.py image preprocessing script is a dataset
+ containing serialized Example protocol buffers. Each Example proto contains
+ the following fields (values are included as examples):
+
+ image/height: 462
+ image/width: 581
+ image/colorspace: 'RGB'
+ image/channels: 3
+ image/class/label: 615
+ image/class/synset: 'n03623198'
+ image/class/text: 'knee pad'
+ image/object/bbox/xmin: 0.1
+ image/object/bbox/xmax: 0.9
+ image/object/bbox/ymin: 0.2
+ image/object/bbox/ymax: 0.6
+ image/object/bbox/label: 615
+ image/format: 'JPEG'
+ image/filename: 'ILSVRC2012_val_00041207.JPEG'
+ image/encoded:
+
+ Args:
+ example_serialized: scalar Tensor tf.string containing a serialized
+ Example protocol buffer.
+
+ Returns:
+ image_buffer: Tensor tf.string containing the contents of a JPEG file.
+ label: Tensor tf.int32 containing the label.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ """
+ # Dense features in Example proto.
+ feature_map = {
+ 'image/encoded': tf.io.FixedLenFeature([], dtype=tf.string,
+ default_value=''),
+ 'image/class/label': tf.io.FixedLenFeature([], dtype=tf.int64,
+ default_value=-1),
+ 'image/class/text': tf.io.FixedLenFeature([], dtype=tf.string,
+ default_value=''),
+ }
+ sparse_float32 = tf.io.VarLenFeature(dtype=tf.float32)
+ # Sparse features in Example proto.
+ feature_map.update(
+ {k: sparse_float32 for k in [
+ 'image/object/bbox/xmin', 'image/object/bbox/ymin',
+ 'image/object/bbox/xmax', 'image/object/bbox/ymax']})
+
+ features = tf.io.parse_single_example(serialized=example_serialized,
+ features=feature_map)
+ label = tf.cast(features['image/class/label'], dtype=tf.int32)
+
+ xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
+ ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
+ xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
+ ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
+
+ # Note that we impose an ordering of (y, x) just to make life difficult.
+ bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
+
+ # Force the variable number of bounding boxes into the shape
+ # [1, num_boxes, coords].
+ bbox = tf.expand_dims(bbox, 0)
+ bbox = tf.transpose(a=bbox, perm=[0, 2, 1])
+
+ return features['image/encoded'], label, bbox
+
+
+def parse_record(raw_record, is_training, dtype):
+ """Parses a record containing a training example of an image.
+
+ The input record is parsed into a label and image, and the image is passed
+ through preprocessing steps (cropping, flipping, and so on).
+
+ Args:
+ raw_record: scalar Tensor tf.string containing a serialized
+ Example protocol buffer.
+ is_training: A boolean denoting whether the input is for training.
+ dtype: data type to use for images/features.
+
+ Returns:
+ Tuple with processed image tensor in a channel-last format and
+ one-hot-encoded label tensor.
+ """
+ image_buffer, label, bbox = parse_example_proto(raw_record)
+
+ image = preprocess_image(
+ image_buffer=image_buffer,
+ bbox=bbox,
+ output_height=DEFAULT_IMAGE_SIZE,
+ output_width=DEFAULT_IMAGE_SIZE,
+ num_channels=NUM_CHANNELS,
+ is_training=is_training)
+ image = tf.cast(image, dtype)
+
+ # Subtract one so that labels are in [0, 1000), and cast to float32 for
+ # Keras model.
+ label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
+ dtype=tf.float32)
+ return image, label
+
+
+def get_parse_record_fn(use_keras_image_data_format=False):
+ """Get a function for parsing the records, accounting for image format.
+
+ This is useful by handling different types of Keras models. For instance,
+ the current resnet_model.resnet50 input format is always channel-last,
+ whereas the keras_applications mobilenet input format depends on
+ tf.keras.backend.image_data_format(). We should set
+ use_keras_image_data_format=False for the former and True for the latter.
+
+ Args:
+ use_keras_image_data_format: A boolean denoting whether data format is keras
+ backend image data format. If False, the image format is channel-last. If
+ True, the image format matches tf.keras.backend.image_data_format().
+
+ Returns:
+ Function to use for parsing the records.
+ """
+ def parse_record_fn(raw_record, is_training, dtype):
+ image, label = parse_record(raw_record, is_training, dtype)
+ if use_keras_image_data_format:
+ if tf.keras.backend.image_data_format() == 'channels_first':
+ image = tf.transpose(image, perm=[2, 0, 1])
+ return image, label
+ return parse_record_fn
+
+
+def input_fn(is_training,
+ data_dir,
+ batch_size,
+ dtype=tf.float32,
+ datasets_num_private_threads=None,
+ parse_record_fn=parse_record,
+ input_context=None,
+ drop_remainder=False,
+ tf_data_experimental_slack=False,
+ training_dataset_cache=False,
+ filenames=None):
+ """Input function which provides batches for train or eval.
+
+ Args:
+ is_training: A boolean denoting whether the input is for training.
+ data_dir: The directory containing the input data.
+ batch_size: The number of samples per batch.
+ dtype: Data type to use for images/features
+ datasets_num_private_threads: Number of private threads for tf.data.
+ parse_record_fn: Function to use for parsing the records.
+ input_context: A `tf.distribute.InputContext` object passed in by
+ `tf.distribute.Strategy`.
+ drop_remainder: A boolean indicates whether to drop the remainder of the
+ batches. If True, the batch dimension will be static.
+ tf_data_experimental_slack: Whether to enable tf.data's
+ `experimental_slack` option.
+ training_dataset_cache: Whether to cache the training dataset on workers.
+ Typically used to improve training performance when training data is in
+ remote storage and can fit into worker memory.
+ filenames: Optional field for providing the file names of the TFRecords.
+
+ Returns:
+ A dataset that can be used for iteration.
+ """
+ if filenames is None:
+ filenames = get_filenames(is_training, data_dir)
+ dataset = tf.data.Dataset.from_tensor_slices(filenames)
+
+ if input_context:
+ logging.info(
+ 'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
+ input_context.input_pipeline_id, input_context.num_input_pipelines)
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+
+ if is_training:
+ # Shuffle the input files
+ dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
+
+ # Convert to individual records.
+ # cycle_length = 10 means that up to 10 files will be read and deserialized in
+ # parallel. You may want to increase this number if you have a large number of
+ # CPU cores.
+ dataset = dataset.interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=10,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if is_training and training_dataset_cache:
+ # Improve training performance when training data is in remote storage and
+ # can fit into worker memory.
+ dataset = dataset.cache()
+
+ return process_record_dataset(
+ dataset=dataset,
+ is_training=is_training,
+ batch_size=batch_size,
+ shuffle_buffer=_SHUFFLE_BUFFER,
+ parse_record_fn=parse_record_fn,
+ dtype=dtype,
+ datasets_num_private_threads=datasets_num_private_threads,
+ drop_remainder=drop_remainder,
+ tf_data_experimental_slack=tf_data_experimental_slack,
+ )
+
+
+def _decode_crop_and_flip(image_buffer, bbox, num_channels):
+ """Crops the given image to a random part of the image, and randomly flips.
+
+ We use the fused decode_and_crop op, which performs better than the two ops
+ used separately in series, but note that this requires that the image be
+ passed in as an un-decoded string Tensor.
+
+ Args:
+ image_buffer: scalar string Tensor representing the raw JPEG image buffer.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ num_channels: Integer depth of the image buffer for decoding.
+
+ Returns:
+ 3-D tensor with cropped image.
+
+ """
+ # A large fraction of image datasets contain a human-annotated bounding box
+ # delineating the region of the image containing the object of interest. We
+ # choose to create a new bounding box for the object which is a randomly
+ # distorted version of the human-annotated bounding box that obeys an
+ # allowed range of aspect ratios, sizes and overlap with the human-annotated
+ # bounding box. If no box is supplied, then we assume the bounding box is
+ # the entire image.
+ sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
+ tf.image.extract_jpeg_shape(image_buffer),
+ bounding_boxes=bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=[0.75, 1.33],
+ area_range=[0.05, 1.0],
+ max_attempts=100,
+ use_image_if_no_bounding_boxes=True)
+ bbox_begin, bbox_size, _ = sample_distorted_bounding_box
+
+ # Reassemble the bounding box in the format the crop op requires.
+ offset_y, offset_x, _ = tf.unstack(bbox_begin)
+ target_height, target_width, _ = tf.unstack(bbox_size)
+ crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
+
+ # Use the fused decode and crop op here, which is faster than each in series.
+ cropped = tf.image.decode_and_crop_jpeg(
+ image_buffer, crop_window, channels=num_channels)
+
+ # Flip to add a little more random distortion in.
+ cropped = tf.image.random_flip_left_right(cropped)
+ return cropped
+
+
+def _central_crop(image, crop_height, crop_width):
+ """Performs central crops of the given image list.
+
+ Args:
+ image: a 3-D image tensor
+ crop_height: the height of the image following the crop.
+ crop_width: the width of the image following the crop.
+
+ Returns:
+ 3-D tensor with cropped image.
+ """
+ shape = tf.shape(input=image)
+ height, width = shape[0], shape[1]
+
+ amount_to_be_cropped_h = (height - crop_height)
+ crop_top = amount_to_be_cropped_h // 2
+ amount_to_be_cropped_w = (width - crop_width)
+ crop_left = amount_to_be_cropped_w // 2
+ return tf.slice(
+ image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
+
+
+def _mean_image_subtraction(image, means, num_channels):
+ """Subtracts the given means from each image channel.
+
+ For example:
+ means = [123.68, 116.779, 103.939]
+ image = _mean_image_subtraction(image, means)
+
+ Note that the rank of `image` must be known.
+
+ Args:
+ image: a tensor of size [height, width, C].
+ means: a C-vector of values to subtract from each channel.
+ num_channels: number of color channels in the image that will be distorted.
+
+ Returns:
+ the centered image.
+
+ Raises:
+ ValueError: If the rank of `image` is unknown, if `image` has a rank other
+ than three or if the number of channels in `image` doesn't match the
+ number of values in `means`.
+ """
+ if image.get_shape().ndims != 3:
+ raise ValueError('Input must be of size [height, width, C>0]')
+
+ if len(means) != num_channels:
+ raise ValueError('len(means) must match the number of channels')
+
+ # We have a 1-D tensor of means; convert to 3-D.
+ # Note(b/130245863): we explicitly call `broadcast` instead of simply
+ # expanding dimensions for better performance.
+ means = tf.broadcast_to(means, tf.shape(image))
+
+ return image - means
+
+
+def _smallest_size_at_least(height, width, resize_min):
+ """Computes new shape with the smallest side equal to `smallest_side`.
+
+ Computes new shape with the smallest side equal to `smallest_side` while
+ preserving the original aspect ratio.
+
+ Args:
+ height: an int32 scalar tensor indicating the current height.
+ width: an int32 scalar tensor indicating the current width.
+ resize_min: A python integer or scalar `Tensor` indicating the size of
+ the smallest side after resize.
+
+ Returns:
+ new_height: an int32 scalar tensor indicating the new height.
+ new_width: an int32 scalar tensor indicating the new width.
+ """
+ resize_min = tf.cast(resize_min, tf.float32)
+
+ # Convert to floats to make subsequent calculations go smoothly.
+ height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
+
+ smaller_dim = tf.minimum(height, width)
+ scale_ratio = resize_min / smaller_dim
+
+ # Convert back to ints to make heights and widths that TF ops will accept.
+ new_height = tf.cast(height * scale_ratio, tf.int32)
+ new_width = tf.cast(width * scale_ratio, tf.int32)
+
+ return new_height, new_width
+
+
+def _aspect_preserving_resize(image, resize_min):
+ """Resize images preserving the original aspect ratio.
+
+ Args:
+ image: A 3-D image `Tensor`.
+ resize_min: A python integer or scalar `Tensor` indicating the size of
+ the smallest side after resize.
+
+ Returns:
+ resized_image: A 3-D tensor containing the resized image.
+ """
+ shape = tf.shape(input=image)
+ height, width = shape[0], shape[1]
+
+ new_height, new_width = _smallest_size_at_least(height, width, resize_min)
+
+ return _resize_image(image, new_height, new_width)
+
+
+def _resize_image(image, height, width):
+ """Simple wrapper around tf.resize_images.
+
+ This is primarily to make sure we use the same `ResizeMethod` and other
+ details each time.
+
+ Args:
+ image: A 3-D image `Tensor`.
+ height: The target height for the resized image.
+ width: The target width for the resized image.
+
+ Returns:
+ resized_image: A 3-D tensor containing the resized image. The first two
+ dimensions have the shape [height, width].
+ """
+ return tf.compat.v1.image.resize(
+ image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
+ align_corners=False)
+
+
+def preprocess_image(image_buffer, bbox, output_height, output_width,
+ num_channels, is_training=False):
+ """Preprocesses the given image.
+
+ Preprocessing includes decoding, cropping, and resizing for both training
+ and eval images. Training preprocessing, however, introduces some random
+ distortion of the image to improve accuracy.
+
+ Args:
+ image_buffer: scalar string Tensor representing the raw JPEG image buffer.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ output_height: The height of the image after preprocessing.
+ output_width: The width of the image after preprocessing.
+ num_channels: Integer depth of the image buffer for decoding.
+ is_training: `True` if we're preprocessing the image for training and
+ `False` otherwise.
+
+ Returns:
+ A preprocessed image.
+ """
+ if is_training:
+ # For training, we want to randomize some of the distortions.
+ image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
+ image = _resize_image(image, output_height, output_width)
+ else:
+ # For validation, we want to decode, resize, then just crop the middle.
+ image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
+ image = _aspect_preserving_resize(image, _RESIZE_MIN)
+ image = _central_crop(image, output_height, output_width)
+
+ image.set_shape([output_height, output_width, num_channels])
+
+ return _mean_image_subtraction(image, CHANNEL_MEANS, num_channels)
diff --git a/models/official/vision/image_classification/resnet/resnet_config.py b/models/official/vision/image_classification/resnet/resnet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a746257f02b85eddfc72192b9474638b92378644
--- /dev/null
+++ b/models/official/vision/image_classification/resnet/resnet_config.py
@@ -0,0 +1,63 @@
+# Lint as: python3
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Configuration definitions for ResNet losses, learning rates, and optimizers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from typing import Any, Mapping
+
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+from official.vision.image_classification.configs import base_configs
+
+
+_RESNET_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
+ (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
+]
+_RESNET_LR_BOUNDARIES = list(p[1] for p in _RESNET_LR_SCHEDULE[1:])
+_RESNET_LR_MULTIPLIERS = list(p[0] for p in _RESNET_LR_SCHEDULE)
+_RESNET_LR_WARMUP_EPOCHS = _RESNET_LR_SCHEDULE[0][1]
+
+
+@dataclasses.dataclass
+class ResNetModelConfig(base_configs.ModelConfig):
+ """Configuration for the ResNet model."""
+ name: str = 'ResNet'
+ num_classes: int = 1000
+ model_params: base_config.Config = dataclasses.field(
+ default_factory=lambda: {
+ 'num_classes': 1000,
+ 'batch_size': None,
+ 'use_l2_regularizer': True,
+ 'rescale_inputs': False,
+ })
+ loss: base_configs.LossConfig = base_configs.LossConfig(
+ name='sparse_categorical_crossentropy')
+ optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
+ name='momentum',
+ decay=0.9,
+ epsilon=0.001,
+ momentum=0.9,
+ moving_average_decay=None)
+ learning_rate: base_configs.LearningRateConfig = (
+ base_configs.LearningRateConfig(
+ name='piecewise_constant_with_warmup',
+ examples_per_epoch=1281167,
+ warmup_epochs=_RESNET_LR_WARMUP_EPOCHS,
+ boundaries=_RESNET_LR_BOUNDARIES,
+ multipliers=_RESNET_LR_MULTIPLIERS))
diff --git a/models/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py b/models/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..c128dc0b99535d806634b42b99a2e56211c567ca
--- /dev/null
+++ b/models/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
@@ -0,0 +1,196 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+
+from official.modeling import performance
+from official.staging.training import controller
+from official.utils.flags import core as flags_core
+from official.utils.misc import distribution_utils
+from official.utils.misc import keras_utils
+from official.utils.misc import model_helpers
+from official.vision.image_classification.resnet import common
+from official.vision.image_classification.resnet import imagenet_preprocessing
+from official.vision.image_classification.resnet import resnet_runnable
+
+flags.DEFINE_boolean(name='use_tf_function', default=True,
+ help='Wrap the train and test step inside a '
+ 'tf.function.')
+flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
+ help='Calculate L2_loss on concatenated weights, '
+ 'instead of using Keras per-layer L2 loss.')
+
+
+def build_stats(runnable, time_callback):
+ """Normalizes and returns dictionary of stats.
+
+ Args:
+ runnable: The module containing all the training and evaluation metrics.
+ time_callback: Time tracking callback instance.
+
+ Returns:
+ Dictionary of normalized results.
+ """
+ stats = {}
+
+ if not runnable.flags_obj.skip_eval:
+ stats['eval_loss'] = runnable.test_loss.result().numpy()
+ stats['eval_acc'] = runnable.test_accuracy.result().numpy()
+
+ stats['train_loss'] = runnable.train_loss.result().numpy()
+ stats['train_acc'] = runnable.train_accuracy.result().numpy()
+
+ if time_callback:
+ timestamp_log = time_callback.timestamp_log
+ stats['step_timestamp_log'] = timestamp_log
+ stats['train_finish_time'] = time_callback.train_finish_time
+ if time_callback.epoch_runtime_log:
+ stats['avg_exp_per_second'] = time_callback.average_examples_per_second
+
+ return stats
+
+
+def get_num_train_iterations(flags_obj):
+ """Returns the number of training steps, train and test epochs."""
+ train_steps = (
+ imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
+ train_epochs = flags_obj.train_epochs
+
+ if flags_obj.train_steps:
+ train_steps = min(flags_obj.train_steps, train_steps)
+ train_epochs = 1
+
+ eval_steps = math.ceil(1.0 * imagenet_preprocessing.NUM_IMAGES['validation'] /
+ flags_obj.batch_size)
+
+ return train_steps, train_epochs, eval_steps
+
+
+def _steps_to_run(steps_in_current_epoch, steps_per_epoch, steps_per_loop):
+ """Calculates steps to run on device."""
+ if steps_per_loop <= 0:
+ raise ValueError('steps_per_loop should be positive integer.')
+ if steps_per_loop == 1:
+ return steps_per_loop
+ return min(steps_per_loop, steps_per_epoch - steps_in_current_epoch)
+
+
+def run(flags_obj):
+ """Run ResNet ImageNet training and eval loop using custom training loops.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+
+ Raises:
+ ValueError: If fp16 is passed as it is not currently supported.
+
+ Returns:
+ Dictionary of training and eval stats.
+ """
+ keras_utils.set_session_config(
+ enable_xla=flags_obj.enable_xla)
+ performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
+
+ if tf.config.list_physical_devices('GPU'):
+ if flags_obj.tf_gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=flags_obj.per_gpu_thread_count,
+ gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
+ num_gpus=flags_obj.num_gpus,
+ datasets_num_private_threads=flags_obj.datasets_num_private_threads)
+ common.set_cudnn_batchnorm_mode()
+
+ # TODO(anj-s): Set data_format without using Keras.
+ data_format = flags_obj.data_format
+ if data_format is None:
+ data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
+ else 'channels_last')
+ tf.keras.backend.set_image_data_format(data_format)
+
+ strategy = distribution_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=flags_obj.num_gpus,
+ all_reduce_alg=flags_obj.all_reduce_alg,
+ num_packs=flags_obj.num_packs,
+ tpu_address=flags_obj.tpu)
+
+ per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
+ flags_obj)
+ steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
+
+ logging.info(
+ 'Training %d epochs, each epoch has %d steps, '
+ 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
+ train_epochs * per_epoch_steps, eval_steps)
+
+ time_callback = keras_utils.TimeHistory(
+ flags_obj.batch_size,
+ flags_obj.log_steps,
+ logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
+ with distribution_utils.get_strategy_scope(strategy):
+ runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
+ per_epoch_steps)
+
+ eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
+ checkpoint_interval = (
+ per_epoch_steps if flags_obj.enable_checkpoint_and_export else None)
+ summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None
+
+ checkpoint_manager = tf.train.CheckpointManager(
+ runnable.checkpoint,
+ directory=flags_obj.model_dir,
+ max_to_keep=10,
+ step_counter=runnable.global_step,
+ checkpoint_interval=checkpoint_interval)
+
+ resnet_controller = controller.Controller(
+ strategy,
+ runnable.train,
+ runnable.evaluate if not flags_obj.skip_eval else None,
+ global_step=runnable.global_step,
+ steps_per_loop=steps_per_loop,
+ train_steps=per_epoch_steps * train_epochs,
+ checkpoint_manager=checkpoint_manager,
+ summary_interval=summary_interval,
+ eval_steps=eval_steps,
+ eval_interval=eval_interval)
+
+ time_callback.on_train_begin()
+ resnet_controller.train(evaluate=not flags_obj.skip_eval)
+ time_callback.on_train_end()
+
+ stats = build_stats(runnable, time_callback)
+ return stats
+
+
+def main(_):
+ model_helpers.apply_clean(flags.FLAGS)
+ stats = run(flags.FLAGS)
+ logging.info('Run stats:\n%s', stats)
+
+
+if __name__ == '__main__':
+ logging.set_verbosity(logging.INFO)
+ common.define_keras_flags()
+ app.run(main)
diff --git a/models/official/vision/image_classification/resnet/resnet_model.py b/models/official/vision/image_classification/resnet/resnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..10f1233356ece188cce51ec254f0064739cd6f41
--- /dev/null
+++ b/models/official/vision/image_classification/resnet/resnet_model.py
@@ -0,0 +1,329 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ResNet50 model for Keras.
+
+Adapted from tf.keras.applications.resnet50.ResNet50().
+This is ResNet model version 1.5.
+
+Related papers/blogs:
+- https://arxiv.org/abs/1512.03385
+- https://arxiv.org/pdf/1603.05027v2.pdf
+- http://torch.ch/blog/2016/02/04/resnets.html
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras import models
+from tensorflow.python.keras import regularizers
+from official.vision.image_classification.resnet import imagenet_preprocessing
+
+layers = tf.keras.layers
+
+
+def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
+ return regularizers.l2(l2_weight_decay) if use_l2_regularizer else None
+
+
+def identity_block(input_tensor,
+ kernel_size,
+ filters,
+ stage,
+ block,
+ use_l2_regularizer=True,
+ batch_norm_decay=0.9,
+ batch_norm_epsilon=1e-5):
+ """The identity block is the block that has no conv layer at shortcut.
+
+ Args:
+ input_tensor: input tensor
+ kernel_size: default 3, the kernel size of middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ block: 'a','b'..., current block label, used for generating layer names
+ use_l2_regularizer: whether to use L2 regularizer on Conv layer.
+ batch_norm_decay: Moment of batch norm layers.
+ batch_norm_epsilon: Epsilon of batch borm layers.
+
+ Returns:
+ Output tensor for the block.
+ """
+ filters1, filters2, filters3 = filters
+ if backend.image_data_format() == 'channels_last':
+ bn_axis = 3
+ else:
+ bn_axis = 1
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+
+ x = layers.Conv2D(
+ filters1, (1, 1),
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2a')(
+ input_tensor)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2a')(
+ x)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(
+ filters2,
+ kernel_size,
+ padding='same',
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2b')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2b')(
+ x)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(
+ filters3, (1, 1),
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2c')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2c')(
+ x)
+
+ x = layers.add([x, input_tensor])
+ x = layers.Activation('relu')(x)
+ return x
+
+
+def conv_block(input_tensor,
+ kernel_size,
+ filters,
+ stage,
+ block,
+ strides=(2, 2),
+ use_l2_regularizer=True,
+ batch_norm_decay=0.9,
+ batch_norm_epsilon=1e-5):
+ """A block that has a conv layer at shortcut.
+
+ Note that from stage 3,
+ the second conv layer at main path is with strides=(2, 2)
+ And the shortcut should have strides=(2, 2) as well
+
+ Args:
+ input_tensor: input tensor
+ kernel_size: default 3, the kernel size of middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ block: 'a','b'..., current block label, used for generating layer names
+ strides: Strides for the second conv layer in the block.
+ use_l2_regularizer: whether to use L2 regularizer on Conv layer.
+ batch_norm_decay: Moment of batch norm layers.
+ batch_norm_epsilon: Epsilon of batch borm layers.
+
+ Returns:
+ Output tensor for the block.
+ """
+ filters1, filters2, filters3 = filters
+ if backend.image_data_format() == 'channels_last':
+ bn_axis = 3
+ else:
+ bn_axis = 1
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+
+ x = layers.Conv2D(
+ filters1, (1, 1),
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2a')(
+ input_tensor)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2a')(
+ x)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(
+ filters2,
+ kernel_size,
+ strides=strides,
+ padding='same',
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2b')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2b')(
+ x)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(
+ filters3, (1, 1),
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2c')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2c')(
+ x)
+
+ shortcut = layers.Conv2D(
+ filters3, (1, 1),
+ strides=strides,
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '1')(
+ input_tensor)
+ shortcut = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '1')(
+ shortcut)
+
+ x = layers.add([x, shortcut])
+ x = layers.Activation('relu')(x)
+ return x
+
+
+def resnet50(num_classes,
+ batch_size=None,
+ use_l2_regularizer=True,
+ rescale_inputs=False,
+ batch_norm_decay=0.9,
+ batch_norm_epsilon=1e-5):
+ """Instantiates the ResNet50 architecture.
+
+ Args:
+ num_classes: `int` number of classes for image classification.
+ batch_size: Size of the batches for each step.
+ use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
+ rescale_inputs: whether to rescale inputs from 0 to 1.
+ batch_norm_decay: Moment of batch norm layers.
+ batch_norm_epsilon: Epsilon of batch borm layers.
+
+ Returns:
+ A Keras model instance.
+ """
+ input_shape = (224, 224, 3)
+ img_input = layers.Input(shape=input_shape, batch_size=batch_size)
+ if rescale_inputs:
+ # Hub image modules expect inputs in the range [0, 1]. This rescales these
+ # inputs to the range expected by the trained model.
+ x = layers.Lambda(
+ lambda x: x * 255.0 - backend.constant(
+ imagenet_preprocessing.CHANNEL_MEANS,
+ shape=[1, 1, 3],
+ dtype=x.dtype),
+ name='rescale')(
+ img_input)
+ else:
+ x = img_input
+
+ if backend.image_data_format() == 'channels_first':
+ x = layers.Permute((3, 1, 2))(x)
+ bn_axis = 1
+ else: # channels_last
+ bn_axis = 3
+
+ block_config = dict(
+ use_l2_regularizer=use_l2_regularizer,
+ batch_norm_decay=batch_norm_decay,
+ batch_norm_epsilon=batch_norm_epsilon)
+ x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
+ x = layers.Conv2D(
+ 64, (7, 7),
+ strides=(2, 2),
+ padding='valid',
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='conv1')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv1')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
+
+ x = conv_block(
+ x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), **block_config)
+ x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', **block_config)
+ x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', **block_config)
+
+ x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', **block_config)
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', **block_config)
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', **block_config)
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', **block_config)
+
+ x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', **block_config)
+
+ x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', **block_config)
+ x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', **block_config)
+ x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', **block_config)
+
+ x = layers.GlobalAveragePooling2D()(x)
+ x = layers.Dense(
+ num_classes,
+ kernel_initializer=initializers.RandomNormal(stddev=0.01),
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='fc1000')(
+ x)
+
+ # A softmax that is followed by the model loss must be done cannot be done
+ # in float16 due to numeric issues. So we pass dtype=float32.
+ x = layers.Activation('softmax', dtype='float32')(x)
+
+ # Create model.
+ return models.Model(img_input, x, name='resnet50')
diff --git a/models/official/vision/image_classification/resnet/resnet_runnable.py b/models/official/vision/image_classification/resnet/resnet_runnable.py
new file mode 100644
index 0000000000000000000000000000000000000000..473b18daf7aaf02bfb1dc86110b3ae0fd2704359
--- /dev/null
+++ b/models/official/vision/image_classification/resnet/resnet_runnable.py
@@ -0,0 +1,221 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from official.modeling import performance
+from official.staging.training import grad_utils
+from official.staging.training import standard_runnable
+from official.staging.training import utils
+from official.utils.flags import core as flags_core
+from official.vision.image_classification.resnet import common
+from official.vision.image_classification.resnet import imagenet_preprocessing
+from official.vision.image_classification.resnet import resnet_model
+
+
+class ResnetRunnable(standard_runnable.StandardTrainable,
+ standard_runnable.StandardEvaluable):
+ """Implements the training and evaluation APIs for Resnet model."""
+
+ def __init__(self, flags_obj, time_callback, epoch_steps):
+ standard_runnable.StandardTrainable.__init__(self,
+ flags_obj.use_tf_while_loop,
+ flags_obj.use_tf_function)
+ standard_runnable.StandardEvaluable.__init__(self,
+ flags_obj.use_tf_function)
+
+ self.strategy = tf.distribute.get_strategy()
+ self.flags_obj = flags_obj
+ self.dtype = flags_core.get_tf_dtype(flags_obj)
+ self.time_callback = time_callback
+
+ # Input pipeline related
+ batch_size = flags_obj.batch_size
+ if batch_size % self.strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ 'Batch size must be divisible by number of replicas : {}'.format(
+ self.strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `experimental_distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)
+
+ if self.flags_obj.use_synthetic_data:
+ self.input_fn = common.get_synth_input_fn(
+ height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
+ width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
+ num_channels=imagenet_preprocessing.NUM_CHANNELS,
+ num_classes=imagenet_preprocessing.NUM_CLASSES,
+ dtype=self.dtype,
+ drop_remainder=True)
+ else:
+ self.input_fn = imagenet_preprocessing.input_fn
+
+ self.model = resnet_model.resnet50(
+ num_classes=imagenet_preprocessing.NUM_CLASSES,
+ use_l2_regularizer=not flags_obj.single_l2_loss_op)
+
+ lr_schedule = common.PiecewiseConstantDecayWithWarmup(
+ batch_size=flags_obj.batch_size,
+ epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
+ warmup_epochs=common.LR_SCHEDULE[0][1],
+ boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
+ multipliers=list(p[0] for p in common.LR_SCHEDULE),
+ compute_lr_on_cpu=True)
+ self.optimizer = common.get_optimizer(lr_schedule)
+ # Make sure iterations variable is created inside scope.
+ self.global_step = self.optimizer.iterations
+
+ use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
+ if use_graph_rewrite and not flags_obj.use_tf_function:
+ raise ValueError('--fp16_implementation=graph_rewrite requires '
+ '--use_tf_function to be true')
+ self.optimizer = performance.configure_optimizer(
+ self.optimizer,
+ use_float16=self.dtype == tf.float16,
+ use_graph_rewrite=use_graph_rewrite,
+ loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
+
+ self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
+ self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
+ 'train_accuracy', dtype=tf.float32)
+ self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
+ self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
+ 'test_accuracy', dtype=tf.float32)
+
+ self.checkpoint = tf.train.Checkpoint(
+ model=self.model, optimizer=self.optimizer)
+
+ # Handling epochs.
+ self.epoch_steps = epoch_steps
+ self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
+
+ def build_train_dataset(self):
+ """See base class."""
+ return utils.make_distributed_dataset(
+ self.strategy,
+ self.input_fn,
+ is_training=True,
+ data_dir=self.flags_obj.data_dir,
+ batch_size=self.batch_size,
+ parse_record_fn=imagenet_preprocessing.parse_record,
+ datasets_num_private_threads=self.flags_obj
+ .datasets_num_private_threads,
+ dtype=self.dtype,
+ drop_remainder=True)
+
+ def build_eval_dataset(self):
+ """See base class."""
+ return utils.make_distributed_dataset(
+ self.strategy,
+ self.input_fn,
+ is_training=False,
+ data_dir=self.flags_obj.data_dir,
+ batch_size=self.batch_size,
+ parse_record_fn=imagenet_preprocessing.parse_record,
+ dtype=self.dtype)
+
+ def train_loop_begin(self):
+ """See base class."""
+ # Reset all metrics
+ self.train_loss.reset_states()
+ self.train_accuracy.reset_states()
+
+ self._epoch_begin()
+ self.time_callback.on_batch_begin(self.epoch_helper.batch_index)
+
+ def train_step(self, iterator):
+ """See base class."""
+
+ def step_fn(inputs):
+ """Function to run on the device."""
+ images, labels = inputs
+ with tf.GradientTape() as tape:
+ logits = self.model(images, training=True)
+
+ prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
+ labels, logits)
+ loss = tf.reduce_sum(prediction_loss) * (1.0 /
+ self.flags_obj.batch_size)
+ num_replicas = self.strategy.num_replicas_in_sync
+ l2_weight_decay = 1e-4
+ if self.flags_obj.single_l2_loss_op:
+ l2_loss = l2_weight_decay * 2 * tf.add_n([
+ tf.nn.l2_loss(v)
+ for v in self.model.trainable_variables
+ if 'bn' not in v.name
+ ])
+
+ loss += (l2_loss / num_replicas)
+ else:
+ loss += (tf.reduce_sum(self.model.losses) / num_replicas)
+
+ grad_utils.minimize_using_explicit_allreduce(
+ tape, self.optimizer, loss, self.model.trainable_variables)
+ self.train_loss.update_state(loss)
+ self.train_accuracy.update_state(labels, logits)
+
+ self.strategy.run(step_fn, args=(next(iterator),))
+
+ def train_loop_end(self):
+ """See base class."""
+ metrics = {
+ 'train_loss': self.train_loss.result(),
+ 'train_accuracy': self.train_accuracy.result(),
+ }
+ self.time_callback.on_batch_end(self.epoch_helper.batch_index - 1)
+ self._epoch_end()
+ return metrics
+
+ def eval_begin(self):
+ """See base class."""
+ self.test_loss.reset_states()
+ self.test_accuracy.reset_states()
+
+ def eval_step(self, iterator):
+ """See base class."""
+
+ def step_fn(inputs):
+ """Function to run on the device."""
+ images, labels = inputs
+ logits = self.model(images, training=False)
+ loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits)
+ loss = tf.reduce_sum(loss) * (1.0 / self.flags_obj.batch_size)
+ self.test_loss.update_state(loss)
+ self.test_accuracy.update_state(labels, logits)
+
+ self.strategy.run(step_fn, args=(next(iterator),))
+
+ def eval_end(self):
+ """See base class."""
+ return {
+ 'test_loss': self.test_loss.result(),
+ 'test_accuracy': self.test_accuracy.result()
+ }
+
+ def _epoch_begin(self):
+ if self.epoch_helper.epoch_begin():
+ self.time_callback.on_epoch_begin(self.epoch_helper.current_epoch)
+
+ def _epoch_end(self):
+ if self.epoch_helper.epoch_end():
+ self.time_callback.on_epoch_end(self.epoch_helper.current_epoch)
diff --git a/models/official/vision/image_classification/resnet/tfhub_export.py b/models/official/vision/image_classification/resnet/tfhub_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff1f124a1d67c93b9deee453a23cf71133bb6434
--- /dev/null
+++ b/models/official/vision/image_classification/resnet/tfhub_export.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A script to export TF-Hub SavedModel."""
+
+from __future__ import absolute_import
+from __future__ import division
+# from __future__ import google_type_annotations
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+
+import tensorflow as tf
+
+from official.vision.image_classification.resnet import imagenet_preprocessing
+from official.vision.image_classification.resnet import resnet_model
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("model_path", None,
+ "File path to TF model checkpoint or H5 file.")
+flags.DEFINE_string("export_path", None,
+ "TF-Hub SavedModel destination path to export.")
+
+
+def export_tfhub(model_path, hub_destination):
+ """Restores a tf.keras.Model and saves for TF-Hub."""
+ model = resnet_model.resnet50(
+ num_classes=imagenet_preprocessing.NUM_CLASSES, rescale_inputs=True)
+ model.load_weights(model_path)
+ model.save(
+ os.path.join(hub_destination, "classification"), include_optimizer=False)
+
+ # Extracts a sub-model to use pooling feature vector as model output.
+ image_input = model.get_layer(index=0).get_output_at(0)
+ feature_vector_output = model.get_layer(name="reduce_mean").get_output_at(0)
+ hub_model = tf.keras.Model(image_input, feature_vector_output)
+
+ # Exports a SavedModel.
+ hub_model.save(
+ os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError("Too many command-line arguments.")
+
+ export_tfhub(FLAGS.model_path, FLAGS.export_path)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/official/vision/image_classification/test_utils.py b/models/official/vision/image_classification/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6dc91dc775ce25950a8918450548c19992eb2c4
--- /dev/null
+++ b/models/official/vision/image_classification/test_utils.py
@@ -0,0 +1,38 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for image classification tasks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import models
+
+
+def trivial_model(num_classes):
+ """Trivial model for ImageNet dataset."""
+
+ input_shape = (224, 224, 3)
+ img_input = layers.Input(shape=input_shape)
+
+ x = layers.Lambda(lambda x: backend.reshape(x, [-1, 224 * 224 * 3]),
+ name='reshape')(img_input)
+ x = layers.Dense(1, name='fc1')(x)
+ x = layers.Dense(num_classes, name='fc1000')(x)
+ x = layers.Activation('softmax', dtype='float32')(x)
+
+ return models.Model(img_input, x, name='trivial')
diff --git a/models/research/README.md b/models/research/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f9e84fb86f44b687c6a9c221fa72cd461e84c01e
--- /dev/null
+++ b/models/research/README.md
@@ -0,0 +1,124 @@
+
+
+# TensorFlow Research Models
+
+This directory contains code implementations and pre-trained models of published research papers.
+
+The research models are maintained by their respective authors.
+
+## Table of Contents
+- [Modeling Libraries and Models](#modeling-libraries-and-models)
+- [Models and Implementations](#models-and-implementations)
+ * [Computer Vision](#computer-vision)
+ * [Natural Language Processing](#natural-language-processing)
+ * [Audio and Speech](#audio-and-speech)
+ * [Reinforcement Learning](#reinforcement-learning)
+ * [Others](#others)
+- [Archived Models and Implementations](#warning-archived-models-and-implementations) (:no_entry_sign: No longer maintained)
+
+## Modeling Libraries and Models
+
+| Directory | Name | Description | Maintainer(s) |
+|-----------|------|-------------|---------------|
+| [object_detection](object_detection) | TensorFlow Object Detection API | A framework that makes it easy to construct, train and deploy object detection models
+
+## Contacts (Maintainers)
+
+* Liang-Chieh Chen, github: [aquariusjay](https://github.com/aquariusjay)
+* YuKun Zhu, github: [yknzhu](https://github.com/YknZhu)
+* George Papandreou, github: [gpapan](https://github.com/gpapan)
+* Hui Hui, github: [huihui-personal](https://github.com/huihui-personal)
+* Maxwell D. Collins, github: [mcollinswisc](https://github.com/mcollinswisc)
+* Ting Liu: github: [tingliu](https://github.com/tingliu)
+
+## Tables of Contents
+
+Demo:
+
+* Colab notebook for off-the-shelf inference.
+
+Running:
+
+* Installation.
+* Running DeepLab on PASCAL VOC 2012 semantic segmentation dataset.
+* Running DeepLab on Cityscapes semantic segmentation dataset.
+* Running DeepLab on ADE20K semantic segmentation dataset.
+
+Models:
+
+* Checkpoints and frozen inference graphs.
+
+Misc:
+
+* Please check FAQ if you have some questions before reporting the issues.
+
+## Getting Help
+
+To get help with issues you may encounter while using the DeepLab Tensorflow
+implementation, create a new question on
+[StackOverflow](https://stackoverflow.com/) with the tag "tensorflow".
+
+Please report bugs (i.e., broken code, not usage questions) to the
+tensorflow/models GitHub [issue
+tracker](https://github.com/tensorflow/models/issues), prefixing the issue name
+with "deeplab".
+
+## License
+
+All the codes in deeplab folder is covered by the [LICENSE](https://github.com/tensorflow/models/blob/master/LICENSE)
+under tensorflow/models. Please refer to the LICENSE for details.
+
+## Change Logs
+
+### March 26, 2020
+* Supported EdgeTPU-DeepLab and EdgeTPU-DeepLab-slim on Cityscapes.
+**Contributor**: Yun Long.
+
+### November 20, 2019
+* Supported MobileNetV3 large and small model variants on Cityscapes.
+**Contributor**: Yukun Zhu.
+
+
+### March 27, 2019
+
+* Supported using different loss weights on different classes during training.
+**Contributor**: Yuwei Yang.
+
+
+### March 26, 2019
+
+* Supported ResNet-v1-18. **Contributor**: Michalis Raptis.
+
+
+### March 6, 2019
+
+* Released the evaluation code (under the `evaluation` folder) for image
+parsing, a.k.a. panoptic segmentation. In particular, the released code supports
+evaluating the parsing results in terms of both the parsing covering and
+panoptic quality metrics. **Contributors**: Maxwell Collins and Ting Liu.
+
+
+### February 6, 2019
+
+* Updated decoder module to exploit multiple low-level features with different
+output_strides.
+
+### December 3, 2018
+
+* Released the MobileNet-v2 checkpoint on ADE20K.
+
+
+### November 19, 2018
+
+* Supported NAS architecture for feature extraction. **Contributor**: Chenxi Liu.
+
+* Supported hard pixel mining during training.
+
+
+### October 1, 2018
+
+* Released MobileNet-v2 depth-multiplier = 0.5 COCO-pretrained checkpoints on
+PASCAL VOC 2012, and Xception-65 COCO pretrained checkpoint (i.e., no PASCAL
+pretrained).
+
+
+### September 5, 2018
+
+* Released Cityscapes pretrained checkpoints with found best dense prediction cell.
+
+
+### May 26, 2018
+
+* Updated ADE20K pretrained checkpoint.
+
+
+### May 18, 2018
+* Added builders for ResNet-v1 and Xception model variants.
+* Added ADE20K support, including colormap and pretrained Xception_65 checkpoint.
+* Fixed a bug on using non-default depth_multiplier for MobileNet-v2.
+
+
+### March 22, 2018
+
+* Released checkpoints using MobileNet-V2 as network backbone and pretrained on
+PASCAL VOC 2012 and Cityscapes.
+
+
+### March 5, 2018
+
+* First release of DeepLab in TensorFlow including deeper Xception network
+backbone. Included chekcpoints that have been pretrained on PASCAL VOC 2012
+and Cityscapes.
+
+## References
+
+1. **Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs**
+ Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, Alan L. Yuille (+ equal
+ contribution).
+ [[link]](https://arxiv.org/abs/1412.7062). In ICLR, 2015.
+
+2. **DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,**
+ **Atrous Convolution, and Fully Connected CRFs**
+ Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, and Alan L Yuille (+ equal
+ contribution).
+ [[link]](http://arxiv.org/abs/1606.00915). TPAMI 2017.
+
+3. **Rethinking Atrous Convolution for Semantic Image Segmentation**
+ Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam.
+ [[link]](http://arxiv.org/abs/1706.05587). arXiv: 1706.05587, 2017.
+
+4. **Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation**
+ Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam.
+ [[link]](https://arxiv.org/abs/1802.02611). In ECCV, 2018.
+
+5. **ParseNet: Looking Wider to See Better**
+ Wei Liu, Andrew Rabinovich, Alexander C Berg
+ [[link]](https://arxiv.org/abs/1506.04579). arXiv:1506.04579, 2015.
+
+6. **Pyramid Scene Parsing Network**
+ Hengshuang Zhao, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, Jiaya Jia
+ [[link]](https://arxiv.org/abs/1612.01105). In CVPR, 2017.
+
+7. **Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate shift**
+ Sergey Ioffe, Christian Szegedy
+ [[link]](https://arxiv.org/abs/1502.03167). In ICML, 2015.
+
+8. **MobileNetV2: Inverted Residuals and Linear Bottlenecks**
+ Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
+ [[link]](https://arxiv.org/abs/1801.04381). In CVPR, 2018.
+
+9. **Xception: Deep Learning with Depthwise Separable Convolutions**
+ François Chollet
+ [[link]](https://arxiv.org/abs/1610.02357). In CVPR, 2017.
+
+10. **Deformable Convolutional Networks -- COCO Detection and Segmentation Challenge 2017 Entry**
+ Haozhi Qi, Zheng Zhang, Bin Xiao, Han Hu, Bowen Cheng, Yichen Wei, Jifeng Dai
+ [[link]](http://presentations.cocodataset.org/COCO17-Detect-MSRA.pdf). ICCV COCO Challenge
+ Workshop, 2017.
+
+11. **Tensorflow: Large-Scale Machine Learning on Heterogeneous Distributed Systems**
+ M. Abadi, A. Agarwal, et al.
+ [[link]](https://arxiv.org/abs/1603.04467). arXiv:1603.04467, 2016.
+
+12. **The Pascal Visual Object Classes Challenge – A Retrospective,**
+ Mark Everingham, S. M. Ali Eslami, Luc Van Gool, Christopher K. I. Williams, John
+ Winn, and Andrew Zisserma.
+ [[link]](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/). IJCV, 2014.
+
+13. **The Cityscapes Dataset for Semantic Urban Scene Understanding**
+ Cordts, Marius, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele.
+ [[link]](https://www.cityscapes-dataset.com/). In CVPR, 2016.
+
+14. **Deep Residual Learning for Image Recognition**
+ Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
+ [[link]](https://arxiv.org/abs/1512.03385). In CVPR, 2016.
+
+15. **Progressive Neural Architecture Search**
+ Chenxi Liu, Barret Zoph, Maxim Neumann, Jonathon Shlens, Wei Hua, Li-Jia Li, Li Fei-Fei, Alan Yuille, Jonathan Huang, Kevin Murphy.
+ [[link]](https://arxiv.org/abs/1712.00559). In ECCV, 2018.
+
+16. **Searching for MobileNetV3**
+ Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam.
+ [[link]](https://arxiv.org/abs/1905.02244). In ICCV, 2019.
diff --git a/models/research/deeplab/__init__.py b/models/research/deeplab/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/deeplab/common.py b/models/research/deeplab/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..928f7176c377e69aa2c5b8bc676f092cf97819c9
--- /dev/null
+++ b/models/research/deeplab/common.py
@@ -0,0 +1,295 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides flags that are common to scripts.
+
+Common flags from train/eval/vis/export_model.py are collected in this script.
+"""
+import collections
+import copy
+import json
+import tensorflow as tf
+
+flags = tf.app.flags
+
+# Flags for input preprocessing.
+
+flags.DEFINE_integer('min_resize_value', None,
+ 'Desired size of the smaller image side.')
+
+flags.DEFINE_integer('max_resize_value', None,
+ 'Maximum allowed size of the larger image side.')
+
+flags.DEFINE_integer('resize_factor', None,
+ 'Resized dimensions are multiple of factor plus one.')
+
+flags.DEFINE_boolean('keep_aspect_ratio', True,
+ 'Keep aspect ratio after resizing or not.')
+
+# Model dependent flags.
+
+flags.DEFINE_integer('logits_kernel_size', 1,
+ 'The kernel size for the convolutional kernel that '
+ 'generates logits.')
+
+# When using 'mobilent_v2', we set atrous_rates = decoder_output_stride = None.
+# When using 'xception_65' or 'resnet_v1' model variants, we set
+# atrous_rates = [6, 12, 18] (output stride 16) and decoder_output_stride = 4.
+# See core/feature_extractor.py for supported model variants.
+flags.DEFINE_string('model_variant', 'mobilenet_v2', 'DeepLab model variant.')
+
+flags.DEFINE_multi_float('image_pyramid', None,
+ 'Input scales for multi-scale feature extraction.')
+
+flags.DEFINE_boolean('add_image_level_feature', True,
+ 'Add image level feature.')
+
+flags.DEFINE_list(
+ 'image_pooling_crop_size', None,
+ 'Image pooling crop size [height, width] used in the ASPP module. When '
+ 'value is None, the model performs image pooling with "crop_size". This'
+ 'flag is useful when one likes to use different image pooling sizes.')
+
+flags.DEFINE_list(
+ 'image_pooling_stride', '1,1',
+ 'Image pooling stride [height, width] used in the ASPP image pooling. ')
+
+flags.DEFINE_boolean('aspp_with_batch_norm', True,
+ 'Use batch norm parameters for ASPP or not.')
+
+flags.DEFINE_boolean('aspp_with_separable_conv', True,
+ 'Use separable convolution for ASPP or not.')
+
+# Defaults to None. Set multi_grid = [1, 2, 4] when using provided
+# 'resnet_v1_{50,101}_beta' checkpoints.
+flags.DEFINE_multi_integer('multi_grid', None,
+ 'Employ a hierarchy of atrous rates for ResNet.')
+
+flags.DEFINE_float('depth_multiplier', 1.0,
+ 'Multiplier for the depth (number of channels) for all '
+ 'convolution ops used in MobileNet.')
+
+flags.DEFINE_integer('divisible_by', None,
+ 'An integer that ensures the layer # channels are '
+ 'divisible by this value. Used in MobileNet.')
+
+# For `xception_65`, use decoder_output_stride = 4. For `mobilenet_v2`, use
+# decoder_output_stride = None.
+flags.DEFINE_list('decoder_output_stride', None,
+ 'Comma-separated list of strings with the number specifying '
+ 'output stride of low-level features at each network level.'
+ 'Current semantic segmentation implementation assumes at '
+ 'most one output stride (i.e., either None or a list with '
+ 'only one element.')
+
+flags.DEFINE_boolean('decoder_use_separable_conv', True,
+ 'Employ separable convolution for decoder or not.')
+
+flags.DEFINE_enum('merge_method', 'max', ['max', 'avg'],
+ 'Scheme to merge multi scale features.')
+
+flags.DEFINE_boolean(
+ 'prediction_with_upsampled_logits', True,
+ 'When performing prediction, there are two options: (1) bilinear '
+ 'upsampling the logits followed by softmax, or (2) softmax followed by '
+ 'bilinear upsampling.')
+
+flags.DEFINE_string(
+ 'dense_prediction_cell_json',
+ '',
+ 'A JSON file that specifies the dense prediction cell.')
+
+flags.DEFINE_integer(
+ 'nas_stem_output_num_conv_filters', 20,
+ 'Number of filters of the stem output tensor in NAS models.')
+
+flags.DEFINE_bool('nas_use_classification_head', False,
+ 'Use image classification head for NAS model variants.')
+
+flags.DEFINE_bool('nas_remove_os32_stride', False,
+ 'Remove the stride in the output stride 32 branch.')
+
+flags.DEFINE_bool('use_bounded_activation', False,
+ 'Whether or not to use bounded activations. Bounded '
+ 'activations better lend themselves to quantized inference.')
+
+flags.DEFINE_boolean('aspp_with_concat_projection', True,
+ 'ASPP with concat projection.')
+
+flags.DEFINE_boolean('aspp_with_squeeze_and_excitation', False,
+ 'ASPP with squeeze and excitation.')
+
+flags.DEFINE_integer('aspp_convs_filters', 256, 'ASPP convolution filters.')
+
+flags.DEFINE_boolean('decoder_use_sum_merge', False,
+ 'Decoder uses simply sum merge.')
+
+flags.DEFINE_integer('decoder_filters', 256, 'Decoder filters.')
+
+flags.DEFINE_boolean('decoder_output_is_logits', False,
+ 'Use decoder output as logits or not.')
+
+flags.DEFINE_boolean('image_se_uses_qsigmoid', False, 'Use q-sigmoid.')
+
+flags.DEFINE_multi_float(
+ 'label_weights', None,
+ 'A list of label weights, each element represents the weight for the label '
+ 'of its index, for example, label_weights = [0.1, 0.5] means the weight '
+ 'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all '
+ 'the labels have the same weight 1.0.')
+
+flags.DEFINE_float('batch_norm_decay', 0.9997, 'Batchnorm decay.')
+
+FLAGS = flags.FLAGS
+
+# Constants
+
+# Perform semantic segmentation predictions.
+OUTPUT_TYPE = 'semantic'
+
+# Semantic segmentation item names.
+LABELS_CLASS = 'labels_class'
+IMAGE = 'image'
+HEIGHT = 'height'
+WIDTH = 'width'
+IMAGE_NAME = 'image_name'
+LABEL = 'label'
+ORIGINAL_IMAGE = 'original_image'
+
+# Test set name.
+TEST_SET = 'test'
+
+
+class ModelOptions(
+ collections.namedtuple('ModelOptions', [
+ 'outputs_to_num_classes',
+ 'crop_size',
+ 'atrous_rates',
+ 'output_stride',
+ 'preprocessed_images_dtype',
+ 'merge_method',
+ 'add_image_level_feature',
+ 'image_pooling_crop_size',
+ 'image_pooling_stride',
+ 'aspp_with_batch_norm',
+ 'aspp_with_separable_conv',
+ 'multi_grid',
+ 'decoder_output_stride',
+ 'decoder_use_separable_conv',
+ 'logits_kernel_size',
+ 'model_variant',
+ 'depth_multiplier',
+ 'divisible_by',
+ 'prediction_with_upsampled_logits',
+ 'dense_prediction_cell_config',
+ 'nas_architecture_options',
+ 'use_bounded_activation',
+ 'aspp_with_concat_projection',
+ 'aspp_with_squeeze_and_excitation',
+ 'aspp_convs_filters',
+ 'decoder_use_sum_merge',
+ 'decoder_filters',
+ 'decoder_output_is_logits',
+ 'image_se_uses_qsigmoid',
+ 'label_weights',
+ 'sync_batch_norm_method',
+ 'batch_norm_decay',
+ ])):
+ """Immutable class to hold model options."""
+
+ __slots__ = ()
+
+ def __new__(cls,
+ outputs_to_num_classes,
+ crop_size=None,
+ atrous_rates=None,
+ output_stride=8,
+ preprocessed_images_dtype=tf.float32):
+ """Constructor to set default values.
+
+ Args:
+ outputs_to_num_classes: A dictionary from output type to the number of
+ classes. For example, for the task of semantic segmentation with 21
+ semantic classes, we would have outputs_to_num_classes['semantic'] = 21.
+ crop_size: A tuple [crop_height, crop_width].
+ atrous_rates: A list of atrous convolution rates for ASPP.
+ output_stride: The ratio of input to output spatial resolution.
+ preprocessed_images_dtype: The type after the preprocessing function.
+
+ Returns:
+ A new ModelOptions instance.
+ """
+ dense_prediction_cell_config = None
+ if FLAGS.dense_prediction_cell_json:
+ with tf.gfile.Open(FLAGS.dense_prediction_cell_json, 'r') as f:
+ dense_prediction_cell_config = json.load(f)
+ decoder_output_stride = None
+ if FLAGS.decoder_output_stride:
+ decoder_output_stride = [
+ int(x) for x in FLAGS.decoder_output_stride]
+ if sorted(decoder_output_stride, reverse=True) != decoder_output_stride:
+ raise ValueError('Decoder output stride need to be sorted in the '
+ 'descending order.')
+ image_pooling_crop_size = None
+ if FLAGS.image_pooling_crop_size:
+ image_pooling_crop_size = [int(x) for x in FLAGS.image_pooling_crop_size]
+ image_pooling_stride = [1, 1]
+ if FLAGS.image_pooling_stride:
+ image_pooling_stride = [int(x) for x in FLAGS.image_pooling_stride]
+ label_weights = FLAGS.label_weights
+ if label_weights is None:
+ label_weights = 1.0
+ nas_architecture_options = {
+ 'nas_stem_output_num_conv_filters': (
+ FLAGS.nas_stem_output_num_conv_filters),
+ 'nas_use_classification_head': FLAGS.nas_use_classification_head,
+ 'nas_remove_os32_stride': FLAGS.nas_remove_os32_stride,
+ }
+ return super(ModelOptions, cls).__new__(
+ cls, outputs_to_num_classes, crop_size, atrous_rates, output_stride,
+ preprocessed_images_dtype,
+ FLAGS.merge_method,
+ FLAGS.add_image_level_feature,
+ image_pooling_crop_size,
+ image_pooling_stride,
+ FLAGS.aspp_with_batch_norm,
+ FLAGS.aspp_with_separable_conv,
+ FLAGS.multi_grid,
+ decoder_output_stride,
+ FLAGS.decoder_use_separable_conv,
+ FLAGS.logits_kernel_size,
+ FLAGS.model_variant,
+ FLAGS.depth_multiplier,
+ FLAGS.divisible_by,
+ FLAGS.prediction_with_upsampled_logits,
+ dense_prediction_cell_config,
+ nas_architecture_options,
+ FLAGS.use_bounded_activation,
+ FLAGS.aspp_with_concat_projection,
+ FLAGS.aspp_with_squeeze_and_excitation,
+ FLAGS.aspp_convs_filters,
+ FLAGS.decoder_use_sum_merge,
+ FLAGS.decoder_filters,
+ FLAGS.decoder_output_is_logits,
+ FLAGS.image_se_uses_qsigmoid,
+ label_weights,
+ 'None',
+ FLAGS.batch_norm_decay)
+
+ def __deepcopy__(self, memo):
+ return ModelOptions(copy.deepcopy(self.outputs_to_num_classes),
+ self.crop_size,
+ self.atrous_rates,
+ self.output_stride,
+ self.preprocessed_images_dtype)
diff --git a/models/research/deeplab/common_test.py b/models/research/deeplab/common_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b64e50e3bb0a574c0ec230075e6fddff3ae996
--- /dev/null
+++ b/models/research/deeplab/common_test.py
@@ -0,0 +1,52 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for common.py."""
+import copy
+
+import tensorflow as tf
+
+from deeplab import common
+
+
+class CommonTest(tf.test.TestCase):
+
+ def testOutputsToNumClasses(self):
+ num_classes = 21
+ model_options = common.ModelOptions(
+ outputs_to_num_classes={common.OUTPUT_TYPE: num_classes})
+ self.assertEqual(model_options.outputs_to_num_classes[common.OUTPUT_TYPE],
+ num_classes)
+
+ def testDeepcopy(self):
+ num_classes = 21
+ model_options = common.ModelOptions(
+ outputs_to_num_classes={common.OUTPUT_TYPE: num_classes})
+ model_options_new = copy.deepcopy(model_options)
+ self.assertEqual((model_options_new.
+ outputs_to_num_classes[common.OUTPUT_TYPE]),
+ num_classes)
+
+ num_classes_new = 22
+ model_options_new.outputs_to_num_classes[common.OUTPUT_TYPE] = (
+ num_classes_new)
+ self.assertEqual(model_options.outputs_to_num_classes[common.OUTPUT_TYPE],
+ num_classes)
+ self.assertEqual((model_options_new.
+ outputs_to_num_classes[common.OUTPUT_TYPE]),
+ num_classes_new)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/core/__init__.py b/models/research/deeplab/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/deeplab/core/conv2d_ws.py b/models/research/deeplab/core/conv2d_ws.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aaaf33dd3c2e098d7d5e815b4918c436ee1796c
--- /dev/null
+++ b/models/research/deeplab/core/conv2d_ws.py
@@ -0,0 +1,369 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Augment slim.conv2d with optional Weight Standardization (WS).
+
+WS is a normalization method to accelerate micro-batch training. When used with
+Group Normalization and trained with 1 image/GPU, WS is able to match or
+outperform the performances of BN trained with large batch sizes.
+[1] Siyuan Qiao, Huiyu Wang, Chenxi Liu, Wei Shen, Alan Yuille
+ Weight Standardization. arXiv:1903.10520
+[2] Lei Huang, Xianglong Liu, Yang Liu, Bo Lang, Dacheng Tao
+ Centered Weight Normalization in Accelerating Training of Deep Neural
+ Networks. ICCV 2017
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import layers as contrib_layers
+
+from tensorflow.contrib.layers.python.layers import layers
+from tensorflow.contrib.layers.python.layers import utils
+
+
+class Conv2D(tf.keras.layers.Conv2D, tf.layers.Layer):
+ """2D convolution layer (e.g. spatial convolution over images).
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. If `use_bias` is True (and a `bias_initializer` is provided),
+ a bias vector is created and added to the outputs. Finally, if
+ `activation` is not `None`, it is applied to the outputs as well.
+ """
+
+ def __init__(self,
+ filters,
+ kernel_size,
+ strides=(1, 1),
+ padding='valid',
+ data_format='channels_last',
+ dilation_rate=(1, 1),
+ activation=None,
+ use_bias=True,
+ kernel_initializer=None,
+ bias_initializer=tf.zeros_initializer(),
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ use_weight_standardization=False,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ bias_constraint=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ """Constructs the 2D convolution layer.
+
+ Args:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 2 integers, specifying the height
+ and width of the 2D convolution window. Can be a single integer to
+ specify the same value for all spatial dimensions.
+ strides: An integer or tuple/list of 2 integers, specifying the strides of
+ the convolution along the height and width. Can be a single integer to
+ specify the same value for all spatial dimensions. Specifying any stride
+ value != 1 is incompatible with specifying any `dilation_rate` value !=
+ 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or
+ `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, height, width,
+ channels)` while `channels_first` corresponds to inputs with shape
+ `(batch, channels, height, width)`.
+ dilation_rate: An integer or tuple/list of 2 integers, specifying the
+ dilation rate to use for dilated convolution. Can be a single integer to
+ specify the same value for all spatial dimensions. Currently, specifying
+ any `dilation_rate` value != 1 is incompatible with specifying any
+ stride value != 1.
+ activation: Activation function. Set it to None to maintain a linear
+ activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ kernel_initializer: An initializer for the convolution kernel.
+ bias_initializer: An initializer for the bias vector. If None, the default
+ initializer will be used.
+ kernel_regularizer: Optional regularizer for the convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ use_weight_standardization: Boolean, whether the layer uses weight
+ standardization.
+ activity_regularizer: Optional regularizer function for the output.
+ kernel_constraint: Optional projection function to be applied to the
+ kernel after being updated by an `Optimizer` (e.g. used to implement
+ norm constraints or value constraints for layer weights). The function
+ must take as input the unprojected variable and must return the
+ projected variable (which must have the same shape). Constraints are not
+ safe to use when doing asynchronous distributed training.
+ bias_constraint: Optional projection function to be applied to the bias
+ after being updated by an `Optimizer`.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ **kwargs: Arbitrary keyword arguments passed to tf.keras.layers.Conv2D
+ """
+
+ super(Conv2D, self).__init__(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ kernel_constraint=kernel_constraint,
+ bias_constraint=bias_constraint,
+ trainable=trainable,
+ name=name,
+ **kwargs)
+ self.use_weight_standardization = use_weight_standardization
+
+ def call(self, inputs):
+ if self.use_weight_standardization:
+ mean, var = tf.nn.moments(self.kernel, [0, 1, 2], keep_dims=True)
+ kernel = (self.kernel - mean) / tf.sqrt(var + 1e-5)
+ outputs = self._convolution_op(inputs, kernel)
+ else:
+ outputs = self._convolution_op(inputs, self.kernel)
+
+ if self.use_bias:
+ if self.data_format == 'channels_first':
+ if self.rank == 1:
+ # tf.nn.bias_add does not accept a 1D input tensor.
+ bias = tf.reshape(self.bias, (1, self.filters, 1))
+ outputs += bias
+ else:
+ outputs = tf.nn.bias_add(outputs, self.bias, data_format='NCHW')
+ else:
+ outputs = tf.nn.bias_add(outputs, self.bias, data_format='NHWC')
+
+ if self.activation is not None:
+ return self.activation(outputs)
+ return outputs
+
+
+@contrib_framework.add_arg_scope
+def conv2d(inputs,
+ num_outputs,
+ kernel_size,
+ stride=1,
+ padding='SAME',
+ data_format=None,
+ rate=1,
+ activation_fn=tf.nn.relu,
+ normalizer_fn=None,
+ normalizer_params=None,
+ weights_initializer=contrib_layers.xavier_initializer(),
+ weights_regularizer=None,
+ biases_initializer=tf.zeros_initializer(),
+ biases_regularizer=None,
+ use_weight_standardization=False,
+ reuse=None,
+ variables_collections=None,
+ outputs_collections=None,
+ trainable=True,
+ scope=None):
+ """Adds a 2D convolution followed by an optional batch_norm layer.
+
+ `convolution` creates a variable called `weights`, representing the
+ convolutional kernel, that is convolved (actually cross-correlated) with the
+ `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is
+ provided (such as `batch_norm`), it is then applied. Otherwise, if
+ `normalizer_fn` is None and a `biases_initializer` is provided then a `biases`
+ variable would be created and added the activations. Finally, if
+ `activation_fn` is not `None`, it is applied to the activations as well.
+
+ Performs atrous convolution with input stride/dilation rate equal to `rate`
+ if a value > 1 for any dimension of `rate` is specified. In this case
+ `stride` values != 1 are not supported.
+
+ Args:
+ inputs: A Tensor of rank N+2 of shape `[batch_size] + input_spatial_shape +
+ [in_channels]` if data_format does not start with "NC" (default), or
+ `[batch_size, in_channels] + input_spatial_shape` if data_format starts
+ with "NC".
+ num_outputs: Integer, the number of output filters.
+ kernel_size: A sequence of N positive integers specifying the spatial
+ dimensions of the filters. Can be a single integer to specify the same
+ value for all spatial dimensions.
+ stride: A sequence of N positive integers specifying the stride at which to
+ compute output. Can be a single integer to specify the same value for all
+ spatial dimensions. Specifying any `stride` value != 1 is incompatible
+ with specifying any `rate` value != 1.
+ padding: One of `"VALID"` or `"SAME"`.
+ data_format: A string or None. Specifies whether the channel dimension of
+ the `input` and output is the last dimension (default, or if `data_format`
+ does not start with "NC"), or the second dimension (if `data_format`
+ starts with "NC"). For N=1, the valid values are "NWC" (default) and
+ "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
+ N=3, the valid values are "NDHWC" (default) and "NCDHW".
+ rate: A sequence of N positive integers specifying the dilation rate to use
+ for atrous convolution. Can be a single integer to specify the same value
+ for all spatial dimensions. Specifying any `rate` value != 1 is
+ incompatible with specifying any `stride` value != 1.
+ activation_fn: Activation function. The default value is a ReLU function.
+ Explicitly set it to None to skip it and maintain a linear activation.
+ normalizer_fn: Normalization function to use instead of `biases`. If
+ `normalizer_fn` is provided then `biases_initializer` and
+ `biases_regularizer` are ignored and `biases` are not created nor added.
+ default set to None for no normalizer function
+ normalizer_params: Normalization function parameters.
+ weights_initializer: An initializer for the weights.
+ weights_regularizer: Optional regularizer for the weights.
+ biases_initializer: An initializer for the biases. If None skip biases.
+ biases_regularizer: Optional regularizer for the biases.
+ use_weight_standardization: Boolean, whether the layer uses weight
+ standardization.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for all the variables or
+ a dictionary containing a different list of collection per variable.
+ outputs_collections: Collection to add the outputs.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ A tensor representing the output of the operation.
+
+ Raises:
+ ValueError: If `data_format` is invalid.
+ ValueError: Both 'rate' and `stride` are not uniformly 1.
+ """
+ if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
+ raise ValueError('Invalid data_format: %r' % (data_format,))
+
+ # pylint: disable=protected-access
+ layer_variable_getter = layers._build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
+ # pylint: enable=protected-access
+ with tf.variable_scope(
+ scope, 'Conv', [inputs], reuse=reuse,
+ custom_getter=layer_variable_getter) as sc:
+ inputs = tf.convert_to_tensor(inputs)
+ input_rank = inputs.get_shape().ndims
+
+ if input_rank != 4:
+ raise ValueError('Convolution expects input with rank %d, got %d' %
+ (4, input_rank))
+
+ data_format = ('channels_first' if data_format and
+ data_format.startswith('NC') else 'channels_last')
+ layer = Conv2D(
+ filters=num_outputs,
+ kernel_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=rate,
+ activation=None,
+ use_bias=not normalizer_fn and biases_initializer,
+ kernel_initializer=weights_initializer,
+ bias_initializer=biases_initializer,
+ kernel_regularizer=weights_regularizer,
+ bias_regularizer=biases_regularizer,
+ use_weight_standardization=use_weight_standardization,
+ activity_regularizer=None,
+ trainable=trainable,
+ name=sc.name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=sc,
+ _reuse=reuse)
+ outputs = layer.apply(inputs)
+
+ # Add variables to collections.
+ # pylint: disable=protected-access
+ layers._add_variable_to_collections(layer.kernel, variables_collections,
+ 'weights')
+ if layer.use_bias:
+ layers._add_variable_to_collections(layer.bias, variables_collections,
+ 'biases')
+ # pylint: enable=protected-access
+ if normalizer_fn is not None:
+ normalizer_params = normalizer_params or {}
+ outputs = normalizer_fn(outputs, **normalizer_params)
+
+ if activation_fn is not None:
+ outputs = activation_fn(outputs)
+ return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
+
+
+def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
+ """Strided 2-D convolution with 'SAME' padding.
+
+ When stride > 1, then we do explicit zero-padding, followed by conv2d with
+ 'VALID' padding.
+
+ Note that
+
+ net = conv2d_same(inputs, num_outputs, 3, stride=stride)
+
+ is equivalent to
+
+ net = conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
+ net = subsample(net, factor=stride)
+
+ whereas
+
+ net = conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
+
+ is different when the input's height or width is even, which is why we add the
+ current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
+
+ Args:
+ inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
+ num_outputs: An integer, the number of output filters.
+ kernel_size: An int with the kernel_size of the filters.
+ stride: An integer, the output stride.
+ rate: An integer, rate for atrous convolution.
+ scope: Scope.
+
+ Returns:
+ output: A 4-D tensor of size [batch, height_out, width_out, channels] with
+ the convolution output.
+ """
+ if stride == 1:
+ return conv2d(
+ inputs,
+ num_outputs,
+ kernel_size,
+ stride=1,
+ rate=rate,
+ padding='SAME',
+ scope=scope)
+ else:
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
+ pad_total = kernel_size_effective - 1
+ pad_beg = pad_total // 2
+ pad_end = pad_total - pad_beg
+ inputs = tf.pad(inputs,
+ [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
+ return conv2d(
+ inputs,
+ num_outputs,
+ kernel_size,
+ stride=stride,
+ rate=rate,
+ padding='VALID',
+ scope=scope)
diff --git a/models/research/deeplab/core/conv2d_ws_test.py b/models/research/deeplab/core/conv2d_ws_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6bea85ee034779ff5bf87b88cd912aa8dc863f2
--- /dev/null
+++ b/models/research/deeplab/core/conv2d_ws_test.py
@@ -0,0 +1,420 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for conv2d_ws."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import layers as contrib_layers
+from deeplab.core import conv2d_ws
+
+
+class ConvolutionTest(tf.test.TestCase):
+
+ def testInvalidShape(self):
+ with self.cached_session():
+ images_3d = tf.random_uniform((5, 6, 7, 9, 3), seed=1)
+ with self.assertRaisesRegexp(
+ ValueError, 'Convolution expects input with rank 4, got 5'):
+ conv2d_ws.conv2d(images_3d, 32, 3)
+
+ def testInvalidDataFormat(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ with self.assertRaisesRegexp(ValueError, 'data_format'):
+ conv2d_ws.conv2d(images, 32, 3, data_format='CHWN')
+
+ def testCreateConv(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
+ output = conv2d_ws.conv2d(images, 32, [3, 3])
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+ weights = contrib_framework.get_variables_by_name('weights')[0]
+ self.assertListEqual(weights.get_shape().as_list(), [3, 3, 4, 32])
+ biases = contrib_framework.get_variables_by_name('biases')[0]
+ self.assertListEqual(biases.get_shape().as_list(), [32])
+
+ def testCreateConvWithWS(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
+ output = conv2d_ws.conv2d(
+ images, 32, [3, 3], use_weight_standardization=True)
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+ weights = contrib_framework.get_variables_by_name('weights')[0]
+ self.assertListEqual(weights.get_shape().as_list(), [3, 3, 4, 32])
+ biases = contrib_framework.get_variables_by_name('biases')[0]
+ self.assertListEqual(biases.get_shape().as_list(), [32])
+
+ def testCreateConvNCHW(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = np.random.uniform(size=(5, 4, height, width)).astype(np.float32)
+ output = conv2d_ws.conv2d(images, 32, [3, 3], data_format='NCHW')
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, 32, height, width])
+ weights = contrib_framework.get_variables_by_name('weights')[0]
+ self.assertListEqual(weights.get_shape().as_list(), [3, 3, 4, 32])
+ biases = contrib_framework.get_variables_by_name('biases')[0]
+ self.assertListEqual(biases.get_shape().as_list(), [32])
+
+ def testCreateSquareConv(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = conv2d_ws.conv2d(images, 32, 3)
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+ def testCreateConvWithTensorShape(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = conv2d_ws.conv2d(images, 32, images.get_shape()[1:3])
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+ def testCreateFullyConv(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 32), seed=1)
+ output = conv2d_ws.conv2d(
+ images, 64, images.get_shape()[1:3], padding='VALID')
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 64])
+ biases = contrib_framework.get_variables_by_name('biases')[0]
+ self.assertListEqual(biases.get_shape().as_list(), [64])
+
+ def testFullyConvWithCustomGetter(self):
+ height, width = 7, 9
+ with self.cached_session():
+ called = [0]
+
+ def custom_getter(getter, *args, **kwargs):
+ called[0] += 1
+ return getter(*args, **kwargs)
+
+ with tf.variable_scope('test', custom_getter=custom_getter):
+ images = tf.random_uniform((5, height, width, 32), seed=1)
+ conv2d_ws.conv2d(images, 64, images.get_shape()[1:3])
+ self.assertEqual(called[0], 2) # Custom getter called twice.
+
+ def testCreateVerticalConv(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 4), seed=1)
+ output = conv2d_ws.conv2d(images, 32, [3, 1])
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+ weights = contrib_framework.get_variables_by_name('weights')[0]
+ self.assertListEqual(weights.get_shape().as_list(), [3, 1, 4, 32])
+ biases = contrib_framework.get_variables_by_name('biases')[0]
+ self.assertListEqual(biases.get_shape().as_list(), [32])
+
+ def testCreateHorizontalConv(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 4), seed=1)
+ output = conv2d_ws.conv2d(images, 32, [1, 3])
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+ weights = contrib_framework.get_variables_by_name('weights')[0]
+ self.assertListEqual(weights.get_shape().as_list(), [1, 3, 4, 32])
+
+ def testCreateConvWithStride(self):
+ height, width = 6, 8
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = conv2d_ws.conv2d(images, 32, [3, 3], stride=2)
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, height / 2, width / 2, 32])
+
+ def testCreateConvCreatesWeightsAndBiasesVars(self):
+ height, width = 7, 9
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ with self.cached_session():
+ self.assertFalse(contrib_framework.get_variables('conv1/weights'))
+ self.assertFalse(contrib_framework.get_variables('conv1/biases'))
+ conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1')
+ self.assertTrue(contrib_framework.get_variables('conv1/weights'))
+ self.assertTrue(contrib_framework.get_variables('conv1/biases'))
+
+ def testCreateConvWithScope(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1')
+ self.assertEqual(output.op.name, 'conv1/Relu')
+
+ def testCreateConvWithCollection(self):
+ height, width = 7, 9
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ with tf.name_scope('fe'):
+ conv = conv2d_ws.conv2d(
+ images, 32, [3, 3], outputs_collections='outputs', scope='Conv')
+ output_collected = tf.get_collection('outputs')[0]
+ self.assertEqual(output_collected.aliases, ['Conv'])
+ self.assertEqual(output_collected, conv)
+
+ def testCreateConvWithoutActivation(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = conv2d_ws.conv2d(images, 32, [3, 3], activation_fn=None)
+ self.assertEqual(output.op.name, 'Conv/BiasAdd')
+
+ def testCreateConvValid(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = conv2d_ws.conv2d(images, 32, [3, 3], padding='VALID')
+ self.assertListEqual(output.get_shape().as_list(), [5, 5, 7, 32])
+
+ def testCreateConvWithWD(self):
+ height, width = 7, 9
+ weight_decay = 0.01
+ with self.cached_session() as sess:
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ regularizer = contrib_layers.l2_regularizer(weight_decay)
+ conv2d_ws.conv2d(images, 32, [3, 3], weights_regularizer=regularizer)
+ l2_loss = tf.nn.l2_loss(
+ contrib_framework.get_variables_by_name('weights')[0])
+ wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
+ self.assertEqual(wd.op.name, 'Conv/kernel/Regularizer/l2_regularizer')
+ sess.run(tf.global_variables_initializer())
+ self.assertAlmostEqual(sess.run(wd), weight_decay * l2_loss.eval())
+
+ def testCreateConvNoRegularizers(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ conv2d_ws.conv2d(images, 32, [3, 3])
+ self.assertEqual(
+ tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
+
+ def testReuseVars(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1')
+ self.assertEqual(len(contrib_framework.get_variables()), 2)
+ conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
+ self.assertEqual(len(contrib_framework.get_variables()), 2)
+
+ def testNonReuseVars(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ conv2d_ws.conv2d(images, 32, [3, 3])
+ self.assertEqual(len(contrib_framework.get_variables()), 2)
+ conv2d_ws.conv2d(images, 32, [3, 3])
+ self.assertEqual(len(contrib_framework.get_variables()), 4)
+
+ def testReuseConvWithWD(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ weight_decay = contrib_layers.l2_regularizer(0.01)
+ with contrib_framework.arg_scope([conv2d_ws.conv2d],
+ weights_regularizer=weight_decay):
+ conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1')
+ self.assertEqual(len(contrib_framework.get_variables()), 2)
+ self.assertEqual(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+ conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
+ self.assertEqual(len(contrib_framework.get_variables()), 2)
+ self.assertEqual(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+
+ def testConvWithBatchNorm(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 32), seed=1)
+ with contrib_framework.arg_scope([conv2d_ws.conv2d],
+ normalizer_fn=contrib_layers.batch_norm,
+ normalizer_params={'decay': 0.9}):
+ net = conv2d_ws.conv2d(images, 32, [3, 3])
+ net = conv2d_ws.conv2d(net, 32, [3, 3])
+ self.assertEqual(len(contrib_framework.get_variables()), 8)
+ self.assertEqual(
+ len(contrib_framework.get_variables('Conv/BatchNorm')), 3)
+ self.assertEqual(
+ len(contrib_framework.get_variables('Conv_1/BatchNorm')), 3)
+
+ def testReuseConvWithBatchNorm(self):
+ height, width = 7, 9
+ with self.cached_session():
+ images = tf.random_uniform((5, height, width, 32), seed=1)
+ with contrib_framework.arg_scope([conv2d_ws.conv2d],
+ normalizer_fn=contrib_layers.batch_norm,
+ normalizer_params={'decay': 0.9}):
+ net = conv2d_ws.conv2d(images, 32, [3, 3], scope='Conv')
+ net = conv2d_ws.conv2d(net, 32, [3, 3], scope='Conv', reuse=True)
+ self.assertEqual(len(contrib_framework.get_variables()), 4)
+ self.assertEqual(
+ len(contrib_framework.get_variables('Conv/BatchNorm')), 3)
+ self.assertEqual(
+ len(contrib_framework.get_variables('Conv_1/BatchNorm')), 0)
+
+ def testCreateConvCreatesWeightsAndBiasesVarsWithRateTwo(self):
+ height, width = 7, 9
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ with self.cached_session():
+ self.assertFalse(contrib_framework.get_variables('conv1/weights'))
+ self.assertFalse(contrib_framework.get_variables('conv1/biases'))
+ conv2d_ws.conv2d(images, 32, [3, 3], rate=2, scope='conv1')
+ self.assertTrue(contrib_framework.get_variables('conv1/weights'))
+ self.assertTrue(contrib_framework.get_variables('conv1/biases'))
+
+ def testOutputSizeWithRateTwoSamePadding(self):
+ num_filters = 32
+ input_size = [5, 10, 12, 3]
+ expected_size = [5, 10, 12, num_filters]
+
+ images = tf.random_uniform(input_size, seed=1)
+ output = conv2d_ws.conv2d(
+ images, num_filters, [3, 3], rate=2, padding='SAME')
+ self.assertListEqual(list(output.get_shape().as_list()), expected_size)
+ with self.cached_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(list(output.eval().shape), expected_size)
+
+ def testOutputSizeWithRateTwoValidPadding(self):
+ num_filters = 32
+ input_size = [5, 10, 12, 3]
+ expected_size = [5, 6, 8, num_filters]
+
+ images = tf.random_uniform(input_size, seed=1)
+ output = conv2d_ws.conv2d(
+ images, num_filters, [3, 3], rate=2, padding='VALID')
+ self.assertListEqual(list(output.get_shape().as_list()), expected_size)
+ with self.cached_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(list(output.eval().shape), expected_size)
+
+ def testOutputSizeWithRateTwoThreeValidPadding(self):
+ num_filters = 32
+ input_size = [5, 10, 12, 3]
+ expected_size = [5, 6, 6, num_filters]
+
+ images = tf.random_uniform(input_size, seed=1)
+ output = conv2d_ws.conv2d(
+ images, num_filters, [3, 3], rate=[2, 3], padding='VALID')
+ self.assertListEqual(list(output.get_shape().as_list()), expected_size)
+ with self.cached_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(list(output.eval().shape), expected_size)
+
+ def testDynamicOutputSizeWithRateOneValidPadding(self):
+ num_filters = 32
+ input_size = [5, 9, 11, 3]
+ expected_size = [None, None, None, num_filters]
+ expected_size_dynamic = [5, 7, 9, num_filters]
+
+ with self.cached_session():
+ images = tf.placeholder(np.float32, [None, None, None, input_size[3]])
+ output = conv2d_ws.conv2d(
+ images, num_filters, [3, 3], rate=1, padding='VALID')
+ tf.global_variables_initializer().run()
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), expected_size)
+ eval_output = output.eval({images: np.zeros(input_size, np.float32)})
+ self.assertListEqual(list(eval_output.shape), expected_size_dynamic)
+
+ def testDynamicOutputSizeWithRateOneValidPaddingNCHW(self):
+ if tf.test.is_gpu_available(cuda_only=True):
+ num_filters = 32
+ input_size = [5, 3, 9, 11]
+ expected_size = [None, num_filters, None, None]
+ expected_size_dynamic = [5, num_filters, 7, 9]
+
+ with self.session(use_gpu=True):
+ images = tf.placeholder(np.float32, [None, input_size[1], None, None])
+ output = conv2d_ws.conv2d(
+ images,
+ num_filters, [3, 3],
+ rate=1,
+ padding='VALID',
+ data_format='NCHW')
+ tf.global_variables_initializer().run()
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), expected_size)
+ eval_output = output.eval({images: np.zeros(input_size, np.float32)})
+ self.assertListEqual(list(eval_output.shape), expected_size_dynamic)
+
+ def testDynamicOutputSizeWithRateTwoValidPadding(self):
+ num_filters = 32
+ input_size = [5, 9, 11, 3]
+ expected_size = [None, None, None, num_filters]
+ expected_size_dynamic = [5, 5, 7, num_filters]
+
+ with self.cached_session():
+ images = tf.placeholder(np.float32, [None, None, None, input_size[3]])
+ output = conv2d_ws.conv2d(
+ images, num_filters, [3, 3], rate=2, padding='VALID')
+ tf.global_variables_initializer().run()
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), expected_size)
+ eval_output = output.eval({images: np.zeros(input_size, np.float32)})
+ self.assertListEqual(list(eval_output.shape), expected_size_dynamic)
+
+ def testWithScope(self):
+ num_filters = 32
+ input_size = [5, 9, 11, 3]
+ expected_size = [5, 5, 7, num_filters]
+
+ images = tf.random_uniform(input_size, seed=1)
+ output = conv2d_ws.conv2d(
+ images, num_filters, [3, 3], rate=2, padding='VALID', scope='conv7')
+ with self.cached_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ self.assertEqual(output.op.name, 'conv7/Relu')
+ self.assertListEqual(list(output.eval().shape), expected_size)
+
+ def testWithScopeWithoutActivation(self):
+ num_filters = 32
+ input_size = [5, 9, 11, 3]
+ expected_size = [5, 5, 7, num_filters]
+
+ images = tf.random_uniform(input_size, seed=1)
+ output = conv2d_ws.conv2d(
+ images,
+ num_filters, [3, 3],
+ rate=2,
+ padding='VALID',
+ activation_fn=None,
+ scope='conv7')
+ with self.cached_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ self.assertEqual(output.op.name, 'conv7/BiasAdd')
+ self.assertListEqual(list(output.eval().shape), expected_size)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/core/dense_prediction_cell.py b/models/research/deeplab/core/dense_prediction_cell.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e32f8e227f0841d51df780618523a53c5eb4ae3
--- /dev/null
+++ b/models/research/deeplab/core/dense_prediction_cell.py
@@ -0,0 +1,290 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Dense Prediction Cell class that can be evolved in semantic segmentation.
+
+DensePredictionCell is used as a `layer` in semantic segmentation whose
+architecture is determined by the `config`, a dictionary specifying
+the architecture.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib import slim as contrib_slim
+
+from deeplab.core import utils
+
+slim = contrib_slim
+
+# Local constants.
+_META_ARCHITECTURE_SCOPE = 'meta_architecture'
+_CONCAT_PROJECTION_SCOPE = 'concat_projection'
+_OP = 'op'
+_CONV = 'conv'
+_PYRAMID_POOLING = 'pyramid_pooling'
+_KERNEL = 'kernel'
+_RATE = 'rate'
+_GRID_SIZE = 'grid_size'
+_TARGET_SIZE = 'target_size'
+_INPUT = 'input'
+
+
+def dense_prediction_cell_hparams():
+ """DensePredictionCell HParams.
+
+ Returns:
+ A dictionary of hyper-parameters used for dense prediction cell with keys:
+ - reduction_size: Integer, the number of output filters for each operation
+ inside the cell.
+ - dropout_on_concat_features: Boolean, apply dropout on the concatenated
+ features or not.
+ - dropout_on_projection_features: Boolean, apply dropout on the projection
+ features or not.
+ - dropout_keep_prob: Float, when `dropout_on_concat_features' or
+ `dropout_on_projection_features' is True, the `keep_prob` value used
+ in the dropout operation.
+ - concat_channels: Integer, the concatenated features will be
+ channel-reduced to `concat_channels` channels.
+ - conv_rate_multiplier: Integer, used to multiply the convolution rates.
+ This is useful in the case when the output_stride is changed from 16
+ to 8, we need to double the convolution rates correspondingly.
+ """
+ return {
+ 'reduction_size': 256,
+ 'dropout_on_concat_features': True,
+ 'dropout_on_projection_features': False,
+ 'dropout_keep_prob': 0.9,
+ 'concat_channels': 256,
+ 'conv_rate_multiplier': 1,
+ }
+
+
+class DensePredictionCell(object):
+ """DensePredictionCell class used as a 'layer' in semantic segmentation."""
+
+ def __init__(self, config, hparams=None):
+ """Initializes the dense prediction cell.
+
+ Args:
+ config: A dictionary storing the architecture of a dense prediction cell.
+ hparams: A dictionary of hyper-parameters, provided by users. This
+ dictionary will be used to update the default dictionary returned by
+ dense_prediction_cell_hparams().
+
+ Raises:
+ ValueError: If `conv_rate_multiplier` has value < 1.
+ """
+ self.hparams = dense_prediction_cell_hparams()
+ if hparams is not None:
+ self.hparams.update(hparams)
+ self.config = config
+
+ # Check values in hparams are valid or not.
+ if self.hparams['conv_rate_multiplier'] < 1:
+ raise ValueError('conv_rate_multiplier cannot have value < 1.')
+
+ def _get_pyramid_pooling_arguments(
+ self, crop_size, output_stride, image_grid, image_pooling_crop_size=None):
+ """Gets arguments for pyramid pooling.
+
+ Args:
+ crop_size: A list of two integers, [crop_height, crop_width] specifying
+ whole patch crop size.
+ output_stride: Integer, output stride value for extracted features.
+ image_grid: A list of two integers, [image_grid_height, image_grid_width],
+ specifying the grid size of how the pyramid pooling will be performed.
+ image_pooling_crop_size: A list of two integers, [crop_height, crop_width]
+ specifying the crop size for image pooling operations. Note that we
+ decouple whole patch crop_size and image_pooling_crop_size as one could
+ perform the image_pooling with different crop sizes.
+
+ Returns:
+ A list of (resize_value, pooled_kernel)
+ """
+ resize_height = utils.scale_dimension(crop_size[0], 1. / output_stride)
+ resize_width = utils.scale_dimension(crop_size[1], 1. / output_stride)
+ # If image_pooling_crop_size is not specified, use crop_size.
+ if image_pooling_crop_size is None:
+ image_pooling_crop_size = crop_size
+ pooled_height = utils.scale_dimension(
+ image_pooling_crop_size[0], 1. / (output_stride * image_grid[0]))
+ pooled_width = utils.scale_dimension(
+ image_pooling_crop_size[1], 1. / (output_stride * image_grid[1]))
+ return ([resize_height, resize_width], [pooled_height, pooled_width])
+
+ def _parse_operation(self, config, crop_size, output_stride,
+ image_pooling_crop_size=None):
+ """Parses one operation.
+
+ When 'operation' is 'pyramid_pooling', we compute the required
+ hyper-parameters and save in config.
+
+ Args:
+ config: A dictionary storing required hyper-parameters for one
+ operation.
+ crop_size: A list of two integers, [crop_height, crop_width] specifying
+ whole patch crop size.
+ output_stride: Integer, output stride value for extracted features.
+ image_pooling_crop_size: A list of two integers, [crop_height, crop_width]
+ specifying the crop size for image pooling operations. Note that we
+ decouple whole patch crop_size and image_pooling_crop_size as one could
+ perform the image_pooling with different crop sizes.
+
+ Returns:
+ A dictionary stores the related information for the operation.
+ """
+ if config[_OP] == _PYRAMID_POOLING:
+ (config[_TARGET_SIZE],
+ config[_KERNEL]) = self._get_pyramid_pooling_arguments(
+ crop_size=crop_size,
+ output_stride=output_stride,
+ image_grid=config[_GRID_SIZE],
+ image_pooling_crop_size=image_pooling_crop_size)
+
+ return config
+
+ def build_cell(self,
+ features,
+ output_stride=16,
+ crop_size=None,
+ image_pooling_crop_size=None,
+ weight_decay=0.00004,
+ reuse=None,
+ is_training=False,
+ fine_tune_batch_norm=False,
+ scope=None):
+ """Builds the dense prediction cell based on the config.
+
+ Args:
+ features: Input feature map of size [batch, height, width, channels].
+ output_stride: Int, output stride at which the features were extracted.
+ crop_size: A list [crop_height, crop_width], determining the input
+ features resolution.
+ image_pooling_crop_size: A list of two integers, [crop_height, crop_width]
+ specifying the crop size for image pooling operations. Note that we
+ decouple whole patch crop_size and image_pooling_crop_size as one could
+ perform the image_pooling with different crop sizes.
+ weight_decay: Float, the weight decay for model variables.
+ reuse: Reuse the model variables or not.
+ is_training: Boolean, is training or not.
+ fine_tune_batch_norm: Boolean, fine-tuning batch norm parameters or not.
+ scope: Optional string, specifying the variable scope.
+
+ Returns:
+ Features after passing through the constructed dense prediction cell with
+ shape = [batch, height, width, channels] where channels are determined
+ by `reduction_size` returned by dense_prediction_cell_hparams().
+
+ Raises:
+ ValueError: Use Convolution with kernel size not equal to 1x1 or 3x3 or
+ the operation is not recognized.
+ """
+ batch_norm_params = {
+ 'is_training': is_training and fine_tune_batch_norm,
+ 'decay': 0.9997,
+ 'epsilon': 1e-5,
+ 'scale': True,
+ }
+ hparams = self.hparams
+ with slim.arg_scope(
+ [slim.conv2d, slim.separable_conv2d],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,
+ normalizer_fn=slim.batch_norm,
+ padding='SAME',
+ stride=1,
+ reuse=reuse):
+ with slim.arg_scope([slim.batch_norm], **batch_norm_params):
+ with tf.variable_scope(scope, _META_ARCHITECTURE_SCOPE, [features]):
+ depth = hparams['reduction_size']
+ branch_logits = []
+ for i, current_config in enumerate(self.config):
+ scope = 'branch%d' % i
+ current_config = self._parse_operation(
+ config=current_config,
+ crop_size=crop_size,
+ output_stride=output_stride,
+ image_pooling_crop_size=image_pooling_crop_size)
+ tf.logging.info(current_config)
+ if current_config[_INPUT] < 0:
+ operation_input = features
+ else:
+ operation_input = branch_logits[current_config[_INPUT]]
+ if current_config[_OP] == _CONV:
+ if current_config[_KERNEL] == [1, 1] or current_config[
+ _KERNEL] == 1:
+ branch_logits.append(
+ slim.conv2d(operation_input, depth, 1, scope=scope))
+ else:
+ conv_rate = [r * hparams['conv_rate_multiplier']
+ for r in current_config[_RATE]]
+ branch_logits.append(
+ utils.split_separable_conv2d(
+ operation_input,
+ filters=depth,
+ kernel_size=current_config[_KERNEL],
+ rate=conv_rate,
+ weight_decay=weight_decay,
+ scope=scope))
+ elif current_config[_OP] == _PYRAMID_POOLING:
+ pooled_features = slim.avg_pool2d(
+ operation_input,
+ kernel_size=current_config[_KERNEL],
+ stride=[1, 1],
+ padding='VALID')
+ pooled_features = slim.conv2d(
+ pooled_features,
+ depth,
+ 1,
+ scope=scope)
+ pooled_features = tf.image.resize_bilinear(
+ pooled_features,
+ current_config[_TARGET_SIZE],
+ align_corners=True)
+ # Set shape for resize_height/resize_width if they are not Tensor.
+ resize_height = current_config[_TARGET_SIZE][0]
+ resize_width = current_config[_TARGET_SIZE][1]
+ if isinstance(resize_height, tf.Tensor):
+ resize_height = None
+ if isinstance(resize_width, tf.Tensor):
+ resize_width = None
+ pooled_features.set_shape(
+ [None, resize_height, resize_width, depth])
+ branch_logits.append(pooled_features)
+ else:
+ raise ValueError('Unrecognized operation.')
+ # Merge branch logits.
+ concat_logits = tf.concat(branch_logits, 3)
+ if self.hparams['dropout_on_concat_features']:
+ concat_logits = slim.dropout(
+ concat_logits,
+ keep_prob=self.hparams['dropout_keep_prob'],
+ is_training=is_training,
+ scope=_CONCAT_PROJECTION_SCOPE + '_dropout')
+ concat_logits = slim.conv2d(concat_logits,
+ self.hparams['concat_channels'],
+ 1,
+ scope=_CONCAT_PROJECTION_SCOPE)
+ if self.hparams['dropout_on_projection_features']:
+ concat_logits = slim.dropout(
+ concat_logits,
+ keep_prob=self.hparams['dropout_keep_prob'],
+ is_training=is_training,
+ scope=_CONCAT_PROJECTION_SCOPE + '_dropout')
+ return concat_logits
diff --git a/models/research/deeplab/core/dense_prediction_cell_branch5_top1_cityscapes.json b/models/research/deeplab/core/dense_prediction_cell_branch5_top1_cityscapes.json
new file mode 100644
index 0000000000000000000000000000000000000000..12b093d07d1a696258cae7eaf4d793978433a69f
--- /dev/null
+++ b/models/research/deeplab/core/dense_prediction_cell_branch5_top1_cityscapes.json
@@ -0,0 +1 @@
+[{"kernel": 3, "rate": [1, 6], "op": "conv", "input": -1}, {"kernel": 3, "rate": [18, 15], "op": "conv", "input": 0}, {"kernel": 3, "rate": [6, 3], "op": "conv", "input": 1}, {"kernel": 3, "rate": [1, 1], "op": "conv", "input": 0}, {"kernel": 3, "rate": [6, 21], "op": "conv", "input": 0}]
\ No newline at end of file
diff --git a/models/research/deeplab/core/dense_prediction_cell_test.py b/models/research/deeplab/core/dense_prediction_cell_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1396a73626d90c5da1db3b187c5a927a0c9eeb11
--- /dev/null
+++ b/models/research/deeplab/core/dense_prediction_cell_test.py
@@ -0,0 +1,136 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for dense_prediction_cell."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from deeplab.core import dense_prediction_cell
+
+
+class DensePredictionCellTest(tf.test.TestCase):
+
+ def setUp(self):
+ self.segmentation_layer = dense_prediction_cell.DensePredictionCell(
+ config=[
+ {
+ dense_prediction_cell._INPUT: -1,
+ dense_prediction_cell._OP: dense_prediction_cell._CONV,
+ dense_prediction_cell._KERNEL: 1,
+ },
+ {
+ dense_prediction_cell._INPUT: 0,
+ dense_prediction_cell._OP: dense_prediction_cell._CONV,
+ dense_prediction_cell._KERNEL: 3,
+ dense_prediction_cell._RATE: [1, 3],
+ },
+ {
+ dense_prediction_cell._INPUT: 1,
+ dense_prediction_cell._OP: (
+ dense_prediction_cell._PYRAMID_POOLING),
+ dense_prediction_cell._GRID_SIZE: [1, 2],
+ },
+ ],
+ hparams={'conv_rate_multiplier': 2})
+
+ def testPyramidPoolingArguments(self):
+ features_size, pooled_kernel = (
+ self.segmentation_layer._get_pyramid_pooling_arguments(
+ crop_size=[513, 513],
+ output_stride=16,
+ image_grid=[4, 4]))
+ self.assertListEqual(features_size, [33, 33])
+ self.assertListEqual(pooled_kernel, [9, 9])
+
+ def testPyramidPoolingArgumentsWithImageGrid1x1(self):
+ features_size, pooled_kernel = (
+ self.segmentation_layer._get_pyramid_pooling_arguments(
+ crop_size=[257, 257],
+ output_stride=16,
+ image_grid=[1, 1]))
+ self.assertListEqual(features_size, [17, 17])
+ self.assertListEqual(pooled_kernel, [17, 17])
+
+ def testParseOperationStringWithConv1x1(self):
+ operation = self.segmentation_layer._parse_operation(
+ config={
+ dense_prediction_cell._OP: dense_prediction_cell._CONV,
+ dense_prediction_cell._KERNEL: [1, 1],
+ },
+ crop_size=[513, 513], output_stride=16)
+ self.assertEqual(operation[dense_prediction_cell._OP],
+ dense_prediction_cell._CONV)
+ self.assertListEqual(operation[dense_prediction_cell._KERNEL], [1, 1])
+
+ def testParseOperationStringWithConv3x3(self):
+ operation = self.segmentation_layer._parse_operation(
+ config={
+ dense_prediction_cell._OP: dense_prediction_cell._CONV,
+ dense_prediction_cell._KERNEL: [3, 3],
+ dense_prediction_cell._RATE: [9, 6],
+ },
+ crop_size=[513, 513], output_stride=16)
+ self.assertEqual(operation[dense_prediction_cell._OP],
+ dense_prediction_cell._CONV)
+ self.assertListEqual(operation[dense_prediction_cell._KERNEL], [3, 3])
+ self.assertEqual(operation[dense_prediction_cell._RATE], [9, 6])
+
+ def testParseOperationStringWithPyramidPooling2x2(self):
+ operation = self.segmentation_layer._parse_operation(
+ config={
+ dense_prediction_cell._OP: dense_prediction_cell._PYRAMID_POOLING,
+ dense_prediction_cell._GRID_SIZE: [2, 2],
+ },
+ crop_size=[513, 513],
+ output_stride=16)
+ self.assertEqual(operation[dense_prediction_cell._OP],
+ dense_prediction_cell._PYRAMID_POOLING)
+ # The feature maps of size [33, 33] should be covered by 2x2 kernels with
+ # size [17, 17].
+ self.assertListEqual(
+ operation[dense_prediction_cell._TARGET_SIZE], [33, 33])
+ self.assertListEqual(operation[dense_prediction_cell._KERNEL], [17, 17])
+
+ def testBuildCell(self):
+ with self.test_session(graph=tf.Graph()) as sess:
+ features = tf.random_normal([2, 33, 33, 5])
+ concat_logits = self.segmentation_layer.build_cell(
+ features,
+ output_stride=8,
+ crop_size=[257, 257])
+ sess.run(tf.global_variables_initializer())
+ concat_logits = sess.run(concat_logits)
+ self.assertTrue(concat_logits.any())
+
+ def testBuildCellWithImagePoolingCropSize(self):
+ with self.test_session(graph=tf.Graph()) as sess:
+ features = tf.random_normal([2, 33, 33, 5])
+ concat_logits = self.segmentation_layer.build_cell(
+ features,
+ output_stride=8,
+ crop_size=[257, 257],
+ image_pooling_crop_size=[129, 129])
+ sess.run(tf.global_variables_initializer())
+ concat_logits = sess.run(concat_logits)
+ self.assertTrue(concat_logits.any())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/core/feature_extractor.py b/models/research/deeplab/core/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..553bd9b6a7393dd7f3e0ebce80302919215a9bfe
--- /dev/null
+++ b/models/research/deeplab/core/feature_extractor.py
@@ -0,0 +1,711 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Extracts features for different models."""
+import copy
+import functools
+
+import tensorflow.compat.v1 as tf
+from tensorflow.contrib import slim as contrib_slim
+
+from deeplab.core import nas_network
+from deeplab.core import resnet_v1_beta
+from deeplab.core import xception
+from nets.mobilenet import conv_blocks
+from nets.mobilenet import mobilenet
+from nets.mobilenet import mobilenet_v2
+from nets.mobilenet import mobilenet_v3
+
+slim = contrib_slim
+
+# Default end point for MobileNetv2 (one-based indexing).
+_MOBILENET_V2_FINAL_ENDPOINT = 'layer_18'
+# Default end point for MobileNetv3.
+_MOBILENET_V3_LARGE_FINAL_ENDPOINT = 'layer_17'
+_MOBILENET_V3_SMALL_FINAL_ENDPOINT = 'layer_13'
+# Default end point for EdgeTPU Mobilenet.
+_MOBILENET_EDGETPU = 'layer_24'
+
+
+def _mobilenet_v2(net,
+ depth_multiplier,
+ output_stride,
+ conv_defs=None,
+ divisible_by=None,
+ reuse=None,
+ scope=None,
+ final_endpoint=None):
+ """Auxiliary function to add support for 'reuse' to mobilenet_v2.
+
+ Args:
+ net: Input tensor of shape [batch_size, height, width, channels].
+ depth_multiplier: Float multiplier for the depth (number of channels)
+ for all convolution ops. The value must be greater than zero. Typical
+ usage will be to set this value in (0, 1) to reduce the number of
+ parameters or computation cost of the model.
+ output_stride: An integer that specifies the requested ratio of input to
+ output spatial resolution. If not None, then we invoke atrous convolution
+ if necessary to prevent the network from reducing the spatial resolution
+ of the activation maps. Allowed values are 8 (accurate fully convolutional
+ mode), 16 (fast fully convolutional mode), 32 (classification mode).
+ conv_defs: MobileNet con def.
+ divisible_by: None (use default setting) or an integer that ensures all
+ layers # channels will be divisible by this number. Used in MobileNet.
+ reuse: Reuse model variables.
+ scope: Optional variable scope.
+ final_endpoint: The endpoint to construct the network up to.
+
+ Returns:
+ Features extracted by MobileNetv2.
+ """
+ if divisible_by is None:
+ divisible_by = 8 if depth_multiplier == 1.0 else 1
+ if conv_defs is None:
+ conv_defs = mobilenet_v2.V2_DEF
+ with tf.variable_scope(
+ scope, 'MobilenetV2', [net], reuse=reuse) as scope:
+ return mobilenet_v2.mobilenet_base(
+ net,
+ conv_defs=conv_defs,
+ depth_multiplier=depth_multiplier,
+ min_depth=8 if depth_multiplier == 1.0 else 1,
+ divisible_by=divisible_by,
+ final_endpoint=final_endpoint or _MOBILENET_V2_FINAL_ENDPOINT,
+ output_stride=output_stride,
+ scope=scope)
+
+
+def _mobilenet_v3(net,
+ depth_multiplier,
+ output_stride,
+ conv_defs=None,
+ divisible_by=None,
+ reuse=None,
+ scope=None,
+ final_endpoint=None):
+ """Auxiliary function to build mobilenet v3.
+
+ Args:
+ net: Input tensor of shape [batch_size, height, width, channels].
+ depth_multiplier: Float multiplier for the depth (number of channels)
+ for all convolution ops. The value must be greater than zero. Typical
+ usage will be to set this value in (0, 1) to reduce the number of
+ parameters or computation cost of the model.
+ output_stride: An integer that specifies the requested ratio of input to
+ output spatial resolution. If not None, then we invoke atrous convolution
+ if necessary to prevent the network from reducing the spatial resolution
+ of the activation maps. Allowed values are 8 (accurate fully convolutional
+ mode), 16 (fast fully convolutional mode), 32 (classification mode).
+ conv_defs: A list of ConvDef namedtuples specifying the net architecture.
+ divisible_by: None (use default setting) or an integer that ensures all
+ layers # channels will be divisible by this number. Used in MobileNet.
+ reuse: Reuse model variables.
+ scope: Optional variable scope.
+ final_endpoint: The endpoint to construct the network up to.
+
+ Returns:
+ net: The output tensor.
+ end_points: A set of activations for external use.
+
+ Raises:
+ ValueError: If conv_defs or final_endpoint is not specified.
+ """
+ del divisible_by
+ with tf.variable_scope(
+ scope, 'MobilenetV3', [net], reuse=reuse) as scope:
+ if conv_defs is None:
+ raise ValueError('conv_defs must be specified for mobilenet v3.')
+ if final_endpoint is None:
+ raise ValueError('Final endpoint must be specified for mobilenet v3.')
+ net, end_points = mobilenet_v3.mobilenet_base(
+ net,
+ depth_multiplier=depth_multiplier,
+ conv_defs=conv_defs,
+ output_stride=output_stride,
+ final_endpoint=final_endpoint,
+ scope=scope)
+
+ return net, end_points
+
+
+def mobilenet_v3_large_seg(net,
+ depth_multiplier,
+ output_stride,
+ divisible_by=None,
+ reuse=None,
+ scope=None,
+ final_endpoint=None):
+ """Final mobilenet v3 large model for segmentation task."""
+ del divisible_by
+ del final_endpoint
+ conv_defs = copy.deepcopy(mobilenet_v3.V3_LARGE)
+
+ # Reduce the filters by a factor of 2 in the last block.
+ for layer, expansion in [(13, 336), (14, 480), (15, 480), (16, None)]:
+ conv_defs['spec'][layer].params['num_outputs'] /= 2
+ # Update expansion size
+ if expansion is not None:
+ factor = expansion / conv_defs['spec'][layer - 1].params['num_outputs']
+ conv_defs['spec'][layer].params[
+ 'expansion_size'] = mobilenet_v3.expand_input(factor)
+
+ return _mobilenet_v3(
+ net,
+ depth_multiplier=depth_multiplier,
+ output_stride=output_stride,
+ divisible_by=8,
+ conv_defs=conv_defs,
+ reuse=reuse,
+ scope=scope,
+ final_endpoint=_MOBILENET_V3_LARGE_FINAL_ENDPOINT)
+
+
+def mobilenet_edgetpu(net,
+ depth_multiplier,
+ output_stride,
+ divisible_by=None,
+ reuse=None,
+ scope=None,
+ final_endpoint=None):
+ """EdgeTPU version of mobilenet model for segmentation task."""
+ del divisible_by
+ del final_endpoint
+ conv_defs = copy.deepcopy(mobilenet_v3.V3_EDGETPU)
+
+ return _mobilenet_v3(
+ net,
+ depth_multiplier=depth_multiplier,
+ output_stride=output_stride,
+ divisible_by=8,
+ conv_defs=conv_defs,
+ reuse=reuse,
+ scope=scope, # the scope is 'MobilenetEdgeTPU'
+ final_endpoint=_MOBILENET_EDGETPU)
+
+
+def mobilenet_v3_small_seg(net,
+ depth_multiplier,
+ output_stride,
+ divisible_by=None,
+ reuse=None,
+ scope=None,
+ final_endpoint=None):
+ """Final mobilenet v3 small model for segmentation task."""
+ del divisible_by
+ del final_endpoint
+ conv_defs = copy.deepcopy(mobilenet_v3.V3_SMALL)
+
+ # Reduce the filters by a factor of 2 in the last block.
+ for layer, expansion in [(9, 144), (10, 288), (11, 288), (12, None)]:
+ conv_defs['spec'][layer].params['num_outputs'] /= 2
+ # Update expansion size
+ if expansion is not None:
+ factor = expansion / conv_defs['spec'][layer - 1].params['num_outputs']
+ conv_defs['spec'][layer].params[
+ 'expansion_size'] = mobilenet_v3.expand_input(factor)
+
+ return _mobilenet_v3(
+ net,
+ depth_multiplier=depth_multiplier,
+ output_stride=output_stride,
+ divisible_by=8,
+ conv_defs=conv_defs,
+ reuse=reuse,
+ scope=scope,
+ final_endpoint=_MOBILENET_V3_SMALL_FINAL_ENDPOINT)
+
+
+# A map from network name to network function.
+networks_map = {
+ 'mobilenet_v2': _mobilenet_v2,
+ 'mobilenet_edgetpu': mobilenet_edgetpu,
+ 'mobilenet_v3_large_seg': mobilenet_v3_large_seg,
+ 'mobilenet_v3_small_seg': mobilenet_v3_small_seg,
+ 'resnet_v1_18': resnet_v1_beta.resnet_v1_18,
+ 'resnet_v1_18_beta': resnet_v1_beta.resnet_v1_18_beta,
+ 'resnet_v1_50': resnet_v1_beta.resnet_v1_50,
+ 'resnet_v1_50_beta': resnet_v1_beta.resnet_v1_50_beta,
+ 'resnet_v1_101': resnet_v1_beta.resnet_v1_101,
+ 'resnet_v1_101_beta': resnet_v1_beta.resnet_v1_101_beta,
+ 'xception_41': xception.xception_41,
+ 'xception_65': xception.xception_65,
+ 'xception_71': xception.xception_71,
+ 'nas_pnasnet': nas_network.pnasnet,
+ 'nas_hnasnet': nas_network.hnasnet,
+}
+
+
+def mobilenet_v2_arg_scope(is_training=True,
+ weight_decay=0.00004,
+ stddev=0.09,
+ activation=tf.nn.relu6,
+ bn_decay=0.997,
+ bn_epsilon=None,
+ bn_renorm=None):
+ """Defines the default MobilenetV2 arg scope.
+
+ Args:
+ is_training: Whether or not we're training the model. If this is set to None
+ is_training parameter in batch_norm is not set. Please note that this also
+ sets the is_training parameter in dropout to None.
+ weight_decay: The weight decay to use for regularizing the model.
+ stddev: Standard deviation for initialization, if negative uses xavier.
+ activation: If True, a modified activation is used (initialized ~ReLU6).
+ bn_decay: decay for the batch norm moving averages.
+ bn_epsilon: batch normalization epsilon.
+ bn_renorm: whether to use batchnorm renormalization
+
+ Returns:
+ An `arg_scope` to use for the mobilenet v1 model.
+ """
+ batch_norm_params = {
+ 'center': True,
+ 'scale': True,
+ 'decay': bn_decay,
+ }
+ if bn_epsilon is not None:
+ batch_norm_params['epsilon'] = bn_epsilon
+ if is_training is not None:
+ batch_norm_params['is_training'] = is_training
+ if bn_renorm is not None:
+ batch_norm_params['renorm'] = bn_renorm
+ dropout_params = {}
+ if is_training is not None:
+ dropout_params['is_training'] = is_training
+
+ instance_norm_params = {
+ 'center': True,
+ 'scale': True,
+ 'epsilon': 0.001,
+ }
+
+ if stddev < 0:
+ weight_intitializer = slim.initializers.xavier_initializer()
+ else:
+ weight_intitializer = tf.truncated_normal_initializer(stddev=stddev)
+
+ # Set weight_decay for weights in Conv and FC layers.
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected, slim.separable_conv2d],
+ weights_initializer=weight_intitializer,
+ activation_fn=activation,
+ normalizer_fn=slim.batch_norm), \
+ slim.arg_scope(
+ [conv_blocks.expanded_conv], normalizer_fn=slim.batch_norm), \
+ slim.arg_scope([mobilenet.apply_activation], activation_fn=activation),\
+ slim.arg_scope([slim.batch_norm], **batch_norm_params), \
+ slim.arg_scope([mobilenet.mobilenet_base, mobilenet.mobilenet],
+ is_training=is_training),\
+ slim.arg_scope([slim.dropout], **dropout_params), \
+ slim.arg_scope([slim.instance_norm], **instance_norm_params), \
+ slim.arg_scope([slim.conv2d], \
+ weights_regularizer=slim.l2_regularizer(weight_decay)), \
+ slim.arg_scope([slim.separable_conv2d], weights_regularizer=None), \
+ slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME') as s:
+ return s
+
+
+# A map from network name to network arg scope.
+arg_scopes_map = {
+ 'mobilenet_v2': mobilenet_v2.training_scope,
+ 'mobilenet_edgetpu': mobilenet_v2_arg_scope,
+ 'mobilenet_v3_large_seg': mobilenet_v2_arg_scope,
+ 'mobilenet_v3_small_seg': mobilenet_v2_arg_scope,
+ 'resnet_v1_18': resnet_v1_beta.resnet_arg_scope,
+ 'resnet_v1_18_beta': resnet_v1_beta.resnet_arg_scope,
+ 'resnet_v1_50': resnet_v1_beta.resnet_arg_scope,
+ 'resnet_v1_50_beta': resnet_v1_beta.resnet_arg_scope,
+ 'resnet_v1_101': resnet_v1_beta.resnet_arg_scope,
+ 'resnet_v1_101_beta': resnet_v1_beta.resnet_arg_scope,
+ 'xception_41': xception.xception_arg_scope,
+ 'xception_65': xception.xception_arg_scope,
+ 'xception_71': xception.xception_arg_scope,
+ 'nas_pnasnet': nas_network.nas_arg_scope,
+ 'nas_hnasnet': nas_network.nas_arg_scope,
+}
+
+# Names for end point features.
+DECODER_END_POINTS = 'decoder_end_points'
+
+# A dictionary from network name to a map of end point features.
+networks_to_feature_maps = {
+ 'mobilenet_v2': {
+ DECODER_END_POINTS: {
+ 4: ['layer_4/depthwise_output'],
+ 8: ['layer_7/depthwise_output'],
+ 16: ['layer_14/depthwise_output'],
+ },
+ },
+ 'mobilenet_v3_large_seg': {
+ DECODER_END_POINTS: {
+ 4: ['layer_4/depthwise_output'],
+ 8: ['layer_7/depthwise_output'],
+ 16: ['layer_13/depthwise_output'],
+ },
+ },
+ 'mobilenet_v3_small_seg': {
+ DECODER_END_POINTS: {
+ 4: ['layer_2/depthwise_output'],
+ 8: ['layer_4/depthwise_output'],
+ 16: ['layer_9/depthwise_output'],
+ },
+ },
+ 'resnet_v1_18': {
+ DECODER_END_POINTS: {
+ 4: ['block1/unit_1/lite_bottleneck_v1/conv2'],
+ 8: ['block2/unit_1/lite_bottleneck_v1/conv2'],
+ 16: ['block3/unit_1/lite_bottleneck_v1/conv2'],
+ },
+ },
+ 'resnet_v1_18_beta': {
+ DECODER_END_POINTS: {
+ 4: ['block1/unit_1/lite_bottleneck_v1/conv2'],
+ 8: ['block2/unit_1/lite_bottleneck_v1/conv2'],
+ 16: ['block3/unit_1/lite_bottleneck_v1/conv2'],
+ },
+ },
+ 'resnet_v1_50': {
+ DECODER_END_POINTS: {
+ 4: ['block1/unit_2/bottleneck_v1/conv3'],
+ 8: ['block2/unit_3/bottleneck_v1/conv3'],
+ 16: ['block3/unit_5/bottleneck_v1/conv3'],
+ },
+ },
+ 'resnet_v1_50_beta': {
+ DECODER_END_POINTS: {
+ 4: ['block1/unit_2/bottleneck_v1/conv3'],
+ 8: ['block2/unit_3/bottleneck_v1/conv3'],
+ 16: ['block3/unit_5/bottleneck_v1/conv3'],
+ },
+ },
+ 'resnet_v1_101': {
+ DECODER_END_POINTS: {
+ 4: ['block1/unit_2/bottleneck_v1/conv3'],
+ 8: ['block2/unit_3/bottleneck_v1/conv3'],
+ 16: ['block3/unit_22/bottleneck_v1/conv3'],
+ },
+ },
+ 'resnet_v1_101_beta': {
+ DECODER_END_POINTS: {
+ 4: ['block1/unit_2/bottleneck_v1/conv3'],
+ 8: ['block2/unit_3/bottleneck_v1/conv3'],
+ 16: ['block3/unit_22/bottleneck_v1/conv3'],
+ },
+ },
+ 'xception_41': {
+ DECODER_END_POINTS: {
+ 4: ['entry_flow/block2/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ 8: ['entry_flow/block3/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ 16: ['exit_flow/block1/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ },
+ },
+ 'xception_65': {
+ DECODER_END_POINTS: {
+ 4: ['entry_flow/block2/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ 8: ['entry_flow/block3/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ 16: ['exit_flow/block1/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ },
+ },
+ 'xception_71': {
+ DECODER_END_POINTS: {
+ 4: ['entry_flow/block3/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ 8: ['entry_flow/block5/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ 16: ['exit_flow/block1/unit_1/xception_module/'
+ 'separable_conv2_pointwise'],
+ },
+ },
+ 'nas_pnasnet': {
+ DECODER_END_POINTS: {
+ 4: ['Stem'],
+ 8: ['Cell_3'],
+ 16: ['Cell_7'],
+ },
+ },
+ 'nas_hnasnet': {
+ DECODER_END_POINTS: {
+ 4: ['Cell_2'],
+ 8: ['Cell_5'],
+ 16: ['Cell_7'],
+ },
+ },
+}
+
+# A map from feature extractor name to the network name scope used in the
+# ImageNet pretrained versions of these models.
+name_scope = {
+ 'mobilenet_v2': 'MobilenetV2',
+ 'mobilenet_edgetpu': 'MobilenetEdgeTPU',
+ 'mobilenet_v3_large_seg': 'MobilenetV3',
+ 'mobilenet_v3_small_seg': 'MobilenetV3',
+ 'resnet_v1_18': 'resnet_v1_18',
+ 'resnet_v1_18_beta': 'resnet_v1_18',
+ 'resnet_v1_50': 'resnet_v1_50',
+ 'resnet_v1_50_beta': 'resnet_v1_50',
+ 'resnet_v1_101': 'resnet_v1_101',
+ 'resnet_v1_101_beta': 'resnet_v1_101',
+ 'xception_41': 'xception_41',
+ 'xception_65': 'xception_65',
+ 'xception_71': 'xception_71',
+ 'nas_pnasnet': 'pnasnet',
+ 'nas_hnasnet': 'hnasnet',
+}
+
+# Mean pixel value.
+_MEAN_RGB = [123.15, 115.90, 103.06]
+
+
+def _preprocess_subtract_imagenet_mean(inputs, dtype=tf.float32):
+ """Subtract Imagenet mean RGB value."""
+ mean_rgb = tf.reshape(_MEAN_RGB, [1, 1, 1, 3])
+ num_channels = tf.shape(inputs)[-1]
+ # We set mean pixel as 0 for the non-RGB channels.
+ mean_rgb_extended = tf.concat(
+ [mean_rgb, tf.zeros([1, 1, 1, num_channels - 3])], axis=3)
+ return tf.cast(inputs - mean_rgb_extended, dtype=dtype)
+
+
+def _preprocess_zero_mean_unit_range(inputs, dtype=tf.float32):
+ """Map image values from [0, 255] to [-1, 1]."""
+ preprocessed_inputs = (2.0 / 255.0) * tf.to_float(inputs) - 1.0
+ return tf.cast(preprocessed_inputs, dtype=dtype)
+
+
+_PREPROCESS_FN = {
+ 'mobilenet_v2': _preprocess_zero_mean_unit_range,
+ 'mobilenet_edgetpu': _preprocess_zero_mean_unit_range,
+ 'mobilenet_v3_large_seg': _preprocess_zero_mean_unit_range,
+ 'mobilenet_v3_small_seg': _preprocess_zero_mean_unit_range,
+ 'resnet_v1_18': _preprocess_subtract_imagenet_mean,
+ 'resnet_v1_18_beta': _preprocess_zero_mean_unit_range,
+ 'resnet_v1_50': _preprocess_subtract_imagenet_mean,
+ 'resnet_v1_50_beta': _preprocess_zero_mean_unit_range,
+ 'resnet_v1_101': _preprocess_subtract_imagenet_mean,
+ 'resnet_v1_101_beta': _preprocess_zero_mean_unit_range,
+ 'xception_41': _preprocess_zero_mean_unit_range,
+ 'xception_65': _preprocess_zero_mean_unit_range,
+ 'xception_71': _preprocess_zero_mean_unit_range,
+ 'nas_pnasnet': _preprocess_zero_mean_unit_range,
+ 'nas_hnasnet': _preprocess_zero_mean_unit_range,
+}
+
+
+def mean_pixel(model_variant=None):
+ """Gets mean pixel value.
+
+ This function returns different mean pixel value, depending on the input
+ model_variant which adopts different preprocessing functions. We currently
+ handle the following preprocessing functions:
+ (1) _preprocess_subtract_imagenet_mean. We simply return mean pixel value.
+ (2) _preprocess_zero_mean_unit_range. We return [127.5, 127.5, 127.5].
+ The return values are used in a way that the padded regions after
+ pre-processing will contain value 0.
+
+ Args:
+ model_variant: Model variant (string) for feature extraction. For
+ backwards compatibility, model_variant=None returns _MEAN_RGB.
+
+ Returns:
+ Mean pixel value.
+ """
+ if model_variant in ['resnet_v1_50',
+ 'resnet_v1_101'] or model_variant is None:
+ return _MEAN_RGB
+ else:
+ return [127.5, 127.5, 127.5]
+
+
+def extract_features(images,
+ output_stride=8,
+ multi_grid=None,
+ depth_multiplier=1.0,
+ divisible_by=None,
+ final_endpoint=None,
+ model_variant=None,
+ weight_decay=0.0001,
+ reuse=None,
+ is_training=False,
+ fine_tune_batch_norm=False,
+ regularize_depthwise=False,
+ preprocess_images=True,
+ preprocessed_images_dtype=tf.float32,
+ num_classes=None,
+ global_pool=False,
+ nas_architecture_options=None,
+ nas_training_hyper_parameters=None,
+ use_bounded_activation=False):
+ """Extracts features by the particular model_variant.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ output_stride: The ratio of input to output spatial resolution.
+ multi_grid: Employ a hierarchy of different atrous rates within network.
+ depth_multiplier: Float multiplier for the depth (number of channels)
+ for all convolution ops used in MobileNet.
+ divisible_by: None (use default setting) or an integer that ensures all
+ layers # channels will be divisible by this number. Used in MobileNet.
+ final_endpoint: The MobileNet endpoint to construct the network up to.
+ model_variant: Model variant for feature extraction.
+ weight_decay: The weight decay for model variables.
+ reuse: Reuse the model variables or not.
+ is_training: Is training or not.
+ fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
+ regularize_depthwise: Whether or not apply L2-norm regularization on the
+ depthwise convolution weights.
+ preprocess_images: Performs preprocessing on images or not. Defaults to
+ True. Set to False if preprocessing will be done by other functions. We
+ supprot two types of preprocessing: (1) Mean pixel substraction and (2)
+ Pixel values normalization to be [-1, 1].
+ preprocessed_images_dtype: The type after the preprocessing function.
+ num_classes: Number of classes for image classification task. Defaults
+ to None for dense prediction tasks.
+ global_pool: Global pooling for image classification task. Defaults to
+ False, since dense prediction tasks do not use this.
+ nas_architecture_options: A dictionary storing NAS architecture options.
+ It is either None or its kerys are:
+ - `nas_stem_output_num_conv_filters`: Number of filters of the NAS stem
+ output tensor.
+ - `nas_use_classification_head`: Boolean, use image classification head.
+ nas_training_hyper_parameters: A dictionary storing hyper-parameters for
+ training nas models. It is either None or its keys are:
+ - `drop_path_keep_prob`: Probability to keep each path in the cell when
+ training.
+ - `total_training_steps`: Total training steps to help drop path
+ probability calculation.
+ use_bounded_activation: Whether or not to use bounded activations. Bounded
+ activations better lend themselves to quantized inference. Currently,
+ bounded activation is only used in xception model.
+
+ Returns:
+ features: A tensor of size [batch, feature_height, feature_width,
+ feature_channels], where feature_height/feature_width are determined
+ by the images height/width and output_stride.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: Unrecognized model variant.
+ """
+ if 'resnet' in model_variant:
+ arg_scope = arg_scopes_map[model_variant](
+ weight_decay=weight_decay,
+ batch_norm_decay=0.95,
+ batch_norm_epsilon=1e-5,
+ batch_norm_scale=True)
+ features, end_points = get_network(
+ model_variant, preprocess_images, preprocessed_images_dtype, arg_scope)(
+ inputs=images,
+ num_classes=num_classes,
+ is_training=(is_training and fine_tune_batch_norm),
+ global_pool=global_pool,
+ output_stride=output_stride,
+ multi_grid=multi_grid,
+ reuse=reuse,
+ scope=name_scope[model_variant])
+ elif 'xception' in model_variant:
+ arg_scope = arg_scopes_map[model_variant](
+ weight_decay=weight_decay,
+ batch_norm_decay=0.9997,
+ batch_norm_epsilon=1e-3,
+ batch_norm_scale=True,
+ regularize_depthwise=regularize_depthwise,
+ use_bounded_activation=use_bounded_activation)
+ features, end_points = get_network(
+ model_variant, preprocess_images, preprocessed_images_dtype, arg_scope)(
+ inputs=images,
+ num_classes=num_classes,
+ is_training=(is_training and fine_tune_batch_norm),
+ global_pool=global_pool,
+ output_stride=output_stride,
+ regularize_depthwise=regularize_depthwise,
+ multi_grid=multi_grid,
+ reuse=reuse,
+ scope=name_scope[model_variant])
+ elif 'mobilenet' in model_variant or model_variant.startswith('mnas'):
+ arg_scope = arg_scopes_map[model_variant](
+ is_training=(is_training and fine_tune_batch_norm),
+ weight_decay=weight_decay)
+ features, end_points = get_network(
+ model_variant, preprocess_images, preprocessed_images_dtype, arg_scope)(
+ inputs=images,
+ depth_multiplier=depth_multiplier,
+ divisible_by=divisible_by,
+ output_stride=output_stride,
+ reuse=reuse,
+ scope=name_scope[model_variant],
+ final_endpoint=final_endpoint)
+ elif model_variant.startswith('nas'):
+ arg_scope = arg_scopes_map[model_variant](
+ weight_decay=weight_decay,
+ batch_norm_decay=0.9997,
+ batch_norm_epsilon=1e-3)
+ features, end_points = get_network(
+ model_variant, preprocess_images, preprocessed_images_dtype, arg_scope)(
+ inputs=images,
+ num_classes=num_classes,
+ is_training=(is_training and fine_tune_batch_norm),
+ global_pool=global_pool,
+ output_stride=output_stride,
+ nas_architecture_options=nas_architecture_options,
+ nas_training_hyper_parameters=nas_training_hyper_parameters,
+ reuse=reuse,
+ scope=name_scope[model_variant])
+ else:
+ raise ValueError('Unknown model variant %s.' % model_variant)
+
+ return features, end_points
+
+
+def get_network(network_name, preprocess_images,
+ preprocessed_images_dtype=tf.float32, arg_scope=None):
+ """Gets the network.
+
+ Args:
+ network_name: Network name.
+ preprocess_images: Preprocesses the images or not.
+ preprocessed_images_dtype: The type after the preprocessing function.
+ arg_scope: Optional, arg_scope to build the network. If not provided the
+ default arg_scope of the network would be used.
+
+ Returns:
+ A network function that is used to extract features.
+
+ Raises:
+ ValueError: network is not supported.
+ """
+ if network_name not in networks_map:
+ raise ValueError('Unsupported network %s.' % network_name)
+ arg_scope = arg_scope or arg_scopes_map[network_name]()
+ def _identity_function(inputs, dtype=preprocessed_images_dtype):
+ return tf.cast(inputs, dtype=dtype)
+ if preprocess_images:
+ preprocess_function = _PREPROCESS_FN[network_name]
+ else:
+ preprocess_function = _identity_function
+ func = networks_map[network_name]
+ @functools.wraps(func)
+ def network_fn(inputs, *args, **kwargs):
+ with slim.arg_scope(arg_scope):
+ return func(preprocess_function(inputs, preprocessed_images_dtype),
+ *args, **kwargs)
+ return network_fn
diff --git a/models/research/deeplab/core/nas_cell.py b/models/research/deeplab/core/nas_cell.py
new file mode 100644
index 0000000000000000000000000000000000000000..d179082dc72b6692e96289ef9ed6964165023c33
--- /dev/null
+++ b/models/research/deeplab/core/nas_cell.py
@@ -0,0 +1,221 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Cell structure used by NAS."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+from six.moves import range
+from six.moves import zip
+import tensorflow as tf
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import slim as contrib_slim
+from deeplab.core import xception as xception_utils
+from deeplab.core.utils import resize_bilinear
+from deeplab.core.utils import scale_dimension
+from tensorflow.contrib.slim.nets import resnet_utils
+
+arg_scope = contrib_framework.arg_scope
+slim = contrib_slim
+
+separable_conv2d_same = functools.partial(xception_utils.separable_conv2d_same,
+ regularize_depthwise=True)
+
+
+class NASBaseCell(object):
+ """NASNet Cell class that is used as a 'layer' in image architectures."""
+
+ def __init__(self, num_conv_filters, operations, used_hiddenstates,
+ hiddenstate_indices, drop_path_keep_prob, total_num_cells,
+ total_training_steps, batch_norm_fn=slim.batch_norm):
+ """Init function.
+
+ For more details about NAS cell, see
+ https://arxiv.org/abs/1707.07012 and https://arxiv.org/abs/1712.00559.
+
+ Args:
+ num_conv_filters: The number of filters for each convolution operation.
+ operations: List of operations that are performed in the NASNet Cell in
+ order.
+ used_hiddenstates: Binary array that signals if the hiddenstate was used
+ within the cell. This is used to determine what outputs of the cell
+ should be concatenated together.
+ hiddenstate_indices: Determines what hiddenstates should be combined
+ together with the specified operations to create the NASNet cell.
+ drop_path_keep_prob: Float, drop path keep probability.
+ total_num_cells: Integer, total number of cells.
+ total_training_steps: Integer, total training steps.
+ batch_norm_fn: Function, batch norm function. Defaults to
+ slim.batch_norm.
+ """
+ if len(hiddenstate_indices) != len(operations):
+ raise ValueError(
+ 'Number of hiddenstate_indices and operations should be the same.')
+ if len(operations) % 2:
+ raise ValueError('Number of operations should be even.')
+ self._num_conv_filters = num_conv_filters
+ self._operations = operations
+ self._used_hiddenstates = used_hiddenstates
+ self._hiddenstate_indices = hiddenstate_indices
+ self._drop_path_keep_prob = drop_path_keep_prob
+ self._total_num_cells = total_num_cells
+ self._total_training_steps = total_training_steps
+ self._batch_norm_fn = batch_norm_fn
+
+ def __call__(self, net, scope, filter_scaling, stride, prev_layer, cell_num):
+ """Runs the conv cell."""
+ self._cell_num = cell_num
+ self._filter_scaling = filter_scaling
+ self._filter_size = int(self._num_conv_filters * filter_scaling)
+
+ with tf.variable_scope(scope):
+ net = self._cell_base(net, prev_layer)
+ for i in range(len(self._operations) // 2):
+ with tf.variable_scope('comb_iter_{}'.format(i)):
+ h1 = net[self._hiddenstate_indices[i * 2]]
+ h2 = net[self._hiddenstate_indices[i * 2 + 1]]
+ with tf.variable_scope('left'):
+ h1 = self._apply_conv_operation(
+ h1, self._operations[i * 2], stride,
+ self._hiddenstate_indices[i * 2] < 2)
+ with tf.variable_scope('right'):
+ h2 = self._apply_conv_operation(
+ h2, self._operations[i * 2 + 1], stride,
+ self._hiddenstate_indices[i * 2 + 1] < 2)
+ with tf.variable_scope('combine'):
+ h = h1 + h2
+ net.append(h)
+
+ with tf.variable_scope('cell_output'):
+ net = self._combine_unused_states(net)
+
+ return net
+
+ def _cell_base(self, net, prev_layer):
+ """Runs the beginning of the conv cell before the chosen ops are run."""
+ filter_size = self._filter_size
+
+ if prev_layer is None:
+ prev_layer = net
+ else:
+ if net.shape[2] != prev_layer.shape[2]:
+ prev_layer = resize_bilinear(
+ prev_layer, tf.shape(net)[1:3], prev_layer.dtype)
+ if filter_size != prev_layer.shape[3]:
+ prev_layer = tf.nn.relu(prev_layer)
+ prev_layer = slim.conv2d(prev_layer, filter_size, 1, scope='prev_1x1')
+ prev_layer = self._batch_norm_fn(prev_layer, scope='prev_bn')
+
+ net = tf.nn.relu(net)
+ net = slim.conv2d(net, filter_size, 1, scope='1x1')
+ net = self._batch_norm_fn(net, scope='beginning_bn')
+ net = tf.split(axis=3, num_or_size_splits=1, value=net)
+ net.append(prev_layer)
+ return net
+
+ def _apply_conv_operation(self, net, operation, stride,
+ is_from_original_input):
+ """Applies the predicted conv operation to net."""
+ if stride > 1 and not is_from_original_input:
+ stride = 1
+ input_filters = net.shape[3]
+ filter_size = self._filter_size
+ if 'separable' in operation:
+ num_layers = int(operation.split('_')[-1])
+ kernel_size = int(operation.split('x')[0][-1])
+ for layer_num in range(num_layers):
+ net = tf.nn.relu(net)
+ net = separable_conv2d_same(
+ net,
+ filter_size,
+ kernel_size,
+ depth_multiplier=1,
+ scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
+ stride=stride)
+ net = self._batch_norm_fn(
+ net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
+ stride = 1
+ elif 'atrous' in operation:
+ kernel_size = int(operation.split('x')[0][-1])
+ net = tf.nn.relu(net)
+ if stride == 2:
+ scaled_height = scale_dimension(tf.shape(net)[1], 0.5)
+ scaled_width = scale_dimension(tf.shape(net)[2], 0.5)
+ net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
+ net = resnet_utils.conv2d_same(
+ net, filter_size, kernel_size, rate=1, stride=1,
+ scope='atrous_{0}x{0}'.format(kernel_size))
+ else:
+ net = resnet_utils.conv2d_same(
+ net, filter_size, kernel_size, rate=2, stride=1,
+ scope='atrous_{0}x{0}'.format(kernel_size))
+ net = self._batch_norm_fn(net, scope='bn_atr_{0}x{0}'.format(kernel_size))
+ elif operation in ['none']:
+ if stride > 1 or (input_filters != filter_size):
+ net = tf.nn.relu(net)
+ net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
+ net = self._batch_norm_fn(net, scope='bn_1')
+ elif 'pool' in operation:
+ pooling_type = operation.split('_')[0]
+ pooling_shape = int(operation.split('_')[-1].split('x')[0])
+ if pooling_type == 'avg':
+ net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding='SAME')
+ elif pooling_type == 'max':
+ net = slim.max_pool2d(net, pooling_shape, stride=stride, padding='SAME')
+ else:
+ raise ValueError('Unimplemented pooling type: ', pooling_type)
+ if input_filters != filter_size:
+ net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
+ net = self._batch_norm_fn(net, scope='bn_1')
+ else:
+ raise ValueError('Unimplemented operation', operation)
+
+ if operation != 'none':
+ net = self._apply_drop_path(net)
+ return net
+
+ def _combine_unused_states(self, net):
+ """Concatenates the unused hidden states of the cell."""
+ used_hiddenstates = self._used_hiddenstates
+ states_to_combine = ([
+ h for h, is_used in zip(net, used_hiddenstates) if not is_used])
+ net = tf.concat(values=states_to_combine, axis=3)
+ return net
+
+ @contrib_framework.add_arg_scope
+ def _apply_drop_path(self, net):
+ """Apply drop_path regularization."""
+ drop_path_keep_prob = self._drop_path_keep_prob
+ if drop_path_keep_prob < 1.0:
+ # Scale keep prob by layer number.
+ assert self._cell_num != -1
+ layer_ratio = (self._cell_num + 1) / float(self._total_num_cells)
+ drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
+ # Decrease keep prob over time.
+ current_step = tf.cast(tf.train.get_or_create_global_step(), tf.float32)
+ current_ratio = tf.minimum(1.0, current_step / self._total_training_steps)
+ drop_path_keep_prob = (1 - current_ratio * (1 - drop_path_keep_prob))
+ # Drop path.
+ noise_shape = [tf.shape(net)[0], 1, 1, 1]
+ random_tensor = drop_path_keep_prob
+ random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32)
+ binary_tensor = tf.cast(tf.floor(random_tensor), net.dtype)
+ keep_prob_inv = tf.cast(1.0 / drop_path_keep_prob, net.dtype)
+ net = net * keep_prob_inv * binary_tensor
+ return net
diff --git a/models/research/deeplab/core/nas_genotypes.py b/models/research/deeplab/core/nas_genotypes.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e6dd55b450658e10acaa420a6cc31635817a8a
--- /dev/null
+++ b/models/research/deeplab/core/nas_genotypes.py
@@ -0,0 +1,45 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Genotypes used by NAS."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from tensorflow.contrib import slim as contrib_slim
+from deeplab.core import nas_cell
+
+slim = contrib_slim
+
+
+class PNASCell(nas_cell.NASBaseCell):
+ """Configuration and construction of the PNASNet-5 Cell."""
+
+ def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
+ total_training_steps, batch_norm_fn=slim.batch_norm):
+ # Name of operations: op_kernel-size_num-layers.
+ operations = [
+ 'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3',
+ 'separable_5x5_2', 'separable_3x3_2', 'separable_3x3_2', 'max_pool_3x3',
+ 'separable_3x3_2', 'none'
+ ]
+ used_hiddenstates = [1, 1, 0, 0, 0, 0, 0]
+ hiddenstate_indices = [1, 1, 0, 0, 0, 0, 4, 0, 1, 0]
+
+ super(PNASCell, self).__init__(
+ num_conv_filters, operations, used_hiddenstates, hiddenstate_indices,
+ drop_path_keep_prob, total_num_cells, total_training_steps,
+ batch_norm_fn)
diff --git a/models/research/deeplab/core/nas_network.py b/models/research/deeplab/core/nas_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..1da2e04dbaa5cfcb7db6f21266daf846000481fd
--- /dev/null
+++ b/models/research/deeplab/core/nas_network.py
@@ -0,0 +1,368 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Network structure used by NAS.
+
+Here we provide a few NAS backbones for semantic segmentation.
+Currently, we have
+
+1. pnasnet
+"Progressive Neural Architecture Search", Chenxi Liu, Barret Zoph,
+Maxim Neumann, Jonathon Shlens, Wei Hua, Li-Jia Li, Li Fei-Fei,
+Alan Yuille, Jonathan Huang, Kevin Murphy. In ECCV, 2018.
+
+2. hnasnet (also called Auto-DeepLab)
+"Auto-DeepLab: Hierarchical Neural Architecture Search for Semantic
+Image Segmentation", Chenxi Liu, Liang-Chieh Chen, Florian Schroff,
+Hartwig Adam, Wei Hua, Alan Yuille, Li Fei-Fei. In CVPR, 2019.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import range
+import tensorflow as tf
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import layers as contrib_layers
+from tensorflow.contrib import slim as contrib_slim
+from tensorflow.contrib import training as contrib_training
+
+from deeplab.core import nas_genotypes
+from deeplab.core import utils
+from deeplab.core.nas_cell import NASBaseCell
+from tensorflow.contrib.slim.nets import resnet_utils
+
+arg_scope = contrib_framework.arg_scope
+slim = contrib_slim
+resize_bilinear = utils.resize_bilinear
+scale_dimension = utils.scale_dimension
+
+
+def config(num_conv_filters=20,
+ total_training_steps=500000,
+ drop_path_keep_prob=1.0):
+ return contrib_training.HParams(
+ # Multiplier when spatial size is reduced by 2.
+ filter_scaling_rate=2.0,
+ # Number of filters of the stem output tensor.
+ num_conv_filters=num_conv_filters,
+ # Probability to keep each path in the cell when training.
+ drop_path_keep_prob=drop_path_keep_prob,
+ # Total training steps to help drop path probability calculation.
+ total_training_steps=total_training_steps,
+ )
+
+
+def nas_arg_scope(weight_decay=4e-5,
+ batch_norm_decay=0.9997,
+ batch_norm_epsilon=0.001,
+ sync_batch_norm_method='None'):
+ """Default arg scope for the NAS models."""
+ batch_norm_params = {
+ # Decay for the moving averages.
+ 'decay': batch_norm_decay,
+ # epsilon to prevent 0s in variance.
+ 'epsilon': batch_norm_epsilon,
+ 'scale': True,
+ }
+ batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
+ weights_regularizer = contrib_layers.l2_regularizer(weight_decay)
+ weights_initializer = contrib_layers.variance_scaling_initializer(
+ factor=1 / 3.0, mode='FAN_IN', uniform=True)
+ with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
+ weights_regularizer=weights_regularizer,
+ weights_initializer=weights_initializer):
+ with arg_scope([slim.fully_connected],
+ activation_fn=None, scope='FC'):
+ with arg_scope([slim.conv2d, slim.separable_conv2d],
+ activation_fn=None, biases_initializer=None):
+ with arg_scope([batch_norm], **batch_norm_params) as sc:
+ return sc
+
+
+def _nas_stem(inputs,
+ batch_norm_fn=slim.batch_norm):
+ """Stem used for NAS models."""
+ net = resnet_utils.conv2d_same(inputs, 64, 3, stride=2, scope='conv0')
+ net = batch_norm_fn(net, scope='conv0_bn')
+ net = tf.nn.relu(net)
+ net = resnet_utils.conv2d_same(net, 64, 3, stride=1, scope='conv1')
+ net = batch_norm_fn(net, scope='conv1_bn')
+ cell_outputs = [net]
+ net = tf.nn.relu(net)
+ net = resnet_utils.conv2d_same(net, 128, 3, stride=2, scope='conv2')
+ net = batch_norm_fn(net, scope='conv2_bn')
+ cell_outputs.append(net)
+ return net, cell_outputs
+
+
+def _build_nas_base(images,
+ cell,
+ backbone,
+ num_classes,
+ hparams,
+ global_pool=False,
+ output_stride=16,
+ nas_use_classification_head=False,
+ reuse=None,
+ scope=None,
+ final_endpoint=None,
+ batch_norm_fn=slim.batch_norm,
+ nas_remove_os32_stride=False):
+ """Constructs a NAS model.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ cell: Cell structure used in the network.
+ backbone: Backbone structure used in the network. A list of integers in
+ which value 0 means "output_stride=4", value 1 means "output_stride=8",
+ value 2 means "output_stride=16", and value 3 means "output_stride=32".
+ num_classes: Number of classes to predict.
+ hparams: Hyperparameters needed to construct the network.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ output_stride: Interger, the stride of output feature maps.
+ nas_use_classification_head: Boolean, use image classification head.
+ reuse: Whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ final_endpoint: The endpoint to construct the network up to.
+ batch_norm_fn: Batch norm function.
+ nas_remove_os32_stride: Boolean, remove stride in output_stride 32 branch.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: If output_stride is not a multiple of backbone output stride.
+ """
+ with tf.variable_scope(scope, 'nas', [images], reuse=reuse):
+ end_points = {}
+ def add_and_check_endpoint(endpoint_name, net):
+ end_points[endpoint_name] = net
+ return final_endpoint and (endpoint_name == final_endpoint)
+
+ net, cell_outputs = _nas_stem(images,
+ batch_norm_fn=batch_norm_fn)
+ if add_and_check_endpoint('Stem', net):
+ return net, end_points
+
+ # Run the cells
+ filter_scaling = 1.0
+ for cell_num in range(len(backbone)):
+ stride = 1
+ if cell_num == 0:
+ if backbone[0] == 1:
+ stride = 2
+ filter_scaling *= hparams.filter_scaling_rate
+ else:
+ if backbone[cell_num] == backbone[cell_num - 1] + 1:
+ stride = 2
+ if backbone[cell_num] == 3 and nas_remove_os32_stride:
+ stride = 1
+ filter_scaling *= hparams.filter_scaling_rate
+ elif backbone[cell_num] == backbone[cell_num - 1] - 1:
+ if backbone[cell_num - 1] == 3 and nas_remove_os32_stride:
+ # No need to rescale features.
+ pass
+ else:
+ # Scale features by a factor of 2.
+ scaled_height = scale_dimension(net.shape[1].value, 2)
+ scaled_width = scale_dimension(net.shape[2].value, 2)
+ net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
+ filter_scaling /= hparams.filter_scaling_rate
+ net = cell(
+ net,
+ scope='cell_{}'.format(cell_num),
+ filter_scaling=filter_scaling,
+ stride=stride,
+ prev_layer=cell_outputs[-2],
+ cell_num=cell_num)
+ if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
+ return net, end_points
+ cell_outputs.append(net)
+ net = tf.nn.relu(net)
+
+ if nas_use_classification_head:
+ # Add image classification head.
+ # We will expand the filters for different output_strides.
+ output_stride_to_expanded_filters = {8: 256, 16: 512, 32: 1024}
+ current_output_scale = 2 + backbone[-1]
+ current_output_stride = 2 ** current_output_scale
+ if output_stride % current_output_stride != 0:
+ raise ValueError(
+ 'output_stride must be a multiple of backbone output stride.')
+ output_stride //= current_output_stride
+ rate = 1
+ if current_output_stride != 32:
+ num_downsampling = 5 - current_output_scale
+ for i in range(num_downsampling):
+ # Gradually donwsample feature maps to output stride = 32.
+ target_output_stride = 2 ** (current_output_scale + 1 + i)
+ target_filters = output_stride_to_expanded_filters[
+ target_output_stride]
+ scope = 'downsample_os{}'.format(target_output_stride)
+ if output_stride != 1:
+ stride = 2
+ output_stride //= 2
+ else:
+ stride = 1
+ rate *= 2
+ net = resnet_utils.conv2d_same(
+ net, target_filters, 3, stride=stride, rate=rate,
+ scope=scope + '_conv')
+ net = batch_norm_fn(net, scope=scope + '_bn')
+ add_and_check_endpoint(scope, net)
+ net = tf.nn.relu(net)
+ # Apply 1x1 convolution to expand dimension to 2048.
+ scope = 'classification_head'
+ net = slim.conv2d(net, 2048, 1, scope=scope + '_conv')
+ net = batch_norm_fn(net, scope=scope + '_bn')
+ add_and_check_endpoint(scope, net)
+ net = tf.nn.relu(net)
+ if global_pool:
+ # Global average pooling.
+ net = tf.reduce_mean(net, [1, 2], name='global_pool', keepdims=True)
+ if num_classes is not None:
+ net = slim.conv2d(net, num_classes, 1, activation_fn=None,
+ normalizer_fn=None, scope='logits')
+ end_points['predictions'] = slim.softmax(net, scope='predictions')
+ return net, end_points
+
+
+def pnasnet(images,
+ num_classes,
+ is_training=True,
+ global_pool=False,
+ output_stride=16,
+ nas_architecture_options=None,
+ nas_training_hyper_parameters=None,
+ reuse=None,
+ scope='pnasnet',
+ final_endpoint=None,
+ sync_batch_norm_method='None'):
+ """Builds PNASNet model."""
+ if nas_architecture_options is None:
+ raise ValueError(
+ 'Using NAS model variants. nas_architecture_options cannot be None.')
+ hparams = config(num_conv_filters=nas_architecture_options[
+ 'nas_stem_output_num_conv_filters'])
+ if nas_training_hyper_parameters:
+ hparams.set_hparam('drop_path_keep_prob',
+ nas_training_hyper_parameters['drop_path_keep_prob'])
+ hparams.set_hparam('total_training_steps',
+ nas_training_hyper_parameters['total_training_steps'])
+ if not is_training:
+ tf.logging.info('During inference, setting drop_path_keep_prob = 1.0.')
+ hparams.set_hparam('drop_path_keep_prob', 1.0)
+ tf.logging.info(hparams)
+ if output_stride == 8:
+ backbone = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ elif output_stride == 16:
+ backbone = [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
+ elif output_stride == 32:
+ backbone = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
+ else:
+ raise ValueError('Unsupported output_stride ', output_stride)
+ batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
+ cell = nas_genotypes.PNASCell(hparams.num_conv_filters,
+ hparams.drop_path_keep_prob,
+ len(backbone),
+ hparams.total_training_steps,
+ batch_norm_fn=batch_norm)
+ with arg_scope([slim.dropout, batch_norm], is_training=is_training):
+ return _build_nas_base(
+ images,
+ cell=cell,
+ backbone=backbone,
+ num_classes=num_classes,
+ hparams=hparams,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ nas_use_classification_head=nas_architecture_options[
+ 'nas_use_classification_head'],
+ reuse=reuse,
+ scope=scope,
+ final_endpoint=final_endpoint,
+ batch_norm_fn=batch_norm,
+ nas_remove_os32_stride=nas_architecture_options[
+ 'nas_remove_os32_stride'])
+
+
+# pylint: disable=unused-argument
+def hnasnet(images,
+ num_classes,
+ is_training=True,
+ global_pool=False,
+ output_stride=8,
+ nas_architecture_options=None,
+ nas_training_hyper_parameters=None,
+ reuse=None,
+ scope='hnasnet',
+ final_endpoint=None,
+ sync_batch_norm_method='None'):
+ """Builds hierarchical model."""
+ if nas_architecture_options is None:
+ raise ValueError(
+ 'Using NAS model variants. nas_architecture_options cannot be None.')
+ hparams = config(num_conv_filters=nas_architecture_options[
+ 'nas_stem_output_num_conv_filters'])
+ if nas_training_hyper_parameters:
+ hparams.set_hparam('drop_path_keep_prob',
+ nas_training_hyper_parameters['drop_path_keep_prob'])
+ hparams.set_hparam('total_training_steps',
+ nas_training_hyper_parameters['total_training_steps'])
+ if not is_training:
+ tf.logging.info('During inference, setting drop_path_keep_prob = 1.0.')
+ hparams.set_hparam('drop_path_keep_prob', 1.0)
+ tf.logging.info(hparams)
+ operations = [
+ 'atrous_5x5', 'separable_3x3_2', 'separable_3x3_2', 'atrous_3x3',
+ 'separable_3x3_2', 'separable_3x3_2', 'separable_5x5_2',
+ 'separable_5x5_2', 'separable_5x5_2', 'atrous_5x5'
+ ]
+ used_hiddenstates = [1, 1, 0, 0, 0, 0, 0]
+ hiddenstate_indices = [1, 0, 1, 0, 3, 1, 4, 2, 3, 5]
+ backbone = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
+ batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
+ cell = NASBaseCell(hparams.num_conv_filters,
+ operations,
+ used_hiddenstates,
+ hiddenstate_indices,
+ hparams.drop_path_keep_prob,
+ len(backbone),
+ hparams.total_training_steps,
+ batch_norm_fn=batch_norm)
+ with arg_scope([slim.dropout, batch_norm], is_training=is_training):
+ return _build_nas_base(
+ images,
+ cell=cell,
+ backbone=backbone,
+ num_classes=num_classes,
+ hparams=hparams,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ nas_use_classification_head=nas_architecture_options[
+ 'nas_use_classification_head'],
+ reuse=reuse,
+ scope=scope,
+ final_endpoint=final_endpoint,
+ batch_norm_fn=batch_norm,
+ nas_remove_os32_stride=nas_architecture_options[
+ 'nas_remove_os32_stride'])
diff --git a/models/research/deeplab/core/nas_network_test.py b/models/research/deeplab/core/nas_network_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..18621b250ad7321f554b8d97449e19bde5ef4174
--- /dev/null
+++ b/models/research/deeplab/core/nas_network_test.py
@@ -0,0 +1,111 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for resnet_v1_beta module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import slim as contrib_slim
+from tensorflow.contrib import training as contrib_training
+
+from deeplab.core import nas_genotypes
+from deeplab.core import nas_network
+
+arg_scope = contrib_framework.arg_scope
+slim = contrib_slim
+
+
+def create_test_input(batch, height, width, channels):
+ """Creates test input tensor."""
+ if None in [batch, height, width, channels]:
+ return tf.placeholder(tf.float32, (batch, height, width, channels))
+ else:
+ return tf.to_float(
+ np.tile(
+ np.reshape(
+ np.reshape(np.arange(height), [height, 1]) +
+ np.reshape(np.arange(width), [1, width]),
+ [1, height, width, 1]),
+ [batch, 1, 1, channels]))
+
+
+class NASNetworkTest(tf.test.TestCase):
+ """Tests with complete small NAS networks."""
+
+ def _pnasnet(self,
+ images,
+ backbone,
+ num_classes,
+ is_training=True,
+ output_stride=16,
+ final_endpoint=None):
+ """Build PNASNet model backbone."""
+ hparams = contrib_training.HParams(
+ filter_scaling_rate=2.0,
+ num_conv_filters=10,
+ drop_path_keep_prob=1.0,
+ total_training_steps=200000,
+ )
+ if not is_training:
+ hparams.set_hparam('drop_path_keep_prob', 1.0)
+
+ cell = nas_genotypes.PNASCell(hparams.num_conv_filters,
+ hparams.drop_path_keep_prob,
+ len(backbone),
+ hparams.total_training_steps)
+ with arg_scope([slim.dropout, slim.batch_norm], is_training=is_training):
+ return nas_network._build_nas_base(
+ images,
+ cell=cell,
+ backbone=backbone,
+ num_classes=num_classes,
+ hparams=hparams,
+ reuse=tf.AUTO_REUSE,
+ scope='pnasnet_small',
+ final_endpoint=final_endpoint)
+
+ def testFullyConvolutionalEndpointShapes(self):
+ num_classes = 10
+ backbone = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
+ inputs = create_test_input(None, 321, 321, 3)
+ with slim.arg_scope(nas_network.nas_arg_scope()):
+ _, end_points = self._pnasnet(inputs, backbone, num_classes)
+ endpoint_to_shape = {
+ 'Stem': [None, 81, 81, 128],
+ 'Cell_0': [None, 81, 81, 50],
+ 'Cell_1': [None, 81, 81, 50],
+ 'Cell_2': [None, 81, 81, 50],
+ 'Cell_3': [None, 41, 41, 100],
+ 'Cell_4': [None, 21, 21, 200],
+ 'Cell_5': [None, 41, 41, 100],
+ 'Cell_6': [None, 21, 21, 200],
+ 'Cell_7': [None, 21, 21, 200],
+ 'Cell_8': [None, 11, 11, 400],
+ 'Cell_9': [None, 11, 11, 400],
+ 'Cell_10': [None, 21, 21, 200],
+ 'Cell_11': [None, 41, 41, 100]
+ }
+ for endpoint, shape in endpoint_to_shape.items():
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/core/preprocess_utils.py b/models/research/deeplab/core/preprocess_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..440717e414d1a6f67b0947eb78830ff84baa812d
--- /dev/null
+++ b/models/research/deeplab/core/preprocess_utils.py
@@ -0,0 +1,533 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utility functions related to preprocessing inputs."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from six.moves import range
+from six.moves import zip
+import tensorflow as tf
+
+
+def flip_dim(tensor_list, prob=0.5, dim=1):
+ """Randomly flips a dimension of the given tensor.
+
+ The decision to randomly flip the `Tensors` is made together. In other words,
+ all or none of the images pass in are flipped.
+
+ Note that tf.random_flip_left_right and tf.random_flip_up_down isn't used so
+ that we can control for the probability as well as ensure the same decision
+ is applied across the images.
+
+ Args:
+ tensor_list: A list of `Tensors` with the same number of dimensions.
+ prob: The probability of a left-right flip.
+ dim: The dimension to flip, 0, 1, ..
+
+ Returns:
+ outputs: A list of the possibly flipped `Tensors` as well as an indicator
+ `Tensor` at the end whose value is `True` if the inputs were flipped and
+ `False` otherwise.
+
+ Raises:
+ ValueError: If dim is negative or greater than the dimension of a `Tensor`.
+ """
+ random_value = tf.random_uniform([])
+
+ def flip():
+ flipped = []
+ for tensor in tensor_list:
+ if dim < 0 or dim >= len(tensor.get_shape().as_list()):
+ raise ValueError('dim must represent a valid dimension.')
+ flipped.append(tf.reverse_v2(tensor, [dim]))
+ return flipped
+
+ is_flipped = tf.less_equal(random_value, prob)
+ outputs = tf.cond(is_flipped, flip, lambda: tensor_list)
+ if not isinstance(outputs, (list, tuple)):
+ outputs = [outputs]
+ outputs.append(is_flipped)
+
+ return outputs
+
+
+def _image_dimensions(image, rank):
+ """Returns the dimensions of an image tensor.
+
+ Args:
+ image: A rank-D Tensor. For 3-D of shape: `[height, width, channels]`.
+ rank: The expected rank of the image
+
+ Returns:
+ A list of corresponding to the dimensions of the input image. Dimensions
+ that are statically known are python integers, otherwise they are integer
+ scalar tensors.
+ """
+ if image.get_shape().is_fully_defined():
+ return image.get_shape().as_list()
+ else:
+ static_shape = image.get_shape().with_rank(rank).as_list()
+ dynamic_shape = tf.unstack(tf.shape(image), rank)
+ return [
+ s if s is not None else d for s, d in zip(static_shape, dynamic_shape)
+ ]
+
+
+def get_label_resize_method(label):
+ """Returns the resize method of labels depending on label dtype.
+
+ Args:
+ label: Groundtruth label tensor.
+
+ Returns:
+ tf.image.ResizeMethod.BILINEAR, if label dtype is floating.
+ tf.image.ResizeMethod.NEAREST_NEIGHBOR, if label dtype is integer.
+
+ Raises:
+ ValueError: If label is neither floating nor integer.
+ """
+ if label.dtype.is_floating:
+ return tf.image.ResizeMethod.BILINEAR
+ elif label.dtype.is_integer:
+ return tf.image.ResizeMethod.NEAREST_NEIGHBOR
+ else:
+ raise ValueError('Label type must be either floating or integer.')
+
+
+def pad_to_bounding_box(image, offset_height, offset_width, target_height,
+ target_width, pad_value):
+ """Pads the given image with the given pad_value.
+
+ Works like tf.image.pad_to_bounding_box, except it can pad the image
+ with any given arbitrary pad value and also handle images whose sizes are not
+ known during graph construction.
+
+ Args:
+ image: 3-D tensor with shape [height, width, channels]
+ offset_height: Number of rows of zeros to add on top.
+ offset_width: Number of columns of zeros to add on the left.
+ target_height: Height of output image.
+ target_width: Width of output image.
+ pad_value: Value to pad the image tensor with.
+
+ Returns:
+ 3-D tensor of shape [target_height, target_width, channels].
+
+ Raises:
+ ValueError: If the shape of image is incompatible with the offset_* or
+ target_* arguments.
+ """
+ with tf.name_scope(None, 'pad_to_bounding_box', [image]):
+ image = tf.convert_to_tensor(image, name='image')
+ original_dtype = image.dtype
+ if original_dtype != tf.float32 and original_dtype != tf.float64:
+ # If image dtype is not float, we convert it to int32 to avoid overflow.
+ image = tf.cast(image, tf.int32)
+ image_rank_assert = tf.Assert(
+ tf.logical_or(
+ tf.equal(tf.rank(image), 3),
+ tf.equal(tf.rank(image), 4)),
+ ['Wrong image tensor rank.'])
+ with tf.control_dependencies([image_rank_assert]):
+ image -= pad_value
+ image_shape = image.get_shape()
+ is_batch = True
+ if image_shape.ndims == 3:
+ is_batch = False
+ image = tf.expand_dims(image, 0)
+ elif image_shape.ndims is None:
+ is_batch = False
+ image = tf.expand_dims(image, 0)
+ image.set_shape([None] * 4)
+ elif image.get_shape().ndims != 4:
+ raise ValueError('Input image must have either 3 or 4 dimensions.')
+ _, height, width, _ = _image_dimensions(image, rank=4)
+ target_width_assert = tf.Assert(
+ tf.greater_equal(
+ target_width, width),
+ ['target_width must be >= width'])
+ target_height_assert = tf.Assert(
+ tf.greater_equal(target_height, height),
+ ['target_height must be >= height'])
+ with tf.control_dependencies([target_width_assert]):
+ after_padding_width = target_width - offset_width - width
+ with tf.control_dependencies([target_height_assert]):
+ after_padding_height = target_height - offset_height - height
+ offset_assert = tf.Assert(
+ tf.logical_and(
+ tf.greater_equal(after_padding_width, 0),
+ tf.greater_equal(after_padding_height, 0)),
+ ['target size not possible with the given target offsets'])
+ batch_params = tf.stack([0, 0])
+ height_params = tf.stack([offset_height, after_padding_height])
+ width_params = tf.stack([offset_width, after_padding_width])
+ channel_params = tf.stack([0, 0])
+ with tf.control_dependencies([offset_assert]):
+ paddings = tf.stack([batch_params, height_params, width_params,
+ channel_params])
+ padded = tf.pad(image, paddings)
+ if not is_batch:
+ padded = tf.squeeze(padded, axis=[0])
+ outputs = padded + pad_value
+ if outputs.dtype != original_dtype:
+ outputs = tf.cast(outputs, original_dtype)
+ return outputs
+
+
+def _crop(image, offset_height, offset_width, crop_height, crop_width):
+ """Crops the given image using the provided offsets and sizes.
+
+ Note that the method doesn't assume we know the input image size but it does
+ assume we know the input image rank.
+
+ Args:
+ image: an image of shape [height, width, channels].
+ offset_height: a scalar tensor indicating the height offset.
+ offset_width: a scalar tensor indicating the width offset.
+ crop_height: the height of the cropped image.
+ crop_width: the width of the cropped image.
+
+ Returns:
+ The cropped (and resized) image.
+
+ Raises:
+ ValueError: if `image` doesn't have rank of 3.
+ InvalidArgumentError: if the rank is not 3 or if the image dimensions are
+ less than the crop size.
+ """
+ original_shape = tf.shape(image)
+
+ if len(image.get_shape().as_list()) != 3:
+ raise ValueError('input must have rank of 3')
+ original_channels = image.get_shape().as_list()[2]
+
+ rank_assertion = tf.Assert(
+ tf.equal(tf.rank(image), 3),
+ ['Rank of image must be equal to 3.'])
+ with tf.control_dependencies([rank_assertion]):
+ cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]])
+
+ size_assertion = tf.Assert(
+ tf.logical_and(
+ tf.greater_equal(original_shape[0], crop_height),
+ tf.greater_equal(original_shape[1], crop_width)),
+ ['Crop size greater than the image size.'])
+
+ offsets = tf.cast(tf.stack([offset_height, offset_width, 0]), tf.int32)
+
+ # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
+ # define the crop size.
+ with tf.control_dependencies([size_assertion]):
+ image = tf.slice(image, offsets, cropped_shape)
+ image = tf.reshape(image, cropped_shape)
+ image.set_shape([crop_height, crop_width, original_channels])
+ return image
+
+
+def random_crop(image_list, crop_height, crop_width):
+ """Crops the given list of images.
+
+ The function applies the same crop to each image in the list. This can be
+ effectively applied when there are multiple image inputs of the same
+ dimension such as:
+
+ image, depths, normals = random_crop([image, depths, normals], 120, 150)
+
+ Args:
+ image_list: a list of image tensors of the same dimension but possibly
+ varying channel.
+ crop_height: the new height.
+ crop_width: the new width.
+
+ Returns:
+ the image_list with cropped images.
+
+ Raises:
+ ValueError: if there are multiple image inputs provided with different size
+ or the images are smaller than the crop dimensions.
+ """
+ if not image_list:
+ raise ValueError('Empty image_list.')
+
+ # Compute the rank assertions.
+ rank_assertions = []
+ for i in range(len(image_list)):
+ image_rank = tf.rank(image_list[i])
+ rank_assert = tf.Assert(
+ tf.equal(image_rank, 3),
+ ['Wrong rank for tensor %s [expected] [actual]',
+ image_list[i].name, 3, image_rank])
+ rank_assertions.append(rank_assert)
+
+ with tf.control_dependencies([rank_assertions[0]]):
+ image_shape = tf.shape(image_list[0])
+ image_height = image_shape[0]
+ image_width = image_shape[1]
+ crop_size_assert = tf.Assert(
+ tf.logical_and(
+ tf.greater_equal(image_height, crop_height),
+ tf.greater_equal(image_width, crop_width)),
+ ['Crop size greater than the image size.'])
+
+ asserts = [rank_assertions[0], crop_size_assert]
+
+ for i in range(1, len(image_list)):
+ image = image_list[i]
+ asserts.append(rank_assertions[i])
+ with tf.control_dependencies([rank_assertions[i]]):
+ shape = tf.shape(image)
+ height = shape[0]
+ width = shape[1]
+
+ height_assert = tf.Assert(
+ tf.equal(height, image_height),
+ ['Wrong height for tensor %s [expected][actual]',
+ image.name, height, image_height])
+ width_assert = tf.Assert(
+ tf.equal(width, image_width),
+ ['Wrong width for tensor %s [expected][actual]',
+ image.name, width, image_width])
+ asserts.extend([height_assert, width_assert])
+
+ # Create a random bounding box.
+ #
+ # Use tf.random_uniform and not numpy.random.rand as doing the former would
+ # generate random numbers at graph eval time, unlike the latter which
+ # generates random numbers at graph definition time.
+ with tf.control_dependencies(asserts):
+ max_offset_height = tf.reshape(image_height - crop_height + 1, [])
+ max_offset_width = tf.reshape(image_width - crop_width + 1, [])
+ offset_height = tf.random_uniform(
+ [], maxval=max_offset_height, dtype=tf.int32)
+ offset_width = tf.random_uniform(
+ [], maxval=max_offset_width, dtype=tf.int32)
+
+ return [_crop(image, offset_height, offset_width,
+ crop_height, crop_width) for image in image_list]
+
+
+def get_random_scale(min_scale_factor, max_scale_factor, step_size):
+ """Gets a random scale value.
+
+ Args:
+ min_scale_factor: Minimum scale value.
+ max_scale_factor: Maximum scale value.
+ step_size: The step size from minimum to maximum value.
+
+ Returns:
+ A random scale value selected between minimum and maximum value.
+
+ Raises:
+ ValueError: min_scale_factor has unexpected value.
+ """
+ if min_scale_factor < 0 or min_scale_factor > max_scale_factor:
+ raise ValueError('Unexpected value of min_scale_factor.')
+
+ if min_scale_factor == max_scale_factor:
+ return tf.cast(min_scale_factor, tf.float32)
+
+ # When step_size = 0, we sample the value uniformly from [min, max).
+ if step_size == 0:
+ return tf.random_uniform([1],
+ minval=min_scale_factor,
+ maxval=max_scale_factor)
+
+ # When step_size != 0, we randomly select one discrete value from [min, max].
+ num_steps = int((max_scale_factor - min_scale_factor) / step_size + 1)
+ scale_factors = tf.lin_space(min_scale_factor, max_scale_factor, num_steps)
+ shuffled_scale_factors = tf.random_shuffle(scale_factors)
+ return shuffled_scale_factors[0]
+
+
+def randomly_scale_image_and_label(image, label=None, scale=1.0):
+ """Randomly scales image and label.
+
+ Args:
+ image: Image with shape [height, width, 3].
+ label: Label with shape [height, width, 1].
+ scale: The value to scale image and label.
+
+ Returns:
+ Scaled image and label.
+ """
+ # No random scaling if scale == 1.
+ if scale == 1.0:
+ return image, label
+ image_shape = tf.shape(image)
+ new_dim = tf.cast(
+ tf.cast([image_shape[0], image_shape[1]], tf.float32) * scale,
+ tf.int32)
+
+ # Need squeeze and expand_dims because image interpolation takes
+ # 4D tensors as input.
+ image = tf.squeeze(tf.image.resize_bilinear(
+ tf.expand_dims(image, 0),
+ new_dim,
+ align_corners=True), [0])
+ if label is not None:
+ label = tf.image.resize(
+ label,
+ new_dim,
+ method=get_label_resize_method(label),
+ align_corners=True)
+
+ return image, label
+
+
+def resolve_shape(tensor, rank=None, scope=None):
+ """Fully resolves the shape of a Tensor.
+
+ Use as much as possible the shape components already known during graph
+ creation and resolve the remaining ones during runtime.
+
+ Args:
+ tensor: Input tensor whose shape we query.
+ rank: The rank of the tensor, provided that we know it.
+ scope: Optional name scope.
+
+ Returns:
+ shape: The full shape of the tensor.
+ """
+ with tf.name_scope(scope, 'resolve_shape', [tensor]):
+ if rank is not None:
+ shape = tensor.get_shape().with_rank(rank).as_list()
+ else:
+ shape = tensor.get_shape().as_list()
+
+ if None in shape:
+ shape_dynamic = tf.shape(tensor)
+ for i in range(len(shape)):
+ if shape[i] is None:
+ shape[i] = shape_dynamic[i]
+
+ return shape
+
+
+def resize_to_range(image,
+ label=None,
+ min_size=None,
+ max_size=None,
+ factor=None,
+ keep_aspect_ratio=True,
+ align_corners=True,
+ label_layout_is_chw=False,
+ scope=None,
+ method=tf.image.ResizeMethod.BILINEAR):
+ """Resizes image or label so their sides are within the provided range.
+
+ The output size can be described by two cases:
+ 1. If the image can be rescaled so its minimum size is equal to min_size
+ without the other side exceeding max_size, then do so.
+ 2. Otherwise, resize so the largest side is equal to max_size.
+
+ An integer in `range(factor)` is added to the computed sides so that the
+ final dimensions are multiples of `factor` plus one.
+
+ Args:
+ image: A 3D tensor of shape [height, width, channels].
+ label: (optional) A 3D tensor of shape [height, width, channels] (default)
+ or [channels, height, width] when label_layout_is_chw = True.
+ min_size: (scalar) desired size of the smaller image side.
+ max_size: (scalar) maximum allowed size of the larger image side. Note
+ that the output dimension is no larger than max_size and may be slightly
+ smaller than max_size when factor is not None.
+ factor: Make output size multiple of factor plus one.
+ keep_aspect_ratio: Boolean, keep aspect ratio or not. If True, the input
+ will be resized while keeping the original aspect ratio. If False, the
+ input will be resized to [max_resize_value, max_resize_value] without
+ keeping the original aspect ratio.
+ align_corners: If True, exactly align all 4 corners of input and output.
+ label_layout_is_chw: If true, the label has shape [channel, height, width].
+ We support this case because for some instance segmentation dataset, the
+ instance segmentation is saved as [num_instances, height, width].
+ scope: Optional name scope.
+ method: Image resize method. Defaults to tf.image.ResizeMethod.BILINEAR.
+
+ Returns:
+ A 3-D tensor of shape [new_height, new_width, channels], where the image
+ has been resized (with the specified method) so that
+ min(new_height, new_width) == ceil(min_size) or
+ max(new_height, new_width) == ceil(max_size).
+
+ Raises:
+ ValueError: If the image is not a 3D tensor.
+ """
+ with tf.name_scope(scope, 'resize_to_range', [image]):
+ new_tensor_list = []
+ min_size = tf.cast(min_size, tf.float32)
+ if max_size is not None:
+ max_size = tf.cast(max_size, tf.float32)
+ # Modify the max_size to be a multiple of factor plus 1 and make sure the
+ # max dimension after resizing is no larger than max_size.
+ if factor is not None:
+ max_size = (max_size - (max_size - 1) % factor)
+
+ [orig_height, orig_width, _] = resolve_shape(image, rank=3)
+ orig_height = tf.cast(orig_height, tf.float32)
+ orig_width = tf.cast(orig_width, tf.float32)
+ orig_min_size = tf.minimum(orig_height, orig_width)
+
+ # Calculate the larger of the possible sizes
+ large_scale_factor = min_size / orig_min_size
+ large_height = tf.cast(tf.floor(orig_height * large_scale_factor), tf.int32)
+ large_width = tf.cast(tf.floor(orig_width * large_scale_factor), tf.int32)
+ large_size = tf.stack([large_height, large_width])
+
+ new_size = large_size
+ if max_size is not None:
+ # Calculate the smaller of the possible sizes, use that if the larger
+ # is too big.
+ orig_max_size = tf.maximum(orig_height, orig_width)
+ small_scale_factor = max_size / orig_max_size
+ small_height = tf.cast(
+ tf.floor(orig_height * small_scale_factor), tf.int32)
+ small_width = tf.cast(tf.floor(orig_width * small_scale_factor), tf.int32)
+ small_size = tf.stack([small_height, small_width])
+ new_size = tf.cond(
+ tf.cast(tf.reduce_max(large_size), tf.float32) > max_size,
+ lambda: small_size,
+ lambda: large_size)
+ # Ensure that both output sides are multiples of factor plus one.
+ if factor is not None:
+ new_size += (factor - (new_size - 1) % factor) % factor
+ if not keep_aspect_ratio:
+ # If not keep the aspect ratio, we resize everything to max_size, allowing
+ # us to do pre-processing without extra padding.
+ new_size = [tf.reduce_max(new_size), tf.reduce_max(new_size)]
+ new_tensor_list.append(tf.image.resize(
+ image, new_size, method=method, align_corners=align_corners))
+ if label is not None:
+ if label_layout_is_chw:
+ # Input label has shape [channel, height, width].
+ resized_label = tf.expand_dims(label, 3)
+ resized_label = tf.image.resize(
+ resized_label,
+ new_size,
+ method=get_label_resize_method(label),
+ align_corners=align_corners)
+ resized_label = tf.squeeze(resized_label, 3)
+ else:
+ # Input label has shape [height, width, channel].
+ resized_label = tf.image.resize(
+ label,
+ new_size,
+ method=get_label_resize_method(label),
+ align_corners=align_corners)
+ new_tensor_list.append(resized_label)
+ else:
+ new_tensor_list.append(None)
+ return new_tensor_list
diff --git a/models/research/deeplab/core/preprocess_utils_test.py b/models/research/deeplab/core/preprocess_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..606fe46dd62787cf1a8adfaa27121affc4a02498
--- /dev/null
+++ b/models/research/deeplab/core/preprocess_utils_test.py
@@ -0,0 +1,515 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for preprocess_utils."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import range
+import tensorflow as tf
+
+from deeplab.core import preprocess_utils
+
+
+class PreprocessUtilsTest(tf.test.TestCase):
+
+ def testNoFlipWhenProbIsZero(self):
+ numpy_image = np.dstack([[[5., 6.],
+ [9., 0.]],
+ [[4., 3.],
+ [3., 5.]]])
+ image = tf.convert_to_tensor(numpy_image)
+
+ with self.test_session():
+ actual, is_flipped = preprocess_utils.flip_dim([image], prob=0, dim=0)
+ self.assertAllEqual(numpy_image, actual.eval())
+ self.assertAllEqual(False, is_flipped.eval())
+ actual, is_flipped = preprocess_utils.flip_dim([image], prob=0, dim=1)
+ self.assertAllEqual(numpy_image, actual.eval())
+ self.assertAllEqual(False, is_flipped.eval())
+ actual, is_flipped = preprocess_utils.flip_dim([image], prob=0, dim=2)
+ self.assertAllEqual(numpy_image, actual.eval())
+ self.assertAllEqual(False, is_flipped.eval())
+
+ def testFlipWhenProbIsOne(self):
+ numpy_image = np.dstack([[[5., 6.],
+ [9., 0.]],
+ [[4., 3.],
+ [3., 5.]]])
+ dim0_flipped = np.dstack([[[9., 0.],
+ [5., 6.]],
+ [[3., 5.],
+ [4., 3.]]])
+ dim1_flipped = np.dstack([[[6., 5.],
+ [0., 9.]],
+ [[3., 4.],
+ [5., 3.]]])
+ dim2_flipped = np.dstack([[[4., 3.],
+ [3., 5.]],
+ [[5., 6.],
+ [9., 0.]]])
+ image = tf.convert_to_tensor(numpy_image)
+
+ with self.test_session():
+ actual, is_flipped = preprocess_utils.flip_dim([image], prob=1, dim=0)
+ self.assertAllEqual(dim0_flipped, actual.eval())
+ self.assertAllEqual(True, is_flipped.eval())
+ actual, is_flipped = preprocess_utils.flip_dim([image], prob=1, dim=1)
+ self.assertAllEqual(dim1_flipped, actual.eval())
+ self.assertAllEqual(True, is_flipped.eval())
+ actual, is_flipped = preprocess_utils.flip_dim([image], prob=1, dim=2)
+ self.assertAllEqual(dim2_flipped, actual.eval())
+ self.assertAllEqual(True, is_flipped.eval())
+
+ def testFlipMultipleImagesConsistentlyWhenProbIsOne(self):
+ numpy_image = np.dstack([[[5., 6.],
+ [9., 0.]],
+ [[4., 3.],
+ [3., 5.]]])
+ numpy_label = np.dstack([[[0., 1.],
+ [2., 3.]]])
+ image_dim1_flipped = np.dstack([[[6., 5.],
+ [0., 9.]],
+ [[3., 4.],
+ [5., 3.]]])
+ label_dim1_flipped = np.dstack([[[1., 0.],
+ [3., 2.]]])
+ image = tf.convert_to_tensor(numpy_image)
+ label = tf.convert_to_tensor(numpy_label)
+
+ with self.test_session() as sess:
+ image, label, is_flipped = preprocess_utils.flip_dim(
+ [image, label], prob=1, dim=1)
+ actual_image, actual_label = sess.run([image, label])
+ self.assertAllEqual(image_dim1_flipped, actual_image)
+ self.assertAllEqual(label_dim1_flipped, actual_label)
+ self.assertEqual(True, is_flipped.eval())
+
+ def testReturnRandomFlipsOnMultipleEvals(self):
+ numpy_image = np.dstack([[[5., 6.],
+ [9., 0.]],
+ [[4., 3.],
+ [3., 5.]]])
+ dim1_flipped = np.dstack([[[6., 5.],
+ [0., 9.]],
+ [[3., 4.],
+ [5., 3.]]])
+ image = tf.convert_to_tensor(numpy_image)
+ tf.compat.v1.set_random_seed(53)
+
+ with self.test_session() as sess:
+ actual, is_flipped = preprocess_utils.flip_dim(
+ [image], prob=0.5, dim=1)
+ actual_image, actual_is_flipped = sess.run([actual, is_flipped])
+ self.assertAllEqual(numpy_image, actual_image)
+ self.assertEqual(False, actual_is_flipped)
+ actual_image, actual_is_flipped = sess.run([actual, is_flipped])
+ self.assertAllEqual(dim1_flipped, actual_image)
+ self.assertEqual(True, actual_is_flipped)
+
+ def testReturnCorrectCropOfSingleImage(self):
+ np.random.seed(0)
+
+ height, width = 10, 20
+ image = np.random.randint(0, 256, size=(height, width, 3))
+
+ crop_height, crop_width = 2, 4
+
+ image_placeholder = tf.placeholder(tf.int32, shape=(None, None, 3))
+ [cropped] = preprocess_utils.random_crop([image_placeholder],
+ crop_height,
+ crop_width)
+
+ with self.test_session():
+ cropped_image = cropped.eval(feed_dict={image_placeholder: image})
+
+ # Ensure we can find the cropped image in the original:
+ is_found = False
+ for x in range(0, width - crop_width + 1):
+ for y in range(0, height - crop_height + 1):
+ if np.isclose(image[y:y+crop_height, x:x+crop_width, :],
+ cropped_image).all():
+ is_found = True
+ break
+
+ self.assertTrue(is_found)
+
+ def testRandomCropMaintainsNumberOfChannels(self):
+ np.random.seed(0)
+
+ crop_height, crop_width = 10, 20
+ image = np.random.randint(0, 256, size=(100, 200, 3))
+
+ tf.compat.v1.set_random_seed(37)
+ image_placeholder = tf.placeholder(tf.int32, shape=(None, None, 3))
+ [cropped] = preprocess_utils.random_crop(
+ [image_placeholder], crop_height, crop_width)
+
+ with self.test_session():
+ cropped_image = cropped.eval(feed_dict={image_placeholder: image})
+ self.assertTupleEqual(cropped_image.shape, (crop_height, crop_width, 3))
+
+ def testReturnDifferentCropAreasOnTwoEvals(self):
+ tf.compat.v1.set_random_seed(0)
+
+ crop_height, crop_width = 2, 3
+ image = np.random.randint(0, 256, size=(100, 200, 3))
+ image_placeholder = tf.placeholder(tf.int32, shape=(None, None, 3))
+ [cropped] = preprocess_utils.random_crop(
+ [image_placeholder], crop_height, crop_width)
+
+ with self.test_session():
+ crop0 = cropped.eval(feed_dict={image_placeholder: image})
+ crop1 = cropped.eval(feed_dict={image_placeholder: image})
+ self.assertFalse(np.isclose(crop0, crop1).all())
+
+ def testReturnConsistenCropsOfImagesInTheList(self):
+ tf.compat.v1.set_random_seed(0)
+
+ height, width = 10, 20
+ crop_height, crop_width = 2, 3
+ labels = np.linspace(0, height * width-1, height * width)
+ labels = labels.reshape((height, width, 1))
+ image = np.tile(labels, (1, 1, 3))
+
+ image_placeholder = tf.placeholder(tf.int32, shape=(None, None, 3))
+ label_placeholder = tf.placeholder(tf.int32, shape=(None, None, 1))
+ [cropped_image, cropped_label] = preprocess_utils.random_crop(
+ [image_placeholder, label_placeholder], crop_height, crop_width)
+
+ with self.test_session() as sess:
+ cropped_image, cropped_labels = sess.run([cropped_image, cropped_label],
+ feed_dict={
+ image_placeholder: image,
+ label_placeholder: labels})
+ for i in range(3):
+ self.assertAllEqual(cropped_image[:, :, i], cropped_labels.squeeze())
+
+ def testDieOnRandomCropWhenImagesWithDifferentWidth(self):
+ crop_height, crop_width = 2, 3
+ image1 = tf.placeholder(tf.float32, name='image1', shape=(None, None, 3))
+ image2 = tf.placeholder(tf.float32, name='image2', shape=(None, None, 1))
+ cropped = preprocess_utils.random_crop(
+ [image1, image2], crop_height, crop_width)
+
+ with self.test_session() as sess:
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ sess.run(cropped, feed_dict={image1: np.random.rand(4, 5, 3),
+ image2: np.random.rand(4, 6, 1)})
+
+ def testDieOnRandomCropWhenImagesWithDifferentHeight(self):
+ crop_height, crop_width = 2, 3
+ image1 = tf.placeholder(tf.float32, name='image1', shape=(None, None, 3))
+ image2 = tf.placeholder(tf.float32, name='image2', shape=(None, None, 1))
+ cropped = preprocess_utils.random_crop(
+ [image1, image2], crop_height, crop_width)
+
+ with self.test_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ tf.errors.InvalidArgumentError,
+ 'Wrong height for tensor'):
+ sess.run(cropped, feed_dict={image1: np.random.rand(4, 5, 3),
+ image2: np.random.rand(3, 5, 1)})
+
+ def testDieOnRandomCropWhenCropSizeIsGreaterThanImage(self):
+ crop_height, crop_width = 5, 9
+ image1 = tf.placeholder(tf.float32, name='image1', shape=(None, None, 3))
+ image2 = tf.placeholder(tf.float32, name='image2', shape=(None, None, 1))
+ cropped = preprocess_utils.random_crop(
+ [image1, image2], crop_height, crop_width)
+
+ with self.test_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ tf.errors.InvalidArgumentError,
+ 'Crop size greater than the image size.'):
+ sess.run(cropped, feed_dict={image1: np.random.rand(4, 5, 3),
+ image2: np.random.rand(4, 5, 1)})
+
+ def testReturnPaddedImageWithNonZeroPadValue(self):
+ for dtype in [np.int32, np.int64, np.float32, np.float64]:
+ image = np.dstack([[[5, 6],
+ [9, 0]],
+ [[4, 3],
+ [3, 5]]]).astype(dtype)
+ expected_image = np.dstack([[[255, 255, 255, 255, 255],
+ [255, 255, 255, 255, 255],
+ [255, 5, 6, 255, 255],
+ [255, 9, 0, 255, 255],
+ [255, 255, 255, 255, 255]],
+ [[255, 255, 255, 255, 255],
+ [255, 255, 255, 255, 255],
+ [255, 4, 3, 255, 255],
+ [255, 3, 5, 255, 255],
+ [255, 255, 255, 255, 255]]]).astype(dtype)
+
+ with self.session() as sess:
+ padded_image = preprocess_utils.pad_to_bounding_box(
+ image, 2, 1, 5, 5, 255)
+ padded_image = sess.run(padded_image)
+ self.assertAllClose(padded_image, expected_image)
+ # Add batch size = 1 to image.
+ padded_image = preprocess_utils.pad_to_bounding_box(
+ np.expand_dims(image, 0), 2, 1, 5, 5, 255)
+ padded_image = sess.run(padded_image)
+ self.assertAllClose(padded_image, np.expand_dims(expected_image, 0))
+
+ def testReturnOriginalImageWhenTargetSizeIsEqualToImageSize(self):
+ image = np.dstack([[[5, 6],
+ [9, 0]],
+ [[4, 3],
+ [3, 5]]])
+ with self.session() as sess:
+ padded_image = preprocess_utils.pad_to_bounding_box(
+ image, 0, 0, 2, 2, 255)
+ padded_image = sess.run(padded_image)
+ self.assertAllClose(padded_image, image)
+
+ def testDieOnTargetSizeGreaterThanImageSize(self):
+ image = np.dstack([[[5, 6],
+ [9, 0]],
+ [[4, 3],
+ [3, 5]]])
+ with self.test_session():
+ image_placeholder = tf.placeholder(tf.float32)
+ padded_image = preprocess_utils.pad_to_bounding_box(
+ image_placeholder, 0, 0, 2, 1, 255)
+ with self.assertRaisesWithPredicateMatch(
+ tf.errors.InvalidArgumentError,
+ 'target_width must be >= width'):
+ padded_image.eval(feed_dict={image_placeholder: image})
+ padded_image = preprocess_utils.pad_to_bounding_box(
+ image_placeholder, 0, 0, 1, 2, 255)
+ with self.assertRaisesWithPredicateMatch(
+ tf.errors.InvalidArgumentError,
+ 'target_height must be >= height'):
+ padded_image.eval(feed_dict={image_placeholder: image})
+
+ def testDieIfTargetSizeNotPossibleWithGivenOffset(self):
+ image = np.dstack([[[5, 6],
+ [9, 0]],
+ [[4, 3],
+ [3, 5]]])
+ with self.test_session():
+ image_placeholder = tf.placeholder(tf.float32)
+ padded_image = preprocess_utils.pad_to_bounding_box(
+ image_placeholder, 3, 0, 4, 4, 255)
+ with self.assertRaisesWithPredicateMatch(
+ tf.errors.InvalidArgumentError,
+ 'target size not possible with the given target offsets'):
+ padded_image.eval(feed_dict={image_placeholder: image})
+
+ def testDieIfImageTensorRankIsTwo(self):
+ image = np.vstack([[5, 6],
+ [9, 0]])
+ with self.test_session():
+ image_placeholder = tf.placeholder(tf.float32)
+ padded_image = preprocess_utils.pad_to_bounding_box(
+ image_placeholder, 0, 0, 2, 2, 255)
+ with self.assertRaisesWithPredicateMatch(
+ tf.errors.InvalidArgumentError,
+ 'Wrong image tensor rank'):
+ padded_image.eval(feed_dict={image_placeholder: image})
+
+ def testResizeTensorsToRange(self):
+ test_shapes = [[60, 40],
+ [15, 30],
+ [15, 50]]
+ min_size = 50
+ max_size = 100
+ factor = None
+ expected_shape_list = [(75, 50, 3),
+ (50, 100, 3),
+ (30, 100, 3)]
+ for i, test_shape in enumerate(test_shapes):
+ image = tf.random.normal([test_shape[0], test_shape[1], 3])
+ new_tensor_list = preprocess_utils.resize_to_range(
+ image=image,
+ label=None,
+ min_size=min_size,
+ max_size=max_size,
+ factor=factor,
+ align_corners=True)
+ with self.test_session() as session:
+ resized_image = session.run(new_tensor_list[0])
+ self.assertEqual(resized_image.shape, expected_shape_list[i])
+
+ def testResizeTensorsToRangeWithFactor(self):
+ test_shapes = [[60, 40],
+ [15, 30],
+ [15, 50]]
+ min_size = 50
+ max_size = 98
+ factor = 8
+ expected_image_shape_list = [(81, 57, 3),
+ (49, 97, 3),
+ (33, 97, 3)]
+ expected_label_shape_list = [(81, 57, 1),
+ (49, 97, 1),
+ (33, 97, 1)]
+ for i, test_shape in enumerate(test_shapes):
+ image = tf.random.normal([test_shape[0], test_shape[1], 3])
+ label = tf.random.normal([test_shape[0], test_shape[1], 1])
+ new_tensor_list = preprocess_utils.resize_to_range(
+ image=image,
+ label=label,
+ min_size=min_size,
+ max_size=max_size,
+ factor=factor,
+ align_corners=True)
+ with self.test_session() as session:
+ new_tensor_list = session.run(new_tensor_list)
+ self.assertEqual(new_tensor_list[0].shape, expected_image_shape_list[i])
+ self.assertEqual(new_tensor_list[1].shape, expected_label_shape_list[i])
+
+ def testResizeTensorsToRangeWithFactorAndLabelShapeCHW(self):
+ test_shapes = [[60, 40],
+ [15, 30],
+ [15, 50]]
+ min_size = 50
+ max_size = 98
+ factor = 8
+ expected_image_shape_list = [(81, 57, 3),
+ (49, 97, 3),
+ (33, 97, 3)]
+ expected_label_shape_list = [(5, 81, 57),
+ (5, 49, 97),
+ (5, 33, 97)]
+ for i, test_shape in enumerate(test_shapes):
+ image = tf.random.normal([test_shape[0], test_shape[1], 3])
+ label = tf.random.normal([5, test_shape[0], test_shape[1]])
+ new_tensor_list = preprocess_utils.resize_to_range(
+ image=image,
+ label=label,
+ min_size=min_size,
+ max_size=max_size,
+ factor=factor,
+ align_corners=True,
+ label_layout_is_chw=True)
+ with self.test_session() as session:
+ new_tensor_list = session.run(new_tensor_list)
+ self.assertEqual(new_tensor_list[0].shape, expected_image_shape_list[i])
+ self.assertEqual(new_tensor_list[1].shape, expected_label_shape_list[i])
+
+ def testResizeTensorsToRangeWithSimilarMinMaxSizes(self):
+ test_shapes = [[60, 40],
+ [15, 30],
+ [15, 50]]
+ # Values set so that one of the side = 97.
+ min_size = 96
+ max_size = 98
+ factor = 8
+ expected_image_shape_list = [(97, 65, 3),
+ (49, 97, 3),
+ (33, 97, 3)]
+ expected_label_shape_list = [(97, 65, 1),
+ (49, 97, 1),
+ (33, 97, 1)]
+ for i, test_shape in enumerate(test_shapes):
+ image = tf.random.normal([test_shape[0], test_shape[1], 3])
+ label = tf.random.normal([test_shape[0], test_shape[1], 1])
+ new_tensor_list = preprocess_utils.resize_to_range(
+ image=image,
+ label=label,
+ min_size=min_size,
+ max_size=max_size,
+ factor=factor,
+ align_corners=True)
+ with self.test_session() as session:
+ new_tensor_list = session.run(new_tensor_list)
+ self.assertEqual(new_tensor_list[0].shape, expected_image_shape_list[i])
+ self.assertEqual(new_tensor_list[1].shape, expected_label_shape_list[i])
+
+ def testResizeTensorsToRangeWithEqualMaxSize(self):
+ test_shapes = [[97, 38],
+ [96, 97]]
+ # Make max_size equal to the larger value of test_shapes.
+ min_size = 97
+ max_size = 97
+ factor = 8
+ expected_image_shape_list = [(97, 41, 3),
+ (97, 97, 3)]
+ expected_label_shape_list = [(97, 41, 1),
+ (97, 97, 1)]
+ for i, test_shape in enumerate(test_shapes):
+ image = tf.random.normal([test_shape[0], test_shape[1], 3])
+ label = tf.random.normal([test_shape[0], test_shape[1], 1])
+ new_tensor_list = preprocess_utils.resize_to_range(
+ image=image,
+ label=label,
+ min_size=min_size,
+ max_size=max_size,
+ factor=factor,
+ align_corners=True)
+ with self.test_session() as session:
+ new_tensor_list = session.run(new_tensor_list)
+ self.assertEqual(new_tensor_list[0].shape, expected_image_shape_list[i])
+ self.assertEqual(new_tensor_list[1].shape, expected_label_shape_list[i])
+
+ def testResizeTensorsToRangeWithPotentialErrorInTFCeil(self):
+ test_shape = [3936, 5248]
+ # Make max_size equal to the larger value of test_shapes.
+ min_size = 1441
+ max_size = 1441
+ factor = 16
+ expected_image_shape = (1089, 1441, 3)
+ expected_label_shape = (1089, 1441, 1)
+ image = tf.random.normal([test_shape[0], test_shape[1], 3])
+ label = tf.random.normal([test_shape[0], test_shape[1], 1])
+ new_tensor_list = preprocess_utils.resize_to_range(
+ image=image,
+ label=label,
+ min_size=min_size,
+ max_size=max_size,
+ factor=factor,
+ align_corners=True)
+ with self.test_session() as session:
+ new_tensor_list = session.run(new_tensor_list)
+ self.assertEqual(new_tensor_list[0].shape, expected_image_shape)
+ self.assertEqual(new_tensor_list[1].shape, expected_label_shape)
+
+ def testResizeTensorsToRangeWithEqualMaxSizeWithoutAspectRatio(self):
+ test_shapes = [[97, 38],
+ [96, 97]]
+ # Make max_size equal to the larger value of test_shapes.
+ min_size = 97
+ max_size = 97
+ factor = 8
+ keep_aspect_ratio = False
+ expected_image_shape_list = [(97, 97, 3),
+ (97, 97, 3)]
+ expected_label_shape_list = [(97, 97, 1),
+ (97, 97, 1)]
+ for i, test_shape in enumerate(test_shapes):
+ image = tf.random.normal([test_shape[0], test_shape[1], 3])
+ label = tf.random.normal([test_shape[0], test_shape[1], 1])
+ new_tensor_list = preprocess_utils.resize_to_range(
+ image=image,
+ label=label,
+ min_size=min_size,
+ max_size=max_size,
+ factor=factor,
+ keep_aspect_ratio=keep_aspect_ratio,
+ align_corners=True)
+ with self.test_session() as session:
+ new_tensor_list = session.run(new_tensor_list)
+ self.assertEqual(new_tensor_list[0].shape, expected_image_shape_list[i])
+ self.assertEqual(new_tensor_list[1].shape, expected_label_shape_list[i])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/core/resnet_v1_beta.py b/models/research/deeplab/core/resnet_v1_beta.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d5f1f19a234fb13cbac2d7397d0948b64d3011b
--- /dev/null
+++ b/models/research/deeplab/core/resnet_v1_beta.py
@@ -0,0 +1,827 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Resnet v1 model variants.
+
+Code branched out from slim/nets/resnet_v1.py, and please refer to it for
+more details.
+
+The original version ResNets-v1 were proposed by:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+from six.moves import range
+import tensorflow as tf
+from tensorflow.contrib import slim as contrib_slim
+from deeplab.core import conv2d_ws
+from deeplab.core import utils
+from tensorflow.contrib.slim.nets import resnet_utils
+
+slim = contrib_slim
+
+_DEFAULT_MULTI_GRID = [1, 1, 1]
+_DEFAULT_MULTI_GRID_RESNET_18 = [1, 1]
+
+
+@slim.add_arg_scope
+def bottleneck(inputs,
+ depth,
+ depth_bottleneck,
+ stride,
+ unit_rate=1,
+ rate=1,
+ outputs_collections=None,
+ scope=None):
+ """Bottleneck residual unit variant with BN after convolutions.
+
+ This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
+ its definition. Note that we use here the bottleneck variant which has an
+ extra bottleneck layer.
+
+ When putting together two consecutive ResNet blocks that use this unit, one
+ should use stride = 2 in the last unit of the first block.
+
+ Args:
+ inputs: A tensor of size [batch, height, width, channels].
+ depth: The depth of the ResNet unit output.
+ depth_bottleneck: The depth of the bottleneck layers.
+ stride: The ResNet unit's stride. Determines the amount of downsampling of
+ the units output compared to its input.
+ unit_rate: An integer, unit rate for atrous convolution.
+ rate: An integer, rate for atrous convolution.
+ outputs_collections: Collection to add the ResNet unit output.
+ scope: Optional variable_scope.
+
+ Returns:
+ The ResNet unit's output.
+ """
+ with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
+ depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
+ if depth == depth_in:
+ shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
+ else:
+ shortcut = conv2d_ws.conv2d(
+ inputs,
+ depth,
+ [1, 1],
+ stride=stride,
+ activation_fn=None,
+ scope='shortcut')
+
+ residual = conv2d_ws.conv2d(inputs, depth_bottleneck, [1, 1], stride=1,
+ scope='conv1')
+ residual = conv2d_ws.conv2d_same(residual, depth_bottleneck, 3, stride,
+ rate=rate*unit_rate, scope='conv2')
+ residual = conv2d_ws.conv2d(residual, depth, [1, 1], stride=1,
+ activation_fn=None, scope='conv3')
+ output = tf.nn.relu(shortcut + residual)
+
+ return slim.utils.collect_named_outputs(outputs_collections, sc.name,
+ output)
+
+
+@slim.add_arg_scope
+def lite_bottleneck(inputs,
+ depth,
+ stride,
+ unit_rate=1,
+ rate=1,
+ outputs_collections=None,
+ scope=None):
+ """Bottleneck residual unit variant with BN after convolutions.
+
+ This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
+ its definition. Note that we use here the bottleneck variant which has an
+ extra bottleneck layer.
+
+ When putting together two consecutive ResNet blocks that use this unit, one
+ should use stride = 2 in the last unit of the first block.
+
+ Args:
+ inputs: A tensor of size [batch, height, width, channels].
+ depth: The depth of the ResNet unit output.
+ stride: The ResNet unit's stride. Determines the amount of downsampling of
+ the units output compared to its input.
+ unit_rate: An integer, unit rate for atrous convolution.
+ rate: An integer, rate for atrous convolution.
+ outputs_collections: Collection to add the ResNet unit output.
+ scope: Optional variable_scope.
+
+ Returns:
+ The ResNet unit's output.
+ """
+ with tf.variable_scope(scope, 'lite_bottleneck_v1', [inputs]) as sc:
+ depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
+ if depth == depth_in:
+ shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
+ else:
+ shortcut = conv2d_ws.conv2d(
+ inputs,
+ depth, [1, 1],
+ stride=stride,
+ activation_fn=None,
+ scope='shortcut')
+
+ residual = conv2d_ws.conv2d_same(
+ inputs, depth, 3, 1, rate=rate * unit_rate, scope='conv1')
+ with slim.arg_scope([conv2d_ws.conv2d], activation_fn=None):
+ residual = conv2d_ws.conv2d_same(
+ residual, depth, 3, stride, rate=rate * unit_rate, scope='conv2')
+ output = tf.nn.relu(shortcut + residual)
+
+ return slim.utils.collect_named_outputs(outputs_collections, sc.name,
+ output)
+
+
+def root_block_fn_for_beta_variant(net, depth_multiplier=1.0):
+ """Gets root_block_fn for beta variant.
+
+ ResNet-v1 beta variant modifies the first original 7x7 convolution to three
+ 3x3 convolutions.
+
+ Args:
+ net: A tensor of size [batch, height, width, channels], input to the model.
+ depth_multiplier: Controls the number of convolution output channels for
+ each input channel. The total number of depthwise convolution output
+ channels will be equal to `num_filters_out * depth_multiplier`.
+
+ Returns:
+ A tensor after three 3x3 convolutions.
+ """
+ net = conv2d_ws.conv2d_same(
+ net, int(64 * depth_multiplier), 3, stride=2, scope='conv1_1')
+ net = conv2d_ws.conv2d_same(
+ net, int(64 * depth_multiplier), 3, stride=1, scope='conv1_2')
+ net = conv2d_ws.conv2d_same(
+ net, int(128 * depth_multiplier), 3, stride=1, scope='conv1_3')
+
+ return net
+
+
+def resnet_v1_beta(inputs,
+ blocks,
+ num_classes=None,
+ is_training=None,
+ global_pool=True,
+ output_stride=None,
+ root_block_fn=None,
+ reuse=None,
+ scope=None,
+ sync_batch_norm_method='None'):
+ """Generator for v1 ResNet models (beta variant).
+
+ This function generates a family of modified ResNet v1 models. In particular,
+ the first original 7x7 convolution is replaced with three 3x3 convolutions.
+ See the resnet_v1_*() methods for specific model instantiations, obtained by
+ selecting different block instantiations that produce ResNets of various
+ depths.
+
+ The code is modified from slim/nets/resnet_v1.py, and please refer to it for
+ more details.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels].
+ blocks: A list of length equal to the number of ResNet blocks. Each element
+ is a resnet_utils.Block object describing the units in the block.
+ num_classes: Number of predicted classes for classification tasks. If None
+ we return the features before the logit layer.
+ is_training: Enable/disable is_training for batch normalization.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution.
+ root_block_fn: The function consisting of convolution operations applied to
+ the root input. If root_block_fn is None, use the original setting of
+ RseNet-v1, which is simply one convolution with 7x7 kernel and stride=2.
+ reuse: whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ sync_batch_norm_method: String, sync batchnorm method.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ If global_pool is False, then height_out and width_out are reduced by a
+ factor of output_stride compared to the respective height_in and width_in,
+ else both height_out and width_out equal one. If num_classes is None, then
+ net is the output of the last ResNet block, potentially after global
+ average pooling. If num_classes is not None, net contains the pre-softmax
+ activations.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: If the target output_stride is not valid.
+ """
+ if root_block_fn is None:
+ root_block_fn = functools.partial(conv2d_ws.conv2d_same,
+ num_outputs=64,
+ kernel_size=7,
+ stride=2,
+ scope='conv1')
+ batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
+ with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
+ end_points_collection = sc.original_name_scope + '_end_points'
+ with slim.arg_scope([
+ conv2d_ws.conv2d, bottleneck, lite_bottleneck,
+ resnet_utils.stack_blocks_dense
+ ],
+ outputs_collections=end_points_collection):
+ if is_training is not None:
+ arg_scope = slim.arg_scope([batch_norm], is_training=is_training)
+ else:
+ arg_scope = slim.arg_scope([])
+ with arg_scope:
+ net = inputs
+ if output_stride is not None:
+ if output_stride % 4 != 0:
+ raise ValueError('The output_stride needs to be a multiple of 4.')
+ output_stride //= 4
+ net = root_block_fn(net)
+ net = slim.max_pool2d(net, 3, stride=2, padding='SAME', scope='pool1')
+ net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
+
+ if global_pool:
+ # Global average pooling.
+ net = tf.reduce_mean(net, [1, 2], name='pool5', keepdims=True)
+ if num_classes is not None:
+ net = conv2d_ws.conv2d(net, num_classes, [1, 1], activation_fn=None,
+ normalizer_fn=None, scope='logits',
+ use_weight_standardization=False)
+ # Convert end_points_collection into a dictionary of end_points.
+ end_points = slim.utils.convert_collection_to_dict(
+ end_points_collection)
+ if num_classes is not None:
+ end_points['predictions'] = slim.softmax(net, scope='predictions')
+ return net, end_points
+
+
+def resnet_v1_beta_block(scope, base_depth, num_units, stride):
+ """Helper function for creating a resnet_v1 beta variant bottleneck block.
+
+ Args:
+ scope: The scope of the block.
+ base_depth: The depth of the bottleneck layer for each unit.
+ num_units: The number of units in the block.
+ stride: The stride of the block, implemented as a stride in the last unit.
+ All other units have stride=1.
+
+ Returns:
+ A resnet_v1 bottleneck block.
+ """
+ return resnet_utils.Block(scope, bottleneck, [{
+ 'depth': base_depth * 4,
+ 'depth_bottleneck': base_depth,
+ 'stride': 1,
+ 'unit_rate': 1
+ }] * (num_units - 1) + [{
+ 'depth': base_depth * 4,
+ 'depth_bottleneck': base_depth,
+ 'stride': stride,
+ 'unit_rate': 1
+ }])
+
+
+def resnet_v1_small_beta_block(scope, base_depth, num_units, stride):
+ """Helper function for creating a resnet_18 beta variant bottleneck block.
+
+ Args:
+ scope: The scope of the block.
+ base_depth: The depth of the bottleneck layer for each unit.
+ num_units: The number of units in the block.
+ stride: The stride of the block, implemented as a stride in the last unit.
+ All other units have stride=1.
+
+ Returns:
+ A resnet_18 bottleneck block.
+ """
+ block_args = []
+ for _ in range(num_units - 1):
+ block_args.append({'depth': base_depth, 'stride': 1, 'unit_rate': 1})
+ block_args.append({'depth': base_depth, 'stride': stride, 'unit_rate': 1})
+ return resnet_utils.Block(scope, lite_bottleneck, block_args)
+
+
+def resnet_v1_18(inputs,
+ num_classes=None,
+ is_training=None,
+ global_pool=False,
+ output_stride=None,
+ multi_grid=None,
+ reuse=None,
+ scope='resnet_v1_18',
+ sync_batch_norm_method='None'):
+ """Resnet v1 18.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels].
+ num_classes: Number of predicted classes for classification tasks. If None
+ we return the features before the logit layer.
+ is_training: Enable/disable is_training for batch normalization.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution.
+ multi_grid: Employ a hierarchy of different atrous rates within network.
+ reuse: whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ sync_batch_norm_method: String, sync batchnorm method.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ If global_pool is False, then height_out and width_out are reduced by a
+ factor of output_stride compared to the respective height_in and width_in,
+ else both height_out and width_out equal one. If num_classes is None, then
+ net is the output of the last ResNet block, potentially after global
+ average pooling. If num_classes is not None, net contains the pre-softmax
+ activations.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: if multi_grid is not None and does not have length = 3.
+ """
+ if multi_grid is None:
+ multi_grid = _DEFAULT_MULTI_GRID_RESNET_18
+ else:
+ if len(multi_grid) != 2:
+ raise ValueError('Expect multi_grid to have length 2.')
+
+ block4_args = []
+ for rate in multi_grid:
+ block4_args.append({'depth': 512, 'stride': 1, 'unit_rate': rate})
+
+ blocks = [
+ resnet_v1_small_beta_block(
+ 'block1', base_depth=64, num_units=2, stride=2),
+ resnet_v1_small_beta_block(
+ 'block2', base_depth=128, num_units=2, stride=2),
+ resnet_v1_small_beta_block(
+ 'block3', base_depth=256, num_units=2, stride=2),
+ resnet_utils.Block('block4', lite_bottleneck, block4_args),
+ ]
+ return resnet_v1_beta(
+ inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def resnet_v1_18_beta(inputs,
+ num_classes=None,
+ is_training=None,
+ global_pool=False,
+ output_stride=None,
+ multi_grid=None,
+ root_depth_multiplier=0.25,
+ reuse=None,
+ scope='resnet_v1_18',
+ sync_batch_norm_method='None'):
+ """Resnet v1 18 beta variant.
+
+ This variant modifies the first convolution layer of ResNet-v1-18. In
+ particular, it changes the original one 7x7 convolution to three 3x3
+ convolutions.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels].
+ num_classes: Number of predicted classes for classification tasks. If None
+ we return the features before the logit layer.
+ is_training: Enable/disable is_training for batch normalization.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution.
+ multi_grid: Employ a hierarchy of different atrous rates within network.
+ root_depth_multiplier: Float, depth multiplier used for the first three
+ convolution layers that replace the 7x7 convolution.
+ reuse: whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ sync_batch_norm_method: String, sync batchnorm method.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ If global_pool is False, then height_out and width_out are reduced by a
+ factor of output_stride compared to the respective height_in and width_in,
+ else both height_out and width_out equal one. If num_classes is None, then
+ net is the output of the last ResNet block, potentially after global
+ average pooling. If num_classes is not None, net contains the pre-softmax
+ activations.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: if multi_grid is not None and does not have length = 3.
+ """
+ if multi_grid is None:
+ multi_grid = _DEFAULT_MULTI_GRID_RESNET_18
+ else:
+ if len(multi_grid) != 2:
+ raise ValueError('Expect multi_grid to have length 2.')
+
+ block4_args = []
+ for rate in multi_grid:
+ block4_args.append({'depth': 512, 'stride': 1, 'unit_rate': rate})
+
+ blocks = [
+ resnet_v1_small_beta_block(
+ 'block1', base_depth=64, num_units=2, stride=2),
+ resnet_v1_small_beta_block(
+ 'block2', base_depth=128, num_units=2, stride=2),
+ resnet_v1_small_beta_block(
+ 'block3', base_depth=256, num_units=2, stride=2),
+ resnet_utils.Block('block4', lite_bottleneck, block4_args),
+ ]
+ return resnet_v1_beta(
+ inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ root_block_fn=functools.partial(root_block_fn_for_beta_variant,
+ depth_multiplier=root_depth_multiplier),
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def resnet_v1_50(inputs,
+ num_classes=None,
+ is_training=None,
+ global_pool=False,
+ output_stride=None,
+ multi_grid=None,
+ reuse=None,
+ scope='resnet_v1_50',
+ sync_batch_norm_method='None'):
+ """Resnet v1 50.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels].
+ num_classes: Number of predicted classes for classification tasks. If None
+ we return the features before the logit layer.
+ is_training: Enable/disable is_training for batch normalization.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution.
+ multi_grid: Employ a hierarchy of different atrous rates within network.
+ reuse: whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ sync_batch_norm_method: String, sync batchnorm method.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ If global_pool is False, then height_out and width_out are reduced by a
+ factor of output_stride compared to the respective height_in and width_in,
+ else both height_out and width_out equal one. If num_classes is None, then
+ net is the output of the last ResNet block, potentially after global
+ average pooling. If num_classes is not None, net contains the pre-softmax
+ activations.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: if multi_grid is not None and does not have length = 3.
+ """
+ if multi_grid is None:
+ multi_grid = _DEFAULT_MULTI_GRID
+ else:
+ if len(multi_grid) != 3:
+ raise ValueError('Expect multi_grid to have length 3.')
+
+ blocks = [
+ resnet_v1_beta_block(
+ 'block1', base_depth=64, num_units=3, stride=2),
+ resnet_v1_beta_block(
+ 'block2', base_depth=128, num_units=4, stride=2),
+ resnet_v1_beta_block(
+ 'block3', base_depth=256, num_units=6, stride=2),
+ resnet_utils.Block('block4', bottleneck, [
+ {'depth': 2048, 'depth_bottleneck': 512, 'stride': 1,
+ 'unit_rate': rate} for rate in multi_grid]),
+ ]
+ return resnet_v1_beta(
+ inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def resnet_v1_50_beta(inputs,
+ num_classes=None,
+ is_training=None,
+ global_pool=False,
+ output_stride=None,
+ multi_grid=None,
+ reuse=None,
+ scope='resnet_v1_50',
+ sync_batch_norm_method='None'):
+ """Resnet v1 50 beta variant.
+
+ This variant modifies the first convolution layer of ResNet-v1-50. In
+ particular, it changes the original one 7x7 convolution to three 3x3
+ convolutions.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels].
+ num_classes: Number of predicted classes for classification tasks. If None
+ we return the features before the logit layer.
+ is_training: Enable/disable is_training for batch normalization.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution.
+ multi_grid: Employ a hierarchy of different atrous rates within network.
+ reuse: whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ sync_batch_norm_method: String, sync batchnorm method.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ If global_pool is False, then height_out and width_out are reduced by a
+ factor of output_stride compared to the respective height_in and width_in,
+ else both height_out and width_out equal one. If num_classes is None, then
+ net is the output of the last ResNet block, potentially after global
+ average pooling. If num_classes is not None, net contains the pre-softmax
+ activations.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: if multi_grid is not None and does not have length = 3.
+ """
+ if multi_grid is None:
+ multi_grid = _DEFAULT_MULTI_GRID
+ else:
+ if len(multi_grid) != 3:
+ raise ValueError('Expect multi_grid to have length 3.')
+
+ blocks = [
+ resnet_v1_beta_block(
+ 'block1', base_depth=64, num_units=3, stride=2),
+ resnet_v1_beta_block(
+ 'block2', base_depth=128, num_units=4, stride=2),
+ resnet_v1_beta_block(
+ 'block3', base_depth=256, num_units=6, stride=2),
+ resnet_utils.Block('block4', bottleneck, [
+ {'depth': 2048, 'depth_bottleneck': 512, 'stride': 1,
+ 'unit_rate': rate} for rate in multi_grid]),
+ ]
+ return resnet_v1_beta(
+ inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ root_block_fn=functools.partial(root_block_fn_for_beta_variant),
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def resnet_v1_101(inputs,
+ num_classes=None,
+ is_training=None,
+ global_pool=False,
+ output_stride=None,
+ multi_grid=None,
+ reuse=None,
+ scope='resnet_v1_101',
+ sync_batch_norm_method='None'):
+ """Resnet v1 101.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels].
+ num_classes: Number of predicted classes for classification tasks. If None
+ we return the features before the logit layer.
+ is_training: Enable/disable is_training for batch normalization.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution.
+ multi_grid: Employ a hierarchy of different atrous rates within network.
+ reuse: whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ sync_batch_norm_method: String, sync batchnorm method.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ If global_pool is False, then height_out and width_out are reduced by a
+ factor of output_stride compared to the respective height_in and width_in,
+ else both height_out and width_out equal one. If num_classes is None, then
+ net is the output of the last ResNet block, potentially after global
+ average pooling. If num_classes is not None, net contains the pre-softmax
+ activations.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: if multi_grid is not None and does not have length = 3.
+ """
+ if multi_grid is None:
+ multi_grid = _DEFAULT_MULTI_GRID
+ else:
+ if len(multi_grid) != 3:
+ raise ValueError('Expect multi_grid to have length 3.')
+
+ blocks = [
+ resnet_v1_beta_block(
+ 'block1', base_depth=64, num_units=3, stride=2),
+ resnet_v1_beta_block(
+ 'block2', base_depth=128, num_units=4, stride=2),
+ resnet_v1_beta_block(
+ 'block3', base_depth=256, num_units=23, stride=2),
+ resnet_utils.Block('block4', bottleneck, [
+ {'depth': 2048, 'depth_bottleneck': 512, 'stride': 1,
+ 'unit_rate': rate} for rate in multi_grid]),
+ ]
+ return resnet_v1_beta(
+ inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def resnet_v1_101_beta(inputs,
+ num_classes=None,
+ is_training=None,
+ global_pool=False,
+ output_stride=None,
+ multi_grid=None,
+ reuse=None,
+ scope='resnet_v1_101',
+ sync_batch_norm_method='None'):
+ """Resnet v1 101 beta variant.
+
+ This variant modifies the first convolution layer of ResNet-v1-101. In
+ particular, it changes the original one 7x7 convolution to three 3x3
+ convolutions.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels].
+ num_classes: Number of predicted classes for classification tasks. If None
+ we return the features before the logit layer.
+ is_training: Enable/disable is_training for batch normalization.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution.
+ multi_grid: Employ a hierarchy of different atrous rates within network.
+ reuse: whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ sync_batch_norm_method: String, sync batchnorm method.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ If global_pool is False, then height_out and width_out are reduced by a
+ factor of output_stride compared to the respective height_in and width_in,
+ else both height_out and width_out equal one. If num_classes is None, then
+ net is the output of the last ResNet block, potentially after global
+ average pooling. If num_classes is not None, net contains the pre-softmax
+ activations.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: if multi_grid is not None and does not have length = 3.
+ """
+ if multi_grid is None:
+ multi_grid = _DEFAULT_MULTI_GRID
+ else:
+ if len(multi_grid) != 3:
+ raise ValueError('Expect multi_grid to have length 3.')
+
+ blocks = [
+ resnet_v1_beta_block(
+ 'block1', base_depth=64, num_units=3, stride=2),
+ resnet_v1_beta_block(
+ 'block2', base_depth=128, num_units=4, stride=2),
+ resnet_v1_beta_block(
+ 'block3', base_depth=256, num_units=23, stride=2),
+ resnet_utils.Block('block4', bottleneck, [
+ {'depth': 2048, 'depth_bottleneck': 512, 'stride': 1,
+ 'unit_rate': rate} for rate in multi_grid]),
+ ]
+ return resnet_v1_beta(
+ inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ root_block_fn=functools.partial(root_block_fn_for_beta_variant),
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def resnet_arg_scope(weight_decay=0.0001,
+ batch_norm_decay=0.997,
+ batch_norm_epsilon=1e-5,
+ batch_norm_scale=True,
+ activation_fn=tf.nn.relu,
+ use_batch_norm=True,
+ sync_batch_norm_method='None',
+ normalization_method='unspecified',
+ use_weight_standardization=False):
+ """Defines the default ResNet arg scope.
+
+ Args:
+ weight_decay: The weight decay to use for regularizing the model.
+ batch_norm_decay: The moving average decay when estimating layer activation
+ statistics in batch normalization.
+ batch_norm_epsilon: Small constant to prevent division by zero when
+ normalizing activations by their variance in batch normalization.
+ batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
+ activations in the batch normalization layer.
+ activation_fn: The activation function which is used in ResNet.
+ use_batch_norm: Deprecated in favor of normalization_method.
+ sync_batch_norm_method: String, sync batchnorm method.
+ normalization_method: String, one of `batch`, `none`, or `group`, to use
+ batch normalization, no normalization, or group normalization.
+ use_weight_standardization: Boolean, whether to use weight standardization.
+
+ Returns:
+ An `arg_scope` to use for the resnet models.
+ """
+ batch_norm_params = {
+ 'decay': batch_norm_decay,
+ 'epsilon': batch_norm_epsilon,
+ 'scale': batch_norm_scale,
+ }
+ batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
+ if normalization_method == 'batch':
+ normalizer_fn = batch_norm
+ elif normalization_method == 'none':
+ normalizer_fn = None
+ elif normalization_method == 'group':
+ normalizer_fn = slim.group_norm
+ elif normalization_method == 'unspecified':
+ normalizer_fn = batch_norm if use_batch_norm else None
+ else:
+ raise ValueError('Unrecognized normalization_method %s' %
+ normalization_method)
+
+ with slim.arg_scope([conv2d_ws.conv2d],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ weights_initializer=slim.variance_scaling_initializer(),
+ activation_fn=activation_fn,
+ normalizer_fn=normalizer_fn,
+ use_weight_standardization=use_weight_standardization):
+ with slim.arg_scope([batch_norm], **batch_norm_params):
+ # The following implies padding='SAME' for pool1, which makes feature
+ # alignment easier for dense prediction tasks. This is also used in
+ # https://github.com/facebook/fb.resnet.torch. However the accompanying
+ # code of 'Deep Residual Learning for Image Recognition' uses
+ # padding='VALID' for pool1. You can switch to that choice by setting
+ # slim.arg_scope([slim.max_pool2d], padding='VALID').
+ with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
+ return arg_sc
diff --git a/models/research/deeplab/core/resnet_v1_beta_test.py b/models/research/deeplab/core/resnet_v1_beta_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b61edcce21803047ea047a0acb1bc9e7ae147da
--- /dev/null
+++ b/models/research/deeplab/core/resnet_v1_beta_test.py
@@ -0,0 +1,564 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for resnet_v1_beta module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy as np
+import six
+import tensorflow as tf
+from tensorflow.contrib import slim as contrib_slim
+
+from deeplab.core import resnet_v1_beta
+from tensorflow.contrib.slim.nets import resnet_utils
+
+slim = contrib_slim
+
+
+def create_test_input(batch, height, width, channels):
+ """Create test input tensor."""
+ if None in [batch, height, width, channels]:
+ return tf.placeholder(tf.float32, (batch, height, width, channels))
+ else:
+ return tf.to_float(
+ np.tile(
+ np.reshape(
+ np.reshape(np.arange(height), [height, 1]) +
+ np.reshape(np.arange(width), [1, width]),
+ [1, height, width, 1]),
+ [batch, 1, 1, channels]))
+
+
+class ResnetCompleteNetworkTest(tf.test.TestCase):
+ """Tests with complete small ResNet v1 networks."""
+
+ def _resnet_small_lite_bottleneck(self,
+ inputs,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ output_stride=None,
+ multi_grid=None,
+ reuse=None,
+ scope='resnet_v1_small'):
+ """A shallow and thin ResNet v1 with lite_bottleneck."""
+ if multi_grid is None:
+ multi_grid = [1, 1]
+ else:
+ if len(multi_grid) != 2:
+ raise ValueError('Expect multi_grid to have length 2.')
+ block = resnet_v1_beta.resnet_v1_small_beta_block
+ blocks = [
+ block('block1', base_depth=1, num_units=1, stride=2),
+ block('block2', base_depth=2, num_units=1, stride=2),
+ block('block3', base_depth=4, num_units=1, stride=2),
+ resnet_utils.Block('block4', resnet_v1_beta.lite_bottleneck, [
+ {'depth': 8,
+ 'stride': 1,
+ 'unit_rate': rate} for rate in multi_grid])]
+ return resnet_v1_beta.resnet_v1_beta(
+ inputs,
+ blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ root_block_fn=functools.partial(
+ resnet_v1_beta.root_block_fn_for_beta_variant,
+ depth_multiplier=0.25),
+ reuse=reuse,
+ scope=scope)
+
+ def _resnet_small(self,
+ inputs,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ output_stride=None,
+ multi_grid=None,
+ reuse=None,
+ scope='resnet_v1_small'):
+ """A shallow and thin ResNet v1 for faster tests."""
+ if multi_grid is None:
+ multi_grid = [1, 1, 1]
+ else:
+ if len(multi_grid) != 3:
+ raise ValueError('Expect multi_grid to have length 3.')
+
+ block = resnet_v1_beta.resnet_v1_beta_block
+ blocks = [
+ block('block1', base_depth=1, num_units=1, stride=2),
+ block('block2', base_depth=2, num_units=1, stride=2),
+ block('block3', base_depth=4, num_units=1, stride=2),
+ resnet_utils.Block('block4', resnet_v1_beta.bottleneck, [
+ {'depth': 32, 'depth_bottleneck': 8, 'stride': 1,
+ 'unit_rate': rate} for rate in multi_grid])]
+
+ return resnet_v1_beta.resnet_v1_beta(
+ inputs,
+ blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ root_block_fn=functools.partial(
+ resnet_v1_beta.root_block_fn_for_beta_variant),
+ reuse=reuse,
+ scope=scope)
+
+ def testClassificationEndPointsWithLiteBottleneck(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ logits, end_points = self._resnet_small_lite_bottleneck(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='resnet')
+
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+ self.assertIn('predictions', end_points)
+ self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+ [2, 1, 1, num_classes])
+
+ def testClassificationEndPointsWithMultigridAndLiteBottleneck(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ multi_grid = [1, 2]
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ logits, end_points = self._resnet_small_lite_bottleneck(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ multi_grid=multi_grid,
+ scope='resnet')
+
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+ self.assertIn('predictions', end_points)
+ self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+ [2, 1, 1, num_classes])
+
+ def testClassificationShapesWithLiteBottleneck(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ _, end_points = self._resnet_small_lite_bottleneck(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='resnet')
+ endpoint_to_shape = {
+ 'resnet/conv1_1': [2, 112, 112, 16],
+ 'resnet/conv1_2': [2, 112, 112, 16],
+ 'resnet/conv1_3': [2, 112, 112, 32],
+ 'resnet/block1': [2, 28, 28, 1],
+ 'resnet/block2': [2, 14, 14, 2],
+ 'resnet/block3': [2, 7, 7, 4],
+ 'resnet/block4': [2, 7, 7, 8]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testFullyConvolutionalEndpointShapesWithLiteBottleneck(self):
+ global_pool = False
+ num_classes = 10
+ inputs = create_test_input(2, 321, 321, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ _, end_points = self._resnet_small_lite_bottleneck(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='resnet')
+ endpoint_to_shape = {
+ 'resnet/conv1_1': [2, 161, 161, 16],
+ 'resnet/conv1_2': [2, 161, 161, 16],
+ 'resnet/conv1_3': [2, 161, 161, 32],
+ 'resnet/block1': [2, 41, 41, 1],
+ 'resnet/block2': [2, 21, 21, 2],
+ 'resnet/block3': [2, 11, 11, 4],
+ 'resnet/block4': [2, 11, 11, 8]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testAtrousFullyConvolutionalEndpointShapesWithLiteBottleneck(self):
+ global_pool = False
+ num_classes = 10
+ output_stride = 8
+ inputs = create_test_input(2, 321, 321, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ _, end_points = self._resnet_small_lite_bottleneck(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ scope='resnet')
+ endpoint_to_shape = {
+ 'resnet/conv1_1': [2, 161, 161, 16],
+ 'resnet/conv1_2': [2, 161, 161, 16],
+ 'resnet/conv1_3': [2, 161, 161, 32],
+ 'resnet/block1': [2, 41, 41, 1],
+ 'resnet/block2': [2, 41, 41, 2],
+ 'resnet/block3': [2, 41, 41, 4],
+ 'resnet/block4': [2, 41, 41, 8]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testAtrousFullyConvolutionalValuesWithLiteBottleneck(self):
+ """Verify dense feature extraction with atrous convolution."""
+ nominal_stride = 32
+ for output_stride in [4, 8, 16, 32, None]:
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ with tf.Graph().as_default():
+ with self.test_session() as sess:
+ tf.set_random_seed(0)
+ inputs = create_test_input(2, 81, 81, 3)
+ # Dense feature extraction followed by subsampling.
+ output, _ = self._resnet_small_lite_bottleneck(
+ inputs,
+ None,
+ is_training=False,
+ global_pool=False,
+ output_stride=output_stride)
+ if output_stride is None:
+ factor = 1
+ else:
+ factor = nominal_stride // output_stride
+ output = resnet_utils.subsample(output, factor)
+ # Make the two networks use the same weights.
+ tf.get_variable_scope().reuse_variables()
+ # Feature extraction at the nominal network rate.
+ expected, _ = self._resnet_small_lite_bottleneck(
+ inputs,
+ None,
+ is_training=False,
+ global_pool=False)
+ sess.run(tf.global_variables_initializer())
+ self.assertAllClose(output.eval(), expected.eval(),
+ atol=1e-4, rtol=1e-4)
+
+ def testUnknownBatchSizeWithLiteBottleneck(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(None, height, width, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ logits, _ = self._resnet_small_lite_bottleneck(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='resnet')
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(),
+ [None, 1, 1, num_classes])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(logits, {inputs: images.eval()})
+ self.assertEqual(output.shape, (batch, 1, 1, num_classes))
+
+ def testFullyConvolutionalUnknownHeightWidthWithLiteBottleneck(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = False
+ inputs = create_test_input(batch, None, None, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ output, _ = self._resnet_small_lite_bottleneck(
+ inputs,
+ None,
+ global_pool=global_pool)
+ self.assertListEqual(output.get_shape().as_list(),
+ [batch, None, None, 8])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(output, {inputs: images.eval()})
+ self.assertEqual(output.shape, (batch, 3, 3, 8))
+
+ def testAtrousFullyConvolutionalUnknownHeightWidthWithLiteBottleneck(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = False
+ output_stride = 8
+ inputs = create_test_input(batch, None, None, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ output, _ = self._resnet_small_lite_bottleneck(
+ inputs,
+ None,
+ global_pool=global_pool,
+ output_stride=output_stride)
+ self.assertListEqual(output.get_shape().as_list(),
+ [batch, None, None, 8])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(output, {inputs: images.eval()})
+ self.assertEqual(output.shape, (batch, 9, 9, 8))
+
+ def testClassificationEndPoints(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ logits, end_points = self._resnet_small(inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='resnet')
+
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+ self.assertIn('predictions', end_points)
+ self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+ [2, 1, 1, num_classes])
+
+ def testClassificationEndPointsWithWS(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ with slim.arg_scope(
+ resnet_v1_beta.resnet_arg_scope(use_weight_standardization=True)):
+ logits, end_points = self._resnet_small(
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
+
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+ self.assertIn('predictions', end_points)
+ self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+ [2, 1, 1, num_classes])
+
+ def testClassificationEndPointsWithGN(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ with slim.arg_scope(
+ resnet_v1_beta.resnet_arg_scope(normalization_method='group')):
+ with slim.arg_scope([slim.group_norm], groups=1):
+ logits, end_points = self._resnet_small(
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
+
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+ self.assertIn('predictions', end_points)
+ self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+ [2, 1, 1, num_classes])
+
+ def testInvalidGroupsWithGN(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ with self.assertRaisesRegexp(ValueError, 'Invalid groups'):
+ with slim.arg_scope(
+ resnet_v1_beta.resnet_arg_scope(normalization_method='group')):
+ with slim.arg_scope([slim.group_norm], groups=32):
+ _, _ = self._resnet_small(
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
+
+ def testClassificationEndPointsWithGNWS(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ with slim.arg_scope(
+ resnet_v1_beta.resnet_arg_scope(
+ normalization_method='group', use_weight_standardization=True)):
+ with slim.arg_scope([slim.group_norm], groups=1):
+ logits, end_points = self._resnet_small(
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
+
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+ self.assertIn('predictions', end_points)
+ self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+ [2, 1, 1, num_classes])
+
+ def testClassificationEndPointsWithMultigrid(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ multi_grid = [1, 2, 4]
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ logits, end_points = self._resnet_small(inputs,
+ num_classes,
+ global_pool=global_pool,
+ multi_grid=multi_grid,
+ scope='resnet')
+
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+ self.assertIn('predictions', end_points)
+ self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+ [2, 1, 1, num_classes])
+
+ def testClassificationShapes(self):
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ _, end_points = self._resnet_small(inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='resnet')
+ endpoint_to_shape = {
+ 'resnet/conv1_1': [2, 112, 112, 64],
+ 'resnet/conv1_2': [2, 112, 112, 64],
+ 'resnet/conv1_3': [2, 112, 112, 128],
+ 'resnet/block1': [2, 28, 28, 4],
+ 'resnet/block2': [2, 14, 14, 8],
+ 'resnet/block3': [2, 7, 7, 16],
+ 'resnet/block4': [2, 7, 7, 32]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testFullyConvolutionalEndpointShapes(self):
+ global_pool = False
+ num_classes = 10
+ inputs = create_test_input(2, 321, 321, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ _, end_points = self._resnet_small(inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='resnet')
+ endpoint_to_shape = {
+ 'resnet/conv1_1': [2, 161, 161, 64],
+ 'resnet/conv1_2': [2, 161, 161, 64],
+ 'resnet/conv1_3': [2, 161, 161, 128],
+ 'resnet/block1': [2, 41, 41, 4],
+ 'resnet/block2': [2, 21, 21, 8],
+ 'resnet/block3': [2, 11, 11, 16],
+ 'resnet/block4': [2, 11, 11, 32]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testAtrousFullyConvolutionalEndpointShapes(self):
+ global_pool = False
+ num_classes = 10
+ output_stride = 8
+ inputs = create_test_input(2, 321, 321, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ _, end_points = self._resnet_small(inputs,
+ num_classes,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ scope='resnet')
+ endpoint_to_shape = {
+ 'resnet/conv1_1': [2, 161, 161, 64],
+ 'resnet/conv1_2': [2, 161, 161, 64],
+ 'resnet/conv1_3': [2, 161, 161, 128],
+ 'resnet/block1': [2, 41, 41, 4],
+ 'resnet/block2': [2, 41, 41, 8],
+ 'resnet/block3': [2, 41, 41, 16],
+ 'resnet/block4': [2, 41, 41, 32]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testAtrousFullyConvolutionalValues(self):
+ """Verify dense feature extraction with atrous convolution."""
+ nominal_stride = 32
+ for output_stride in [4, 8, 16, 32, None]:
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ with tf.Graph().as_default():
+ with self.test_session() as sess:
+ tf.set_random_seed(0)
+ inputs = create_test_input(2, 81, 81, 3)
+ # Dense feature extraction followed by subsampling.
+ output, _ = self._resnet_small(inputs,
+ None,
+ is_training=False,
+ global_pool=False,
+ output_stride=output_stride)
+ if output_stride is None:
+ factor = 1
+ else:
+ factor = nominal_stride // output_stride
+ output = resnet_utils.subsample(output, factor)
+ # Make the two networks use the same weights.
+ tf.get_variable_scope().reuse_variables()
+ # Feature extraction at the nominal network rate.
+ expected, _ = self._resnet_small(inputs,
+ None,
+ is_training=False,
+ global_pool=False)
+ sess.run(tf.global_variables_initializer())
+ self.assertAllClose(output.eval(), expected.eval(),
+ atol=1e-4, rtol=1e-4)
+
+ def testUnknownBatchSize(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(None, height, width, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ logits, _ = self._resnet_small(inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='resnet')
+ self.assertTrue(logits.op.name.startswith('resnet/logits'))
+ self.assertListEqual(logits.get_shape().as_list(),
+ [None, 1, 1, num_classes])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(logits, {inputs: images.eval()})
+ self.assertEqual(output.shape, (batch, 1, 1, num_classes))
+
+ def testFullyConvolutionalUnknownHeightWidth(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = False
+ inputs = create_test_input(batch, None, None, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ output, _ = self._resnet_small(inputs,
+ None,
+ global_pool=global_pool)
+ self.assertListEqual(output.get_shape().as_list(),
+ [batch, None, None, 32])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(output, {inputs: images.eval()})
+ self.assertEqual(output.shape, (batch, 3, 3, 32))
+
+ def testAtrousFullyConvolutionalUnknownHeightWidth(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = False
+ output_stride = 8
+ inputs = create_test_input(batch, None, None, 3)
+ with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+ output, _ = self._resnet_small(inputs,
+ None,
+ global_pool=global_pool,
+ output_stride=output_stride)
+ self.assertListEqual(output.get_shape().as_list(),
+ [batch, None, None, 32])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(output, {inputs: images.eval()})
+ self.assertEqual(output.shape, (batch, 9, 9, 32))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/core/utils.py b/models/research/deeplab/core/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bf3d09ad4647c757da5f9ebb3c2f676e3ccc00c
--- /dev/null
+++ b/models/research/deeplab/core/utils.py
@@ -0,0 +1,214 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""This script contains utility functions."""
+import tensorflow as tf
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import slim as contrib_slim
+
+slim = contrib_slim
+
+
+# Quantized version of sigmoid function.
+q_sigmoid = lambda x: tf.nn.relu6(x + 3) * 0.16667
+
+
+def resize_bilinear(images, size, output_dtype=tf.float32):
+ """Returns resized images as output_type.
+
+ Args:
+ images: A tensor of size [batch, height_in, width_in, channels].
+ size: A 1-D int32 Tensor of 2 elements: new_height, new_width. The new size
+ for the images.
+ output_dtype: The destination type.
+ Returns:
+ A tensor of size [batch, height_out, width_out, channels] as a dtype of
+ output_dtype.
+ """
+ images = tf.image.resize_bilinear(images, size, align_corners=True)
+ return tf.cast(images, dtype=output_dtype)
+
+
+def scale_dimension(dim, scale):
+ """Scales the input dimension.
+
+ Args:
+ dim: Input dimension (a scalar or a scalar Tensor).
+ scale: The amount of scaling applied to the input.
+
+ Returns:
+ Scaled dimension.
+ """
+ if isinstance(dim, tf.Tensor):
+ return tf.cast((tf.to_float(dim) - 1.0) * scale + 1.0, dtype=tf.int32)
+ else:
+ return int((float(dim) - 1.0) * scale + 1.0)
+
+
+def split_separable_conv2d(inputs,
+ filters,
+ kernel_size=3,
+ rate=1,
+ weight_decay=0.00004,
+ depthwise_weights_initializer_stddev=0.33,
+ pointwise_weights_initializer_stddev=0.06,
+ scope=None):
+ """Splits a separable conv2d into depthwise and pointwise conv2d.
+
+ This operation differs from `tf.layers.separable_conv2d` as this operation
+ applies activation function between depthwise and pointwise conv2d.
+
+ Args:
+ inputs: Input tensor with shape [batch, height, width, channels].
+ filters: Number of filters in the 1x1 pointwise convolution.
+ kernel_size: A list of length 2: [kernel_height, kernel_width] of
+ of the filters. Can be an int if both values are the same.
+ rate: Atrous convolution rate for the depthwise convolution.
+ weight_decay: The weight decay to use for regularizing the model.
+ depthwise_weights_initializer_stddev: The standard deviation of the
+ truncated normal weight initializer for depthwise convolution.
+ pointwise_weights_initializer_stddev: The standard deviation of the
+ truncated normal weight initializer for pointwise convolution.
+ scope: Optional scope for the operation.
+
+ Returns:
+ Computed features after split separable conv2d.
+ """
+ outputs = slim.separable_conv2d(
+ inputs,
+ None,
+ kernel_size=kernel_size,
+ depth_multiplier=1,
+ rate=rate,
+ weights_initializer=tf.truncated_normal_initializer(
+ stddev=depthwise_weights_initializer_stddev),
+ weights_regularizer=None,
+ scope=scope + '_depthwise')
+ return slim.conv2d(
+ outputs,
+ filters,
+ 1,
+ weights_initializer=tf.truncated_normal_initializer(
+ stddev=pointwise_weights_initializer_stddev),
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ scope=scope + '_pointwise')
+
+
+def get_label_weight_mask(labels, ignore_label, num_classes, label_weights=1.0):
+ """Gets the label weight mask.
+
+ Args:
+ labels: A Tensor of labels with the shape of [-1].
+ ignore_label: Integer, label to ignore.
+ num_classes: Integer, the number of semantic classes.
+ label_weights: A float or a list of weights. If it is a float, it means all
+ the labels have the same weight. If it is a list of weights, then each
+ element in the list represents the weight for the label of its index, for
+ example, label_weights = [0.1, 0.5] means the weight for label 0 is 0.1
+ and the weight for label 1 is 0.5.
+
+ Returns:
+ A Tensor of label weights with the same shape of labels, each element is the
+ weight for the label with the same index in labels and the element is 0.0
+ if the label is to ignore.
+
+ Raises:
+ ValueError: If label_weights is neither a float nor a list, or if
+ label_weights is a list and its length is not equal to num_classes.
+ """
+ if not isinstance(label_weights, (float, list)):
+ raise ValueError(
+ 'The type of label_weights is invalid, it must be a float or a list.')
+
+ if isinstance(label_weights, list) and len(label_weights) != num_classes:
+ raise ValueError(
+ 'Length of label_weights must be equal to num_classes if it is a list, '
+ 'label_weights: %s, num_classes: %d.' % (label_weights, num_classes))
+
+ not_ignore_mask = tf.not_equal(labels, ignore_label)
+ not_ignore_mask = tf.cast(not_ignore_mask, tf.float32)
+ if isinstance(label_weights, float):
+ return not_ignore_mask * label_weights
+
+ label_weights = tf.constant(label_weights, tf.float32)
+ weight_mask = tf.einsum('...y,y->...',
+ tf.one_hot(labels, num_classes, dtype=tf.float32),
+ label_weights)
+ return tf.multiply(not_ignore_mask, weight_mask)
+
+
+def get_batch_norm_fn(sync_batch_norm_method):
+ """Gets batch norm function.
+
+ Currently we only support the following methods:
+ - `None` (no sync batch norm). We use slim.batch_norm in this case.
+
+ Args:
+ sync_batch_norm_method: String, method used to sync batch norm.
+
+ Returns:
+ Batchnorm function.
+
+ Raises:
+ ValueError: If sync_batch_norm_method is not supported.
+ """
+ if sync_batch_norm_method == 'None':
+ return slim.batch_norm
+ else:
+ raise ValueError('Unsupported sync_batch_norm_method.')
+
+
+def get_batch_norm_params(decay=0.9997,
+ epsilon=1e-5,
+ center=True,
+ scale=True,
+ is_training=True,
+ sync_batch_norm_method='None',
+ initialize_gamma_as_zeros=False):
+ """Gets batch norm parameters.
+
+ Args:
+ decay: Float, decay for the moving average.
+ epsilon: Float, value added to variance to avoid dividing by zero.
+ center: Boolean. If True, add offset of `beta` to normalized tensor. If
+ False,`beta` is ignored.
+ scale: Boolean. If True, multiply by `gamma`. If False, `gamma` is not used.
+ is_training: Boolean, whether or not the layer is in training mode.
+ sync_batch_norm_method: String, method used to sync batch norm.
+ initialize_gamma_as_zeros: Boolean, initializing `gamma` as zeros or not.
+
+ Returns:
+ A dictionary for batchnorm parameters.
+
+ Raises:
+ ValueError: If sync_batch_norm_method is not supported.
+ """
+ batch_norm_params = {
+ 'is_training': is_training,
+ 'decay': decay,
+ 'epsilon': epsilon,
+ 'scale': scale,
+ 'center': center,
+ }
+ if initialize_gamma_as_zeros:
+ if sync_batch_norm_method == 'None':
+ # Slim-type gamma_initialier.
+ batch_norm_params['param_initializers'] = {
+ 'gamma': tf.zeros_initializer(),
+ }
+ else:
+ raise ValueError('Unsupported sync_batch_norm_method.')
+ return batch_norm_params
diff --git a/models/research/deeplab/core/utils_test.py b/models/research/deeplab/core/utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfdb63ef2d3faaa8090867a5382876972b0cff3d
--- /dev/null
+++ b/models/research/deeplab/core/utils_test.py
@@ -0,0 +1,90 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utils.py."""
+
+import numpy as np
+import tensorflow as tf
+
+from deeplab.core import utils
+
+
+class UtilsTest(tf.test.TestCase):
+
+ def testScaleDimensionOutput(self):
+ self.assertEqual(161, utils.scale_dimension(321, 0.5))
+ self.assertEqual(193, utils.scale_dimension(321, 0.6))
+ self.assertEqual(241, utils.scale_dimension(321, 0.75))
+
+ def testGetLabelWeightMask_withFloatLabelWeights(self):
+ labels = tf.constant([0, 4, 1, 3, 2])
+ ignore_label = 4
+ num_classes = 5
+ label_weights = 0.5
+ expected_label_weight_mask = np.array([0.5, 0.0, 0.5, 0.5, 0.5],
+ dtype=np.float32)
+
+ with self.test_session() as sess:
+ label_weight_mask = utils.get_label_weight_mask(
+ labels, ignore_label, num_classes, label_weights=label_weights)
+ label_weight_mask = sess.run(label_weight_mask)
+ self.assertAllEqual(label_weight_mask, expected_label_weight_mask)
+
+ def testGetLabelWeightMask_withListLabelWeights(self):
+ labels = tf.constant([0, 4, 1, 3, 2])
+ ignore_label = 4
+ num_classes = 5
+ label_weights = [0.0, 0.1, 0.2, 0.3, 0.4]
+ expected_label_weight_mask = np.array([0.0, 0.0, 0.1, 0.3, 0.2],
+ dtype=np.float32)
+
+ with self.test_session() as sess:
+ label_weight_mask = utils.get_label_weight_mask(
+ labels, ignore_label, num_classes, label_weights=label_weights)
+ label_weight_mask = sess.run(label_weight_mask)
+ self.assertAllEqual(label_weight_mask, expected_label_weight_mask)
+
+ def testGetLabelWeightMask_withInvalidLabelWeightsType(self):
+ labels = tf.constant([0, 4, 1, 3, 2])
+ ignore_label = 4
+ num_classes = 5
+
+ self.assertRaisesWithRegexpMatch(
+ ValueError,
+ '^The type of label_weights is invalid, it must be a float or a list',
+ utils.get_label_weight_mask,
+ labels=labels,
+ ignore_label=ignore_label,
+ num_classes=num_classes,
+ label_weights=None)
+
+ def testGetLabelWeightMask_withInvalidLabelWeightsLength(self):
+ labels = tf.constant([0, 4, 1, 3, 2])
+ ignore_label = 4
+ num_classes = 5
+ label_weights = [0.0, 0.1, 0.2]
+
+ self.assertRaisesWithRegexpMatch(
+ ValueError,
+ '^Length of label_weights must be equal to num_classes if it is a list',
+ utils.get_label_weight_mask,
+ labels=labels,
+ ignore_label=ignore_label,
+ num_classes=num_classes,
+ label_weights=label_weights)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/core/xception.py b/models/research/deeplab/core/xception.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9925714716ea0346dd8df75b956a876e52bde69
--- /dev/null
+++ b/models/research/deeplab/core/xception.py
@@ -0,0 +1,945 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""Xception model.
+
+"Xception: Deep Learning with Depthwise Separable Convolutions"
+Fran{\c{c}}ois Chollet
+https://arxiv.org/abs/1610.02357
+
+We implement the modified version by Jifeng Dai et al. for their COCO 2017
+detection challenge submission, where the model is made deeper and has aligned
+features for dense prediction tasks. See their slides for details:
+
+"Deformable Convolutional Networks -- COCO Detection and Segmentation Challenge
+2017 Entry"
+Haozhi Qi, Zheng Zhang, Bin Xiao, Han Hu, Bowen Cheng, Yichen Wei and Jifeng Dai
+ICCV 2017 COCO Challenge workshop
+http://presentations.cocodataset.org/COCO17-Detect-MSRA.pdf
+
+We made a few more changes on top of MSRA's modifications:
+1. Fully convolutional: All the max-pooling layers are replaced with separable
+ conv2d with stride = 2. This allows us to use atrous convolution to extract
+ feature maps at any resolution.
+
+2. We support adding ReLU and BatchNorm after depthwise convolution, motivated
+ by the design of MobileNetv1.
+
+"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision
+Applications"
+Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang,
+Tobias Weyand, Marco Andreetto, Hartwig Adam
+https://arxiv.org/abs/1704.04861
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+from six.moves import range
+import tensorflow as tf
+from tensorflow.contrib import slim as contrib_slim
+
+from deeplab.core import utils
+from tensorflow.contrib.slim.nets import resnet_utils
+from nets.mobilenet import conv_blocks as mobilenet_v3_ops
+
+slim = contrib_slim
+
+
+_DEFAULT_MULTI_GRID = [1, 1, 1]
+# The cap for tf.clip_by_value.
+_CLIP_CAP = 6
+
+
+class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
+ """A named tuple describing an Xception block.
+
+ Its parts are:
+ scope: The scope of the block.
+ unit_fn: The Xception unit function which takes as input a tensor and
+ returns another tensor with the output of the Xception unit.
+ args: A list of length equal to the number of units in the block. The list
+ contains one dictionary for each unit in the block to serve as argument to
+ unit_fn.
+ """
+
+
+def fixed_padding(inputs, kernel_size, rate=1):
+ """Pads the input along the spatial dimensions independently of input size.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels].
+ kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
+ Should be a positive integer.
+ rate: An integer, rate for atrous convolution.
+
+ Returns:
+ output: A tensor of size [batch, height_out, width_out, channels] with the
+ input, either intact (if kernel_size == 1) or padded (if kernel_size > 1).
+ """
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
+ pad_total = kernel_size_effective - 1
+ pad_beg = pad_total // 2
+ pad_end = pad_total - pad_beg
+ padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
+ [pad_beg, pad_end], [0, 0]])
+ return padded_inputs
+
+
+@slim.add_arg_scope
+def separable_conv2d_same(inputs,
+ num_outputs,
+ kernel_size,
+ depth_multiplier,
+ stride,
+ rate=1,
+ use_explicit_padding=True,
+ regularize_depthwise=False,
+ scope=None,
+ **kwargs):
+ """Strided 2-D separable convolution with 'SAME' padding.
+
+ If stride > 1 and use_explicit_padding is True, then we do explicit zero-
+ padding, followed by conv2d with 'VALID' padding.
+
+ Note that
+
+ net = separable_conv2d_same(inputs, num_outputs, 3,
+ depth_multiplier=1, stride=stride)
+
+ is equivalent to
+
+ net = slim.separable_conv2d(inputs, num_outputs, 3,
+ depth_multiplier=1, stride=1, padding='SAME')
+ net = resnet_utils.subsample(net, factor=stride)
+
+ whereas
+
+ net = slim.separable_conv2d(inputs, num_outputs, 3, stride=stride,
+ depth_multiplier=1, padding='SAME')
+
+ is different when the input's height or width is even, which is why we add the
+ current function.
+
+ Consequently, if the input feature map has even height or width, setting
+ `use_explicit_padding=False` will result in feature misalignment by one pixel
+ along the corresponding dimension.
+
+ Args:
+ inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
+ num_outputs: An integer, the number of output filters.
+ kernel_size: An int with the kernel_size of the filters.
+ depth_multiplier: The number of depthwise convolution output channels for
+ each input channel. The total number of depthwise convolution output
+ channels will be equal to `num_filters_in * depth_multiplier`.
+ stride: An integer, the output stride.
+ rate: An integer, rate for atrous convolution.
+ use_explicit_padding: If True, use explicit padding to make the model fully
+ compatible with the open source version, otherwise use the native
+ Tensorflow 'SAME' padding.
+ regularize_depthwise: Whether or not apply L2-norm regularization on the
+ depthwise convolution weights.
+ scope: Scope.
+ **kwargs: additional keyword arguments to pass to slim.conv2d
+
+ Returns:
+ output: A 4-D tensor of size [batch, height_out, width_out, channels] with
+ the convolution output.
+ """
+ def _separable_conv2d(padding):
+ """Wrapper for separable conv2d."""
+ return slim.separable_conv2d(inputs,
+ num_outputs,
+ kernel_size,
+ depth_multiplier=depth_multiplier,
+ stride=stride,
+ rate=rate,
+ padding=padding,
+ scope=scope,
+ **kwargs)
+ def _split_separable_conv2d(padding):
+ """Splits separable conv2d into depthwise and pointwise conv2d."""
+ outputs = slim.separable_conv2d(inputs,
+ None,
+ kernel_size,
+ depth_multiplier=depth_multiplier,
+ stride=stride,
+ rate=rate,
+ padding=padding,
+ scope=scope + '_depthwise',
+ **kwargs)
+ return slim.conv2d(outputs,
+ num_outputs,
+ 1,
+ scope=scope + '_pointwise',
+ **kwargs)
+ if stride == 1 or not use_explicit_padding:
+ if regularize_depthwise:
+ outputs = _separable_conv2d(padding='SAME')
+ else:
+ outputs = _split_separable_conv2d(padding='SAME')
+ else:
+ inputs = fixed_padding(inputs, kernel_size, rate)
+ if regularize_depthwise:
+ outputs = _separable_conv2d(padding='VALID')
+ else:
+ outputs = _split_separable_conv2d(padding='VALID')
+ return outputs
+
+
+@slim.add_arg_scope
+def xception_module(inputs,
+ depth_list,
+ skip_connection_type,
+ stride,
+ kernel_size=3,
+ unit_rate_list=None,
+ rate=1,
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=False,
+ outputs_collections=None,
+ scope=None,
+ use_bounded_activation=False,
+ use_explicit_padding=True,
+ use_squeeze_excite=False,
+ se_pool_size=None):
+ """An Xception module.
+
+ The output of one Xception module is equal to the sum of `residual` and
+ `shortcut`, where `residual` is the feature computed by three separable
+ convolution. The `shortcut` is the feature computed by 1x1 convolution with
+ or without striding. In some cases, the `shortcut` path could be a simple
+ identity function or none (i.e, no shortcut).
+
+ Note that we replace the max pooling operations in the Xception module with
+ another separable convolution with striding, since atrous rate is not properly
+ supported in current TensorFlow max pooling implementation.
+
+ Args:
+ inputs: A tensor of size [batch, height, width, channels].
+ depth_list: A list of three integers specifying the depth values of one
+ Xception module.
+ skip_connection_type: Skip connection type for the residual path. Only
+ supports 'conv', 'sum', or 'none'.
+ stride: The block unit's stride. Determines the amount of downsampling of
+ the units output compared to its input.
+ kernel_size: Integer, convolution kernel size.
+ unit_rate_list: A list of three integers, determining the unit rate for
+ each separable convolution in the xception module.
+ rate: An integer, rate for atrous convolution.
+ activation_fn_in_separable_conv: Includes activation function in the
+ separable convolution or not.
+ regularize_depthwise: Whether or not apply L2-norm regularization on the
+ depthwise convolution weights.
+ outputs_collections: Collection to add the Xception unit output.
+ scope: Optional variable_scope.
+ use_bounded_activation: Whether or not to use bounded activations. Bounded
+ activations better lend themselves to quantized inference.
+ use_explicit_padding: If True, use explicit padding to make the model fully
+ compatible with the open source version, otherwise use the native
+ Tensorflow 'SAME' padding.
+ use_squeeze_excite: Boolean, use squeeze-and-excitation or not.
+ se_pool_size: None or integer specifying the pooling size used in SE module.
+
+ Returns:
+ The Xception module's output.
+
+ Raises:
+ ValueError: If depth_list and unit_rate_list do not contain three elements,
+ or if stride != 1 for the third separable convolution operation in the
+ residual path, or unsupported skip connection type.
+ """
+ if len(depth_list) != 3:
+ raise ValueError('Expect three elements in depth_list.')
+ if unit_rate_list:
+ if len(unit_rate_list) != 3:
+ raise ValueError('Expect three elements in unit_rate_list.')
+
+ with tf.variable_scope(scope, 'xception_module', [inputs]) as sc:
+ residual = inputs
+
+ def _separable_conv(features, depth, kernel_size, depth_multiplier,
+ regularize_depthwise, rate, stride, scope):
+ """Separable conv block."""
+ if activation_fn_in_separable_conv:
+ activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
+ else:
+ if use_bounded_activation:
+ # When use_bounded_activation is True, we clip the feature values and
+ # apply relu6 for activation.
+ activation_fn = lambda x: tf.clip_by_value(x, -_CLIP_CAP, _CLIP_CAP)
+ features = tf.nn.relu6(features)
+ else:
+ # Original network design.
+ activation_fn = None
+ features = tf.nn.relu(features)
+ return separable_conv2d_same(features,
+ depth,
+ kernel_size,
+ depth_multiplier=depth_multiplier,
+ stride=stride,
+ rate=rate,
+ activation_fn=activation_fn,
+ use_explicit_padding=use_explicit_padding,
+ regularize_depthwise=regularize_depthwise,
+ scope=scope)
+ for i in range(3):
+ residual = _separable_conv(residual,
+ depth_list[i],
+ kernel_size=kernel_size,
+ depth_multiplier=1,
+ regularize_depthwise=regularize_depthwise,
+ rate=rate*unit_rate_list[i],
+ stride=stride if i == 2 else 1,
+ scope='separable_conv' + str(i+1))
+ if use_squeeze_excite:
+ residual = mobilenet_v3_ops.squeeze_excite(
+ input_tensor=residual,
+ squeeze_factor=16,
+ inner_activation_fn=tf.nn.relu,
+ gating_fn=lambda x: tf.nn.relu6(x+3)*0.16667,
+ pool=se_pool_size)
+
+ if skip_connection_type == 'conv':
+ shortcut = slim.conv2d(inputs,
+ depth_list[-1],
+ [1, 1],
+ stride=stride,
+ activation_fn=None,
+ scope='shortcut')
+ if use_bounded_activation:
+ residual = tf.clip_by_value(residual, -_CLIP_CAP, _CLIP_CAP)
+ shortcut = tf.clip_by_value(shortcut, -_CLIP_CAP, _CLIP_CAP)
+ outputs = residual + shortcut
+ if use_bounded_activation:
+ outputs = tf.nn.relu6(outputs)
+ elif skip_connection_type == 'sum':
+ if use_bounded_activation:
+ residual = tf.clip_by_value(residual, -_CLIP_CAP, _CLIP_CAP)
+ inputs = tf.clip_by_value(inputs, -_CLIP_CAP, _CLIP_CAP)
+ outputs = residual + inputs
+ if use_bounded_activation:
+ outputs = tf.nn.relu6(outputs)
+ elif skip_connection_type == 'none':
+ outputs = residual
+ else:
+ raise ValueError('Unsupported skip connection type.')
+
+ return slim.utils.collect_named_outputs(outputs_collections,
+ sc.name,
+ outputs)
+
+
+@slim.add_arg_scope
+def stack_blocks_dense(net,
+ blocks,
+ output_stride=None,
+ outputs_collections=None):
+ """Stacks Xception blocks and controls output feature density.
+
+ First, this function creates scopes for the Xception in the form of
+ 'block_name/unit_1', 'block_name/unit_2', etc.
+
+ Second, this function allows the user to explicitly control the output
+ stride, which is the ratio of the input to output spatial resolution. This
+ is useful for dense prediction tasks such as semantic segmentation or
+ object detection.
+
+ Control of the output feature density is implemented by atrous convolution.
+
+ Args:
+ net: A tensor of size [batch, height, width, channels].
+ blocks: A list of length equal to the number of Xception blocks. Each
+ element is an Xception Block object describing the units in the block.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution, which needs to be equal to
+ the product of unit strides from the start up to some level of Xception.
+ For example, if the Xception employs units with strides 1, 2, 1, 3, 4, 1,
+ then valid values for the output_stride are 1, 2, 6, 24 or None (which
+ is equivalent to output_stride=24).
+ outputs_collections: Collection to add the Xception block outputs.
+
+ Returns:
+ net: Output tensor with stride equal to the specified output_stride.
+
+ Raises:
+ ValueError: If the target output_stride is not valid.
+ """
+ # The current_stride variable keeps track of the effective stride of the
+ # activations. This allows us to invoke atrous convolution whenever applying
+ # the next residual unit would result in the activations having stride larger
+ # than the target output_stride.
+ current_stride = 1
+
+ # The atrous convolution rate parameter.
+ rate = 1
+
+ for block in blocks:
+ with tf.variable_scope(block.scope, 'block', [net]) as sc:
+ for i, unit in enumerate(block.args):
+ if output_stride is not None and current_stride > output_stride:
+ raise ValueError('The target output_stride cannot be reached.')
+ with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
+ # If we have reached the target output_stride, then we need to employ
+ # atrous convolution with stride=1 and multiply the atrous rate by the
+ # current unit's stride for use in subsequent layers.
+ if output_stride is not None and current_stride == output_stride:
+ net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
+ rate *= unit.get('stride', 1)
+ else:
+ net = block.unit_fn(net, rate=1, **unit)
+ current_stride *= unit.get('stride', 1)
+
+ # Collect activations at the block's end before performing subsampling.
+ net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
+
+ if output_stride is not None and current_stride != output_stride:
+ raise ValueError('The target output_stride cannot be reached.')
+
+ return net
+
+
+def xception(inputs,
+ blocks,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ keep_prob=0.5,
+ output_stride=None,
+ reuse=None,
+ scope=None,
+ sync_batch_norm_method='None'):
+ """Generator for Xception models.
+
+ This function generates a family of Xception models. See the xception_*()
+ methods for specific model instantiations, obtained by selecting different
+ block instantiations that produce Xception of various depths.
+
+ Args:
+ inputs: A tensor of size [batch, height_in, width_in, channels]. Must be
+ floating point. If a pretrained checkpoint is used, pixel values should be
+ the same as during training (see go/slim-classification-models for
+ specifics).
+ blocks: A list of length equal to the number of Xception blocks. Each
+ element is an Xception Block object describing the units in the block.
+ num_classes: Number of predicted classes for classification tasks.
+ If 0 or None, we return the features before the logit layer.
+ is_training: whether batch_norm layers are in training mode.
+ global_pool: If True, we perform global average pooling before computing the
+ logits. Set to True for image classification, False for dense prediction.
+ keep_prob: Keep probability used in the pre-logits dropout layer.
+ output_stride: If None, then the output will be computed at the nominal
+ network stride. If output_stride is not None, it specifies the requested
+ ratio of input to output spatial resolution.
+ reuse: whether or not the network and its variables should be reused. To be
+ able to reuse 'scope' must be given.
+ scope: Optional variable_scope.
+ sync_batch_norm_method: String, sync batchnorm method. Currently only
+ support `None`.
+
+ Returns:
+ net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+ If global_pool is False, then height_out and width_out are reduced by a
+ factor of output_stride compared to the respective height_in and width_in,
+ else both height_out and width_out equal one. If num_classes is 0 or None,
+ then net is the output of the last Xception block, potentially after
+ global average pooling. If num_classes is a non-zero integer, net contains
+ the pre-softmax activations.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+
+ Raises:
+ ValueError: If the target output_stride is not valid.
+ """
+ with tf.variable_scope(
+ scope, 'xception', [inputs], reuse=reuse) as sc:
+ end_points_collection = sc.original_name_scope + 'end_points'
+ batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
+ with slim.arg_scope([slim.conv2d,
+ slim.separable_conv2d,
+ xception_module,
+ stack_blocks_dense],
+ outputs_collections=end_points_collection):
+ with slim.arg_scope([batch_norm], is_training=is_training):
+ net = inputs
+ if output_stride is not None:
+ if output_stride % 2 != 0:
+ raise ValueError('The output_stride needs to be a multiple of 2.')
+ output_stride //= 2
+ # Root block function operated on inputs.
+ net = resnet_utils.conv2d_same(net, 32, 3, stride=2,
+ scope='entry_flow/conv1_1')
+ net = resnet_utils.conv2d_same(net, 64, 3, stride=1,
+ scope='entry_flow/conv1_2')
+
+ # Extract features for entry_flow, middle_flow, and exit_flow.
+ net = stack_blocks_dense(net, blocks, output_stride)
+
+ # Convert end_points_collection into a dictionary of end_points.
+ end_points = slim.utils.convert_collection_to_dict(
+ end_points_collection, clear_collection=True)
+
+ if global_pool:
+ # Global average pooling.
+ net = tf.reduce_mean(net, [1, 2], name='global_pool', keepdims=True)
+ end_points['global_pool'] = net
+ if num_classes:
+ net = slim.dropout(net, keep_prob=keep_prob, is_training=is_training,
+ scope='prelogits_dropout')
+ net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
+ normalizer_fn=None, scope='logits')
+ end_points[sc.name + '/logits'] = net
+ end_points['predictions'] = slim.softmax(net, scope='predictions')
+ return net, end_points
+
+
+def xception_block(scope,
+ depth_list,
+ skip_connection_type,
+ activation_fn_in_separable_conv,
+ regularize_depthwise,
+ num_units,
+ stride,
+ kernel_size=3,
+ unit_rate_list=None,
+ use_squeeze_excite=False,
+ se_pool_size=None):
+ """Helper function for creating a Xception block.
+
+ Args:
+ scope: The scope of the block.
+ depth_list: The depth of the bottleneck layer for each unit.
+ skip_connection_type: Skip connection type for the residual path. Only
+ supports 'conv', 'sum', or 'none'.
+ activation_fn_in_separable_conv: Includes activation function in the
+ separable convolution or not.
+ regularize_depthwise: Whether or not apply L2-norm regularization on the
+ depthwise convolution weights.
+ num_units: The number of units in the block.
+ stride: The stride of the block, implemented as a stride in the last unit.
+ All other units have stride=1.
+ kernel_size: Integer, convolution kernel size.
+ unit_rate_list: A list of three integers, determining the unit rate in the
+ corresponding xception block.
+ use_squeeze_excite: Boolean, use squeeze-and-excitation or not.
+ se_pool_size: None or integer specifying the pooling size used in SE module.
+
+ Returns:
+ An Xception block.
+ """
+ if unit_rate_list is None:
+ unit_rate_list = _DEFAULT_MULTI_GRID
+ return Block(scope, xception_module, [{
+ 'depth_list': depth_list,
+ 'skip_connection_type': skip_connection_type,
+ 'activation_fn_in_separable_conv': activation_fn_in_separable_conv,
+ 'regularize_depthwise': regularize_depthwise,
+ 'stride': stride,
+ 'kernel_size': kernel_size,
+ 'unit_rate_list': unit_rate_list,
+ 'use_squeeze_excite': use_squeeze_excite,
+ 'se_pool_size': se_pool_size,
+ }] * num_units)
+
+
+def xception_41(inputs,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ keep_prob=0.5,
+ output_stride=None,
+ regularize_depthwise=False,
+ multi_grid=None,
+ reuse=None,
+ scope='xception_41',
+ sync_batch_norm_method='None'):
+ """Xception-41 model."""
+ blocks = [
+ xception_block('entry_flow/block1',
+ depth_list=[128, 128, 128],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2),
+ xception_block('entry_flow/block2',
+ depth_list=[256, 256, 256],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2),
+ xception_block('entry_flow/block3',
+ depth_list=[728, 728, 728],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2),
+ xception_block('middle_flow/block1',
+ depth_list=[728, 728, 728],
+ skip_connection_type='sum',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=8,
+ stride=1),
+ xception_block('exit_flow/block1',
+ depth_list=[728, 1024, 1024],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2),
+ xception_block('exit_flow/block2',
+ depth_list=[1536, 1536, 2048],
+ skip_connection_type='none',
+ activation_fn_in_separable_conv=True,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=1,
+ unit_rate_list=multi_grid),
+ ]
+ return xception(inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ keep_prob=keep_prob,
+ output_stride=output_stride,
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def xception_65_factory(inputs,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ keep_prob=0.5,
+ output_stride=None,
+ regularize_depthwise=False,
+ kernel_size=3,
+ multi_grid=None,
+ reuse=None,
+ use_squeeze_excite=False,
+ se_pool_size=None,
+ scope='xception_65',
+ sync_batch_norm_method='None'):
+ """Xception-65 model factory."""
+ blocks = [
+ xception_block('entry_flow/block1',
+ depth_list=[128, 128, 128],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2,
+ kernel_size=kernel_size,
+ use_squeeze_excite=False,
+ se_pool_size=se_pool_size),
+ xception_block('entry_flow/block2',
+ depth_list=[256, 256, 256],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2,
+ kernel_size=kernel_size,
+ use_squeeze_excite=False,
+ se_pool_size=se_pool_size),
+ xception_block('entry_flow/block3',
+ depth_list=[728, 728, 728],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2,
+ kernel_size=kernel_size,
+ use_squeeze_excite=use_squeeze_excite,
+ se_pool_size=se_pool_size),
+ xception_block('middle_flow/block1',
+ depth_list=[728, 728, 728],
+ skip_connection_type='sum',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=16,
+ stride=1,
+ kernel_size=kernel_size,
+ use_squeeze_excite=use_squeeze_excite,
+ se_pool_size=se_pool_size),
+ xception_block('exit_flow/block1',
+ depth_list=[728, 1024, 1024],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2,
+ kernel_size=kernel_size,
+ use_squeeze_excite=use_squeeze_excite,
+ se_pool_size=se_pool_size),
+ xception_block('exit_flow/block2',
+ depth_list=[1536, 1536, 2048],
+ skip_connection_type='none',
+ activation_fn_in_separable_conv=True,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=1,
+ kernel_size=kernel_size,
+ unit_rate_list=multi_grid,
+ use_squeeze_excite=False,
+ se_pool_size=se_pool_size),
+ ]
+ return xception(inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ keep_prob=keep_prob,
+ output_stride=output_stride,
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def xception_65(inputs,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ keep_prob=0.5,
+ output_stride=None,
+ regularize_depthwise=False,
+ multi_grid=None,
+ reuse=None,
+ scope='xception_65',
+ sync_batch_norm_method='None'):
+ """Xception-65 model."""
+ return xception_65_factory(
+ inputs=inputs,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ keep_prob=keep_prob,
+ output_stride=output_stride,
+ regularize_depthwise=regularize_depthwise,
+ multi_grid=multi_grid,
+ reuse=reuse,
+ scope=scope,
+ use_squeeze_excite=False,
+ se_pool_size=None,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def xception_71_factory(inputs,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ keep_prob=0.5,
+ output_stride=None,
+ regularize_depthwise=False,
+ kernel_size=3,
+ multi_grid=None,
+ reuse=None,
+ scope='xception_71',
+ use_squeeze_excite=False,
+ se_pool_size=None,
+ sync_batch_norm_method='None'):
+ """Xception-71 model factory."""
+ blocks = [
+ xception_block('entry_flow/block1',
+ depth_list=[128, 128, 128],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2,
+ kernel_size=kernel_size,
+ use_squeeze_excite=False,
+ se_pool_size=se_pool_size),
+ xception_block('entry_flow/block2',
+ depth_list=[256, 256, 256],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=1,
+ kernel_size=kernel_size,
+ use_squeeze_excite=False,
+ se_pool_size=se_pool_size),
+ xception_block('entry_flow/block3',
+ depth_list=[256, 256, 256],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2,
+ kernel_size=kernel_size,
+ use_squeeze_excite=False,
+ se_pool_size=se_pool_size),
+ xception_block('entry_flow/block4',
+ depth_list=[728, 728, 728],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=1,
+ kernel_size=kernel_size,
+ use_squeeze_excite=use_squeeze_excite,
+ se_pool_size=se_pool_size),
+ xception_block('entry_flow/block5',
+ depth_list=[728, 728, 728],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2,
+ kernel_size=kernel_size,
+ use_squeeze_excite=use_squeeze_excite,
+ se_pool_size=se_pool_size),
+ xception_block('middle_flow/block1',
+ depth_list=[728, 728, 728],
+ skip_connection_type='sum',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=16,
+ stride=1,
+ kernel_size=kernel_size,
+ use_squeeze_excite=use_squeeze_excite,
+ se_pool_size=se_pool_size),
+ xception_block('exit_flow/block1',
+ depth_list=[728, 1024, 1024],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2,
+ kernel_size=kernel_size,
+ use_squeeze_excite=use_squeeze_excite,
+ se_pool_size=se_pool_size),
+ xception_block('exit_flow/block2',
+ depth_list=[1536, 1536, 2048],
+ skip_connection_type='none',
+ activation_fn_in_separable_conv=True,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=1,
+ kernel_size=kernel_size,
+ unit_rate_list=multi_grid,
+ use_squeeze_excite=False,
+ se_pool_size=se_pool_size),
+ ]
+ return xception(inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ keep_prob=keep_prob,
+ output_stride=output_stride,
+ reuse=reuse,
+ scope=scope,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def xception_71(inputs,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ keep_prob=0.5,
+ output_stride=None,
+ regularize_depthwise=False,
+ multi_grid=None,
+ reuse=None,
+ scope='xception_71',
+ sync_batch_norm_method='None'):
+ """Xception-71 model."""
+ return xception_71_factory(
+ inputs=inputs,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ keep_prob=keep_prob,
+ output_stride=output_stride,
+ regularize_depthwise=regularize_depthwise,
+ multi_grid=multi_grid,
+ reuse=reuse,
+ scope=scope,
+ use_squeeze_excite=False,
+ se_pool_size=None,
+ sync_batch_norm_method=sync_batch_norm_method)
+
+
+def xception_arg_scope(weight_decay=0.00004,
+ batch_norm_decay=0.9997,
+ batch_norm_epsilon=0.001,
+ batch_norm_scale=True,
+ weights_initializer_stddev=0.09,
+ regularize_depthwise=False,
+ use_batch_norm=True,
+ use_bounded_activation=False,
+ sync_batch_norm_method='None'):
+ """Defines the default Xception arg scope.
+
+ Args:
+ weight_decay: The weight decay to use for regularizing the model.
+ batch_norm_decay: The moving average decay when estimating layer activation
+ statistics in batch normalization.
+ batch_norm_epsilon: Small constant to prevent division by zero when
+ normalizing activations by their variance in batch normalization.
+ batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
+ activations in the batch normalization layer.
+ weights_initializer_stddev: The standard deviation of the trunctated normal
+ weight initializer.
+ regularize_depthwise: Whether or not apply L2-norm regularization on the
+ depthwise convolution weights.
+ use_batch_norm: Whether or not to use batch normalization.
+ use_bounded_activation: Whether or not to use bounded activations. Bounded
+ activations better lend themselves to quantized inference.
+ sync_batch_norm_method: String, sync batchnorm method. Currently only
+ support `None`. Also, it is only effective for Xception.
+
+ Returns:
+ An `arg_scope` to use for the Xception models.
+ """
+ batch_norm_params = {
+ 'decay': batch_norm_decay,
+ 'epsilon': batch_norm_epsilon,
+ 'scale': batch_norm_scale,
+ }
+ if regularize_depthwise:
+ depthwise_regularizer = slim.l2_regularizer(weight_decay)
+ else:
+ depthwise_regularizer = None
+ activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
+ batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
+ with slim.arg_scope(
+ [slim.conv2d, slim.separable_conv2d],
+ weights_initializer=tf.truncated_normal_initializer(
+ stddev=weights_initializer_stddev),
+ activation_fn=activation_fn,
+ normalizer_fn=batch_norm if use_batch_norm else None):
+ with slim.arg_scope([batch_norm], **batch_norm_params):
+ with slim.arg_scope(
+ [slim.conv2d],
+ weights_regularizer=slim.l2_regularizer(weight_decay)):
+ with slim.arg_scope(
+ [slim.separable_conv2d],
+ weights_regularizer=depthwise_regularizer):
+ with slim.arg_scope(
+ [xception_module],
+ use_bounded_activation=use_bounded_activation,
+ use_explicit_padding=not use_bounded_activation) as arg_sc:
+ return arg_sc
diff --git a/models/research/deeplab/core/xception_test.py b/models/research/deeplab/core/xception_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc338daa6e56d1290f5a9330a6728c1f8512881e
--- /dev/null
+++ b/models/research/deeplab/core/xception_test.py
@@ -0,0 +1,488 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for xception.py."""
+import numpy as np
+import six
+import tensorflow as tf
+from tensorflow.contrib import slim as contrib_slim
+
+from deeplab.core import xception
+from tensorflow.contrib.slim.nets import resnet_utils
+
+slim = contrib_slim
+
+
+def create_test_input(batch, height, width, channels):
+ """Create test input tensor."""
+ if None in [batch, height, width, channels]:
+ return tf.placeholder(tf.float32, (batch, height, width, channels))
+ else:
+ return tf.cast(
+ np.tile(
+ np.reshape(
+ np.reshape(np.arange(height), [height, 1]) +
+ np.reshape(np.arange(width), [1, width]),
+ [1, height, width, 1]),
+ [batch, 1, 1, channels]),
+ tf.float32)
+
+
+class UtilityFunctionTest(tf.test.TestCase):
+
+ def testSeparableConv2DSameWithInputEvenSize(self):
+ n, n2 = 4, 2
+
+ # Input image.
+ x = create_test_input(1, n, n, 1)
+
+ # Convolution kernel.
+ dw = create_test_input(1, 3, 3, 1)
+ dw = tf.reshape(dw, [3, 3, 1, 1])
+
+ tf.get_variable('Conv/depthwise_weights', initializer=dw)
+ tf.get_variable('Conv/pointwise_weights',
+ initializer=tf.ones([1, 1, 1, 1]))
+ tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
+ tf.get_variable_scope().reuse_variables()
+
+ y1 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
+ stride=1, scope='Conv')
+ y1_expected = tf.cast([[14, 28, 43, 26],
+ [28, 48, 66, 37],
+ [43, 66, 84, 46],
+ [26, 37, 46, 22]], tf.float32)
+ y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
+
+ y2 = resnet_utils.subsample(y1, 2)
+ y2_expected = tf.cast([[14, 43],
+ [43, 84]], tf.float32)
+ y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
+
+ y3 = xception.separable_conv2d_same(x, 1, 3, depth_multiplier=1,
+ regularize_depthwise=True,
+ stride=2, scope='Conv')
+ y3_expected = y2_expected
+
+ y4 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
+ stride=2, scope='Conv')
+ y4_expected = tf.cast([[48, 37],
+ [37, 22]], tf.float32)
+ y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])
+
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ self.assertAllClose(y1.eval(), y1_expected.eval())
+ self.assertAllClose(y2.eval(), y2_expected.eval())
+ self.assertAllClose(y3.eval(), y3_expected.eval())
+ self.assertAllClose(y4.eval(), y4_expected.eval())
+
+ def testSeparableConv2DSameWithInputOddSize(self):
+ n, n2 = 5, 3
+
+ # Input image.
+ x = create_test_input(1, n, n, 1)
+
+ # Convolution kernel.
+ dw = create_test_input(1, 3, 3, 1)
+ dw = tf.reshape(dw, [3, 3, 1, 1])
+
+ tf.get_variable('Conv/depthwise_weights', initializer=dw)
+ tf.get_variable('Conv/pointwise_weights',
+ initializer=tf.ones([1, 1, 1, 1]))
+ tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
+ tf.get_variable_scope().reuse_variables()
+
+ y1 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
+ stride=1, scope='Conv')
+ y1_expected = tf.cast([[14, 28, 43, 58, 34],
+ [28, 48, 66, 84, 46],
+ [43, 66, 84, 102, 55],
+ [58, 84, 102, 120, 64],
+ [34, 46, 55, 64, 30]], tf.float32)
+ y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
+
+ y2 = resnet_utils.subsample(y1, 2)
+ y2_expected = tf.cast([[14, 43, 34],
+ [43, 84, 55],
+ [34, 55, 30]], tf.float32)
+ y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
+
+ y3 = xception.separable_conv2d_same(x, 1, 3, depth_multiplier=1,
+ regularize_depthwise=True,
+ stride=2, scope='Conv')
+ y3_expected = y2_expected
+
+ y4 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
+ stride=2, scope='Conv')
+ y4_expected = y2_expected
+
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ self.assertAllClose(y1.eval(), y1_expected.eval())
+ self.assertAllClose(y2.eval(), y2_expected.eval())
+ self.assertAllClose(y3.eval(), y3_expected.eval())
+ self.assertAllClose(y4.eval(), y4_expected.eval())
+
+
+class XceptionNetworkTest(tf.test.TestCase):
+ """Tests with small Xception network."""
+
+ def _xception_small(self,
+ inputs,
+ num_classes=None,
+ is_training=True,
+ global_pool=True,
+ output_stride=None,
+ regularize_depthwise=True,
+ reuse=None,
+ scope='xception_small'):
+ """A shallow and thin Xception for faster tests."""
+ block = xception.xception_block
+ blocks = [
+ block('entry_flow/block1',
+ depth_list=[1, 1, 1],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2),
+ block('entry_flow/block2',
+ depth_list=[2, 2, 2],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2),
+ block('entry_flow/block3',
+ depth_list=[4, 4, 4],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=1),
+ block('entry_flow/block4',
+ depth_list=[4, 4, 4],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2),
+ block('middle_flow/block1',
+ depth_list=[4, 4, 4],
+ skip_connection_type='sum',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=2,
+ stride=1),
+ block('exit_flow/block1',
+ depth_list=[8, 8, 8],
+ skip_connection_type='conv',
+ activation_fn_in_separable_conv=False,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=2),
+ block('exit_flow/block2',
+ depth_list=[16, 16, 16],
+ skip_connection_type='none',
+ activation_fn_in_separable_conv=True,
+ regularize_depthwise=regularize_depthwise,
+ num_units=1,
+ stride=1),
+ ]
+ return xception.xception(inputs,
+ blocks=blocks,
+ num_classes=num_classes,
+ is_training=is_training,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ reuse=reuse,
+ scope=scope)
+
+ def testClassificationEndPoints(self):
+ global_pool = True
+ num_classes = 3
+ inputs = create_test_input(2, 32, 32, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ logits, end_points = self._xception_small(
+ inputs,
+ num_classes=num_classes,
+ global_pool=global_pool,
+ scope='xception')
+ self.assertTrue(
+ logits.op.name.startswith('xception/logits'))
+ self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+ self.assertTrue('predictions' in end_points)
+ self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+ [2, 1, 1, num_classes])
+ self.assertTrue('global_pool' in end_points)
+ self.assertListEqual(end_points['global_pool'].get_shape().as_list(),
+ [2, 1, 1, 16])
+
+ def testEndpointNames(self):
+ global_pool = True
+ num_classes = 3
+ inputs = create_test_input(2, 32, 32, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ _, end_points = self._xception_small(
+ inputs,
+ num_classes=num_classes,
+ global_pool=global_pool,
+ scope='xception')
+ expected = [
+ 'xception/entry_flow/conv1_1',
+ 'xception/entry_flow/conv1_2',
+ 'xception/entry_flow/block1/unit_1/xception_module/separable_conv1',
+ 'xception/entry_flow/block1/unit_1/xception_module/separable_conv2',
+ 'xception/entry_flow/block1/unit_1/xception_module/separable_conv3',
+ 'xception/entry_flow/block1/unit_1/xception_module/shortcut',
+ 'xception/entry_flow/block1/unit_1/xception_module',
+ 'xception/entry_flow/block1',
+ 'xception/entry_flow/block2/unit_1/xception_module/separable_conv1',
+ 'xception/entry_flow/block2/unit_1/xception_module/separable_conv2',
+ 'xception/entry_flow/block2/unit_1/xception_module/separable_conv3',
+ 'xception/entry_flow/block2/unit_1/xception_module/shortcut',
+ 'xception/entry_flow/block2/unit_1/xception_module',
+ 'xception/entry_flow/block2',
+ 'xception/entry_flow/block3/unit_1/xception_module/separable_conv1',
+ 'xception/entry_flow/block3/unit_1/xception_module/separable_conv2',
+ 'xception/entry_flow/block3/unit_1/xception_module/separable_conv3',
+ 'xception/entry_flow/block3/unit_1/xception_module/shortcut',
+ 'xception/entry_flow/block3/unit_1/xception_module',
+ 'xception/entry_flow/block3',
+ 'xception/entry_flow/block4/unit_1/xception_module/separable_conv1',
+ 'xception/entry_flow/block4/unit_1/xception_module/separable_conv2',
+ 'xception/entry_flow/block4/unit_1/xception_module/separable_conv3',
+ 'xception/entry_flow/block4/unit_1/xception_module/shortcut',
+ 'xception/entry_flow/block4/unit_1/xception_module',
+ 'xception/entry_flow/block4',
+ 'xception/middle_flow/block1/unit_1/xception_module/separable_conv1',
+ 'xception/middle_flow/block1/unit_1/xception_module/separable_conv2',
+ 'xception/middle_flow/block1/unit_1/xception_module/separable_conv3',
+ 'xception/middle_flow/block1/unit_1/xception_module',
+ 'xception/middle_flow/block1/unit_2/xception_module/separable_conv1',
+ 'xception/middle_flow/block1/unit_2/xception_module/separable_conv2',
+ 'xception/middle_flow/block1/unit_2/xception_module/separable_conv3',
+ 'xception/middle_flow/block1/unit_2/xception_module',
+ 'xception/middle_flow/block1',
+ 'xception/exit_flow/block1/unit_1/xception_module/separable_conv1',
+ 'xception/exit_flow/block1/unit_1/xception_module/separable_conv2',
+ 'xception/exit_flow/block1/unit_1/xception_module/separable_conv3',
+ 'xception/exit_flow/block1/unit_1/xception_module/shortcut',
+ 'xception/exit_flow/block1/unit_1/xception_module',
+ 'xception/exit_flow/block1',
+ 'xception/exit_flow/block2/unit_1/xception_module/separable_conv1',
+ 'xception/exit_flow/block2/unit_1/xception_module/separable_conv2',
+ 'xception/exit_flow/block2/unit_1/xception_module/separable_conv3',
+ 'xception/exit_flow/block2/unit_1/xception_module',
+ 'xception/exit_flow/block2',
+ 'global_pool',
+ 'xception/logits',
+ 'predictions',
+ ]
+ self.assertItemsEqual(list(end_points.keys()), expected)
+
+ def testClassificationShapes(self):
+ global_pool = True
+ num_classes = 3
+ inputs = create_test_input(2, 64, 64, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ _, end_points = self._xception_small(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='xception')
+ endpoint_to_shape = {
+ 'xception/entry_flow/conv1_1': [2, 32, 32, 32],
+ 'xception/entry_flow/block1': [2, 16, 16, 1],
+ 'xception/entry_flow/block2': [2, 8, 8, 2],
+ 'xception/entry_flow/block4': [2, 4, 4, 4],
+ 'xception/middle_flow/block1': [2, 4, 4, 4],
+ 'xception/exit_flow/block1': [2, 2, 2, 8],
+ 'xception/exit_flow/block2': [2, 2, 2, 16]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testFullyConvolutionalEndpointShapes(self):
+ global_pool = False
+ num_classes = 3
+ inputs = create_test_input(2, 65, 65, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ _, end_points = self._xception_small(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='xception')
+ endpoint_to_shape = {
+ 'xception/entry_flow/conv1_1': [2, 33, 33, 32],
+ 'xception/entry_flow/block1': [2, 17, 17, 1],
+ 'xception/entry_flow/block2': [2, 9, 9, 2],
+ 'xception/entry_flow/block4': [2, 5, 5, 4],
+ 'xception/middle_flow/block1': [2, 5, 5, 4],
+ 'xception/exit_flow/block1': [2, 3, 3, 8],
+ 'xception/exit_flow/block2': [2, 3, 3, 16]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testAtrousFullyConvolutionalEndpointShapes(self):
+ global_pool = False
+ num_classes = 3
+ output_stride = 8
+ inputs = create_test_input(2, 65, 65, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ _, end_points = self._xception_small(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ scope='xception')
+ endpoint_to_shape = {
+ 'xception/entry_flow/block1': [2, 17, 17, 1],
+ 'xception/entry_flow/block2': [2, 9, 9, 2],
+ 'xception/entry_flow/block4': [2, 9, 9, 4],
+ 'xception/middle_flow/block1': [2, 9, 9, 4],
+ 'xception/exit_flow/block1': [2, 9, 9, 8],
+ 'xception/exit_flow/block2': [2, 9, 9, 16]}
+ for endpoint, shape in six.iteritems(endpoint_to_shape):
+ self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+ def testAtrousFullyConvolutionalValues(self):
+ """Verify dense feature extraction with atrous convolution."""
+ nominal_stride = 32
+ for output_stride in [4, 8, 16, 32, None]:
+ with slim.arg_scope(xception.xception_arg_scope()):
+ with tf.Graph().as_default():
+ with self.test_session() as sess:
+ tf.set_random_seed(0)
+ inputs = create_test_input(2, 96, 97, 3)
+ # Dense feature extraction followed by subsampling.
+ output, _ = self._xception_small(
+ inputs,
+ None,
+ is_training=False,
+ global_pool=False,
+ output_stride=output_stride)
+ if output_stride is None:
+ factor = 1
+ else:
+ factor = nominal_stride // output_stride
+ output = resnet_utils.subsample(output, factor)
+ # Make the two networks use the same weights.
+ tf.get_variable_scope().reuse_variables()
+ # Feature extraction at the nominal network rate.
+ expected, _ = self._xception_small(
+ inputs,
+ None,
+ is_training=False,
+ global_pool=False)
+ sess.run(tf.global_variables_initializer())
+ self.assertAllClose(output.eval(), expected.eval(),
+ atol=1e-5, rtol=1e-5)
+
+ def testUnknownBatchSize(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(None, height, width, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ logits, _ = self._xception_small(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ scope='xception')
+ self.assertTrue(logits.op.name.startswith('xception/logits'))
+ self.assertListEqual(logits.get_shape().as_list(),
+ [None, 1, 1, num_classes])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(logits, {inputs: images.eval()})
+ self.assertEquals(output.shape, (batch, 1, 1, num_classes))
+
+ def testFullyConvolutionalUnknownHeightWidth(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = False
+ inputs = create_test_input(batch, None, None, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ output, _ = self._xception_small(
+ inputs,
+ None,
+ global_pool=global_pool)
+ self.assertListEqual(output.get_shape().as_list(),
+ [batch, None, None, 16])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(output, {inputs: images.eval()})
+ self.assertEquals(output.shape, (batch, 3, 3, 16))
+
+ def testAtrousFullyConvolutionalUnknownHeightWidth(self):
+ batch = 2
+ height, width = 65, 65
+ global_pool = False
+ output_stride = 8
+ inputs = create_test_input(batch, None, None, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ output, _ = self._xception_small(
+ inputs,
+ None,
+ global_pool=global_pool,
+ output_stride=output_stride)
+ self.assertListEqual(output.get_shape().as_list(),
+ [batch, None, None, 16])
+ images = create_test_input(batch, height, width, 3)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(output, {inputs: images.eval()})
+ self.assertEquals(output.shape, (batch, 9, 9, 16))
+
+ def testEndpointsReuse(self):
+ inputs = create_test_input(2, 32, 32, 3)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ _, end_points0 = xception.xception_65(
+ inputs,
+ num_classes=10,
+ reuse=False)
+ with slim.arg_scope(xception.xception_arg_scope()):
+ _, end_points1 = xception.xception_65(
+ inputs,
+ num_classes=10,
+ reuse=True)
+ self.assertItemsEqual(list(end_points0.keys()), list(end_points1.keys()))
+
+ def testUseBoundedAcitvation(self):
+ global_pool = False
+ num_classes = 3
+ output_stride = 16
+ for use_bounded_activation in (True, False):
+ tf.reset_default_graph()
+ inputs = create_test_input(2, 65, 65, 3)
+ with slim.arg_scope(xception.xception_arg_scope(
+ use_bounded_activation=use_bounded_activation)):
+ _, _ = self._xception_small(
+ inputs,
+ num_classes,
+ global_pool=global_pool,
+ output_stride=output_stride,
+ scope='xception')
+ for node in tf.get_default_graph().as_graph_def().node:
+ if node.op.startswith('Relu'):
+ self.assertEqual(node.op == 'Relu6', use_bounded_activation)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/datasets/__init__.py b/models/research/deeplab/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/deeplab/datasets/build_ade20k_data.py b/models/research/deeplab/datasets/build_ade20k_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc04ed0db04c83af6deaad8b59087624e8bd40e8
--- /dev/null
+++ b/models/research/deeplab/datasets/build_ade20k_data.py
@@ -0,0 +1,123 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Converts ADE20K data to TFRecord file format with Example protos."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import math
+import os
+import random
+import sys
+import build_data
+from six.moves import range
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string(
+ 'train_image_folder',
+ './ADE20K/ADEChallengeData2016/images/training',
+ 'Folder containing trainng images')
+tf.app.flags.DEFINE_string(
+ 'train_image_label_folder',
+ './ADE20K/ADEChallengeData2016/annotations/training',
+ 'Folder containing annotations for trainng images')
+
+tf.app.flags.DEFINE_string(
+ 'val_image_folder',
+ './ADE20K/ADEChallengeData2016/images/validation',
+ 'Folder containing validation images')
+
+tf.app.flags.DEFINE_string(
+ 'val_image_label_folder',
+ './ADE20K/ADEChallengeData2016/annotations/validation',
+ 'Folder containing annotations for validation')
+
+tf.app.flags.DEFINE_string(
+ 'output_dir', './ADE20K/tfrecord',
+ 'Path to save converted tfrecord of Tensorflow example')
+
+_NUM_SHARDS = 4
+
+
+def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir):
+ """Converts the ADE20k dataset into into tfrecord format.
+
+ Args:
+ dataset_split: Dataset split (e.g., train, val).
+ dataset_dir: Dir in which the dataset locates.
+ dataset_label_dir: Dir in which the annotations locates.
+
+ Raises:
+ RuntimeError: If loaded image and label have different shape.
+ """
+
+ img_names = tf.gfile.Glob(os.path.join(dataset_dir, '*.jpg'))
+ random.shuffle(img_names)
+ seg_names = []
+ for f in img_names:
+ # get the filename without the extension
+ basename = os.path.basename(f).split('.')[0]
+ # cover its corresponding *_seg.png
+ seg = os.path.join(dataset_label_dir, basename+'.png')
+ seg_names.append(seg)
+
+ num_images = len(img_names)
+ num_per_shard = int(math.ceil(num_images / _NUM_SHARDS))
+
+ image_reader = build_data.ImageReader('jpeg', channels=3)
+ label_reader = build_data.ImageReader('png', channels=1)
+
+ for shard_id in range(_NUM_SHARDS):
+ output_filename = os.path.join(
+ FLAGS.output_dir,
+ '%s-%05d-of-%05d.tfrecord' % (dataset_split, shard_id, _NUM_SHARDS))
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+ start_idx = shard_id * num_per_shard
+ end_idx = min((shard_id + 1) * num_per_shard, num_images)
+ for i in range(start_idx, end_idx):
+ sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
+ i + 1, num_images, shard_id))
+ sys.stdout.flush()
+ # Read the image.
+ image_filename = img_names[i]
+ image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
+ height, width = image_reader.read_image_dims(image_data)
+ # Read the semantic segmentation annotation.
+ seg_filename = seg_names[i]
+ seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
+ seg_height, seg_width = label_reader.read_image_dims(seg_data)
+ if height != seg_height or width != seg_width:
+ raise RuntimeError('Shape mismatched between image and label.')
+ # Convert to tf example.
+ example = build_data.image_seg_to_tfexample(
+ image_data, img_names[i], height, width, seg_data)
+ tfrecord_writer.write(example.SerializeToString())
+ sys.stdout.write('\n')
+ sys.stdout.flush()
+
+
+def main(unused_argv):
+ tf.gfile.MakeDirs(FLAGS.output_dir)
+ _convert_dataset(
+ 'train', FLAGS.train_image_folder, FLAGS.train_image_label_folder)
+ _convert_dataset('val', FLAGS.val_image_folder, FLAGS.val_image_label_folder)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/deeplab/datasets/build_cityscapes_data.py b/models/research/deeplab/datasets/build_cityscapes_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce81baef20a460abaa634d3f1dcb6760a0858dec
--- /dev/null
+++ b/models/research/deeplab/datasets/build_cityscapes_data.py
@@ -0,0 +1,188 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Converts Cityscapes data to TFRecord file format with Example protos.
+
+The Cityscapes dataset is expected to have the following directory structure:
+
+ + cityscapes
+ - build_cityscapes_data.py (current working directiory).
+ - build_data.py
+ + cityscapesscripts
+ + annotation
+ + evaluation
+ + helpers
+ + preparation
+ + viewer
+ + gtFine
+ + train
+ + val
+ + test
+ + leftImg8bit
+ + train
+ + val
+ + test
+ + tfrecord
+
+This script converts data into sharded data files and save at tfrecord folder.
+
+Note that before running this script, the users should (1) register the
+Cityscapes dataset website at https://www.cityscapes-dataset.com to
+download the dataset, and (2) run the script provided by Cityscapes
+`preparation/createTrainIdLabelImgs.py` to generate the training groundtruth.
+
+Also note that the tensorflow model will be trained with `TrainId' instead
+of `EvalId' used on the evaluation server. Thus, the users need to convert
+the predicted labels to `EvalId` for evaluation on the server. See the
+vis.py for more details.
+
+The Example proto contains the following fields:
+
+ image/encoded: encoded image content.
+ image/filename: image filename.
+ image/format: image file format.
+ image/height: image height.
+ image/width: image width.
+ image/channels: image channels.
+ image/segmentation/class/encoded: encoded semantic segmentation content.
+ image/segmentation/class/format: semantic segmentation file format.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import glob
+import math
+import os.path
+import re
+import sys
+import build_data
+from six.moves import range
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('cityscapes_root',
+ './cityscapes',
+ 'Cityscapes dataset root folder.')
+
+tf.app.flags.DEFINE_string(
+ 'output_dir',
+ './tfrecord',
+ 'Path to save converted SSTable of TensorFlow examples.')
+
+
+_NUM_SHARDS = 10
+
+# A map from data type to folder name that saves the data.
+_FOLDERS_MAP = {
+ 'image': 'leftImg8bit',
+ 'label': 'gtFine',
+}
+
+# A map from data type to filename postfix.
+_POSTFIX_MAP = {
+ 'image': '_leftImg8bit',
+ 'label': '_gtFine_labelTrainIds',
+}
+
+# A map from data type to data format.
+_DATA_FORMAT_MAP = {
+ 'image': 'png',
+ 'label': 'png',
+}
+
+# Image file pattern.
+_IMAGE_FILENAME_RE = re.compile('(.+)' + _POSTFIX_MAP['image'])
+
+
+def _get_files(data, dataset_split):
+ """Gets files for the specified data type and dataset split.
+
+ Args:
+ data: String, desired data ('image' or 'label').
+ dataset_split: String, dataset split ('train', 'val', 'test')
+
+ Returns:
+ A list of sorted file names or None when getting label for
+ test set.
+ """
+ if data == 'label' and dataset_split == 'test':
+ return None
+ pattern = '*%s.%s' % (_POSTFIX_MAP[data], _DATA_FORMAT_MAP[data])
+ search_files = os.path.join(
+ FLAGS.cityscapes_root, _FOLDERS_MAP[data], dataset_split, '*', pattern)
+ filenames = glob.glob(search_files)
+ return sorted(filenames)
+
+
+def _convert_dataset(dataset_split):
+ """Converts the specified dataset split to TFRecord format.
+
+ Args:
+ dataset_split: The dataset split (e.g., train, val).
+
+ Raises:
+ RuntimeError: If loaded image and label have different shape, or if the
+ image file with specified postfix could not be found.
+ """
+ image_files = _get_files('image', dataset_split)
+ label_files = _get_files('label', dataset_split)
+
+ num_images = len(image_files)
+ num_per_shard = int(math.ceil(num_images / _NUM_SHARDS))
+
+ image_reader = build_data.ImageReader('png', channels=3)
+ label_reader = build_data.ImageReader('png', channels=1)
+
+ for shard_id in range(_NUM_SHARDS):
+ shard_filename = '%s-%05d-of-%05d.tfrecord' % (
+ dataset_split, shard_id, _NUM_SHARDS)
+ output_filename = os.path.join(FLAGS.output_dir, shard_filename)
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+ start_idx = shard_id * num_per_shard
+ end_idx = min((shard_id + 1) * num_per_shard, num_images)
+ for i in range(start_idx, end_idx):
+ sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
+ i + 1, num_images, shard_id))
+ sys.stdout.flush()
+ # Read the image.
+ image_data = tf.gfile.FastGFile(image_files[i], 'rb').read()
+ height, width = image_reader.read_image_dims(image_data)
+ # Read the semantic segmentation annotation.
+ seg_data = tf.gfile.FastGFile(label_files[i], 'rb').read()
+ seg_height, seg_width = label_reader.read_image_dims(seg_data)
+ if height != seg_height or width != seg_width:
+ raise RuntimeError('Shape mismatched between image and label.')
+ # Convert to tf example.
+ re_match = _IMAGE_FILENAME_RE.search(image_files[i])
+ if re_match is None:
+ raise RuntimeError('Invalid image filename: ' + image_files[i])
+ filename = os.path.basename(re_match.group(1))
+ example = build_data.image_seg_to_tfexample(
+ image_data, filename, height, width, seg_data)
+ tfrecord_writer.write(example.SerializeToString())
+ sys.stdout.write('\n')
+ sys.stdout.flush()
+
+
+def main(unused_argv):
+ # Only support converting 'train' and 'val' sets for now.
+ for dataset_split in ['train', 'val']:
+ _convert_dataset(dataset_split)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/deeplab/datasets/build_data.py b/models/research/deeplab/datasets/build_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..45628674dbf3653ca0ca20014a968794bb8cd861
--- /dev/null
+++ b/models/research/deeplab/datasets/build_data.py
@@ -0,0 +1,161 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Contains common utility functions and classes for building dataset.
+
+This script contains utility functions and classes to converts dataset to
+TFRecord file format with Example protos.
+
+The Example proto contains the following fields:
+
+ image/encoded: encoded image content.
+ image/filename: image filename.
+ image/format: image file format.
+ image/height: image height.
+ image/width: image width.
+ image/channels: image channels.
+ image/segmentation/class/encoded: encoded semantic segmentation content.
+ image/segmentation/class/format: semantic segmentation file format.
+"""
+import collections
+import six
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_enum('image_format', 'png', ['jpg', 'jpeg', 'png'],
+ 'Image format.')
+
+tf.app.flags.DEFINE_enum('label_format', 'png', ['png'],
+ 'Segmentation label format.')
+
+# A map from image format to expected data format.
+_IMAGE_FORMAT_MAP = {
+ 'jpg': 'jpeg',
+ 'jpeg': 'jpeg',
+ 'png': 'png',
+}
+
+
+class ImageReader(object):
+ """Helper class that provides TensorFlow image coding utilities."""
+
+ def __init__(self, image_format='jpeg', channels=3):
+ """Class constructor.
+
+ Args:
+ image_format: Image format. Only 'jpeg', 'jpg', or 'png' are supported.
+ channels: Image channels.
+ """
+ with tf.Graph().as_default():
+ self._decode_data = tf.placeholder(dtype=tf.string)
+ self._image_format = image_format
+ self._session = tf.Session()
+ if self._image_format in ('jpeg', 'jpg'):
+ self._decode = tf.image.decode_jpeg(self._decode_data,
+ channels=channels)
+ elif self._image_format == 'png':
+ self._decode = tf.image.decode_png(self._decode_data,
+ channels=channels)
+
+ def read_image_dims(self, image_data):
+ """Reads the image dimensions.
+
+ Args:
+ image_data: string of image data.
+
+ Returns:
+ image_height and image_width.
+ """
+ image = self.decode_image(image_data)
+ return image.shape[:2]
+
+ def decode_image(self, image_data):
+ """Decodes the image data string.
+
+ Args:
+ image_data: string of image data.
+
+ Returns:
+ Decoded image data.
+
+ Raises:
+ ValueError: Value of image channels not supported.
+ """
+ image = self._session.run(self._decode,
+ feed_dict={self._decode_data: image_data})
+ if len(image.shape) != 3 or image.shape[2] not in (1, 3):
+ raise ValueError('The image channels not supported.')
+
+ return image
+
+
+def _int64_list_feature(values):
+ """Returns a TF-Feature of int64_list.
+
+ Args:
+ values: A scalar or list of values.
+
+ Returns:
+ A TF-Feature.
+ """
+ if not isinstance(values, collections.Iterable):
+ values = [values]
+
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
+
+
+def _bytes_list_feature(values):
+ """Returns a TF-Feature of bytes.
+
+ Args:
+ values: A string.
+
+ Returns:
+ A TF-Feature.
+ """
+ def norm2bytes(value):
+ return value.encode() if isinstance(value, str) and six.PY3 else value
+
+ return tf.train.Feature(
+ bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
+
+
+def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
+ """Converts one image/segmentation pair to tf example.
+
+ Args:
+ image_data: string of image data.
+ filename: image filename.
+ height: image height.
+ width: image width.
+ seg_data: string of semantic segmentation data.
+
+ Returns:
+ tf example of one image/segmentation pair.
+ """
+ return tf.train.Example(features=tf.train.Features(feature={
+ 'image/encoded': _bytes_list_feature(image_data),
+ 'image/filename': _bytes_list_feature(filename),
+ 'image/format': _bytes_list_feature(
+ _IMAGE_FORMAT_MAP[FLAGS.image_format]),
+ 'image/height': _int64_list_feature(height),
+ 'image/width': _int64_list_feature(width),
+ 'image/channels': _int64_list_feature(3),
+ 'image/segmentation/class/encoded': (
+ _bytes_list_feature(seg_data)),
+ 'image/segmentation/class/format': _bytes_list_feature(
+ FLAGS.label_format),
+ }))
diff --git a/models/research/deeplab/datasets/build_voc2012_data.py b/models/research/deeplab/datasets/build_voc2012_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0bdecb6a0f954d90164ac64b55966d0fe754557
--- /dev/null
+++ b/models/research/deeplab/datasets/build_voc2012_data.py
@@ -0,0 +1,146 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Converts PASCAL VOC 2012 data to TFRecord file format with Example protos.
+
+PASCAL VOC 2012 dataset is expected to have the following directory structure:
+
+ + pascal_voc_seg
+ - build_data.py
+ - build_voc2012_data.py (current working directory).
+ + VOCdevkit
+ + VOC2012
+ + JPEGImages
+ + SegmentationClass
+ + ImageSets
+ + Segmentation
+ + tfrecord
+
+Image folder:
+ ./VOCdevkit/VOC2012/JPEGImages
+
+Semantic segmentation annotations:
+ ./VOCdevkit/VOC2012/SegmentationClass
+
+list folder:
+ ./VOCdevkit/VOC2012/ImageSets/Segmentation
+
+This script converts data into sharded data files and save at tfrecord folder.
+
+The Example proto contains the following fields:
+
+ image/encoded: encoded image content.
+ image/filename: image filename.
+ image/format: image file format.
+ image/height: image height.
+ image/width: image width.
+ image/channels: image channels.
+ image/segmentation/class/encoded: encoded semantic segmentation content.
+ image/segmentation/class/format: semantic segmentation file format.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import math
+import os.path
+import sys
+import build_data
+from six.moves import range
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('image_folder',
+ './VOCdevkit/VOC2012/JPEGImages',
+ 'Folder containing images.')
+
+tf.app.flags.DEFINE_string(
+ 'semantic_segmentation_folder',
+ './VOCdevkit/VOC2012/SegmentationClassRaw',
+ 'Folder containing semantic segmentation annotations.')
+
+tf.app.flags.DEFINE_string(
+ 'list_folder',
+ './VOCdevkit/VOC2012/ImageSets/Segmentation',
+ 'Folder containing lists for training and validation')
+
+tf.app.flags.DEFINE_string(
+ 'output_dir',
+ './tfrecord',
+ 'Path to save converted SSTable of TensorFlow examples.')
+
+
+_NUM_SHARDS = 4
+
+
+def _convert_dataset(dataset_split):
+ """Converts the specified dataset split to TFRecord format.
+
+ Args:
+ dataset_split: The dataset split (e.g., train, test).
+
+ Raises:
+ RuntimeError: If loaded image and label have different shape.
+ """
+ dataset = os.path.basename(dataset_split)[:-4]
+ sys.stdout.write('Processing ' + dataset)
+ filenames = [x.strip('\n') for x in open(dataset_split, 'r')]
+ num_images = len(filenames)
+ num_per_shard = int(math.ceil(num_images / _NUM_SHARDS))
+
+ image_reader = build_data.ImageReader('jpeg', channels=3)
+ label_reader = build_data.ImageReader('png', channels=1)
+
+ for shard_id in range(_NUM_SHARDS):
+ output_filename = os.path.join(
+ FLAGS.output_dir,
+ '%s-%05d-of-%05d.tfrecord' % (dataset, shard_id, _NUM_SHARDS))
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+ start_idx = shard_id * num_per_shard
+ end_idx = min((shard_id + 1) * num_per_shard, num_images)
+ for i in range(start_idx, end_idx):
+ sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
+ i + 1, len(filenames), shard_id))
+ sys.stdout.flush()
+ # Read the image.
+ image_filename = os.path.join(
+ FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format)
+ image_data = tf.gfile.GFile(image_filename, 'rb').read()
+ height, width = image_reader.read_image_dims(image_data)
+ # Read the semantic segmentation annotation.
+ seg_filename = os.path.join(
+ FLAGS.semantic_segmentation_folder,
+ filenames[i] + '.' + FLAGS.label_format)
+ seg_data = tf.gfile.GFile(seg_filename, 'rb').read()
+ seg_height, seg_width = label_reader.read_image_dims(seg_data)
+ if height != seg_height or width != seg_width:
+ raise RuntimeError('Shape mismatched between image and label.')
+ # Convert to tf example.
+ example = build_data.image_seg_to_tfexample(
+ image_data, filenames[i], height, width, seg_data)
+ tfrecord_writer.write(example.SerializeToString())
+ sys.stdout.write('\n')
+ sys.stdout.flush()
+
+
+def main(unused_argv):
+ dataset_splits = tf.gfile.Glob(os.path.join(FLAGS.list_folder, '*.txt'))
+ for dataset_split in dataset_splits:
+ _convert_dataset(dataset_split)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/deeplab/datasets/convert_cityscapes.sh b/models/research/deeplab/datasets/convert_cityscapes.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a95b5d66aad79ae7cbd6ad2d3ee60550ab7f6239
--- /dev/null
+++ b/models/research/deeplab/datasets/convert_cityscapes.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Script to preprocess the Cityscapes dataset. Note (1) the users should
+# register the Cityscapes dataset website at
+# https://www.cityscapes-dataset.com/downloads/ to download the dataset,
+# and (2) the users should download the utility scripts provided by
+# Cityscapes at https://github.com/mcordts/cityscapesScripts.
+#
+# Usage:
+# bash ./convert_cityscapes.sh
+#
+# The folder structure is assumed to be:
+# + datasets
+# - build_cityscapes_data.py
+# - convert_cityscapes.sh
+# + cityscapes
+# + cityscapesscripts (downloaded scripts)
+# + gtFine
+# + leftImg8bit
+#
+
+# Exit immediately if a command exits with a non-zero status.
+set -e
+
+CURRENT_DIR=$(pwd)
+WORK_DIR="."
+
+# Root path for Cityscapes dataset.
+CITYSCAPES_ROOT="${WORK_DIR}/cityscapes"
+
+# Create training labels.
+python "${CITYSCAPES_ROOT}/cityscapesscripts/preparation/createTrainIdLabelImgs.py"
+
+# Build TFRecords of the dataset.
+# First, create output directory for storing TFRecords.
+OUTPUT_DIR="${CITYSCAPES_ROOT}/tfrecord"
+mkdir -p "${OUTPUT_DIR}"
+
+BUILD_SCRIPT="${CURRENT_DIR}/build_cityscapes_data.py"
+
+echo "Converting Cityscapes dataset..."
+python "${BUILD_SCRIPT}" \
+ --cityscapes_root="${CITYSCAPES_ROOT}" \
+ --output_dir="${OUTPUT_DIR}" \
diff --git a/models/research/deeplab/datasets/data_generator.py b/models/research/deeplab/datasets/data_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..d84e66f9c48181d579a027daa08206491d995b65
--- /dev/null
+++ b/models/research/deeplab/datasets/data_generator.py
@@ -0,0 +1,350 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wrapper for providing semantic segmentaion data.
+
+The SegmentationDataset class provides both images and annotations (semantic
+segmentation and/or instance segmentation) for TensorFlow. Currently, we
+support the following datasets:
+
+1. PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/).
+
+PASCAL VOC 2012 semantic segmentation dataset annotates 20 foreground objects
+(e.g., bike, person, and so on) and leaves all the other semantic classes as
+one background class. The dataset contains 1464, 1449, and 1456 annotated
+images for the training, validation and test respectively.
+
+2. Cityscapes dataset (https://www.cityscapes-dataset.com)
+
+The Cityscapes dataset contains 19 semantic labels (such as road, person, car,
+and so on) for urban street scenes.
+
+3. ADE20K dataset (http://groups.csail.mit.edu/vision/datasets/ADE20K)
+
+The ADE20K dataset contains 150 semantic labels both urban street scenes and
+indoor scenes.
+
+References:
+ M. Everingham, S. M. A. Eslami, L. V. Gool, C. K. I. Williams, J. Winn,
+ and A. Zisserman, The pascal visual object classes challenge a retrospective.
+ IJCV, 2014.
+
+ M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson,
+ U. Franke, S. Roth, and B. Schiele, "The cityscapes dataset for semantic urban
+ scene understanding," In Proc. of CVPR, 2016.
+
+ B. Zhou, H. Zhao, X. Puig, S. Fidler, A. Barriuso, A. Torralba, "Scene Parsing
+ through ADE20K dataset", In Proc. of CVPR, 2017.
+"""
+
+import collections
+import os
+import tensorflow as tf
+from deeplab import common
+from deeplab import input_preprocess
+
+# Named tuple to describe the dataset properties.
+DatasetDescriptor = collections.namedtuple(
+ 'DatasetDescriptor',
+ [
+ 'splits_to_sizes', # Splits of the dataset into training, val and test.
+ 'num_classes', # Number of semantic classes, including the
+ # background class (if exists). For example, there
+ # are 20 foreground classes + 1 background class in
+ # the PASCAL VOC 2012 dataset. Thus, we set
+ # num_classes=21.
+ 'ignore_label', # Ignore label value.
+ ])
+
+_CITYSCAPES_INFORMATION = DatasetDescriptor(
+ splits_to_sizes={'train_fine': 2975,
+ 'train_coarse': 22973,
+ 'trainval_fine': 3475,
+ 'trainval_coarse': 23473,
+ 'val_fine': 500,
+ 'test_fine': 1525},
+ num_classes=19,
+ ignore_label=255,
+)
+
+_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
+ splits_to_sizes={
+ 'train': 1464,
+ 'train_aug': 10582,
+ 'trainval': 2913,
+ 'val': 1449,
+ },
+ num_classes=21,
+ ignore_label=255,
+)
+
+_ADE20K_INFORMATION = DatasetDescriptor(
+ splits_to_sizes={
+ 'train': 20210, # num of samples in images/training
+ 'val': 2000, # num of samples in images/validation
+ },
+ num_classes=151,
+ ignore_label=0,
+)
+
+_DATASETS_INFORMATION = {
+ 'cityscapes': _CITYSCAPES_INFORMATION,
+ 'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
+ 'ade20k': _ADE20K_INFORMATION,
+}
+
+# Default file pattern of TFRecord of TensorFlow Example.
+_FILE_PATTERN = '%s-*'
+
+
+def get_cityscapes_dataset_name():
+ return 'cityscapes'
+
+
+class Dataset(object):
+ """Represents input dataset for deeplab model."""
+
+ def __init__(self,
+ dataset_name,
+ split_name,
+ dataset_dir,
+ batch_size,
+ crop_size,
+ min_resize_value=None,
+ max_resize_value=None,
+ resize_factor=None,
+ min_scale_factor=1.,
+ max_scale_factor=1.,
+ scale_factor_step_size=0,
+ model_variant=None,
+ num_readers=1,
+ is_training=False,
+ should_shuffle=False,
+ should_repeat=False):
+ """Initializes the dataset.
+
+ Args:
+ dataset_name: Dataset name.
+ split_name: A train/val Split name.
+ dataset_dir: The directory of the dataset sources.
+ batch_size: Batch size.
+ crop_size: The size used to crop the image and label.
+ min_resize_value: Desired size of the smaller image side.
+ max_resize_value: Maximum allowed size of the larger image side.
+ resize_factor: Resized dimensions are multiple of factor plus one.
+ min_scale_factor: Minimum scale factor value.
+ max_scale_factor: Maximum scale factor value.
+ scale_factor_step_size: The step size from min scale factor to max scale
+ factor. The input is randomly scaled based on the value of
+ (min_scale_factor, max_scale_factor, scale_factor_step_size).
+ model_variant: Model variant (string) for choosing how to mean-subtract
+ the images. See feature_extractor.network_map for supported model
+ variants.
+ num_readers: Number of readers for data provider.
+ is_training: Boolean, if dataset is for training or not.
+ should_shuffle: Boolean, if should shuffle the input data.
+ should_repeat: Boolean, if should repeat the input data.
+
+ Raises:
+ ValueError: Dataset name and split name are not supported.
+ """
+ if dataset_name not in _DATASETS_INFORMATION:
+ raise ValueError('The specified dataset is not supported yet.')
+ self.dataset_name = dataset_name
+
+ splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
+
+ if split_name not in splits_to_sizes:
+ raise ValueError('data split name %s not recognized' % split_name)
+
+ if model_variant is None:
+ tf.logging.warning('Please specify a model_variant. See '
+ 'feature_extractor.network_map for supported model '
+ 'variants.')
+
+ self.split_name = split_name
+ self.dataset_dir = dataset_dir
+ self.batch_size = batch_size
+ self.crop_size = crop_size
+ self.min_resize_value = min_resize_value
+ self.max_resize_value = max_resize_value
+ self.resize_factor = resize_factor
+ self.min_scale_factor = min_scale_factor
+ self.max_scale_factor = max_scale_factor
+ self.scale_factor_step_size = scale_factor_step_size
+ self.model_variant = model_variant
+ self.num_readers = num_readers
+ self.is_training = is_training
+ self.should_shuffle = should_shuffle
+ self.should_repeat = should_repeat
+
+ self.num_of_classes = _DATASETS_INFORMATION[self.dataset_name].num_classes
+ self.ignore_label = _DATASETS_INFORMATION[self.dataset_name].ignore_label
+
+ def _parse_function(self, example_proto):
+ """Function to parse the example proto.
+
+ Args:
+ example_proto: Proto in the format of tf.Example.
+
+ Returns:
+ A dictionary with parsed image, label, height, width and image name.
+
+ Raises:
+ ValueError: Label is of wrong shape.
+ """
+
+ # Currently only supports jpeg and png.
+ # Need to use this logic because the shape is not known for
+ # tf.image.decode_image and we rely on this info to
+ # extend label if necessary.
+ def _decode_image(content, channels):
+ return tf.cond(
+ tf.image.is_jpeg(content),
+ lambda: tf.image.decode_jpeg(content, channels),
+ lambda: tf.image.decode_png(content, channels))
+
+ features = {
+ 'image/encoded':
+ tf.FixedLenFeature((), tf.string, default_value=''),
+ 'image/filename':
+ tf.FixedLenFeature((), tf.string, default_value=''),
+ 'image/format':
+ tf.FixedLenFeature((), tf.string, default_value='jpeg'),
+ 'image/height':
+ tf.FixedLenFeature((), tf.int64, default_value=0),
+ 'image/width':
+ tf.FixedLenFeature((), tf.int64, default_value=0),
+ 'image/segmentation/class/encoded':
+ tf.FixedLenFeature((), tf.string, default_value=''),
+ 'image/segmentation/class/format':
+ tf.FixedLenFeature((), tf.string, default_value='png'),
+ }
+
+ parsed_features = tf.parse_single_example(example_proto, features)
+
+ image = _decode_image(parsed_features['image/encoded'], channels=3)
+
+ label = None
+ if self.split_name != common.TEST_SET:
+ label = _decode_image(
+ parsed_features['image/segmentation/class/encoded'], channels=1)
+
+ image_name = parsed_features['image/filename']
+ if image_name is None:
+ image_name = tf.constant('')
+
+ sample = {
+ common.IMAGE: image,
+ common.IMAGE_NAME: image_name,
+ common.HEIGHT: parsed_features['image/height'],
+ common.WIDTH: parsed_features['image/width'],
+ }
+
+ if label is not None:
+ if label.get_shape().ndims == 2:
+ label = tf.expand_dims(label, 2)
+ elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
+ pass
+ else:
+ raise ValueError('Input label shape must be [height, width], or '
+ '[height, width, 1].')
+
+ label.set_shape([None, None, 1])
+
+ sample[common.LABELS_CLASS] = label
+
+ return sample
+
+ def _preprocess_image(self, sample):
+ """Preprocesses the image and label.
+
+ Args:
+ sample: A sample containing image and label.
+
+ Returns:
+ sample: Sample with preprocessed image and label.
+
+ Raises:
+ ValueError: Ground truth label not provided during training.
+ """
+ image = sample[common.IMAGE]
+ label = sample[common.LABELS_CLASS]
+
+ original_image, image, label = input_preprocess.preprocess_image_and_label(
+ image=image,
+ label=label,
+ crop_height=self.crop_size[0],
+ crop_width=self.crop_size[1],
+ min_resize_value=self.min_resize_value,
+ max_resize_value=self.max_resize_value,
+ resize_factor=self.resize_factor,
+ min_scale_factor=self.min_scale_factor,
+ max_scale_factor=self.max_scale_factor,
+ scale_factor_step_size=self.scale_factor_step_size,
+ ignore_label=self.ignore_label,
+ is_training=self.is_training,
+ model_variant=self.model_variant)
+
+ sample[common.IMAGE] = image
+
+ if not self.is_training:
+ # Original image is only used during visualization.
+ sample[common.ORIGINAL_IMAGE] = original_image
+
+ if label is not None:
+ sample[common.LABEL] = label
+
+ # Remove common.LABEL_CLASS key in the sample since it is only used to
+ # derive label and not used in training and evaluation.
+ sample.pop(common.LABELS_CLASS, None)
+
+ return sample
+
+ def get_one_shot_iterator(self):
+ """Gets an iterator that iterates across the dataset once.
+
+ Returns:
+ An iterator of type tf.data.Iterator.
+ """
+
+ files = self._get_all_files()
+
+ dataset = (
+ tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
+ .map(self._parse_function, num_parallel_calls=self.num_readers)
+ .map(self._preprocess_image, num_parallel_calls=self.num_readers))
+
+ if self.should_shuffle:
+ dataset = dataset.shuffle(buffer_size=100)
+
+ if self.should_repeat:
+ dataset = dataset.repeat() # Repeat forever for training.
+ else:
+ dataset = dataset.repeat(1)
+
+ dataset = dataset.batch(self.batch_size).prefetch(self.batch_size)
+ return dataset.make_one_shot_iterator()
+
+ def _get_all_files(self):
+ """Gets all the files to read data from.
+
+ Returns:
+ A list of input files.
+ """
+ file_pattern = _FILE_PATTERN
+ file_pattern = os.path.join(self.dataset_dir,
+ file_pattern % self.split_name)
+ return tf.gfile.Glob(file_pattern)
diff --git a/models/research/deeplab/datasets/data_generator_test.py b/models/research/deeplab/datasets/data_generator_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4425d01da0c6f3bafaaff7c038498349a5c3f98
--- /dev/null
+++ b/models/research/deeplab/datasets/data_generator_test.py
@@ -0,0 +1,115 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for deeplab.datasets.data_generator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from six.moves import range
+import tensorflow as tf
+
+from deeplab import common
+from deeplab.datasets import data_generator
+
+ImageAttributes = collections.namedtuple(
+ 'ImageAttributes', ['image', 'label', 'height', 'width', 'image_name'])
+
+
+class DatasetTest(tf.test.TestCase):
+
+ # Note: training dataset cannot be tested since there is shuffle operation.
+ # When disabling the shuffle, training dataset is operated same as validation
+ # dataset. Therefore it is not tested again.
+ def testPascalVocSegTestData(self):
+ dataset = data_generator.Dataset(
+ dataset_name='pascal_voc_seg',
+ split_name='val',
+ dataset_dir=
+ 'deeplab/testing/pascal_voc_seg',
+ batch_size=1,
+ crop_size=[3, 3], # Use small size for testing.
+ min_resize_value=3,
+ max_resize_value=3,
+ resize_factor=None,
+ min_scale_factor=0.01,
+ max_scale_factor=2.0,
+ scale_factor_step_size=0.25,
+ is_training=False,
+ model_variant='mobilenet_v2')
+
+ self.assertAllEqual(dataset.num_of_classes, 21)
+ self.assertAllEqual(dataset.ignore_label, 255)
+
+ num_of_images = 3
+ with self.test_session() as sess:
+ iterator = dataset.get_one_shot_iterator()
+
+ for i in range(num_of_images):
+ batch = iterator.get_next()
+ batch, = sess.run([batch])
+ image_attributes = _get_attributes_of_image(i)
+ self.assertEqual(batch[common.HEIGHT][0], image_attributes.height)
+ self.assertEqual(batch[common.WIDTH][0], image_attributes.width)
+ self.assertEqual(batch[common.IMAGE_NAME][0],
+ image_attributes.image_name.encode())
+
+ # All data have been read.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError, ''):
+ sess.run([iterator.get_next()])
+
+
+def _get_attributes_of_image(index):
+ """Gets the attributes of the image.
+
+ Args:
+ index: Index of image in all images.
+
+ Returns:
+ Attributes of the image in the format of ImageAttributes.
+
+ Raises:
+ ValueError: If index is of wrong value.
+ """
+ if index == 0:
+ return ImageAttributes(
+ image=None,
+ label=None,
+ height=366,
+ width=500,
+ image_name='2007_000033')
+ elif index == 1:
+ return ImageAttributes(
+ image=None,
+ label=None,
+ height=335,
+ width=500,
+ image_name='2007_000042')
+ elif index == 2:
+ return ImageAttributes(
+ image=None,
+ label=None,
+ height=333,
+ width=500,
+ image_name='2007_000061')
+ else:
+ raise ValueError('Index can only be 0, 1 or 2.')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/datasets/download_and_convert_ade20k.sh b/models/research/deeplab/datasets/download_and_convert_ade20k.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3614ae42c16e4f727a725066be8948b666995241
--- /dev/null
+++ b/models/research/deeplab/datasets/download_and_convert_ade20k.sh
@@ -0,0 +1,80 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Script to download and preprocess the ADE20K dataset.
+#
+# Usage:
+# bash ./download_and_convert_ade20k.sh
+#
+# The folder structure is assumed to be:
+# + datasets
+# - build_data.py
+# - build_ade20k_data.py
+# - download_and_convert_ade20k.sh
+# + ADE20K
+# + tfrecord
+# + ADEChallengeData2016
+# + annotations
+# + training
+# + validation
+# + images
+# + training
+# + validation
+
+# Exit immediately if a command exits with a non-zero status.
+set -e
+
+CURRENT_DIR=$(pwd)
+WORK_DIR="./ADE20K"
+mkdir -p "${WORK_DIR}"
+cd "${WORK_DIR}"
+
+# Helper function to download and unpack ADE20K dataset.
+download_and_uncompress() {
+ local BASE_URL=${1}
+ local FILENAME=${2}
+
+ if [ ! -f "${FILENAME}" ]; then
+ echo "Downloading ${FILENAME} to ${WORK_DIR}"
+ wget -nd -c "${BASE_URL}/${FILENAME}"
+ fi
+ echo "Uncompressing ${FILENAME}"
+ unzip "${FILENAME}"
+}
+
+# Download the images.
+BASE_URL="http://data.csail.mit.edu/places/ADEchallenge"
+FILENAME="ADEChallengeData2016.zip"
+
+download_and_uncompress "${BASE_URL}" "${FILENAME}"
+
+cd "${CURRENT_DIR}"
+
+# Root path for ADE20K dataset.
+ADE20K_ROOT="${WORK_DIR}/ADEChallengeData2016"
+
+# Build TFRecords of the dataset.
+# First, create output directory for storing TFRecords.
+OUTPUT_DIR="${WORK_DIR}/tfrecord"
+mkdir -p "${OUTPUT_DIR}"
+
+echo "Converting ADE20K dataset..."
+python ./build_ade20k_data.py \
+ --train_image_folder="${ADE20K_ROOT}/images/training/" \
+ --train_image_label_folder="${ADE20K_ROOT}/annotations/training/" \
+ --val_image_folder="${ADE20K_ROOT}/images/validation/" \
+ --val_image_label_folder="${ADE20K_ROOT}/annotations/validation/" \
+ --output_dir="${OUTPUT_DIR}"
diff --git a/models/research/deeplab/datasets/download_and_convert_voc2012.sh b/models/research/deeplab/datasets/download_and_convert_voc2012.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c02235182d427dfb1d63154a8266ad37b0a1d53f
--- /dev/null
+++ b/models/research/deeplab/datasets/download_and_convert_voc2012.sh
@@ -0,0 +1,91 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Script to download and preprocess the PASCAL VOC 2012 dataset.
+#
+# Usage:
+# bash ./download_and_convert_voc2012.sh
+#
+# The folder structure is assumed to be:
+# + datasets
+# - build_data.py
+# - build_voc2012_data.py
+# - download_and_convert_voc2012.sh
+# - remove_gt_colormap.py
+# + pascal_voc_seg
+# + VOCdevkit
+# + VOC2012
+# + JPEGImages
+# + SegmentationClass
+#
+
+# Exit immediately if a command exits with a non-zero status.
+set -e
+
+CURRENT_DIR=$(pwd)
+WORK_DIR="./pascal_voc_seg"
+SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
+mkdir -p "${WORK_DIR}"
+cd "${WORK_DIR}"
+
+# Helper function to download and unpack VOC 2012 dataset.
+download_and_uncompress() {
+ local BASE_URL=${1}
+ local FILENAME=${2}
+
+ if [ ! -f "${FILENAME}" ]; then
+ echo "Downloading ${FILENAME} to ${WORK_DIR}"
+ wget -nd -c "${BASE_URL}/${FILENAME}"
+ fi
+ echo "Uncompressing ${FILENAME}"
+ tar -xf "${FILENAME}"
+}
+
+# Download the images.
+BASE_URL="http://host.robots.ox.ac.uk/pascal/VOC/voc2012/"
+FILENAME="VOCtrainval_11-May-2012.tar"
+
+download_and_uncompress "${BASE_URL}" "${FILENAME}"
+
+cd "${CURRENT_DIR}"
+
+# Root path for PASCAL VOC 2012 dataset.
+PASCAL_ROOT="${WORK_DIR}/VOCdevkit/VOC2012"
+
+# Remove the colormap in the ground truth annotations.
+SEG_FOLDER="${PASCAL_ROOT}/SegmentationClass"
+SEMANTIC_SEG_FOLDER="${PASCAL_ROOT}/SegmentationClassRaw"
+
+echo "Removing the color map in ground truth annotations..."
+python3 "${SCRIPT_DIR}/remove_gt_colormap.py" \
+ --original_gt_folder="${SEG_FOLDER}" \
+ --output_dir="${SEMANTIC_SEG_FOLDER}"
+
+# Build TFRecords of the dataset.
+# First, create output directory for storing TFRecords.
+OUTPUT_DIR="${WORK_DIR}/tfrecord"
+mkdir -p "${OUTPUT_DIR}"
+
+IMAGE_FOLDER="${PASCAL_ROOT}/JPEGImages"
+LIST_FOLDER="${PASCAL_ROOT}/ImageSets/Segmentation"
+
+echo "Converting PASCAL VOC 2012 dataset..."
+python3 "${SCRIPT_DIR}/build_voc2012_data.py" \
+ --image_folder="${IMAGE_FOLDER}" \
+ --semantic_segmentation_folder="${SEMANTIC_SEG_FOLDER}" \
+ --list_folder="${LIST_FOLDER}" \
+ --image_format="jpg" \
+ --output_dir="${OUTPUT_DIR}"
diff --git a/models/research/deeplab/datasets/remove_gt_colormap.py b/models/research/deeplab/datasets/remove_gt_colormap.py
new file mode 100644
index 0000000000000000000000000000000000000000..900570038ed0f1add9d670157494d4cab6bf5324
--- /dev/null
+++ b/models/research/deeplab/datasets/remove_gt_colormap.py
@@ -0,0 +1,83 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Removes the color map from segmentation annotations.
+
+Removes the color map from the ground truth segmentation annotations and save
+the results to output_dir.
+"""
+import glob
+import os.path
+import numpy as np
+
+from PIL import Image
+
+import tensorflow as tf
+
+FLAGS = tf.compat.v1.flags.FLAGS
+
+tf.compat.v1.flags.DEFINE_string('original_gt_folder',
+ './VOCdevkit/VOC2012/SegmentationClass',
+ 'Original ground truth annotations.')
+
+tf.compat.v1.flags.DEFINE_string('segmentation_format', 'png', 'Segmentation format.')
+
+tf.compat.v1.flags.DEFINE_string('output_dir',
+ './VOCdevkit/VOC2012/SegmentationClassRaw',
+ 'folder to save modified ground truth annotations.')
+
+
+def _remove_colormap(filename):
+ """Removes the color map from the annotation.
+
+ Args:
+ filename: Ground truth annotation filename.
+
+ Returns:
+ Annotation without color map.
+ """
+ return np.array(Image.open(filename))
+
+
+def _save_annotation(annotation, filename):
+ """Saves the annotation as png file.
+
+ Args:
+ annotation: Segmentation annotation.
+ filename: Output filename.
+ """
+ pil_image = Image.fromarray(annotation.astype(dtype=np.uint8))
+ with tf.io.gfile.GFile(filename, mode='w') as f:
+ pil_image.save(f, 'PNG')
+
+
+def main(unused_argv):
+ # Create the output directory if not exists.
+ if not tf.io.gfile.isdir(FLAGS.output_dir):
+ tf.io.gfile.makedirs(FLAGS.output_dir)
+
+ annotations = glob.glob(os.path.join(FLAGS.original_gt_folder,
+ '*.' + FLAGS.segmentation_format))
+ for annotation in annotations:
+ raw_annotation = _remove_colormap(annotation)
+ filename = os.path.basename(annotation)[:-4]
+ _save_annotation(raw_annotation,
+ os.path.join(
+ FLAGS.output_dir,
+ filename + '.' + FLAGS.segmentation_format))
+
+
+if __name__ == '__main__':
+ tf.compat.v1.app.run()
diff --git a/models/research/deeplab/deeplab_demo.ipynb b/models/research/deeplab/deeplab_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..81ccfde1b6484625ad1e0662d9a4cf12941d262c
--- /dev/null
+++ b/models/research/deeplab/deeplab_demo.ipynb
@@ -0,0 +1,369 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "KFPcBuVFw61h"
+ },
+ "source": [
+ "# Overview\n",
+ "\n",
+ "This colab demonstrates the steps to use the DeepLab model to perform semantic segmentation on a sample input image. Expected outputs are semantic labels overlayed on the sample image.\n",
+ "\n",
+ "### About DeepLab\n",
+ "The models used in this colab perform semantic segmentation. Semantic segmentation models focus on assigning semantic labels, such as sky, person, or car, to multiple objects and stuff in a single image."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "t3ozFsEEP-u_"
+ },
+ "source": [
+ "# Instructions\n",
+ "\u003ch3\u003e\u003ca href=\"https://cloud.google.com/tpu/\"\u003e\u003cimg valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"\u003e\u003c/a\u003e \u0026nbsp;\u0026nbsp;Use a free TPU device\u003c/h3\u003e\n",
+ "\n",
+ " 1. On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
+ " 1. Click Runtime again and select **Runtime \u003e Run All**. You can also run the cells manually with Shift-ENTER."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "7cRiapZ1P3wy"
+ },
+ "source": [
+ "## Import Libraries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "cellView": "code",
+ "colab": {},
+ "colab_type": "code",
+ "id": "kAbdmRmvq0Je"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from io import BytesIO\n",
+ "import tarfile\n",
+ "import tempfile\n",
+ "from six.moves import urllib\n",
+ "\n",
+ "from matplotlib import gridspec\n",
+ "from matplotlib import pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "\n",
+ "%tensorflow_version 1.x\n",
+ "import tensorflow as tf"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "p47cYGGOQE1W"
+ },
+ "source": [
+ "## Import helper methods\n",
+ "These methods help us perform the following tasks:\n",
+ "* Load the latest version of the pretrained DeepLab model\n",
+ "* Load the colormap from the PASCAL VOC dataset\n",
+ "* Adds colors to various labels, such as \"pink\" for people, \"green\" for bicycle and more\n",
+ "* Visualize an image, and add an overlay of colors on various regions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "cellView": "code",
+ "colab": {},
+ "colab_type": "code",
+ "id": "vN0kU6NJ1Ye5"
+ },
+ "outputs": [],
+ "source": [
+ "class DeepLabModel(object):\n",
+ " \"\"\"Class to load deeplab model and run inference.\"\"\"\n",
+ "\n",
+ " INPUT_TENSOR_NAME = 'ImageTensor:0'\n",
+ " OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'\n",
+ " INPUT_SIZE = 513\n",
+ " FROZEN_GRAPH_NAME = 'frozen_inference_graph'\n",
+ "\n",
+ " def __init__(self, tarball_path):\n",
+ " \"\"\"Creates and loads pretrained deeplab model.\"\"\"\n",
+ " self.graph = tf.Graph()\n",
+ "\n",
+ " graph_def = None\n",
+ " # Extract frozen graph from tar archive.\n",
+ " tar_file = tarfile.open(tarball_path)\n",
+ " for tar_info in tar_file.getmembers():\n",
+ " if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):\n",
+ " file_handle = tar_file.extractfile(tar_info)\n",
+ " graph_def = tf.GraphDef.FromString(file_handle.read())\n",
+ " break\n",
+ "\n",
+ " tar_file.close()\n",
+ "\n",
+ " if graph_def is None:\n",
+ " raise RuntimeError('Cannot find inference graph in tar archive.')\n",
+ "\n",
+ " with self.graph.as_default():\n",
+ " tf.import_graph_def(graph_def, name='')\n",
+ "\n",
+ " self.sess = tf.Session(graph=self.graph)\n",
+ "\n",
+ " def run(self, image):\n",
+ " \"\"\"Runs inference on a single image.\n",
+ "\n",
+ " Args:\n",
+ " image: A PIL.Image object, raw input image.\n",
+ "\n",
+ " Returns:\n",
+ " resized_image: RGB image resized from original input image.\n",
+ " seg_map: Segmentation map of `resized_image`.\n",
+ " \"\"\"\n",
+ " width, height = image.size\n",
+ " resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)\n",
+ " target_size = (int(resize_ratio * width), int(resize_ratio * height))\n",
+ " resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)\n",
+ " batch_seg_map = self.sess.run(\n",
+ " self.OUTPUT_TENSOR_NAME,\n",
+ " feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})\n",
+ " seg_map = batch_seg_map[0]\n",
+ " return resized_image, seg_map\n",
+ "\n",
+ "\n",
+ "def create_pascal_label_colormap():\n",
+ " \"\"\"Creates a label colormap used in PASCAL VOC segmentation benchmark.\n",
+ "\n",
+ " Returns:\n",
+ " A Colormap for visualizing segmentation results.\n",
+ " \"\"\"\n",
+ " colormap = np.zeros((256, 3), dtype=int)\n",
+ " ind = np.arange(256, dtype=int)\n",
+ "\n",
+ " for shift in reversed(range(8)):\n",
+ " for channel in range(3):\n",
+ " colormap[:, channel] |= ((ind \u003e\u003e channel) \u0026 1) \u003c\u003c shift\n",
+ " ind \u003e\u003e= 3\n",
+ "\n",
+ " return colormap\n",
+ "\n",
+ "\n",
+ "def label_to_color_image(label):\n",
+ " \"\"\"Adds color defined by the dataset colormap to the label.\n",
+ "\n",
+ " Args:\n",
+ " label: A 2D array with integer type, storing the segmentation label.\n",
+ "\n",
+ " Returns:\n",
+ " result: A 2D array with floating type. The element of the array\n",
+ " is the color indexed by the corresponding element in the input label\n",
+ " to the PASCAL color map.\n",
+ "\n",
+ " Raises:\n",
+ " ValueError: If label is not of rank 2 or its value is larger than color\n",
+ " map maximum entry.\n",
+ " \"\"\"\n",
+ " if label.ndim != 2:\n",
+ " raise ValueError('Expect 2-D input label')\n",
+ "\n",
+ " colormap = create_pascal_label_colormap()\n",
+ "\n",
+ " if np.max(label) \u003e= len(colormap):\n",
+ " raise ValueError('label value too large.')\n",
+ "\n",
+ " return colormap[label]\n",
+ "\n",
+ "\n",
+ "def vis_segmentation(image, seg_map):\n",
+ " \"\"\"Visualizes input image, segmentation map and overlay view.\"\"\"\n",
+ " plt.figure(figsize=(15, 5))\n",
+ " grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])\n",
+ "\n",
+ " plt.subplot(grid_spec[0])\n",
+ " plt.imshow(image)\n",
+ " plt.axis('off')\n",
+ " plt.title('input image')\n",
+ "\n",
+ " plt.subplot(grid_spec[1])\n",
+ " seg_image = label_to_color_image(seg_map).astype(np.uint8)\n",
+ " plt.imshow(seg_image)\n",
+ " plt.axis('off')\n",
+ " plt.title('segmentation map')\n",
+ "\n",
+ " plt.subplot(grid_spec[2])\n",
+ " plt.imshow(image)\n",
+ " plt.imshow(seg_image, alpha=0.7)\n",
+ " plt.axis('off')\n",
+ " plt.title('segmentation overlay')\n",
+ "\n",
+ " unique_labels = np.unique(seg_map)\n",
+ " ax = plt.subplot(grid_spec[3])\n",
+ " plt.imshow(\n",
+ " FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')\n",
+ " ax.yaxis.tick_right()\n",
+ " plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])\n",
+ " plt.xticks([], [])\n",
+ " ax.tick_params(width=0.0)\n",
+ " plt.grid('off')\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "LABEL_NAMES = np.asarray([\n",
+ " 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',\n",
+ " 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',\n",
+ " 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'\n",
+ "])\n",
+ "\n",
+ "FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)\n",
+ "FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "nGcZzNkASG9A"
+ },
+ "source": [
+ "## Select a pretrained model\n",
+ "We have trained the DeepLab model using various backbone networks. Select one from the MODEL_NAME list."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "c4oXKmnjw6i_"
+ },
+ "outputs": [],
+ "source": [
+ "MODEL_NAME = 'mobilenetv2_coco_voctrainaug' # @param ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']\n",
+ "\n",
+ "_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'\n",
+ "_MODEL_URLS = {\n",
+ " 'mobilenetv2_coco_voctrainaug':\n",
+ " 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',\n",
+ " 'mobilenetv2_coco_voctrainval':\n",
+ " 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',\n",
+ " 'xception_coco_voctrainaug':\n",
+ " 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',\n",
+ " 'xception_coco_voctrainval':\n",
+ " 'deeplabv3_pascal_trainval_2018_01_04.tar.gz',\n",
+ "}\n",
+ "_TARBALL_NAME = 'deeplab_model.tar.gz'\n",
+ "\n",
+ "model_dir = tempfile.mkdtemp()\n",
+ "tf.gfile.MakeDirs(model_dir)\n",
+ "\n",
+ "download_path = os.path.join(model_dir, _TARBALL_NAME)\n",
+ "print('downloading model, this might take a while...')\n",
+ "urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME],\n",
+ " download_path)\n",
+ "print('download completed! loading DeepLab model...')\n",
+ "\n",
+ "MODEL = DeepLabModel(download_path)\n",
+ "print('model loaded successfully!')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "SZst78N-4OKO"
+ },
+ "source": [
+ "## Run on sample images\n",
+ "\n",
+ "Select one of sample images (leave `IMAGE_URL` empty) or feed any internet image\n",
+ "url for inference.\n",
+ "\n",
+ "Note that this colab uses single scale inference for fast computation,\n",
+ "so the results may slightly differ from the visualizations in the\n",
+ "[README](https://github.com/tensorflow/models/blob/master/research/deeplab/README.md) file,\n",
+ "which uses multi-scale and left-right flipped inputs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "cellView": "form",
+ "colab": {},
+ "colab_type": "code",
+ "id": "edGukUHXyymr"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "SAMPLE_IMAGE = 'image1' # @param ['image1', 'image2', 'image3']\n",
+ "IMAGE_URL = '' #@param {type:\"string\"}\n",
+ "\n",
+ "_SAMPLE_URL = ('https://github.com/tensorflow/models/blob/master/research/'\n",
+ " 'deeplab/g3doc/img/%s.jpg?raw=true')\n",
+ "\n",
+ "\n",
+ "def run_visualization(url):\n",
+ " \"\"\"Inferences DeepLab model and visualizes result.\"\"\"\n",
+ " try:\n",
+ " f = urllib.request.urlopen(url)\n",
+ " jpeg_str = f.read()\n",
+ " original_im = Image.open(BytesIO(jpeg_str))\n",
+ " except IOError:\n",
+ " print('Cannot retrieve image. Please check url: ' + url)\n",
+ " return\n",
+ "\n",
+ " print('running deeplab on image %s...' % url)\n",
+ " resized_im, seg_map = MODEL.run(original_im)\n",
+ "\n",
+ " vis_segmentation(resized_im, seg_map)\n",
+ "\n",
+ "\n",
+ "image_url = IMAGE_URL or _SAMPLE_URL % SAMPLE_IMAGE\n",
+ "run_visualization(image_url)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "aUbVoHScTJYe"
+ },
+ "source": [
+ "## What's next\n",
+ "\n",
+ "* Learn about [Cloud TPUs](https://cloud.google.com/tpu/docs) that Google designed and optimized specifically to speed up and scale up ML workloads for training and inference and to enable ML engineers and researchers to iterate more quickly.\n",
+ "* Explore the range of [Cloud TPU tutorials and Colabs](https://cloud.google.com/tpu/docs/tutorials) to find other examples that can be used when implementing your ML project.\n",
+ "* For more information on running the DeepLab model on Cloud TPUs, see the [DeepLab tutorial](https://cloud.google.com/tpu/docs/tutorials/deeplab).\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "DeepLab Demo.ipynb",
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/models/research/deeplab/deprecated/__init__.py b/models/research/deeplab/deprecated/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/deeplab/deprecated/segmentation_dataset.py b/models/research/deeplab/deprecated/segmentation_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a5980b1d940878cd1aead4a5d301cca7b4a642b
--- /dev/null
+++ b/models/research/deeplab/deprecated/segmentation_dataset.py
@@ -0,0 +1,200 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Provides data from semantic segmentation datasets.
+
+The SegmentationDataset class provides both images and annotations (semantic
+segmentation and/or instance segmentation) for TensorFlow. Currently, we
+support the following datasets:
+
+1. PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/).
+
+PASCAL VOC 2012 semantic segmentation dataset annotates 20 foreground objects
+(e.g., bike, person, and so on) and leaves all the other semantic classes as
+one background class. The dataset contains 1464, 1449, and 1456 annotated
+images for the training, validation and test respectively.
+
+2. Cityscapes dataset (https://www.cityscapes-dataset.com)
+
+The Cityscapes dataset contains 19 semantic labels (such as road, person, car,
+and so on) for urban street scenes.
+
+3. ADE20K dataset (http://groups.csail.mit.edu/vision/datasets/ADE20K)
+
+The ADE20K dataset contains 150 semantic labels both urban street scenes and
+indoor scenes.
+
+References:
+ M. Everingham, S. M. A. Eslami, L. V. Gool, C. K. I. Williams, J. Winn,
+ and A. Zisserman, The pascal visual object classes challenge a retrospective.
+ IJCV, 2014.
+
+ M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson,
+ U. Franke, S. Roth, and B. Schiele, "The cityscapes dataset for semantic urban
+ scene understanding," In Proc. of CVPR, 2016.
+
+ B. Zhou, H. Zhao, X. Puig, S. Fidler, A. Barriuso, A. Torralba, "Scene Parsing
+ through ADE20K dataset", In Proc. of CVPR, 2017.
+"""
+import collections
+import os.path
+import tensorflow as tf
+from tensorflow.contrib import slim as contrib_slim
+
+slim = contrib_slim
+
+dataset = slim.dataset
+
+tfexample_decoder = slim.tfexample_decoder
+
+
+_ITEMS_TO_DESCRIPTIONS = {
+ 'image': 'A color image of varying height and width.',
+ 'labels_class': ('A semantic segmentation label whose size matches image.'
+ 'Its values range from 0 (background) to num_classes.'),
+}
+
+# Named tuple to describe the dataset properties.
+DatasetDescriptor = collections.namedtuple(
+ 'DatasetDescriptor',
+ ['splits_to_sizes', # Splits of the dataset into training, val, and test.
+ 'num_classes', # Number of semantic classes, including the background
+ # class (if exists). For example, there are 20
+ # foreground classes + 1 background class in the PASCAL
+ # VOC 2012 dataset. Thus, we set num_classes=21.
+ 'ignore_label', # Ignore label value.
+ ]
+)
+
+_CITYSCAPES_INFORMATION = DatasetDescriptor(
+ splits_to_sizes={
+ 'train': 2975,
+ 'val': 500,
+ },
+ num_classes=19,
+ ignore_label=255,
+)
+
+_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
+ splits_to_sizes={
+ 'train': 1464,
+ 'train_aug': 10582,
+ 'trainval': 2913,
+ 'val': 1449,
+ },
+ num_classes=21,
+ ignore_label=255,
+)
+
+# These number (i.e., 'train'/'test') seems to have to be hard coded
+# You are required to figure it out for your training/testing example.
+_ADE20K_INFORMATION = DatasetDescriptor(
+ splits_to_sizes={
+ 'train': 20210, # num of samples in images/training
+ 'val': 2000, # num of samples in images/validation
+ },
+ num_classes=151,
+ ignore_label=0,
+)
+
+
+_DATASETS_INFORMATION = {
+ 'cityscapes': _CITYSCAPES_INFORMATION,
+ 'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
+ 'ade20k': _ADE20K_INFORMATION,
+}
+
+# Default file pattern of TFRecord of TensorFlow Example.
+_FILE_PATTERN = '%s-*'
+
+
+def get_cityscapes_dataset_name():
+ return 'cityscapes'
+
+
+def get_dataset(dataset_name, split_name, dataset_dir):
+ """Gets an instance of slim Dataset.
+
+ Args:
+ dataset_name: Dataset name.
+ split_name: A train/val Split name.
+ dataset_dir: The directory of the dataset sources.
+
+ Returns:
+ An instance of slim Dataset.
+
+ Raises:
+ ValueError: if the dataset_name or split_name is not recognized.
+ """
+ if dataset_name not in _DATASETS_INFORMATION:
+ raise ValueError('The specified dataset is not supported yet.')
+
+ splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
+
+ if split_name not in splits_to_sizes:
+ raise ValueError('data split name %s not recognized' % split_name)
+
+ # Prepare the variables for different datasets.
+ num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
+ ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label
+
+ file_pattern = _FILE_PATTERN
+ file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
+
+ # Specify how the TF-Examples are decoded.
+ keys_to_features = {
+ 'image/encoded': tf.FixedLenFeature(
+ (), tf.string, default_value=''),
+ 'image/filename': tf.FixedLenFeature(
+ (), tf.string, default_value=''),
+ 'image/format': tf.FixedLenFeature(
+ (), tf.string, default_value='jpeg'),
+ 'image/height': tf.FixedLenFeature(
+ (), tf.int64, default_value=0),
+ 'image/width': tf.FixedLenFeature(
+ (), tf.int64, default_value=0),
+ 'image/segmentation/class/encoded': tf.FixedLenFeature(
+ (), tf.string, default_value=''),
+ 'image/segmentation/class/format': tf.FixedLenFeature(
+ (), tf.string, default_value='png'),
+ }
+ items_to_handlers = {
+ 'image': tfexample_decoder.Image(
+ image_key='image/encoded',
+ format_key='image/format',
+ channels=3),
+ 'image_name': tfexample_decoder.Tensor('image/filename'),
+ 'height': tfexample_decoder.Tensor('image/height'),
+ 'width': tfexample_decoder.Tensor('image/width'),
+ 'labels_class': tfexample_decoder.Image(
+ image_key='image/segmentation/class/encoded',
+ format_key='image/segmentation/class/format',
+ channels=1),
+ }
+
+ decoder = tfexample_decoder.TFExampleDecoder(
+ keys_to_features, items_to_handlers)
+
+ return dataset.Dataset(
+ data_sources=file_pattern,
+ reader=tf.TFRecordReader,
+ decoder=decoder,
+ num_samples=splits_to_sizes[split_name],
+ items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
+ ignore_label=ignore_label,
+ num_classes=num_classes,
+ name=dataset_name,
+ multi_label=True)
diff --git a/models/research/deeplab/eval.py b/models/research/deeplab/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f5fb8ba9c7493e45e567e9cd5ba9fe567dd9690
--- /dev/null
+++ b/models/research/deeplab/eval.py
@@ -0,0 +1,227 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Evaluation script for the DeepLab model.
+
+See model.py for more details and usage.
+"""
+
+import numpy as np
+import six
+import tensorflow as tf
+from tensorflow.contrib import metrics as contrib_metrics
+from tensorflow.contrib import quantize as contrib_quantize
+from tensorflow.contrib import tfprof as contrib_tfprof
+from tensorflow.contrib import training as contrib_training
+from deeplab import common
+from deeplab import model
+from deeplab.datasets import data_generator
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
+
+# Settings for log directories.
+
+flags.DEFINE_string('eval_logdir', None, 'Where to write the event logs.')
+
+flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
+
+# Settings for evaluating the model.
+
+flags.DEFINE_integer('eval_batch_size', 1,
+ 'The number of images in each batch during evaluation.')
+
+flags.DEFINE_list('eval_crop_size', '513,513',
+ 'Image crop size [height, width] for evaluation.')
+
+flags.DEFINE_integer('eval_interval_secs', 60 * 5,
+ 'How often (in seconds) to run evaluation.')
+
+# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
+# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
+# one could use different atrous_rates/output_stride during training/evaluation.
+flags.DEFINE_multi_integer('atrous_rates', None,
+ 'Atrous rates for atrous spatial pyramid pooling.')
+
+flags.DEFINE_integer('output_stride', 16,
+ 'The ratio of input to output spatial resolution.')
+
+# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test.
+flags.DEFINE_multi_float('eval_scales', [1.0],
+ 'The scales to resize images for evaluation.')
+
+# Change to True for adding flipped images during test.
+flags.DEFINE_bool('add_flipped_images', False,
+ 'Add flipped images for evaluation or not.')
+
+flags.DEFINE_integer(
+ 'quantize_delay_step', -1,
+ 'Steps to start quantized training. If < 0, will not quantize model.')
+
+# Dataset settings.
+
+flags.DEFINE_string('dataset', 'pascal_voc_seg',
+ 'Name of the segmentation dataset.')
+
+flags.DEFINE_string('eval_split', 'val',
+ 'Which split of the dataset used for evaluation')
+
+flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
+
+flags.DEFINE_integer('max_number_of_evaluations', 0,
+ 'Maximum number of eval iterations. Will loop '
+ 'indefinitely upon nonpositive values.')
+
+
+def main(unused_argv):
+ tf.logging.set_verbosity(tf.logging.INFO)
+
+ dataset = data_generator.Dataset(
+ dataset_name=FLAGS.dataset,
+ split_name=FLAGS.eval_split,
+ dataset_dir=FLAGS.dataset_dir,
+ batch_size=FLAGS.eval_batch_size,
+ crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
+ min_resize_value=FLAGS.min_resize_value,
+ max_resize_value=FLAGS.max_resize_value,
+ resize_factor=FLAGS.resize_factor,
+ model_variant=FLAGS.model_variant,
+ num_readers=2,
+ is_training=False,
+ should_shuffle=False,
+ should_repeat=False)
+
+ tf.gfile.MakeDirs(FLAGS.eval_logdir)
+ tf.logging.info('Evaluating on %s set', FLAGS.eval_split)
+
+ with tf.Graph().as_default():
+ samples = dataset.get_one_shot_iterator().get_next()
+
+ model_options = common.ModelOptions(
+ outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
+ crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
+ atrous_rates=FLAGS.atrous_rates,
+ output_stride=FLAGS.output_stride)
+
+ # Set shape in order for tf.contrib.tfprof.model_analyzer to work properly.
+ samples[common.IMAGE].set_shape(
+ [FLAGS.eval_batch_size,
+ int(FLAGS.eval_crop_size[0]),
+ int(FLAGS.eval_crop_size[1]),
+ 3])
+ if tuple(FLAGS.eval_scales) == (1.0,):
+ tf.logging.info('Performing single-scale test.')
+ predictions = model.predict_labels(samples[common.IMAGE], model_options,
+ image_pyramid=FLAGS.image_pyramid)
+ else:
+ tf.logging.info('Performing multi-scale test.')
+ if FLAGS.quantize_delay_step >= 0:
+ raise ValueError(
+ 'Quantize mode is not supported with multi-scale test.')
+
+ predictions = model.predict_labels_multi_scale(
+ samples[common.IMAGE],
+ model_options=model_options,
+ eval_scales=FLAGS.eval_scales,
+ add_flipped_images=FLAGS.add_flipped_images)
+ predictions = predictions[common.OUTPUT_TYPE]
+ predictions = tf.reshape(predictions, shape=[-1])
+ labels = tf.reshape(samples[common.LABEL], shape=[-1])
+ weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))
+
+ # Set ignore_label regions to label 0, because metrics.mean_iou requires
+ # range of labels = [0, dataset.num_classes). Note the ignore_label regions
+ # are not evaluated since the corresponding regions contain weights = 0.
+ labels = tf.where(
+ tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels)
+
+ predictions_tag = 'miou'
+ for eval_scale in FLAGS.eval_scales:
+ predictions_tag += '_' + str(eval_scale)
+ if FLAGS.add_flipped_images:
+ predictions_tag += '_flipped'
+
+ # Define the evaluation metric.
+ metric_map = {}
+ num_classes = dataset.num_of_classes
+ metric_map['eval/%s_overall' % predictions_tag] = tf.metrics.mean_iou(
+ labels=labels, predictions=predictions, num_classes=num_classes,
+ weights=weights)
+ # IoU for each class.
+ one_hot_predictions = tf.one_hot(predictions, num_classes)
+ one_hot_predictions = tf.reshape(one_hot_predictions, [-1, num_classes])
+ one_hot_labels = tf.one_hot(labels, num_classes)
+ one_hot_labels = tf.reshape(one_hot_labels, [-1, num_classes])
+ for c in range(num_classes):
+ predictions_tag_c = '%s_class_%d' % (predictions_tag, c)
+ tp, tp_op = tf.metrics.true_positives(
+ labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c],
+ weights=weights)
+ fp, fp_op = tf.metrics.false_positives(
+ labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c],
+ weights=weights)
+ fn, fn_op = tf.metrics.false_negatives(
+ labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c],
+ weights=weights)
+ tp_fp_fn_op = tf.group(tp_op, fp_op, fn_op)
+ iou = tf.where(tf.greater(tp + fn, 0.0),
+ tp / (tp + fn + fp),
+ tf.constant(np.NaN))
+ metric_map['eval/%s' % predictions_tag_c] = (iou, tp_fp_fn_op)
+
+ (metrics_to_values,
+ metrics_to_updates) = contrib_metrics.aggregate_metric_map(metric_map)
+
+ summary_ops = []
+ for metric_name, metric_value in six.iteritems(metrics_to_values):
+ op = tf.summary.scalar(metric_name, metric_value)
+ op = tf.Print(op, [metric_value], metric_name)
+ summary_ops.append(op)
+
+ summary_op = tf.summary.merge(summary_ops)
+ summary_hook = contrib_training.SummaryAtEndHook(
+ log_dir=FLAGS.eval_logdir, summary_op=summary_op)
+ hooks = [summary_hook]
+
+ num_eval_iters = None
+ if FLAGS.max_number_of_evaluations > 0:
+ num_eval_iters = FLAGS.max_number_of_evaluations
+
+ if FLAGS.quantize_delay_step >= 0:
+ contrib_quantize.create_eval_graph()
+
+ contrib_tfprof.model_analyzer.print_model_analysis(
+ tf.get_default_graph(),
+ tfprof_options=contrib_tfprof.model_analyzer
+ .TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
+ contrib_tfprof.model_analyzer.print_model_analysis(
+ tf.get_default_graph(),
+ tfprof_options=contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
+ contrib_training.evaluate_repeatedly(
+ checkpoint_dir=FLAGS.checkpoint_dir,
+ master=FLAGS.master,
+ eval_ops=list(metrics_to_updates.values()),
+ max_number_of_evaluations=num_eval_iters,
+ hooks=hooks,
+ eval_interval_secs=FLAGS.eval_interval_secs)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('checkpoint_dir')
+ flags.mark_flag_as_required('eval_logdir')
+ flags.mark_flag_as_required('dataset_dir')
+ tf.app.run()
diff --git a/models/research/deeplab/evaluation/README.md b/models/research/deeplab/evaluation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..69255384e9a293e7acada74b40bd4288ba121edb
--- /dev/null
+++ b/models/research/deeplab/evaluation/README.md
@@ -0,0 +1,311 @@
+# Evaluation Metrics for Whole Image Parsing
+
+Whole Image Parsing [1], also known as Panoptic Segmentation [2], generalizes
+the tasks of semantic segmentation for "stuff" classes and instance
+segmentation for "thing" classes, assigning both semantic and instance labels
+to every pixel in an image.
+
+Previous works evaluate the parsing result with separate metrics (e.g., one for
+semantic segmentation result and one for object detection result). Recently,
+Kirillov et al. propose the unified instance-based Panoptic Quality (PQ) metric
+[2] into several benchmarks [3, 4].
+
+However, we notice that the instance-based PQ metric often places
+disproportionate emphasis on small instance parsing, as well as on "thing" over
+"stuff" classes. To remedy these effects, we propose an alternative
+region-based Parsing Covering (PC) metric [5], which adapts the Covering
+metric [6], previously used for class-agnostics segmentation quality
+evaluation, to the task of image parsing.
+
+Here, we provide implementation of both PQ and PC for evaluating the parsing
+results. We briefly explain both metrics below for reference.
+
+## Panoptic Quality (PQ)
+
+Given a groundtruth segmentation S and a predicted segmentation S', PQ is
+defined as follows:
+
+
+
+
+
+where R and R' are groundtruth regions and predicted regions respectively,
+and |TP|, |FP|, and |FN| are the number of true positives, false postives,
+and false negatives. The matching is determined by a threshold of 0.5
+Intersection-Over-Union (IOU).
+
+PQ treats all regions of the same ‘stuff‘ class as one instance, and the
+size of instances is not considered. For example, instances with 10 × 10
+pixels contribute equally to the metric as instances with 1000 × 1000 pixels.
+Therefore, PQ is sensitive to false positives with small regions and some
+heuristics could improve the performance, such as removing those small
+regions (as also pointed out in the open-sourced evaluation code from [2]).
+Thus, we argue that PQ is suitable in applications where one cares equally for
+the parsing quality of instances irrespective of their sizes.
+
+## Parsing Covering (PC)
+
+We notice that there are applications where one pays more attention to large
+objects, e.g., autonomous driving (where nearby objects are more important
+than far away ones). Motivated by this, we propose to also evaluate the
+quality of image parsing results by extending the existing Covering metric [5],
+which accounts for instance sizes. Specifically, our proposed metric, Parsing
+Covering (PC), is defined as follows:
+
+
+
+
+
+
+where Si and Si' are the groundtruth segmentation and
+predicted segmentation for the i-th semantic class respectively, and
+Ni is the total number of pixels of groundtruth regions from
+Si . The Covering for class i, Covi , is computed in
+the same way as the original Covering metric except that only groundtruth
+regions from Si and predicted regions from Si' are
+considered. PC is then obtained by computing the average of Covi
+over C semantic classes.
+
+A notable difference between PQ and the proposed PC is that there is no
+matching involved in PC and hence no matching threshold. As an attempt to
+treat equally "thing" and "stuff", the segmentation of "stuff" classes still
+receives partial PC score if the segmentation is only partially correct. For
+example, if one out of three equally-sized trees is perfectly segmented, the
+model will get the same partial score by using PC regardless of considering
+"tree" as "stuff" or "thing".
+
+## Tutorial
+
+To evaluate the parsing results with PQ and PC, we provide two options:
+
+1. Python off-line evaluation with results saved in the [COCO format](http://cocodataset.org/#format-results).
+2. TensorFlow on-line evaluation.
+
+Below, we explain each option in detail.
+
+#### 1. Python off-line evaluation with results saved in COCO format
+
+[COCO result format](http://cocodataset.org/#format-results) has been
+adopted by several benchmarks [3, 4]. Therefore, we provide a convenient
+function, `eval_coco_format`, to evaluate the results saved in COCO format
+in terms of PC and re-implemented PQ.
+
+Before using the provided function, the users need to download the official COCO
+panotpic segmentation task API. Please see [installation](../g3doc/installation.md#add-libraries-to-pythonpath)
+for reference.
+
+Once the official COCO panoptic segmentation task API is downloaded, the
+users should be able to run the `eval_coco_format.py` to evaluate the parsing
+results in terms of both PC and reimplemented PQ.
+
+To be concrete, let's take a look at the function, `eval_coco_format` in
+`eval_coco_format.py`:
+
+```python
+eval_coco_format(gt_json_file,
+ pred_json_file,
+ gt_folder=None,
+ pred_folder=None,
+ metric='pq',
+ num_categories=201,
+ ignored_label=0,
+ max_instances_per_category=256,
+ intersection_offset=None,
+ normalize_by_image_size=True,
+ num_workers=0,
+ print_digits=3):
+
+```
+where
+
+1. `gt_json_file`: Path to a JSON file giving ground-truth annotations in COCO
+format.
+2. `pred_json_file`: Path to a JSON file for the predictions to evaluate.
+3. `gt_folder`: Folder containing panoptic-format ID images to match
+ground-truth annotations to image regions.
+4. `pred_folder`: Path to a folder containing ID images for predictions.
+5. `metric`: Name of a metric to compute. Set to `pc`, `pq` for evaluation in PC
+or PQ, respectively.
+6. `num_categories`: The number of segmentation categories (or "classes") in the
+dataset.
+7. `ignored_label`: A category id that is ignored in evaluation, e.g. the "void"
+label in COCO panoptic segmentation dataset.
+8. `max_instances_per_category`: The maximum number of instances for each
+category to ensure unique instance labels.
+9. `intersection_offset`: The maximum number of unique labels.
+10. `normalize_by_image_size`: Whether to normalize groundtruth instance region
+areas by image size when using PC.
+11. `num_workers`: If set to a positive number, will spawn child processes to
+compute parts of the metric in parallel by splitting the images between the
+workers. If set to -1, will use the value of multiprocessing.cpu_count().
+12. `print_digits`: Number of significant digits to print in summary of computed
+metrics.
+
+The input arguments have default values set for the COCO panoptic segmentation
+dataset. Thus, users only need to provide the `gt_json_file` and the
+`pred_json_file` (following the COCO format) to run the evaluation on COCO with
+PQ. If users want to evaluate the results on other datasets, they may need
+to change the default values.
+
+As an example, the interested users could take a look at the provided unit
+test, `test_compare_pq_with_reference_eval`, in `eval_coco_format_test.py`.
+
+#### 2. TensorFlow on-line evaluation
+
+Users may also want to run the TensorFlow on-line evaluation, similar to the
+[tf.contrib.metrics.streaming_mean_iou](https://www.tensorflow.org/api_docs/python/tf/contrib/metrics/streaming_mean_iou).
+
+Below, we provide a code snippet that shows how to use the provided
+`streaming_panoptic_quality` and `streaming_parsing_covering`.
+
+```python
+metric_map = {}
+metric_map['panoptic_quality'] = streaming_metrics.streaming_panoptic_quality(
+ category_label,
+ instance_label,
+ category_prediction,
+ instance_prediction,
+ num_classes=201,
+ max_instances_per_category=256,
+ ignored_label=0,
+ offset=256*256)
+metric_map['parsing_covering'] = streaming_metrics.streaming_parsing_covering(
+ category_label,
+ instance_label,
+ category_prediction,
+ instance_prediction,
+ num_classes=201,
+ max_instances_per_category=256,
+ ignored_label=0,
+ offset=256*256,
+ normalize_by_image_size=True)
+metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map(
+ metric_map)
+```
+where `metric_map` is a dictionary storing the streamed results of PQ and PC.
+
+The `category_label` and the `instance_label` are the semantic segmentation and
+instance segmentation groundtruth, respectively. That is, in the panoptic
+segmentation format:
+panoptic_label = category_label * max_instances_per_category + instance_label.
+Similarly, the `category_prediction` and the `instance_prediction` are the
+predicted semantic segmentation and instance segmentation, respectively.
+
+Below, we provide a code snippet about how to summarize the results in the
+context of tf.summary.
+
+```python
+summary_ops = []
+for metric_name, metric_value in metrics_to_values.iteritems():
+ if metric_name == 'panoptic_quality':
+ [pq, sq, rq, total_tp, total_fn, total_fp] = tf.unstack(
+ metric_value, 6, axis=0)
+ panoptic_metrics = {
+ # Panoptic quality.
+ 'pq': pq,
+ # Segmentation quality.
+ 'sq': sq,
+ # Recognition quality.
+ 'rq': rq,
+ # Total true positives.
+ 'total_tp': total_tp,
+ # Total false negatives.
+ 'total_fn': total_fn,
+ # Total false positives.
+ 'total_fp': total_fp,
+ }
+ # Find the valid classes that will be used for evaluation. We will
+ # ignore the `ignore_label` class and other classes which have (tp + fn
+ # + fp) equal to 0.
+ valid_classes = tf.logical_and(
+ tf.not_equal(tf.range(0, num_classes), void_label),
+ tf.not_equal(total_tp + total_fn + total_fp, 0))
+ for target_metric, target_value in panoptic_metrics.iteritems():
+ output_metric_name = '{}_{}'.format(metric_name, target_metric)
+ op = tf.summary.scalar(
+ output_metric_name,
+ tf.reduce_mean(tf.boolean_mask(target_value, valid_classes)))
+ op = tf.Print(op, [target_value], output_metric_name + '_classwise: ',
+ summarize=num_classes)
+ op = tf.Print(
+ op,
+ [tf.reduce_mean(tf.boolean_mask(target_value, valid_classes))],
+ output_metric_name + '_mean: ',
+ summarize=1)
+ summary_ops.append(op)
+ elif metric_name == 'parsing_covering':
+ [per_class_covering,
+ total_per_class_weighted_ious,
+ total_per_class_gt_areas] = tf.unstack(metric_value, 3, axis=0)
+ # Find the valid classes that will be used for evaluation. We will
+ # ignore the `void_label` class and other classes which have
+ # total_per_class_weighted_ious + total_per_class_gt_areas equal to 0.
+ valid_classes = tf.logical_and(
+ tf.not_equal(tf.range(0, num_classes), void_label),
+ tf.not_equal(
+ total_per_class_weighted_ious + total_per_class_gt_areas, 0))
+ op = tf.summary.scalar(
+ metric_name,
+ tf.reduce_mean(tf.boolean_mask(per_class_covering, valid_classes)))
+ op = tf.Print(op, [per_class_covering], metric_name + '_classwise: ',
+ summarize=num_classes)
+ op = tf.Print(
+ op,
+ [tf.reduce_mean(
+ tf.boolean_mask(per_class_covering, valid_classes))],
+ metric_name + '_mean: ',
+ summarize=1)
+ summary_ops.append(op)
+ else:
+ raise ValueError('The metric_name "%s" is not supported.' % metric_name)
+```
+
+Afterwards, the users could use the following code to run the evaluation in
+TensorFlow.
+
+Users can take a look at eval.py for reference which provides a simple
+example to run the streaming evaluation of mIOU for semantic segmentation.
+
+```python
+metric_values = slim.evaluation.evaluation_loop(
+ master=FLAGS.master,
+ checkpoint_dir=FLAGS.checkpoint_dir,
+ logdir=FLAGS.eval_logdir,
+ num_evals=num_batches,
+ eval_op=metrics_to_updates.values(),
+ final_op=metrics_to_values.values(),
+ summary_op=tf.summary.merge(summary_ops),
+ max_number_of_evaluations=FLAGS.max_number_of_evaluations,
+ eval_interval_secs=FLAGS.eval_interval_secs)
+```
+
+
+### References
+
+1. **Image Parsing: Unifying Segmentation, Detection, and Recognition**
+ Zhuowen Tu, Xiangrong Chen, Alan L. Yuille, and Song-Chun Zhu
+ IJCV, 2005.
+
+2. **Panoptic Segmentation**
+ Alexander Kirillov, Kaiming He, Ross Girshick, Carsten Rother and Piotr
+ Dollár
+ arXiv:1801.00868, 2018.
+
+3. **Microsoft COCO: Common Objects in Context**
+ Tsung-Yi Lin, Michael Maire, Serge Belongie, Lubomir Bourdev, Ross
+ Girshick, James Hays, Pietro Perona, Deva Ramanan, C. Lawrence Zitnick,
+ Piotr Dollar
+ In the Proc. of ECCV, 2014.
+
+4. **The Mapillary Vistas Dataset for Semantic Understanding of Street Scenes**
+ Gerhard Neuhold, Tobias Ollmann, Samuel Rota Bulò, and Peter Kontschieder
+ In the Proc. of ICCV, 2017.
+
+5. **DeeperLab: Single-Shot Image Parser**
+ Tien-Ju Yang, Maxwell D. Collins, Yukun Zhu, Jyh-Jing Hwang, Ting Liu,
+ Xiao Zhang, Vivienne Sze, George Papandreou, Liang-Chieh Chen
+ arXiv: 1902.05093, 2019.
+
+6. **Contour Detection and Hierarchical Image Segmentation**
+ Pablo Arbelaez, Michael Maire, Charless Fowlkes, and Jitendra Malik
+ PAMI, 2011
diff --git a/models/research/deeplab/evaluation/__init__.py b/models/research/deeplab/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/deeplab/evaluation/base_metric.py b/models/research/deeplab/evaluation/base_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7606ef44c1c2c027e593c494659f0dbcd455d3
--- /dev/null
+++ b/models/research/deeplab/evaluation/base_metric.py
@@ -0,0 +1,191 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines the top-level interface for evaluating segmentations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import numpy as np
+import six
+
+
+_EPSILON = 1e-10
+
+
+def realdiv_maybe_zero(x, y):
+ """Element-wise x / y where y may contain zeros, for those returns 0 too."""
+ return np.where(
+ np.less(np.abs(y), _EPSILON), np.zeros_like(x), np.divide(x, y))
+
+
+@six.add_metaclass(abc.ABCMeta)
+class SegmentationMetric(object):
+ """Abstract base class for computers of segmentation metrics.
+
+ Subclasses will implement both:
+ 1. Comparing the predicted segmentation for an image with the groundtruth.
+ 2. Computing the final metric over a set of images.
+ These are often done as separate steps, due to the need to accumulate
+ intermediate values other than the metric itself across images, computing the
+ actual metric value only on these accumulations after all the images have been
+ compared.
+
+ A simple usage would be:
+
+ metric = MetricImplementation(...)
+ for , in evaluation_set:
+ = run_segmentation()
+ metric.compare_and_accumulate(, )
+ print(metric.result())
+
+ """
+
+ def __init__(self, num_categories, ignored_label, max_instances_per_category,
+ offset):
+ """Base initialization for SegmentationMetric.
+
+ Args:
+ num_categories: The number of segmentation categories (or "classes" in the
+ dataset.
+ ignored_label: A category id that is ignored in evaluation, e.g. the void
+ label as defined in COCO panoptic segmentation dataset.
+ max_instances_per_category: The maximum number of instances for each
+ category. Used in ensuring unique instance labels.
+ offset: The maximum number of unique labels. This is used, by multiplying
+ the ground-truth labels, to generate unique ids for individual regions
+ of overlap between groundtruth and predicted segments.
+ """
+ self.num_categories = num_categories
+ self.ignored_label = ignored_label
+ self.max_instances_per_category = max_instances_per_category
+ self.offset = offset
+ self.reset()
+
+ def _naively_combine_labels(self, category_array, instance_array):
+ """Naively creates a combined label array from categories and instances."""
+ return (category_array.astype(np.uint32) * self.max_instances_per_category +
+ instance_array.astype(np.uint32))
+
+ @abc.abstractmethod
+ def compare_and_accumulate(
+ self, groundtruth_category_array, groundtruth_instance_array,
+ predicted_category_array, predicted_instance_array):
+ """Compares predicted segmentation with groundtruth, accumulates its metric.
+
+ It is not assumed that instance ids are unique across different categories.
+ See for example combine_semantic_and_instance_predictions.py in official
+ PanopticAPI evaluation code for issues to consider when fusing category
+ and instance labels.
+
+ Instances ids of the ignored category have the meaning that id 0 is "void"
+ and remaining ones are crowd instances.
+
+ Args:
+ groundtruth_category_array: A 2D numpy uint16 array of groundtruth
+ per-pixel category labels.
+ groundtruth_instance_array: A 2D numpy uint16 array of groundtruth
+ instance labels.
+ predicted_category_array: A 2D numpy uint16 array of predicted per-pixel
+ category labels.
+ predicted_instance_array: A 2D numpy uint16 array of predicted instance
+ labels.
+
+ Returns:
+ The value of the metric over all comparisons done so far, including this
+ one, as a float scalar.
+ """
+ raise NotImplementedError('Must be implemented in subclasses.')
+
+ @abc.abstractmethod
+ def result(self):
+ """Computes the metric over all comparisons done so far."""
+ raise NotImplementedError('Must be implemented in subclasses.')
+
+ @abc.abstractmethod
+ def detailed_results(self, is_thing=None):
+ """Computes and returns the detailed final metric results.
+
+ Args:
+ is_thing: A boolean array of length `num_categories`. The entry
+ `is_thing[category_id]` is True iff that category is a "thing" category
+ instead of "stuff."
+
+ Returns:
+ A dictionary with a breakdown of metrics and/or metric factors by things,
+ stuff, and all categories.
+ """
+ raise NotImplementedError('Not implemented in subclasses.')
+
+ @abc.abstractmethod
+ def result_per_category(self):
+ """For supported metrics, return individual per-category metric values.
+
+ Returns:
+ A numpy array of shape `[self.num_categories]`, where index `i` is the
+ metrics value over only that category.
+ """
+ raise NotImplementedError('Not implemented in subclass.')
+
+ def print_detailed_results(self, is_thing=None, print_digits=3):
+ """Prints out a detailed breakdown of metric results.
+
+ Args:
+ is_thing: A boolean array of length num_categories.
+ `is_thing[category_id]` will say whether that category is a "thing"
+ rather than "stuff."
+ print_digits: Number of significant digits to print in computed metrics.
+ """
+ raise NotImplementedError('Not implemented in subclass.')
+
+ @abc.abstractmethod
+ def merge(self, other_instance):
+ """Combines the accumulated results of another instance into self.
+
+ The following two cases should put `metric_a` into an equivalent state.
+
+ Case 1 (with merge):
+
+ metric_a = MetricsSubclass(...)
+ metric_a.compare_and_accumulate()
+ metric_a.compare_and_accumulate()
+
+ metric_b = MetricsSubclass(...)
+ metric_b.compare_and_accumulate()
+ metric_b.compare_and_accumulate()
+
+ metric_a.merge(metric_b)
+
+ Case 2 (without merge):
+
+ metric_a = MetricsSubclass(...)
+ metric_a.compare_and_accumulate()
+ metric_a.compare_and_accumulate()
+ metric_a.compare_and_accumulate()
+ metric_a.compare_and_accumulate()
+
+ Args:
+ other_instance: Another compatible instance of the same metric subclass.
+ """
+ raise NotImplementedError('Not implemented in subclass.')
+
+ @abc.abstractmethod
+ def reset(self):
+ """Resets the accumulation to the metric class's state at initialization.
+
+ Note that this function will be called in SegmentationMetric.__init__.
+ """
+ raise NotImplementedError('Must be implemented in subclasses.')
diff --git a/models/research/deeplab/evaluation/eval_coco_format.py b/models/research/deeplab/evaluation/eval_coco_format.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a26446f16b9787f246034a58247eb36d0064f80
--- /dev/null
+++ b/models/research/deeplab/evaluation/eval_coco_format.py
@@ -0,0 +1,338 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Computes evaluation metrics on groundtruth and predictions in COCO format.
+
+The Common Objects in Context (COCO) dataset defines a format for specifying
+combined semantic and instance segmentations as "panoptic" segmentations. This
+is done with the combination of JSON and image files as specified at:
+http://cocodataset.org/#format-results
+where the JSON file specifies the overall structure of the result,
+including the categories for each annotation, and the images specify the image
+region for each annotation in that image by its ID.
+
+This script computes additional metrics such as Parsing Covering on datasets and
+predictions in this format. An implementation of Panoptic Quality is also
+provided for convenience.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import json
+import multiprocessing
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+from PIL import Image
+import utils as panopticapi_utils
+import six
+
+from deeplab.evaluation import panoptic_quality
+from deeplab.evaluation import parsing_covering
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string(
+ 'gt_json_file', None,
+ ' Path to a JSON file giving ground-truth annotations in COCO format.')
+flags.DEFINE_string('pred_json_file', None,
+ 'Path to a JSON file for the predictions to evaluate.')
+flags.DEFINE_string(
+ 'gt_folder', None,
+ 'Folder containing panoptic-format ID images to match ground-truth '
+ 'annotations to image regions.')
+flags.DEFINE_string('pred_folder', None,
+ 'Folder containing ID images for predictions.')
+flags.DEFINE_enum(
+ 'metric', 'pq', ['pq', 'pc'], 'Shorthand name of a metric to compute. '
+ 'Supported values are:\n'
+ 'Panoptic Quality (pq)\n'
+ 'Parsing Covering (pc)')
+flags.DEFINE_integer(
+ 'num_categories', 201,
+ 'The number of segmentation categories (or "classes") in the dataset.')
+flags.DEFINE_integer(
+ 'ignored_label', 0,
+ 'A category id that is ignored in evaluation, e.g. the void label as '
+ 'defined in COCO panoptic segmentation dataset.')
+flags.DEFINE_integer(
+ 'max_instances_per_category', 256,
+ 'The maximum number of instances for each category. Used in ensuring '
+ 'unique instance labels.')
+flags.DEFINE_integer('intersection_offset', None,
+ 'The maximum number of unique labels.')
+flags.DEFINE_bool(
+ 'normalize_by_image_size', True,
+ 'Whether to normalize groundtruth instance region areas by image size. If '
+ 'True, groundtruth instance areas and weighted IoUs will be divided by the '
+ 'size of the corresponding image before accumulated across the dataset. '
+ 'Only used for Parsing Covering (pc) evaluation.')
+flags.DEFINE_integer(
+ 'num_workers', 0, 'If set to a positive number, will spawn child processes '
+ 'to compute parts of the metric in parallel by splitting '
+ 'the images between the workers. If set to -1, will use '
+ 'the value of multiprocessing.cpu_count().')
+flags.DEFINE_integer('print_digits', 3,
+ 'Number of significant digits to print in metrics.')
+
+
+def _build_metric(metric,
+ num_categories,
+ ignored_label,
+ max_instances_per_category,
+ intersection_offset=None,
+ normalize_by_image_size=True):
+ """Creates a metric aggregator objet of the given name."""
+ if metric == 'pq':
+ logging.warning('One should check Panoptic Quality results against the '
+ 'official COCO API code. Small numerical differences '
+ '(< 0.1%) can be magnified by rounding.')
+ return panoptic_quality.PanopticQuality(num_categories, ignored_label,
+ max_instances_per_category,
+ intersection_offset)
+ elif metric == 'pc':
+ return parsing_covering.ParsingCovering(
+ num_categories, ignored_label, max_instances_per_category,
+ intersection_offset, normalize_by_image_size)
+ else:
+ raise ValueError('No implementation for metric "%s"' % metric)
+
+
+def _matched_annotations(gt_json, pred_json):
+ """Yields a set of (groundtruth, prediction) image annotation pairs.."""
+ image_id_to_pred_ann = {
+ annotation['image_id']: annotation
+ for annotation in pred_json['annotations']
+ }
+ for gt_ann in gt_json['annotations']:
+ image_id = gt_ann['image_id']
+ pred_ann = image_id_to_pred_ann[image_id]
+ yield gt_ann, pred_ann
+
+
+def _open_panoptic_id_image(image_path):
+ """Loads a COCO-format panoptic ID image from file."""
+ return panopticapi_utils.rgb2id(
+ np.array(Image.open(image_path), dtype=np.uint32))
+
+
+def _split_panoptic(ann_json, id_array, ignored_label, allow_crowds):
+ """Given the COCO JSON and ID map, splits into categories and instances."""
+ category = np.zeros(id_array.shape, np.uint16)
+ instance = np.zeros(id_array.shape, np.uint16)
+ next_instance_id = collections.defaultdict(int)
+ # Skip instance label 0 for ignored label. That is reserved for void.
+ next_instance_id[ignored_label] = 1
+ for segment_info in ann_json['segments_info']:
+ if allow_crowds and segment_info['iscrowd']:
+ category_id = ignored_label
+ else:
+ category_id = segment_info['category_id']
+ mask = np.equal(id_array, segment_info['id'])
+ category[mask] = category_id
+ instance[mask] = next_instance_id[category_id]
+ next_instance_id[category_id] += 1
+ return category, instance
+
+
+def _category_and_instance_from_annotation(ann_json, folder, ignored_label,
+ allow_crowds):
+ """Given the COCO JSON annotations, finds maps of categories and instances."""
+ panoptic_id_image = _open_panoptic_id_image(
+ os.path.join(folder, ann_json['file_name']))
+ return _split_panoptic(ann_json, panoptic_id_image, ignored_label,
+ allow_crowds)
+
+
+def _compute_metric(metric_aggregator, gt_folder, pred_folder,
+ annotation_pairs):
+ """Iterates over matched annotation pairs and computes a metric over them."""
+ for gt_ann, pred_ann in annotation_pairs:
+ # We only expect "iscrowd" to appear in the ground-truth, and not in model
+ # output. In predicted JSON it is simply ignored, as done in official code.
+ gt_category, gt_instance = _category_and_instance_from_annotation(
+ gt_ann, gt_folder, metric_aggregator.ignored_label, True)
+ pred_category, pred_instance = _category_and_instance_from_annotation(
+ pred_ann, pred_folder, metric_aggregator.ignored_label, False)
+
+ metric_aggregator.compare_and_accumulate(gt_category, gt_instance,
+ pred_category, pred_instance)
+ return metric_aggregator
+
+
+def _iterate_work_queue(work_queue):
+ """Creates an iterable that retrieves items from a queue until one is None."""
+ task = work_queue.get(block=True)
+ while task is not None:
+ yield task
+ task = work_queue.get(block=True)
+
+
+def _run_metrics_worker(metric_aggregator, gt_folder, pred_folder, work_queue,
+ result_queue):
+ result = _compute_metric(metric_aggregator, gt_folder, pred_folder,
+ _iterate_work_queue(work_queue))
+ result_queue.put(result, block=True)
+
+
+def _is_thing_array(categories_json, ignored_label):
+ """is_thing[category_id] is a bool on if category is "thing" or "stuff"."""
+ is_thing_dict = {}
+ for category_json in categories_json:
+ is_thing_dict[category_json['id']] = bool(category_json['isthing'])
+
+ # Check our assumption that the category ids are consecutive.
+ # Usually metrics should be able to handle this case, but adding a warning
+ # here.
+ max_category_id = max(six.iterkeys(is_thing_dict))
+ if len(is_thing_dict) != max_category_id + 1:
+ seen_ids = six.viewkeys(is_thing_dict)
+ all_ids = set(six.moves.range(max_category_id + 1))
+ unseen_ids = all_ids.difference(seen_ids)
+ if unseen_ids != {ignored_label}:
+ logging.warning(
+ 'Nonconsecutive category ids or no category JSON specified for ids: '
+ '%s', unseen_ids)
+
+ is_thing_array = np.zeros(max_category_id + 1)
+ for category_id, is_thing in six.iteritems(is_thing_dict):
+ is_thing_array[category_id] = is_thing
+
+ return is_thing_array
+
+
+def eval_coco_format(gt_json_file,
+ pred_json_file,
+ gt_folder=None,
+ pred_folder=None,
+ metric='pq',
+ num_categories=201,
+ ignored_label=0,
+ max_instances_per_category=256,
+ intersection_offset=None,
+ normalize_by_image_size=True,
+ num_workers=0,
+ print_digits=3):
+ """Top-level code to compute metrics on a COCO-format result.
+
+ Note that the default values are set for COCO panoptic segmentation dataset,
+ and thus the users may want to change it for their own dataset evaluation.
+
+ Args:
+ gt_json_file: Path to a JSON file giving ground-truth annotations in COCO
+ format.
+ pred_json_file: Path to a JSON file for the predictions to evaluate.
+ gt_folder: Folder containing panoptic-format ID images to match ground-truth
+ annotations to image regions.
+ pred_folder: Folder containing ID images for predictions.
+ metric: Name of a metric to compute.
+ num_categories: The number of segmentation categories (or "classes") in the
+ dataset.
+ ignored_label: A category id that is ignored in evaluation, e.g. the "void"
+ label as defined in the COCO panoptic segmentation dataset.
+ max_instances_per_category: The maximum number of instances for each
+ category. Used in ensuring unique instance labels.
+ intersection_offset: The maximum number of unique labels.
+ normalize_by_image_size: Whether to normalize groundtruth instance region
+ areas by image size. If True, groundtruth instance areas and weighted IoUs
+ will be divided by the size of the corresponding image before accumulated
+ across the dataset. Only used for Parsing Covering (pc) evaluation.
+ num_workers: If set to a positive number, will spawn child processes to
+ compute parts of the metric in parallel by splitting the images between
+ the workers. If set to -1, will use the value of
+ multiprocessing.cpu_count().
+ print_digits: Number of significant digits to print in summary of computed
+ metrics.
+
+ Returns:
+ The computed result of the metric as a float scalar.
+ """
+ with open(gt_json_file, 'r') as gt_json_fo:
+ gt_json = json.load(gt_json_fo)
+ with open(pred_json_file, 'r') as pred_json_fo:
+ pred_json = json.load(pred_json_fo)
+ if gt_folder is None:
+ gt_folder = gt_json_file.replace('.json', '')
+ if pred_folder is None:
+ pred_folder = pred_json_file.replace('.json', '')
+ if intersection_offset is None:
+ intersection_offset = (num_categories + 1) * max_instances_per_category
+
+ metric_aggregator = _build_metric(
+ metric, num_categories, ignored_label, max_instances_per_category,
+ intersection_offset, normalize_by_image_size)
+
+ if num_workers == -1:
+ logging.info('Attempting to get the CPU count to set # workers.')
+ num_workers = multiprocessing.cpu_count()
+
+ if num_workers > 0:
+ logging.info('Computing metric in parallel with %d workers.', num_workers)
+ work_queue = multiprocessing.Queue()
+ result_queue = multiprocessing.Queue()
+ workers = []
+ worker_args = (metric_aggregator, gt_folder, pred_folder, work_queue,
+ result_queue)
+ for _ in six.moves.range(num_workers):
+ workers.append(
+ multiprocessing.Process(target=_run_metrics_worker, args=worker_args))
+ for worker in workers:
+ worker.start()
+ for ann_pair in _matched_annotations(gt_json, pred_json):
+ work_queue.put(ann_pair, block=True)
+
+ # Will cause each worker to return a result and terminate upon recieving a
+ # None task.
+ for _ in six.moves.range(num_workers):
+ work_queue.put(None, block=True)
+
+ # Retrieve results.
+ for _ in six.moves.range(num_workers):
+ metric_aggregator.merge(result_queue.get(block=True))
+
+ for worker in workers:
+ worker.join()
+ else:
+ logging.info('Computing metric in a single process.')
+ annotation_pairs = _matched_annotations(gt_json, pred_json)
+ _compute_metric(metric_aggregator, gt_folder, pred_folder, annotation_pairs)
+
+ is_thing = _is_thing_array(gt_json['categories'], ignored_label)
+ metric_aggregator.print_detailed_results(
+ is_thing=is_thing, print_digits=print_digits)
+ return metric_aggregator.detailed_results(is_thing=is_thing)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+
+ eval_coco_format(FLAGS.gt_json_file, FLAGS.pred_json_file, FLAGS.gt_folder,
+ FLAGS.pred_folder, FLAGS.metric, FLAGS.num_categories,
+ FLAGS.ignored_label, FLAGS.max_instances_per_category,
+ FLAGS.intersection_offset, FLAGS.normalize_by_image_size,
+ FLAGS.num_workers, FLAGS.print_digits)
+
+
+if __name__ == '__main__':
+ flags.mark_flags_as_required(
+ ['gt_json_file', 'gt_folder', 'pred_json_file', 'pred_folder'])
+ app.run(main)
diff --git a/models/research/deeplab/evaluation/eval_coco_format_test.py b/models/research/deeplab/evaluation/eval_coco_format_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9093ff127e5dce27775421b9d136fc7cbc27c77
--- /dev/null
+++ b/models/research/deeplab/evaluation/eval_coco_format_test.py
@@ -0,0 +1,140 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for eval_coco_format script."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+from absl.testing import absltest
+import evaluation as panopticapi_eval
+
+from deeplab.evaluation import eval_coco_format
+
+_TEST_DIR = 'deeplab/evaluation/testdata'
+
+FLAGS = flags.FLAGS
+
+
+class EvalCocoFormatTest(absltest.TestCase):
+
+ def test_compare_pq_with_reference_eval(self):
+ sample_data_dir = os.path.join(_TEST_DIR)
+ gt_json_file = os.path.join(sample_data_dir, 'coco_gt.json')
+ gt_folder = os.path.join(sample_data_dir, 'coco_gt')
+ pred_json_file = os.path.join(sample_data_dir, 'coco_pred.json')
+ pred_folder = os.path.join(sample_data_dir, 'coco_pred')
+
+ panopticapi_results = panopticapi_eval.pq_compute(
+ gt_json_file, pred_json_file, gt_folder, pred_folder)
+ deeplab_results = eval_coco_format.eval_coco_format(
+ gt_json_file,
+ pred_json_file,
+ gt_folder,
+ pred_folder,
+ metric='pq',
+ num_categories=7,
+ ignored_label=0,
+ max_instances_per_category=256,
+ intersection_offset=(256 * 256))
+ self.assertCountEqual(
+ list(deeplab_results.keys()), ['All', 'Things', 'Stuff'])
+ for cat_group in ['All', 'Things', 'Stuff']:
+ self.assertCountEqual(deeplab_results[cat_group], ['pq', 'sq', 'rq', 'n'])
+ for metric in ['pq', 'sq', 'rq', 'n']:
+ self.assertAlmostEqual(deeplab_results[cat_group][metric],
+ panopticapi_results[cat_group][metric])
+
+ def test_compare_pc_with_golden_value(self):
+ sample_data_dir = os.path.join(_TEST_DIR)
+ gt_json_file = os.path.join(sample_data_dir, 'coco_gt.json')
+ gt_folder = os.path.join(sample_data_dir, 'coco_gt')
+ pred_json_file = os.path.join(sample_data_dir, 'coco_pred.json')
+ pred_folder = os.path.join(sample_data_dir, 'coco_pred')
+
+ deeplab_results = eval_coco_format.eval_coco_format(
+ gt_json_file,
+ pred_json_file,
+ gt_folder,
+ pred_folder,
+ metric='pc',
+ num_categories=7,
+ ignored_label=0,
+ max_instances_per_category=256,
+ intersection_offset=(256 * 256),
+ normalize_by_image_size=False)
+ self.assertCountEqual(
+ list(deeplab_results.keys()), ['All', 'Things', 'Stuff'])
+ for cat_group in ['All', 'Things', 'Stuff']:
+ self.assertCountEqual(deeplab_results[cat_group], ['pc', 'n'])
+ self.assertAlmostEqual(deeplab_results['All']['pc'], 0.68210561)
+ self.assertEqual(deeplab_results['All']['n'], 6)
+ self.assertAlmostEqual(deeplab_results['Things']['pc'], 0.5890529)
+ self.assertEqual(deeplab_results['Things']['n'], 4)
+ self.assertAlmostEqual(deeplab_results['Stuff']['pc'], 0.86821097)
+ self.assertEqual(deeplab_results['Stuff']['n'], 2)
+
+ def test_compare_pc_with_golden_value_normalize_by_size(self):
+ sample_data_dir = os.path.join(_TEST_DIR)
+ gt_json_file = os.path.join(sample_data_dir, 'coco_gt.json')
+ gt_folder = os.path.join(sample_data_dir, 'coco_gt')
+ pred_json_file = os.path.join(sample_data_dir, 'coco_pred.json')
+ pred_folder = os.path.join(sample_data_dir, 'coco_pred')
+
+ deeplab_results = eval_coco_format.eval_coco_format(
+ gt_json_file,
+ pred_json_file,
+ gt_folder,
+ pred_folder,
+ metric='pc',
+ num_categories=7,
+ ignored_label=0,
+ max_instances_per_category=256,
+ intersection_offset=(256 * 256),
+ normalize_by_image_size=True)
+ self.assertCountEqual(
+ list(deeplab_results.keys()), ['All', 'Things', 'Stuff'])
+ self.assertAlmostEqual(deeplab_results['All']['pc'], 0.68214908840)
+
+ def test_pc_with_multiple_workers(self):
+ sample_data_dir = os.path.join(_TEST_DIR)
+ gt_json_file = os.path.join(sample_data_dir, 'coco_gt.json')
+ gt_folder = os.path.join(sample_data_dir, 'coco_gt')
+ pred_json_file = os.path.join(sample_data_dir, 'coco_pred.json')
+ pred_folder = os.path.join(sample_data_dir, 'coco_pred')
+
+ deeplab_results = eval_coco_format.eval_coco_format(
+ gt_json_file,
+ pred_json_file,
+ gt_folder,
+ pred_folder,
+ metric='pc',
+ num_categories=7,
+ ignored_label=0,
+ max_instances_per_category=256,
+ intersection_offset=(256 * 256),
+ num_workers=3,
+ normalize_by_image_size=False)
+ self.assertCountEqual(
+ list(deeplab_results.keys()), ['All', 'Things', 'Stuff'])
+ self.assertAlmostEqual(deeplab_results['All']['pc'], 0.68210561668)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/models/research/deeplab/evaluation/g3doc/img/equation_pc.png b/models/research/deeplab/evaluation/g3doc/img/equation_pc.png
new file mode 100644
index 0000000000000000000000000000000000000000..90f15e7a461f929db9774f2c3c7e9dca549f433b
Binary files /dev/null and b/models/research/deeplab/evaluation/g3doc/img/equation_pc.png differ
diff --git a/models/research/deeplab/evaluation/g3doc/img/equation_pq.png b/models/research/deeplab/evaluation/g3doc/img/equation_pq.png
new file mode 100644
index 0000000000000000000000000000000000000000..13a4393c181f27f5eb9be43f47e9cb132b28b924
Binary files /dev/null and b/models/research/deeplab/evaluation/g3doc/img/equation_pq.png differ
diff --git a/models/research/deeplab/evaluation/panoptic_quality.py b/models/research/deeplab/evaluation/panoptic_quality.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d0f3f98f09819feda52bc89069333665ff5d94
--- /dev/null
+++ b/models/research/deeplab/evaluation/panoptic_quality.py
@@ -0,0 +1,259 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of the Panoptic Quality metric.
+
+Panoptic Quality is an instance-based metric for evaluating the task of
+image parsing, aka panoptic segmentation.
+
+Please see the paper for details:
+"Panoptic Segmentation", Alexander Kirillov, Kaiming He, Ross Girshick,
+Carsten Rother and Piotr Dollar. arXiv:1801.00868, 2018.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import numpy as np
+import prettytable
+import six
+
+from deeplab.evaluation import base_metric
+
+
+def _ids_to_counts(id_array):
+ """Given a numpy array, a mapping from each unique entry to its count."""
+ ids, counts = np.unique(id_array, return_counts=True)
+ return dict(six.moves.zip(ids, counts))
+
+
+class PanopticQuality(base_metric.SegmentationMetric):
+ """Metric class for Panoptic Quality.
+
+ "Panoptic Segmentation" by Alexander Kirillov, Kaiming He, Ross Girshick,
+ Carsten Rother, Piotr Dollar.
+ https://arxiv.org/abs/1801.00868
+ """
+
+ def compare_and_accumulate(
+ self, groundtruth_category_array, groundtruth_instance_array,
+ predicted_category_array, predicted_instance_array):
+ """See base class."""
+ # First, combine the category and instance labels so that every unique
+ # value for (category, instance) is assigned a unique integer label.
+ pred_segment_id = self._naively_combine_labels(predicted_category_array,
+ predicted_instance_array)
+ gt_segment_id = self._naively_combine_labels(groundtruth_category_array,
+ groundtruth_instance_array)
+
+ # Pre-calculate areas for all groundtruth and predicted segments.
+ gt_segment_areas = _ids_to_counts(gt_segment_id)
+ pred_segment_areas = _ids_to_counts(pred_segment_id)
+
+ # We assume there is only one void segment and it has instance id = 0.
+ void_segment_id = self.ignored_label * self.max_instances_per_category
+
+ # There may be other ignored groundtruth segments with instance id > 0, find
+ # those ids using the unique segment ids extracted with the area computation
+ # above.
+ ignored_segment_ids = {
+ gt_segment_id for gt_segment_id in six.iterkeys(gt_segment_areas)
+ if (gt_segment_id //
+ self.max_instances_per_category) == self.ignored_label
+ }
+
+ # Next, combine the groundtruth and predicted labels. Dividing up the pixels
+ # based on which groundtruth segment and which predicted segment they belong
+ # to, this will assign a different 32-bit integer label to each choice
+ # of (groundtruth segment, predicted segment), encoded as
+ # gt_segment_id * offset + pred_segment_id.
+ intersection_id_array = (
+ gt_segment_id.astype(np.uint32) * self.offset +
+ pred_segment_id.astype(np.uint32))
+
+ # For every combination of (groundtruth segment, predicted segment) with a
+ # non-empty intersection, this counts the number of pixels in that
+ # intersection.
+ intersection_areas = _ids_to_counts(intersection_id_array)
+
+ # Helper function that computes the area of the overlap between a predicted
+ # segment and the ground-truth void/ignored segment.
+ def prediction_void_overlap(pred_segment_id):
+ void_intersection_id = void_segment_id * self.offset + pred_segment_id
+ return intersection_areas.get(void_intersection_id, 0)
+
+ # Compute overall ignored overlap.
+ def prediction_ignored_overlap(pred_segment_id):
+ total_ignored_overlap = 0
+ for ignored_segment_id in ignored_segment_ids:
+ intersection_id = ignored_segment_id * self.offset + pred_segment_id
+ total_ignored_overlap += intersection_areas.get(intersection_id, 0)
+ return total_ignored_overlap
+
+ # Sets that are populated with which segments groundtruth/predicted segments
+ # have been matched with overlapping predicted/groundtruth segments
+ # respectively.
+ gt_matched = set()
+ pred_matched = set()
+
+ # Calculate IoU per pair of intersecting segments of the same category.
+ for intersection_id, intersection_area in six.iteritems(intersection_areas):
+ gt_segment_id = intersection_id // self.offset
+ pred_segment_id = intersection_id % self.offset
+
+ gt_category = gt_segment_id // self.max_instances_per_category
+ pred_category = pred_segment_id // self.max_instances_per_category
+ if gt_category != pred_category:
+ continue
+
+ # Union between the groundtruth and predicted segments being compared does
+ # not include the portion of the predicted segment that consists of
+ # groundtruth "void" pixels.
+ union = (
+ gt_segment_areas[gt_segment_id] +
+ pred_segment_areas[pred_segment_id] - intersection_area -
+ prediction_void_overlap(pred_segment_id))
+ iou = intersection_area / union
+ if iou > 0.5:
+ self.tp_per_class[gt_category] += 1
+ self.iou_per_class[gt_category] += iou
+ gt_matched.add(gt_segment_id)
+ pred_matched.add(pred_segment_id)
+
+ # Count false negatives for each category.
+ for gt_segment_id in six.iterkeys(gt_segment_areas):
+ if gt_segment_id in gt_matched:
+ continue
+ category = gt_segment_id // self.max_instances_per_category
+ # Failing to detect a void segment is not a false negative.
+ if category == self.ignored_label:
+ continue
+ self.fn_per_class[category] += 1
+
+ # Count false positives for each category.
+ for pred_segment_id in six.iterkeys(pred_segment_areas):
+ if pred_segment_id in pred_matched:
+ continue
+ # A false positive is not penalized if is mostly ignored in the
+ # groundtruth.
+ if (prediction_ignored_overlap(pred_segment_id) /
+ pred_segment_areas[pred_segment_id]) > 0.5:
+ continue
+ category = pred_segment_id // self.max_instances_per_category
+ self.fp_per_class[category] += 1
+
+ return self.result()
+
+ def _valid_categories(self):
+ """Categories with a "valid" value for the metric, have > 0 instances.
+
+ We will ignore the `ignore_label` class and other classes which have
+ `tp + fn + fp = 0`.
+
+ Returns:
+ Boolean array of shape `[num_categories]`.
+ """
+ valid_categories = np.not_equal(
+ self.tp_per_class + self.fn_per_class + self.fp_per_class, 0)
+ if self.ignored_label >= 0 and self.ignored_label < self.num_categories:
+ valid_categories[self.ignored_label] = False
+ return valid_categories
+
+ def detailed_results(self, is_thing=None):
+ """See base class."""
+ valid_categories = self._valid_categories()
+
+ # If known, break down which categories are valid _and_ things/stuff.
+ category_sets = collections.OrderedDict()
+ category_sets['All'] = valid_categories
+ if is_thing is not None:
+ category_sets['Things'] = np.logical_and(valid_categories, is_thing)
+ category_sets['Stuff'] = np.logical_and(valid_categories,
+ np.logical_not(is_thing))
+
+ # Compute individual per-class metrics that constitute factors of PQ.
+ sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class)
+ rq = base_metric.realdiv_maybe_zero(
+ self.tp_per_class,
+ self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class)
+ pq = np.multiply(sq, rq)
+
+ # Assemble detailed results dictionary.
+ results = {}
+ for category_set_name, in_category_set in six.iteritems(category_sets):
+ if np.any(in_category_set):
+ results[category_set_name] = {
+ 'pq': np.mean(pq[in_category_set]),
+ 'sq': np.mean(sq[in_category_set]),
+ 'rq': np.mean(rq[in_category_set]),
+ # The number of categories in this subset.
+ 'n': np.sum(in_category_set.astype(np.int32)),
+ }
+ else:
+ results[category_set_name] = {'pq': 0, 'sq': 0, 'rq': 0, 'n': 0}
+
+ return results
+
+ def result_per_category(self):
+ """See base class."""
+ sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class)
+ rq = base_metric.realdiv_maybe_zero(
+ self.tp_per_class,
+ self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class)
+ return np.multiply(sq, rq)
+
+ def print_detailed_results(self, is_thing=None, print_digits=3):
+ """See base class."""
+ results = self.detailed_results(is_thing=is_thing)
+
+ tab = prettytable.PrettyTable()
+
+ tab.add_column('', [], align='l')
+ for fieldname in ['PQ', 'SQ', 'RQ', 'N']:
+ tab.add_column(fieldname, [], align='r')
+
+ for category_set, subset_results in six.iteritems(results):
+ data_cols = [
+ round(subset_results[col_key], print_digits) * 100
+ for col_key in ['pq', 'sq', 'rq']
+ ]
+ data_cols += [subset_results['n']]
+ tab.add_row([category_set] + data_cols)
+
+ print(tab)
+
+ def result(self):
+ """See base class."""
+ pq_per_class = self.result_per_category()
+ valid_categories = self._valid_categories()
+ if not np.any(valid_categories):
+ return 0.
+ return np.mean(pq_per_class[valid_categories])
+
+ def merge(self, other_instance):
+ """See base class."""
+ self.iou_per_class += other_instance.iou_per_class
+ self.tp_per_class += other_instance.tp_per_class
+ self.fn_per_class += other_instance.fn_per_class
+ self.fp_per_class += other_instance.fp_per_class
+
+ def reset(self):
+ """See base class."""
+ self.iou_per_class = np.zeros(self.num_categories, dtype=np.float64)
+ self.tp_per_class = np.zeros(self.num_categories, dtype=np.float64)
+ self.fn_per_class = np.zeros(self.num_categories, dtype=np.float64)
+ self.fp_per_class = np.zeros(self.num_categories, dtype=np.float64)
diff --git a/models/research/deeplab/evaluation/panoptic_quality_test.py b/models/research/deeplab/evaluation/panoptic_quality_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..00c88c293b8edc39b7ceb28b2fc7fea4da1c3cb0
--- /dev/null
+++ b/models/research/deeplab/evaluation/panoptic_quality_test.py
@@ -0,0 +1,336 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Panoptic Quality metric."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl.testing import absltest
+import numpy as np
+import six
+
+from deeplab.evaluation import panoptic_quality
+from deeplab.evaluation import test_utils
+
+# See the definition of the color names at:
+# https://en.wikipedia.org/wiki/Web_colors.
+_CLASS_COLOR_MAP = {
+ (0, 0, 0): 0,
+ (0, 0, 255): 1, # Person (blue).
+ (255, 0, 0): 2, # Bear (red).
+ (0, 255, 0): 3, # Tree (lime).
+ (255, 0, 255): 4, # Bird (fuchsia).
+ (0, 255, 255): 5, # Sky (aqua).
+ (255, 255, 0): 6, # Cat (yellow).
+}
+
+
+class PanopticQualityTest(absltest.TestCase):
+
+ def test_perfect_match(self):
+ categories = np.zeros([6, 6], np.uint16)
+ instances = np.array([
+ [1, 1, 1, 1, 1, 1],
+ [1, 2, 2, 2, 2, 1],
+ [1, 2, 2, 2, 2, 1],
+ [1, 2, 2, 2, 2, 1],
+ [1, 2, 2, 1, 1, 1],
+ [1, 2, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+
+ pq = panoptic_quality.PanopticQuality(
+ num_categories=1,
+ ignored_label=2,
+ max_instances_per_category=16,
+ offset=16)
+ pq.compare_and_accumulate(categories, instances, categories, instances)
+ np.testing.assert_array_equal(pq.iou_per_class, [2.0])
+ np.testing.assert_array_equal(pq.tp_per_class, [2])
+ np.testing.assert_array_equal(pq.fn_per_class, [0])
+ np.testing.assert_array_equal(pq.fp_per_class, [0])
+ np.testing.assert_array_equal(pq.result_per_category(), [1.0])
+ self.assertEqual(pq.result(), 1.0)
+
+ def test_totally_wrong(self):
+ det_categories = np.array([
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 1, 0],
+ [0, 1, 1, 1, 1, 0],
+ [0, 1, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=np.uint16)
+ gt_categories = 1 - det_categories
+ instances = np.zeros([6, 6], np.uint16)
+
+ pq = panoptic_quality.PanopticQuality(
+ num_categories=2,
+ ignored_label=2,
+ max_instances_per_category=1,
+ offset=16)
+ pq.compare_and_accumulate(gt_categories, instances, det_categories,
+ instances)
+ np.testing.assert_array_equal(pq.iou_per_class, [0.0, 0.0])
+ np.testing.assert_array_equal(pq.tp_per_class, [0, 0])
+ np.testing.assert_array_equal(pq.fn_per_class, [1, 1])
+ np.testing.assert_array_equal(pq.fp_per_class, [1, 1])
+ np.testing.assert_array_equal(pq.result_per_category(), [0.0, 0.0])
+ self.assertEqual(pq.result(), 0.0)
+
+ def test_matches_by_iou(self):
+ good_det_labels = np.array(
+ [
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 2, 2, 2, 2, 1],
+ [1, 2, 2, 2, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+ gt_labels = np.array(
+ [
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 2, 2, 2, 1],
+ [1, 2, 2, 2, 2, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+
+ pq = panoptic_quality.PanopticQuality(
+ num_categories=1,
+ ignored_label=2,
+ max_instances_per_category=16,
+ offset=16)
+ pq.compare_and_accumulate(
+ np.zeros_like(gt_labels), gt_labels, np.zeros_like(good_det_labels),
+ good_det_labels)
+
+ # iou(1, 1) = 28/30
+ # iou(2, 2) = 6/8
+ np.testing.assert_array_almost_equal(pq.iou_per_class, [28 / 30 + 6 / 8])
+ np.testing.assert_array_equal(pq.tp_per_class, [2])
+ np.testing.assert_array_equal(pq.fn_per_class, [0])
+ np.testing.assert_array_equal(pq.fp_per_class, [0])
+ self.assertAlmostEqual(pq.result(), (28 / 30 + 6 / 8) / 2)
+
+ bad_det_labels = np.array(
+ [
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 2, 2, 1],
+ [1, 1, 1, 2, 2, 1],
+ [1, 1, 1, 2, 2, 1],
+ [1, 1, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+
+ pq.reset()
+ pq.compare_and_accumulate(
+ np.zeros_like(gt_labels), gt_labels, np.zeros_like(bad_det_labels),
+ bad_det_labels)
+
+ # iou(1, 1) = 27/32
+ np.testing.assert_array_almost_equal(pq.iou_per_class, [27 / 32])
+ np.testing.assert_array_equal(pq.tp_per_class, [1])
+ np.testing.assert_array_equal(pq.fn_per_class, [1])
+ np.testing.assert_array_equal(pq.fp_per_class, [1])
+ self.assertAlmostEqual(pq.result(), (27 / 32) * (1 / 2))
+
+ def test_wrong_instances(self):
+ categories = np.array([
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 2, 2, 1, 2, 2],
+ [1, 2, 2, 1, 2, 2],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+ predicted_instances = np.array([
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=np.uint16)
+ groundtruth_instances = np.zeros([6, 6], dtype=np.uint16)
+
+ pq = panoptic_quality.PanopticQuality(
+ num_categories=3,
+ ignored_label=0,
+ max_instances_per_category=10,
+ offset=100)
+ pq.compare_and_accumulate(categories, groundtruth_instances, categories,
+ predicted_instances)
+
+ np.testing.assert_array_equal(pq.iou_per_class, [0.0, 1.0, 0.0])
+ np.testing.assert_array_equal(pq.tp_per_class, [0, 1, 0])
+ np.testing.assert_array_equal(pq.fn_per_class, [0, 0, 1])
+ np.testing.assert_array_equal(pq.fp_per_class, [0, 0, 2])
+ np.testing.assert_array_equal(pq.result_per_category(), [0, 1, 0])
+ self.assertAlmostEqual(pq.result(), 0.5)
+
+ def test_instance_order_is_arbitrary(self):
+ categories = np.array([
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 2, 2, 1, 2, 2],
+ [1, 2, 2, 1, 2, 2],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+ predicted_instances = np.array([
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=np.uint16)
+ groundtruth_instances = np.array([
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 0, 0, 0],
+ [0, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=np.uint16)
+
+ pq = panoptic_quality.PanopticQuality(
+ num_categories=3,
+ ignored_label=0,
+ max_instances_per_category=10,
+ offset=100)
+ pq.compare_and_accumulate(categories, groundtruth_instances, categories,
+ predicted_instances)
+
+ np.testing.assert_array_equal(pq.iou_per_class, [0.0, 1.0, 2.0])
+ np.testing.assert_array_equal(pq.tp_per_class, [0, 1, 2])
+ np.testing.assert_array_equal(pq.fn_per_class, [0, 0, 0])
+ np.testing.assert_array_equal(pq.fp_per_class, [0, 0, 0])
+ np.testing.assert_array_equal(pq.result_per_category(), [0, 1, 1])
+ self.assertAlmostEqual(pq.result(), 1.0)
+
+ def test_matches_expected(self):
+ pred_classes = test_utils.read_segmentation_with_rgb_color_map(
+ 'team_pred_class.png', _CLASS_COLOR_MAP)
+ pred_instances = test_utils.read_test_image(
+ 'team_pred_instance.png', mode='L')
+
+ instance_class_map = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 2,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ gt_instances, gt_classes = test_utils.panoptic_segmentation_with_class_map(
+ 'team_gt_instance.png', instance_class_map)
+
+ pq = panoptic_quality.PanopticQuality(
+ num_categories=3,
+ ignored_label=0,
+ max_instances_per_category=256,
+ offset=256 * 256)
+ pq.compare_and_accumulate(gt_classes, gt_instances, pred_classes,
+ pred_instances)
+ np.testing.assert_array_almost_equal(
+ pq.iou_per_class, [2.06104, 5.26827, 0.54069], decimal=4)
+ np.testing.assert_array_equal(pq.tp_per_class, [1, 7, 1])
+ np.testing.assert_array_equal(pq.fn_per_class, [0, 1, 0])
+ np.testing.assert_array_equal(pq.fp_per_class, [0, 0, 0])
+ np.testing.assert_array_almost_equal(pq.result_per_category(),
+ [2.061038, 0.702436, 0.54069])
+ self.assertAlmostEqual(pq.result(), 0.62156287)
+
+ def test_merge_accumulates_all_across_instances(self):
+ categories = np.zeros([6, 6], np.uint16)
+ good_det_labels = np.array([
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 2, 2, 2, 2, 1],
+ [1, 2, 2, 2, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+ gt_labels = np.array([
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 2, 2, 2, 1],
+ [1, 2, 2, 2, 2, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+
+ good_pq = panoptic_quality.PanopticQuality(
+ num_categories=1,
+ ignored_label=2,
+ max_instances_per_category=16,
+ offset=16)
+ for _ in six.moves.range(2):
+ good_pq.compare_and_accumulate(categories, gt_labels, categories,
+ good_det_labels)
+
+ bad_det_labels = np.array([
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 2, 2, 1],
+ [1, 1, 1, 2, 2, 1],
+ [1, 1, 1, 2, 2, 1],
+ [1, 1, 1, 1, 1, 1],
+ ],
+ dtype=np.uint16)
+
+ bad_pq = panoptic_quality.PanopticQuality(
+ num_categories=1,
+ ignored_label=2,
+ max_instances_per_category=16,
+ offset=16)
+ for _ in six.moves.range(2):
+ bad_pq.compare_and_accumulate(categories, gt_labels, categories,
+ bad_det_labels)
+
+ good_pq.merge(bad_pq)
+
+ np.testing.assert_array_almost_equal(
+ good_pq.iou_per_class, [2 * (28 / 30 + 6 / 8) + 2 * (27 / 32)])
+ np.testing.assert_array_equal(good_pq.tp_per_class, [2 * 2 + 2])
+ np.testing.assert_array_equal(good_pq.fn_per_class, [2])
+ np.testing.assert_array_equal(good_pq.fp_per_class, [2])
+ self.assertAlmostEqual(good_pq.result(), 0.63177083)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/models/research/deeplab/evaluation/parsing_covering.py b/models/research/deeplab/evaluation/parsing_covering.py
new file mode 100644
index 0000000000000000000000000000000000000000..a40e55fc6be7a7563ceba75db17b188244dad832
--- /dev/null
+++ b/models/research/deeplab/evaluation/parsing_covering.py
@@ -0,0 +1,246 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of the Parsing Covering metric.
+
+Parsing Covering is a region-based metric for evaluating the task of
+image parsing, aka panoptic segmentation.
+
+Please see the paper for details:
+"DeeperLab: Single-Shot Image Parser", Tien-Ju Yang, Maxwell D. Collins,
+Yukun Zhu, Jyh-Jing Hwang, Ting Liu, Xiao Zhang, Vivienne Sze,
+George Papandreou, Liang-Chieh Chen. arXiv: 1902.05093, 2019.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import numpy as np
+import prettytable
+import six
+
+from deeplab.evaluation import base_metric
+
+
+class ParsingCovering(base_metric.SegmentationMetric):
+ r"""Metric class for Parsing Covering.
+
+ Computes segmentation covering metric introduced in (Arbelaez, et al., 2010)
+ with extension to handle multi-class semantic labels (a.k.a. parsing
+ covering). Specifically, segmentation covering (SC) is defined in Eq. (8) in
+ (Arbelaez et al., 2010) as:
+
+ SC(c) = \sum_{R\in S}(|R| * \max_{R'\in S'}O(R,R')) / \sum_{R\in S}|R|,
+
+ where S are the groundtruth instance regions and S' are the predicted
+ instance regions. The parsing covering is simply:
+
+ PC = \sum_{c=1}^{C}SC(c) / C,
+
+ where C is the number of classes.
+ """
+
+ def __init__(self,
+ num_categories,
+ ignored_label,
+ max_instances_per_category,
+ offset,
+ normalize_by_image_size=True):
+ """Initialization for ParsingCovering.
+
+ Args:
+ num_categories: The number of segmentation categories (or "classes" in the
+ dataset.
+ ignored_label: A category id that is ignored in evaluation, e.g. the void
+ label as defined in COCO panoptic segmentation dataset.
+ max_instances_per_category: The maximum number of instances for each
+ category. Used in ensuring unique instance labels.
+ offset: The maximum number of unique labels. This is used, by multiplying
+ the ground-truth labels, to generate unique ids for individual regions
+ of overlap between groundtruth and predicted segments.
+ normalize_by_image_size: Whether to normalize groundtruth instance region
+ areas by image size. If True, groundtruth instance areas and weighted
+ IoUs will be divided by the size of the corresponding image before
+ accumulated across the dataset.
+ """
+ super(ParsingCovering, self).__init__(num_categories, ignored_label,
+ max_instances_per_category, offset)
+ self.normalize_by_image_size = normalize_by_image_size
+
+ def compare_and_accumulate(
+ self, groundtruth_category_array, groundtruth_instance_array,
+ predicted_category_array, predicted_instance_array):
+ """See base class."""
+ # Allocate intermediate data structures.
+ max_ious = np.zeros([self.num_categories, self.max_instances_per_category],
+ dtype=np.float64)
+ gt_areas = np.zeros([self.num_categories, self.max_instances_per_category],
+ dtype=np.float64)
+ pred_areas = np.zeros(
+ [self.num_categories, self.max_instances_per_category],
+ dtype=np.float64)
+ # This is a dictionary in the format:
+ # {(category, gt_instance): [(pred_instance, intersection_area)]}.
+ intersections = collections.defaultdict(list)
+
+ # First, combine the category and instance labels so that every unique
+ # value for (category, instance) is assigned a unique integer label.
+ pred_segment_id = self._naively_combine_labels(predicted_category_array,
+ predicted_instance_array)
+ gt_segment_id = self._naively_combine_labels(groundtruth_category_array,
+ groundtruth_instance_array)
+
+ # Next, combine the groundtruth and predicted labels. Dividing up the pixels
+ # based on which groundtruth segment and which predicted segment they belong
+ # to, this will assign a different 32-bit integer label to each choice
+ # of (groundtruth segment, predicted segment), encoded as
+ # gt_segment_id * offset + pred_segment_id.
+ intersection_id_array = (
+ gt_segment_id.astype(np.uint32) * self.offset +
+ pred_segment_id.astype(np.uint32))
+
+ # For every combination of (groundtruth segment, predicted segment) with a
+ # non-empty intersection, this counts the number of pixels in that
+ # intersection.
+ intersection_ids, intersection_areas = np.unique(
+ intersection_id_array, return_counts=True)
+
+ # Find areas of all groundtruth and predicted instances, as well as of their
+ # intersections.
+ for intersection_id, intersection_area in six.moves.zip(
+ intersection_ids, intersection_areas):
+ gt_segment_id = intersection_id // self.offset
+ gt_category = gt_segment_id // self.max_instances_per_category
+ if gt_category == self.ignored_label:
+ continue
+ gt_instance = gt_segment_id % self.max_instances_per_category
+ gt_areas[gt_category, gt_instance] += intersection_area
+
+ pred_segment_id = intersection_id % self.offset
+ pred_category = pred_segment_id // self.max_instances_per_category
+ pred_instance = pred_segment_id % self.max_instances_per_category
+ pred_areas[pred_category, pred_instance] += intersection_area
+ if pred_category != gt_category:
+ continue
+
+ intersections[gt_category, gt_instance].append((pred_instance,
+ intersection_area))
+
+ # Find maximum IoU for every groundtruth instance.
+ for gt_label, instance_intersections in six.iteritems(intersections):
+ category, gt_instance = gt_label
+ gt_area = gt_areas[category, gt_instance]
+ ious = []
+ for pred_instance, intersection_area in instance_intersections:
+ pred_area = pred_areas[category, pred_instance]
+ union = gt_area + pred_area - intersection_area
+ ious.append(intersection_area / union)
+ max_ious[category, gt_instance] = max(ious)
+
+ # Normalize groundtruth instance areas by image size if necessary.
+ if self.normalize_by_image_size:
+ gt_areas /= groundtruth_category_array.size
+
+ # Compute per-class weighted IoUs and areas summed over all groundtruth
+ # instances.
+ self.weighted_iou_per_class += np.sum(max_ious * gt_areas, axis=-1)
+ self.gt_area_per_class += np.sum(gt_areas, axis=-1)
+
+ return self.result()
+
+ def result_per_category(self):
+ """See base class."""
+ return base_metric.realdiv_maybe_zero(self.weighted_iou_per_class,
+ self.gt_area_per_class)
+
+ def _valid_categories(self):
+ """Categories with a "valid" value for the metric, have > 0 instances.
+
+ We will ignore the `ignore_label` class and other classes which have
+ groundtruth area of 0.
+
+ Returns:
+ Boolean array of shape `[num_categories]`.
+ """
+ valid_categories = np.not_equal(self.gt_area_per_class, 0)
+ if self.ignored_label >= 0 and self.ignored_label < self.num_categories:
+ valid_categories[self.ignored_label] = False
+ return valid_categories
+
+ def detailed_results(self, is_thing=None):
+ """See base class."""
+ valid_categories = self._valid_categories()
+
+ # If known, break down which categories are valid _and_ things/stuff.
+ category_sets = collections.OrderedDict()
+ category_sets['All'] = valid_categories
+ if is_thing is not None:
+ category_sets['Things'] = np.logical_and(valid_categories, is_thing)
+ category_sets['Stuff'] = np.logical_and(valid_categories,
+ np.logical_not(is_thing))
+
+ covering_per_class = self.result_per_category()
+ results = {}
+ for category_set_name, in_category_set in six.iteritems(category_sets):
+ if np.any(in_category_set):
+ results[category_set_name] = {
+ 'pc': np.mean(covering_per_class[in_category_set]),
+ # The number of valid categories in this subset.
+ 'n': np.sum(in_category_set.astype(np.int32)),
+ }
+ else:
+ results[category_set_name] = {'pc': 0, 'n': 0}
+
+ return results
+
+ def print_detailed_results(self, is_thing=None, print_digits=3):
+ """See base class."""
+ results = self.detailed_results(is_thing=is_thing)
+
+ tab = prettytable.PrettyTable()
+
+ tab.add_column('', [], align='l')
+ for fieldname in ['PC', 'N']:
+ tab.add_column(fieldname, [], align='r')
+
+ for category_set, subset_results in six.iteritems(results):
+ data_cols = [
+ round(subset_results['pc'], print_digits) * 100, subset_results['n']
+ ]
+ tab.add_row([category_set] + data_cols)
+
+ print(tab)
+
+ def result(self):
+ """See base class."""
+ covering_per_class = self.result_per_category()
+ valid_categories = self._valid_categories()
+ if not np.any(valid_categories):
+ return 0.
+ return np.mean(covering_per_class[valid_categories])
+
+ def merge(self, other_instance):
+ """See base class."""
+ self.weighted_iou_per_class += other_instance.weighted_iou_per_class
+ self.gt_area_per_class += other_instance.gt_area_per_class
+
+ def reset(self):
+ """See base class."""
+ self.weighted_iou_per_class = np.zeros(
+ self.num_categories, dtype=np.float64)
+ self.gt_area_per_class = np.zeros(self.num_categories, dtype=np.float64)
diff --git a/models/research/deeplab/evaluation/parsing_covering_test.py b/models/research/deeplab/evaluation/parsing_covering_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..124d1b372559ef672ea1c9f821eac3fec52c97ea
--- /dev/null
+++ b/models/research/deeplab/evaluation/parsing_covering_test.py
@@ -0,0 +1,173 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Parsing Covering metric."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+from absl.testing import absltest
+import numpy as np
+
+from deeplab.evaluation import parsing_covering
+from deeplab.evaluation import test_utils
+
+# See the definition of the color names at:
+# https://en.wikipedia.org/wiki/Web_colors.
+_CLASS_COLOR_MAP = {
+ (0, 0, 0): 0,
+ (0, 0, 255): 1, # Person (blue).
+ (255, 0, 0): 2, # Bear (red).
+ (0, 255, 0): 3, # Tree (lime).
+ (255, 0, 255): 4, # Bird (fuchsia).
+ (0, 255, 255): 5, # Sky (aqua).
+ (255, 255, 0): 6, # Cat (yellow).
+}
+
+
+class CoveringConveringTest(absltest.TestCase):
+
+ def test_perfect_match(self):
+ categories = np.zeros([6, 6], np.uint16)
+ instances = np.array([
+ [2, 2, 2, 2, 2, 2],
+ [2, 4, 4, 4, 4, 2],
+ [2, 4, 4, 4, 4, 2],
+ [2, 4, 4, 4, 4, 2],
+ [2, 4, 4, 2, 2, 2],
+ [2, 4, 2, 2, 2, 2],
+ ],
+ dtype=np.uint16)
+
+ pc = parsing_covering.ParsingCovering(
+ num_categories=3,
+ ignored_label=2,
+ max_instances_per_category=2,
+ offset=16,
+ normalize_by_image_size=False)
+ pc.compare_and_accumulate(categories, instances, categories, instances)
+ np.testing.assert_array_equal(pc.weighted_iou_per_class, [0.0, 21.0, 0.0])
+ np.testing.assert_array_equal(pc.gt_area_per_class, [0.0, 21.0, 0.0])
+ np.testing.assert_array_equal(pc.result_per_category(), [0.0, 1.0, 0.0])
+ self.assertEqual(pc.result(), 1.0)
+
+ def test_totally_wrong(self):
+ categories = np.zeros([6, 6], np.uint16)
+ gt_instances = np.array([
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 1, 0],
+ [0, 1, 1, 1, 1, 0],
+ [0, 1, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=np.uint16)
+ pred_instances = 1 - gt_instances
+
+ pc = parsing_covering.ParsingCovering(
+ num_categories=2,
+ ignored_label=0,
+ max_instances_per_category=1,
+ offset=16,
+ normalize_by_image_size=False)
+ pc.compare_and_accumulate(categories, gt_instances, categories,
+ pred_instances)
+ np.testing.assert_array_equal(pc.weighted_iou_per_class, [0.0, 0.0])
+ np.testing.assert_array_equal(pc.gt_area_per_class, [0.0, 10.0])
+ np.testing.assert_array_equal(pc.result_per_category(), [0.0, 0.0])
+ self.assertEqual(pc.result(), 0.0)
+
+ def test_matches_expected(self):
+ pred_classes = test_utils.read_segmentation_with_rgb_color_map(
+ 'team_pred_class.png', _CLASS_COLOR_MAP)
+ pred_instances = test_utils.read_test_image(
+ 'team_pred_instance.png', mode='L')
+
+ instance_class_map = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 2,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ gt_instances, gt_classes = test_utils.panoptic_segmentation_with_class_map(
+ 'team_gt_instance.png', instance_class_map)
+
+ pc = parsing_covering.ParsingCovering(
+ num_categories=3,
+ ignored_label=0,
+ max_instances_per_category=256,
+ offset=256 * 256,
+ normalize_by_image_size=False)
+ pc.compare_and_accumulate(gt_classes, gt_instances, pred_classes,
+ pred_instances)
+ np.testing.assert_array_almost_equal(
+ pc.weighted_iou_per_class, [0.0, 39864.14634, 3136], decimal=4)
+ np.testing.assert_array_equal(pc.gt_area_per_class, [0.0, 56870, 5800])
+ np.testing.assert_array_almost_equal(
+ pc.result_per_category(), [0.0, 0.70097, 0.54069], decimal=4)
+ self.assertAlmostEqual(pc.result(), 0.6208296732)
+
+ def test_matches_expected_normalize_by_size(self):
+ pred_classes = test_utils.read_segmentation_with_rgb_color_map(
+ 'team_pred_class.png', _CLASS_COLOR_MAP)
+ pred_instances = test_utils.read_test_image(
+ 'team_pred_instance.png', mode='L')
+
+ instance_class_map = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 2,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ gt_instances, gt_classes = test_utils.panoptic_segmentation_with_class_map(
+ 'team_gt_instance.png', instance_class_map)
+
+ pc = parsing_covering.ParsingCovering(
+ num_categories=3,
+ ignored_label=0,
+ max_instances_per_category=256,
+ offset=256 * 256,
+ normalize_by_image_size=True)
+ pc.compare_and_accumulate(gt_classes, gt_instances, pred_classes,
+ pred_instances)
+ np.testing.assert_array_almost_equal(
+ pc.weighted_iou_per_class, [0.0, 0.5002088756, 0.03935002196],
+ decimal=4)
+ np.testing.assert_array_almost_equal(
+ pc.gt_area_per_class, [0.0, 0.7135955832, 0.07277746408], decimal=4)
+ # Note that the per-category and overall PCs are identical to those without
+ # normalization in the previous test, because we only have a single image.
+ np.testing.assert_array_almost_equal(
+ pc.result_per_category(), [0.0, 0.70097, 0.54069], decimal=4)
+ self.assertAlmostEqual(pc.result(), 0.6208296732)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/models/research/deeplab/evaluation/streaming_metrics.py b/models/research/deeplab/evaluation/streaming_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..8313792676a62e58af70300a0cfa43528a904435
--- /dev/null
+++ b/models/research/deeplab/evaluation/streaming_metrics.py
@@ -0,0 +1,240 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Code to compute segmentation in a "streaming" pattern in Tensorflow.
+
+These aggregate the metric over examples of the evaluation set. Each example is
+assumed to be fed in in a stream, and the metric implementation accumulates
+across them.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from deeplab.evaluation import panoptic_quality
+from deeplab.evaluation import parsing_covering
+
+_EPSILON = 1e-10
+
+
+def _realdiv_maybe_zero(x, y):
+ """Support tf.realdiv(x, y) where y may contain zeros."""
+ return tf.where(tf.less(y, _EPSILON), tf.zeros_like(x), tf.realdiv(x, y))
+
+
+def _running_total(value, shape, name=None):
+ """Maintains a running total of tensor `value` between calls."""
+ with tf.variable_scope(name, 'running_total', [value]):
+ total_var = tf.get_variable(
+ 'total',
+ shape,
+ value.dtype,
+ initializer=tf.zeros_initializer(),
+ trainable=False,
+ collections=[
+ tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES
+ ])
+ updated_total = tf.assign_add(total_var, value, use_locking=True)
+
+ return total_var, updated_total
+
+
+def _panoptic_quality_helper(
+ groundtruth_category_array, groundtruth_instance_array,
+ predicted_category_array, predicted_instance_array, num_classes,
+ max_instances_per_category, ignored_label, offset):
+ """Helper function to compute panoptic quality."""
+ pq = panoptic_quality.PanopticQuality(num_classes, ignored_label,
+ max_instances_per_category, offset)
+ pq.compare_and_accumulate(groundtruth_category_array,
+ groundtruth_instance_array,
+ predicted_category_array, predicted_instance_array)
+ return pq.iou_per_class, pq.tp_per_class, pq.fn_per_class, pq.fp_per_class
+
+
+def streaming_panoptic_quality(groundtruth_categories,
+ groundtruth_instances,
+ predicted_categories,
+ predicted_instances,
+ num_classes,
+ max_instances_per_category,
+ ignored_label,
+ offset,
+ name=None):
+ """Aggregates the panoptic metric across calls with different input tensors.
+
+ See tf.metrics.* functions for comparable functionality and usage.
+
+ Args:
+ groundtruth_categories: A 2D uint16 tensor of groundtruth category labels.
+ groundtruth_instances: A 2D uint16 tensor of groundtruth instance labels.
+ predicted_categories: A 2D uint16 tensor of predicted category labels.
+ predicted_instances: A 2D uint16 tensor of predicted instance labels.
+ num_classes: Number of classes in the dataset as an integer.
+ max_instances_per_category: The maximum number of instances for each class
+ as an integer or integer tensor.
+ ignored_label: The class id to be ignored in evaluation as an integer or
+ integer tensor.
+ offset: The maximum number of unique labels as an integer or integer tensor.
+ name: An optional variable_scope name.
+
+ Returns:
+ qualities: A tensor of shape `[6, num_classes]`, where (1) panoptic quality,
+ (2) segmentation quality, (3) recognition quality, (4) total_tp,
+ (5) total_fn and (6) total_fp are saved in the respective rows.
+ update_ops: List of operations that update the running overall panoptic
+ quality.
+
+ Raises:
+ RuntimeError: If eager execution is enabled.
+ """
+ if tf.executing_eagerly():
+ raise RuntimeError('Cannot aggregate when eager execution is enabled.')
+
+ input_args = [
+ tf.convert_to_tensor(groundtruth_categories, tf.uint16),
+ tf.convert_to_tensor(groundtruth_instances, tf.uint16),
+ tf.convert_to_tensor(predicted_categories, tf.uint16),
+ tf.convert_to_tensor(predicted_instances, tf.uint16),
+ tf.convert_to_tensor(num_classes, tf.int32),
+ tf.convert_to_tensor(max_instances_per_category, tf.int32),
+ tf.convert_to_tensor(ignored_label, tf.int32),
+ tf.convert_to_tensor(offset, tf.int32),
+ ]
+ return_types = [
+ tf.float64,
+ tf.float64,
+ tf.float64,
+ tf.float64,
+ ]
+ with tf.variable_scope(name, 'streaming_panoptic_quality', input_args):
+ panoptic_results = tf.py_func(
+ _panoptic_quality_helper, input_args, return_types, stateful=False)
+ iou, tp, fn, fp = tuple(panoptic_results)
+
+ total_iou, updated_iou = _running_total(
+ iou, [num_classes], name='iou_total')
+ total_tp, updated_tp = _running_total(tp, [num_classes], name='tp_total')
+ total_fn, updated_fn = _running_total(fn, [num_classes], name='fn_total')
+ total_fp, updated_fp = _running_total(fp, [num_classes], name='fp_total')
+ update_ops = [updated_iou, updated_tp, updated_fn, updated_fp]
+
+ sq = _realdiv_maybe_zero(total_iou, total_tp)
+ rq = _realdiv_maybe_zero(total_tp,
+ total_tp + 0.5 * total_fn + 0.5 * total_fp)
+ pq = tf.multiply(sq, rq)
+ qualities = tf.stack([pq, sq, rq, total_tp, total_fn, total_fp], axis=0)
+ return qualities, update_ops
+
+
+def _parsing_covering_helper(
+ groundtruth_category_array, groundtruth_instance_array,
+ predicted_category_array, predicted_instance_array, num_classes,
+ max_instances_per_category, ignored_label, offset, normalize_by_image_size):
+ """Helper function to compute parsing covering."""
+ pc = parsing_covering.ParsingCovering(num_classes, ignored_label,
+ max_instances_per_category, offset,
+ normalize_by_image_size)
+ pc.compare_and_accumulate(groundtruth_category_array,
+ groundtruth_instance_array,
+ predicted_category_array, predicted_instance_array)
+ return pc.weighted_iou_per_class, pc.gt_area_per_class
+
+
+def streaming_parsing_covering(groundtruth_categories,
+ groundtruth_instances,
+ predicted_categories,
+ predicted_instances,
+ num_classes,
+ max_instances_per_category,
+ ignored_label,
+ offset,
+ normalize_by_image_size=True,
+ name=None):
+ """Aggregates the covering across calls with different input tensors.
+
+ See tf.metrics.* functions for comparable functionality and usage.
+
+ Args:
+ groundtruth_categories: A 2D uint16 tensor of groundtruth category labels.
+ groundtruth_instances: A 2D uint16 tensor of groundtruth instance labels.
+ predicted_categories: A 2D uint16 tensor of predicted category labels.
+ predicted_instances: A 2D uint16 tensor of predicted instance labels.
+ num_classes: Number of classes in the dataset as an integer.
+ max_instances_per_category: The maximum number of instances for each class
+ as an integer or integer tensor.
+ ignored_label: The class id to be ignored in evaluation as an integer or
+ integer tensor.
+ offset: The maximum number of unique labels as an integer or integer tensor.
+ normalize_by_image_size: Whether to normalize groundtruth region areas by
+ image size. If True, groundtruth instance areas and weighted IoUs will be
+ divided by the size of the corresponding image before accumulated across
+ the dataset.
+ name: An optional variable_scope name.
+
+ Returns:
+ coverings: A tensor of shape `[3, num_classes]`, where (1) per class
+ coverings, (2) per class sum of weighted IoUs, and (3) per class sum of
+ groundtruth region areas are saved in the perspective rows.
+ update_ops: List of operations that update the running overall parsing
+ covering.
+
+ Raises:
+ RuntimeError: If eager execution is enabled.
+ """
+ if tf.executing_eagerly():
+ raise RuntimeError('Cannot aggregate when eager execution is enabled.')
+
+ input_args = [
+ tf.convert_to_tensor(groundtruth_categories, tf.uint16),
+ tf.convert_to_tensor(groundtruth_instances, tf.uint16),
+ tf.convert_to_tensor(predicted_categories, tf.uint16),
+ tf.convert_to_tensor(predicted_instances, tf.uint16),
+ tf.convert_to_tensor(num_classes, tf.int32),
+ tf.convert_to_tensor(max_instances_per_category, tf.int32),
+ tf.convert_to_tensor(ignored_label, tf.int32),
+ tf.convert_to_tensor(offset, tf.int32),
+ tf.convert_to_tensor(normalize_by_image_size, tf.bool),
+ ]
+ return_types = [
+ tf.float64,
+ tf.float64,
+ ]
+ with tf.variable_scope(name, 'streaming_parsing_covering', input_args):
+ covering_results = tf.py_func(
+ _parsing_covering_helper, input_args, return_types, stateful=False)
+ weighted_iou_per_class, gt_area_per_class = tuple(covering_results)
+
+ total_weighted_iou_per_class, updated_weighted_iou_per_class = (
+ _running_total(
+ weighted_iou_per_class, [num_classes],
+ name='weighted_iou_per_class_total'))
+ total_gt_area_per_class, updated_gt_area_per_class = _running_total(
+ gt_area_per_class, [num_classes], name='gt_area_per_class_total')
+
+ covering_per_class = _realdiv_maybe_zero(total_weighted_iou_per_class,
+ total_gt_area_per_class)
+ coverings = tf.stack([
+ covering_per_class,
+ total_weighted_iou_per_class,
+ total_gt_area_per_class,
+ ],
+ axis=0)
+ update_ops = [updated_weighted_iou_per_class, updated_gt_area_per_class]
+
+ return coverings, update_ops
diff --git a/models/research/deeplab/evaluation/streaming_metrics_test.py b/models/research/deeplab/evaluation/streaming_metrics_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..656007e6238e5c106dd8eee08fe65e4ba7457801
--- /dev/null
+++ b/models/research/deeplab/evaluation/streaming_metrics_test.py
@@ -0,0 +1,549 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for segmentation "streaming" metrics."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+
+import numpy as np
+import six
+import tensorflow as tf
+
+from deeplab.evaluation import streaming_metrics
+from deeplab.evaluation import test_utils
+
+# See the definition of the color names at:
+# https://en.wikipedia.org/wiki/Web_colors.
+_CLASS_COLOR_MAP = {
+ (0, 0, 0): 0,
+ (0, 0, 255): 1, # Person (blue).
+ (255, 0, 0): 2, # Bear (red).
+ (0, 255, 0): 3, # Tree (lime).
+ (255, 0, 255): 4, # Bird (fuchsia).
+ (0, 255, 255): 5, # Sky (aqua).
+ (255, 255, 0): 6, # Cat (yellow).
+}
+
+
+class StreamingPanopticQualityTest(tf.test.TestCase):
+
+ def test_streaming_metric_on_single_image(self):
+ offset = 256 * 256
+
+ instance_class_map = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 2,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ gt_instances, gt_classes = test_utils.panoptic_segmentation_with_class_map(
+ 'team_gt_instance.png', instance_class_map)
+
+ pred_classes = test_utils.read_segmentation_with_rgb_color_map(
+ 'team_pred_class.png', _CLASS_COLOR_MAP)
+ pred_instances = test_utils.read_test_image(
+ 'team_pred_instance.png', mode='L')
+
+ gt_class_tensor = tf.placeholder(tf.uint16)
+ gt_instance_tensor = tf.placeholder(tf.uint16)
+ pred_class_tensor = tf.placeholder(tf.uint16)
+ pred_instance_tensor = tf.placeholder(tf.uint16)
+ qualities, update_pq = streaming_metrics.streaming_panoptic_quality(
+ gt_class_tensor,
+ gt_instance_tensor,
+ pred_class_tensor,
+ pred_instance_tensor,
+ num_classes=3,
+ max_instances_per_category=256,
+ ignored_label=0,
+ offset=offset)
+ pq, sq, rq, total_tp, total_fn, total_fp = tf.unstack(qualities, 6, axis=0)
+ feed_dict = {
+ gt_class_tensor: gt_classes,
+ gt_instance_tensor: gt_instances,
+ pred_class_tensor: pred_classes,
+ pred_instance_tensor: pred_instances
+ }
+
+ with self.session() as sess:
+ sess.run(tf.local_variables_initializer())
+ sess.run(update_pq, feed_dict=feed_dict)
+ (result_pq, result_sq, result_rq, result_total_tp, result_total_fn,
+ result_total_fp) = sess.run([pq, sq, rq, total_tp, total_fn, total_fp],
+ feed_dict=feed_dict)
+ np.testing.assert_array_almost_equal(
+ result_pq, [2.06104, 0.7024, 0.54069], decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_sq, [2.06104, 0.7526, 0.54069], decimal=4)
+ np.testing.assert_array_almost_equal(result_rq, [1., 0.9333, 1.], decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_total_tp, [1., 7., 1.], decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_total_fn, [0., 1., 0.], decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_total_fp, [0., 0., 0.], decimal=4)
+
+ def test_streaming_metric_on_multiple_images(self):
+ num_classes = 7
+ offset = 256 * 256
+
+ bird_gt_instance_class_map = {
+ 92: 5,
+ 176: 3,
+ 255: 4,
+ }
+ cat_gt_instance_class_map = {
+ 0: 0,
+ 255: 6,
+ }
+ team_gt_instance_class_map = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 2,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ test_image = collections.namedtuple(
+ 'TestImage',
+ ['gt_class_map', 'gt_path', 'pred_inst_path', 'pred_class_path'])
+ test_images = [
+ test_image(bird_gt_instance_class_map, 'bird_gt.png',
+ 'bird_pred_instance.png', 'bird_pred_class.png'),
+ test_image(cat_gt_instance_class_map, 'cat_gt.png',
+ 'cat_pred_instance.png', 'cat_pred_class.png'),
+ test_image(team_gt_instance_class_map, 'team_gt_instance.png',
+ 'team_pred_instance.png', 'team_pred_class.png'),
+ ]
+
+ gt_classes = []
+ gt_instances = []
+ pred_classes = []
+ pred_instances = []
+ for test_image in test_images:
+ (image_gt_instances,
+ image_gt_classes) = test_utils.panoptic_segmentation_with_class_map(
+ test_image.gt_path, test_image.gt_class_map)
+ gt_classes.append(image_gt_classes)
+ gt_instances.append(image_gt_instances)
+
+ pred_classes.append(
+ test_utils.read_segmentation_with_rgb_color_map(
+ test_image.pred_class_path, _CLASS_COLOR_MAP))
+ pred_instances.append(
+ test_utils.read_test_image(test_image.pred_inst_path, mode='L'))
+
+ gt_class_tensor = tf.placeholder(tf.uint16)
+ gt_instance_tensor = tf.placeholder(tf.uint16)
+ pred_class_tensor = tf.placeholder(tf.uint16)
+ pred_instance_tensor = tf.placeholder(tf.uint16)
+ qualities, update_pq = streaming_metrics.streaming_panoptic_quality(
+ gt_class_tensor,
+ gt_instance_tensor,
+ pred_class_tensor,
+ pred_instance_tensor,
+ num_classes=num_classes,
+ max_instances_per_category=256,
+ ignored_label=0,
+ offset=offset)
+ pq, sq, rq, total_tp, total_fn, total_fp = tf.unstack(qualities, 6, axis=0)
+ with self.session() as sess:
+ sess.run(tf.local_variables_initializer())
+ for pred_class, pred_instance, gt_class, gt_instance in six.moves.zip(
+ pred_classes, pred_instances, gt_classes, gt_instances):
+ sess.run(
+ update_pq,
+ feed_dict={
+ gt_class_tensor: gt_class,
+ gt_instance_tensor: gt_instance,
+ pred_class_tensor: pred_class,
+ pred_instance_tensor: pred_instance
+ })
+ (result_pq, result_sq, result_rq, result_total_tp, result_total_fn,
+ result_total_fp) = sess.run(
+ [pq, sq, rq, total_tp, total_fn, total_fp],
+ feed_dict={
+ gt_class_tensor: 0,
+ gt_instance_tensor: 0,
+ pred_class_tensor: 0,
+ pred_instance_tensor: 0
+ })
+ np.testing.assert_array_almost_equal(
+ result_pq,
+ [4.3107, 0.7024, 0.54069, 0.745353, 0.85768, 0.99107, 0.77410],
+ decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_sq, [5.3883, 0.7526, 0.5407, 0.7454, 0.8577, 0.9911, 0.7741],
+ decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_rq, [0.8, 0.9333, 1., 1., 1., 1., 1.], decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_total_tp, [2., 7., 1., 1., 1., 1., 1.], decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_total_fn, [0., 1., 0., 0., 0., 0., 0.], decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_total_fp, [1., 0., 0., 0., 0., 0., 0.], decimal=4)
+
+
+class StreamingParsingCoveringTest(tf.test.TestCase):
+
+ def test_streaming_metric_on_single_image(self):
+ offset = 256 * 256
+
+ instance_class_map = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 2,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ gt_instances, gt_classes = test_utils.panoptic_segmentation_with_class_map(
+ 'team_gt_instance.png', instance_class_map)
+
+ pred_classes = test_utils.read_segmentation_with_rgb_color_map(
+ 'team_pred_class.png', _CLASS_COLOR_MAP)
+ pred_instances = test_utils.read_test_image(
+ 'team_pred_instance.png', mode='L')
+
+ gt_class_tensor = tf.placeholder(tf.uint16)
+ gt_instance_tensor = tf.placeholder(tf.uint16)
+ pred_class_tensor = tf.placeholder(tf.uint16)
+ pred_instance_tensor = tf.placeholder(tf.uint16)
+ coverings, update_ops = streaming_metrics.streaming_parsing_covering(
+ gt_class_tensor,
+ gt_instance_tensor,
+ pred_class_tensor,
+ pred_instance_tensor,
+ num_classes=3,
+ max_instances_per_category=256,
+ ignored_label=0,
+ offset=offset,
+ normalize_by_image_size=False)
+ (per_class_coverings, per_class_weighted_ious, per_class_gt_areas) = (
+ tf.unstack(coverings, num=3, axis=0))
+ feed_dict = {
+ gt_class_tensor: gt_classes,
+ gt_instance_tensor: gt_instances,
+ pred_class_tensor: pred_classes,
+ pred_instance_tensor: pred_instances
+ }
+
+ with self.session() as sess:
+ sess.run(tf.local_variables_initializer())
+ sess.run(update_ops, feed_dict=feed_dict)
+ (result_per_class_coverings, result_per_class_weighted_ious,
+ result_per_class_gt_areas) = (
+ sess.run([
+ per_class_coverings,
+ per_class_weighted_ious,
+ per_class_gt_areas,
+ ],
+ feed_dict=feed_dict))
+
+ np.testing.assert_array_almost_equal(
+ result_per_class_coverings, [0.0, 0.7009696912, 0.5406896552],
+ decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_per_class_weighted_ious, [0.0, 39864.14634, 3136], decimal=4)
+ np.testing.assert_array_equal(result_per_class_gt_areas, [0, 56870, 5800])
+
+ def test_streaming_metric_on_multiple_images(self):
+ """Tests streaming parsing covering metric."""
+ num_classes = 7
+ offset = 256 * 256
+
+ bird_gt_instance_class_map = {
+ 92: 5,
+ 176: 3,
+ 255: 4,
+ }
+ cat_gt_instance_class_map = {
+ 0: 0,
+ 255: 6,
+ }
+ team_gt_instance_class_map = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 2,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ test_image = collections.namedtuple(
+ 'TestImage',
+ ['gt_class_map', 'gt_path', 'pred_inst_path', 'pred_class_path'])
+ test_images = [
+ test_image(bird_gt_instance_class_map, 'bird_gt.png',
+ 'bird_pred_instance.png', 'bird_pred_class.png'),
+ test_image(cat_gt_instance_class_map, 'cat_gt.png',
+ 'cat_pred_instance.png', 'cat_pred_class.png'),
+ test_image(team_gt_instance_class_map, 'team_gt_instance.png',
+ 'team_pred_instance.png', 'team_pred_class.png'),
+ ]
+
+ gt_classes = []
+ gt_instances = []
+ pred_classes = []
+ pred_instances = []
+ for test_image in test_images:
+ (image_gt_instances,
+ image_gt_classes) = test_utils.panoptic_segmentation_with_class_map(
+ test_image.gt_path, test_image.gt_class_map)
+ gt_classes.append(image_gt_classes)
+ gt_instances.append(image_gt_instances)
+
+ pred_instances.append(
+ test_utils.read_test_image(test_image.pred_inst_path, mode='L'))
+ pred_classes.append(
+ test_utils.read_segmentation_with_rgb_color_map(
+ test_image.pred_class_path, _CLASS_COLOR_MAP))
+
+ gt_class_tensor = tf.placeholder(tf.uint16)
+ gt_instance_tensor = tf.placeholder(tf.uint16)
+ pred_class_tensor = tf.placeholder(tf.uint16)
+ pred_instance_tensor = tf.placeholder(tf.uint16)
+ coverings, update_ops = streaming_metrics.streaming_parsing_covering(
+ gt_class_tensor,
+ gt_instance_tensor,
+ pred_class_tensor,
+ pred_instance_tensor,
+ num_classes=num_classes,
+ max_instances_per_category=256,
+ ignored_label=0,
+ offset=offset,
+ normalize_by_image_size=False)
+ (per_class_coverings, per_class_weighted_ious, per_class_gt_areas) = (
+ tf.unstack(coverings, num=3, axis=0))
+
+ with self.session() as sess:
+ sess.run(tf.local_variables_initializer())
+ for pred_class, pred_instance, gt_class, gt_instance in six.moves.zip(
+ pred_classes, pred_instances, gt_classes, gt_instances):
+ sess.run(
+ update_ops,
+ feed_dict={
+ gt_class_tensor: gt_class,
+ gt_instance_tensor: gt_instance,
+ pred_class_tensor: pred_class,
+ pred_instance_tensor: pred_instance
+ })
+ (result_per_class_coverings, result_per_class_weighted_ious,
+ result_per_class_gt_areas) = (
+ sess.run(
+ [
+ per_class_coverings,
+ per_class_weighted_ious,
+ per_class_gt_areas,
+ ],
+ feed_dict={
+ gt_class_tensor: 0,
+ gt_instance_tensor: 0,
+ pred_class_tensor: 0,
+ pred_instance_tensor: 0
+ }))
+
+ np.testing.assert_array_almost_equal(
+ result_per_class_coverings, [
+ 0.0,
+ 0.7009696912,
+ 0.5406896552,
+ 0.7453531599,
+ 0.8576779026,
+ 0.9910687881,
+ 0.7741046032,
+ ],
+ decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_per_class_weighted_ious, [
+ 0.0,
+ 39864.14634,
+ 3136,
+ 1177.657993,
+ 2498.41573,
+ 33366.31289,
+ 26671,
+ ],
+ decimal=4)
+ np.testing.assert_array_equal(result_per_class_gt_areas, [
+ 0.0,
+ 56870,
+ 5800,
+ 1580,
+ 2913,
+ 33667,
+ 34454,
+ ])
+
+ def test_streaming_metric_on_multiple_images_normalize_by_size(self):
+ """Tests streaming parsing covering metric with image size normalization."""
+ num_classes = 7
+ offset = 256 * 256
+
+ bird_gt_instance_class_map = {
+ 92: 5,
+ 176: 3,
+ 255: 4,
+ }
+ cat_gt_instance_class_map = {
+ 0: 0,
+ 255: 6,
+ }
+ team_gt_instance_class_map = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 2,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ test_image = collections.namedtuple(
+ 'TestImage',
+ ['gt_class_map', 'gt_path', 'pred_inst_path', 'pred_class_path'])
+ test_images = [
+ test_image(bird_gt_instance_class_map, 'bird_gt.png',
+ 'bird_pred_instance.png', 'bird_pred_class.png'),
+ test_image(cat_gt_instance_class_map, 'cat_gt.png',
+ 'cat_pred_instance.png', 'cat_pred_class.png'),
+ test_image(team_gt_instance_class_map, 'team_gt_instance.png',
+ 'team_pred_instance.png', 'team_pred_class.png'),
+ ]
+
+ gt_classes = []
+ gt_instances = []
+ pred_classes = []
+ pred_instances = []
+ for test_image in test_images:
+ (image_gt_instances,
+ image_gt_classes) = test_utils.panoptic_segmentation_with_class_map(
+ test_image.gt_path, test_image.gt_class_map)
+ gt_classes.append(image_gt_classes)
+ gt_instances.append(image_gt_instances)
+
+ pred_instances.append(
+ test_utils.read_test_image(test_image.pred_inst_path, mode='L'))
+ pred_classes.append(
+ test_utils.read_segmentation_with_rgb_color_map(
+ test_image.pred_class_path, _CLASS_COLOR_MAP))
+
+ gt_class_tensor = tf.placeholder(tf.uint16)
+ gt_instance_tensor = tf.placeholder(tf.uint16)
+ pred_class_tensor = tf.placeholder(tf.uint16)
+ pred_instance_tensor = tf.placeholder(tf.uint16)
+ coverings, update_ops = streaming_metrics.streaming_parsing_covering(
+ gt_class_tensor,
+ gt_instance_tensor,
+ pred_class_tensor,
+ pred_instance_tensor,
+ num_classes=num_classes,
+ max_instances_per_category=256,
+ ignored_label=0,
+ offset=offset,
+ normalize_by_image_size=True)
+ (per_class_coverings, per_class_weighted_ious, per_class_gt_areas) = (
+ tf.unstack(coverings, num=3, axis=0))
+
+ with self.session() as sess:
+ sess.run(tf.local_variables_initializer())
+ for pred_class, pred_instance, gt_class, gt_instance in six.moves.zip(
+ pred_classes, pred_instances, gt_classes, gt_instances):
+ sess.run(
+ update_ops,
+ feed_dict={
+ gt_class_tensor: gt_class,
+ gt_instance_tensor: gt_instance,
+ pred_class_tensor: pred_class,
+ pred_instance_tensor: pred_instance
+ })
+ (result_per_class_coverings, result_per_class_weighted_ious,
+ result_per_class_gt_areas) = (
+ sess.run(
+ [
+ per_class_coverings,
+ per_class_weighted_ious,
+ per_class_gt_areas,
+ ],
+ feed_dict={
+ gt_class_tensor: 0,
+ gt_instance_tensor: 0,
+ pred_class_tensor: 0,
+ pred_instance_tensor: 0
+ }))
+
+ np.testing.assert_array_almost_equal(
+ result_per_class_coverings, [
+ 0.0,
+ 0.7009696912,
+ 0.5406896552,
+ 0.7453531599,
+ 0.8576779026,
+ 0.9910687881,
+ 0.7741046032,
+ ],
+ decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_per_class_weighted_ious, [
+ 0.0,
+ 0.5002088756,
+ 0.03935002196,
+ 0.03086105851,
+ 0.06547211033,
+ 0.8743792686,
+ 0.2549565051,
+ ],
+ decimal=4)
+ np.testing.assert_array_almost_equal(
+ result_per_class_gt_areas, [
+ 0.0,
+ 0.7135955832,
+ 0.07277746408,
+ 0.04140461216,
+ 0.07633647799,
+ 0.8822589099,
+ 0.3293566581,
+ ],
+ decimal=4)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/evaluation/test_utils.py b/models/research/deeplab/evaluation/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ad4f551271527ef1d1990398de5523b074d5779
--- /dev/null
+++ b/models/research/deeplab/evaluation/test_utils.py
@@ -0,0 +1,119 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions to set up unit tests on Panoptic Segmentation code."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+
+
+from absl import flags
+import numpy as np
+import scipy.misc
+import six
+from six.moves import map
+
+FLAGS = flags.FLAGS
+
+_TEST_DIR = 'deeplab/evaluation/testdata'
+
+
+def read_test_image(testdata_path, *args, **kwargs):
+ """Loads a test image.
+
+ Args:
+ testdata_path: Image path relative to panoptic_segmentation/testdata as a
+ string.
+ *args: Additional positional arguments passed to `imread`.
+ **kwargs: Additional keyword arguments passed to `imread`.
+
+ Returns:
+ The image, as a numpy array.
+ """
+ image_path = os.path.join(_TEST_DIR, testdata_path)
+ return scipy.misc.imread(image_path, *args, **kwargs)
+
+
+def read_segmentation_with_rgb_color_map(image_testdata_path,
+ rgb_to_semantic_label,
+ output_dtype=None):
+ """Reads a test segmentation as an image and a map from colors to labels.
+
+ Args:
+ image_testdata_path: Image path relative to panoptic_segmentation/testdata
+ as a string.
+ rgb_to_semantic_label: Mapping from RGB colors to integer labels as a
+ dictionary.
+ output_dtype: Type of the output labels. If None, defaults to the type of
+ the provided color map.
+
+ Returns:
+ A 2D numpy array of labels.
+
+ Raises:
+ ValueError: On an incomplete `rgb_to_semantic_label`.
+ """
+ rgb_image = read_test_image(image_testdata_path, mode='RGB')
+ if len(rgb_image.shape) != 3 or rgb_image.shape[2] != 3:
+ raise AssertionError(
+ 'Expected RGB image, actual shape is %s' % rgb_image.sape)
+
+ num_pixels = rgb_image.shape[0] * rgb_image.shape[1]
+ unique_colors = np.unique(np.reshape(rgb_image, [num_pixels, 3]), axis=0)
+ if not set(map(tuple, unique_colors)).issubset(
+ six.viewkeys(rgb_to_semantic_label)):
+ raise ValueError('RGB image has colors not in color map.')
+
+ output_dtype = output_dtype or type(
+ next(six.itervalues(rgb_to_semantic_label)))
+ output_labels = np.empty(rgb_image.shape[:2], dtype=output_dtype)
+ for rgb_color, int_label in six.iteritems(rgb_to_semantic_label):
+ color_array = np.array(rgb_color, ndmin=3)
+ output_labels[np.all(rgb_image == color_array, axis=2)] = int_label
+ return output_labels
+
+
+def panoptic_segmentation_with_class_map(instance_testdata_path,
+ instance_label_to_semantic_label):
+ """Reads in a panoptic segmentation with an instance map and a map to classes.
+
+ Args:
+ instance_testdata_path: Path to a grayscale instance map, given as a string
+ and relative to panoptic_segmentation/testdata.
+ instance_label_to_semantic_label: A map from instance labels to class
+ labels.
+
+ Returns:
+ A tuple `(instance_labels, class_labels)` of numpy arrays.
+
+ Raises:
+ ValueError: On a mismatched set of instances in
+ the
+ `instance_label_to_semantic_label`.
+ """
+ instance_labels = read_test_image(instance_testdata_path, mode='L')
+ if set(np.unique(instance_labels)) != set(
+ six.iterkeys(instance_label_to_semantic_label)):
+ raise ValueError('Provided class map does not match present instance ids.')
+
+ class_labels = np.empty_like(instance_labels)
+ for instance_id, class_id in six.iteritems(instance_label_to_semantic_label):
+ class_labels[instance_labels == instance_id] = class_id
+
+ return instance_labels, class_labels
diff --git a/models/research/deeplab/evaluation/test_utils_test.py b/models/research/deeplab/evaluation/test_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e9bed37e4bf721304e60d7fa12e6cfa9c4b7ef8
--- /dev/null
+++ b/models/research/deeplab/evaluation/test_utils_test.py
@@ -0,0 +1,74 @@
+# Lint as: python2, python3
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for test_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+from absl.testing import absltest
+import numpy as np
+
+from deeplab.evaluation import test_utils
+
+
+class TestUtilsTest(absltest.TestCase):
+
+ def test_read_test_image(self):
+ image_array = test_utils.read_test_image('team_pred_class.png')
+ self.assertSequenceEqual(image_array.shape, (231, 345, 4))
+
+ def test_reads_segmentation_with_color_map(self):
+ rgb_to_semantic_label = {(0, 0, 0): 0, (0, 0, 255): 1, (255, 0, 0): 23}
+ labels = test_utils.read_segmentation_with_rgb_color_map(
+ 'team_pred_class.png', rgb_to_semantic_label)
+
+ input_image = test_utils.read_test_image('team_pred_class.png')
+ np.testing.assert_array_equal(
+ labels == 0,
+ np.logical_and(input_image[:, :, 0] == 0, input_image[:, :, 2] == 0))
+ np.testing.assert_array_equal(labels == 1, input_image[:, :, 2] == 255)
+ np.testing.assert_array_equal(labels == 23, input_image[:, :, 0] == 255)
+
+ def test_reads_gt_segmentation(self):
+ instance_label_to_semantic_label = {
+ 0: 0,
+ 47: 1,
+ 97: 1,
+ 133: 1,
+ 150: 1,
+ 174: 1,
+ 198: 23,
+ 215: 1,
+ 244: 1,
+ 255: 1,
+ }
+ instances, classes = test_utils.panoptic_segmentation_with_class_map(
+ 'team_gt_instance.png', instance_label_to_semantic_label)
+
+ expected_label_shape = (231, 345)
+ self.assertSequenceEqual(instances.shape, expected_label_shape)
+ self.assertSequenceEqual(classes.shape, expected_label_shape)
+ np.testing.assert_array_equal(instances == 0, classes == 0)
+ np.testing.assert_array_equal(instances == 198, classes == 23)
+ np.testing.assert_array_equal(
+ np.logical_and(instances != 0, instances != 198), classes == 1)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/models/research/deeplab/evaluation/testdata/README.md b/models/research/deeplab/evaluation/testdata/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..711b4767de830938b277d8d135175c2287d9c9db
--- /dev/null
+++ b/models/research/deeplab/evaluation/testdata/README.md
@@ -0,0 +1,14 @@
+# Segmentation Evalaution Test Data
+
+## Source Images
+
+* [team_input.png](team_input.png) \
+ Source:
+ https://ai.googleblog.com/2018/03/semantic-image-segmentation-with.html
+* [cat_input.jpg](cat_input.jpg) \
+ Source: https://www.flickr.com/photos/magdalena_b/4995858743
+* [bird_input.jpg](bird_input.jpg) \
+ Source: https://www.flickr.com/photos/chivinskia/40619099560
+* [congress_input.jpg](congress_input.jpg) \
+ Source:
+ https://cao.house.gov/sites/cao.house.gov/files/documents/SAR-Jan-Jun-2016.pdf
diff --git a/models/research/deeplab/evaluation/testdata/bird_gt.png b/models/research/deeplab/evaluation/testdata/bird_gt.png
new file mode 100644
index 0000000000000000000000000000000000000000..05d854915d1809abe3ba10f03c20e75706e0bb17
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/bird_gt.png differ
diff --git a/models/research/deeplab/evaluation/testdata/bird_pred_class.png b/models/research/deeplab/evaluation/testdata/bird_pred_class.png
new file mode 100644
index 0000000000000000000000000000000000000000..07351bf061115d0990486cbb086b6b9ec53e691b
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/bird_pred_class.png differ
diff --git a/models/research/deeplab/evaluation/testdata/bird_pred_instance.png b/models/research/deeplab/evaluation/testdata/bird_pred_instance.png
new file mode 100644
index 0000000000000000000000000000000000000000..faa1371f52510fb6f15fecb0eecc3441b2c8eadb
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/bird_pred_instance.png differ
diff --git a/models/research/deeplab/evaluation/testdata/cat_gt.png b/models/research/deeplab/evaluation/testdata/cat_gt.png
new file mode 100644
index 0000000000000000000000000000000000000000..41f60111f3de899a9e1ca3a646bea72d86b3009f
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/cat_gt.png differ
diff --git a/models/research/deeplab/evaluation/testdata/cat_pred_class.png b/models/research/deeplab/evaluation/testdata/cat_pred_class.png
new file mode 100644
index 0000000000000000000000000000000000000000..3728c68ced20312567e70540b667b53269000318
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/cat_pred_class.png differ
diff --git a/models/research/deeplab/evaluation/testdata/cat_pred_instance.png b/models/research/deeplab/evaluation/testdata/cat_pred_instance.png
new file mode 100644
index 0000000000000000000000000000000000000000..ebd9ba4855f5c88a3b336d50e21d864a37175bbe
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/cat_pred_instance.png differ
diff --git a/models/research/deeplab/evaluation/testdata/coco_gt.json b/models/research/deeplab/evaluation/testdata/coco_gt.json
new file mode 100644
index 0000000000000000000000000000000000000000..5f79bf184338b8ec1ed540fd388f2de1f6a9451b
--- /dev/null
+++ b/models/research/deeplab/evaluation/testdata/coco_gt.json
@@ -0,0 +1,214 @@
+{
+ "info": {
+ "description": "Test COCO-format dataset",
+ "url": "https://github.com/tensorflow/models/tree/master/research/deeplab",
+ "version": "1.0",
+ "year": 2019
+ },
+ "images": [
+ {
+ "id": 1,
+ "file_name": "bird.jpg",
+ "height": 159,
+ "width": 240,
+ "flickr_url": "https://www.flickr.com/photos/chivinskia/40619099560"
+ },
+ {
+ "id": 2,
+ "file_name": "cat.jpg",
+ "height": 330,
+ "width": 317,
+ "flickr_url": "https://www.flickr.com/photos/magdalena_b/4995858743"
+ },
+ {
+ "id": 3,
+ "file_name": "team.jpg",
+ "height": 231,
+ "width": 345
+ },
+ {
+ "id": 4,
+ "file_name": "congress.jpg",
+ "height": 267,
+ "width": 525
+ }
+ ],
+ "annotations": [
+ {
+ "image_id": 1,
+ "file_name": "bird.png",
+ "segments_info": [
+ {
+ "id": 255,
+ "area": 2913,
+ "category_id": 4,
+ "iscrowd": 0
+ },
+ {
+ "id": 2586368,
+ "area": 1580,
+ "category_id": 3,
+ "iscrowd": 0
+ },
+ {
+ "id": 16770360,
+ "area": 33667,
+ "category_id": 5,
+ "iscrowd": 0
+ }
+ ]
+ },
+ {
+ "image_id": 2,
+ "file_name": "cat.png",
+ "segments_info": [
+ {
+ "id": 16711691,
+ "area": 34454,
+ "category_id": 6,
+ "iscrowd": 0
+ }
+ ]
+ },
+ {
+ "image_id": 3,
+ "file_name": "team.png",
+ "segments_info": [
+ {
+ "id": 129,
+ "area": 5443,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 255,
+ "area": 3574,
+ "category_id": 2,
+ "iscrowd": 0
+ },
+ {
+ "id": 47615,
+ "area": 11483,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 65532,
+ "area": 7080,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 8585107,
+ "area": 11363,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 9011200,
+ "area": 7158,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 12858027,
+ "area": 6419,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 16053492,
+ "area": 4350,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 16711680,
+ "area": 5800,
+ "category_id": 1,
+ "iscrowd": 0
+ }
+ ]
+ },
+ {
+ "image_id": 4,
+ "file_name": "congress.png",
+ "segments_info": [
+ {
+ "id": 255,
+ "area": 243,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 65315,
+ "area": 553,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 65516,
+ "area": 652,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 9895680,
+ "area": 82774,
+ "category_id": 1,
+ "iscrowd": 1
+ },
+ {
+ "id": 16711739,
+ "area": 137,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 16711868,
+ "area": 179,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 16762624,
+ "area": 2742,
+ "category_id": 1,
+ "iscrowd": 0
+ }
+ ]
+ }
+ ],
+ "categories": [
+ {
+ "id": 1,
+ "name": "person",
+ "isthing": 1
+ },
+ {
+ "id": 2,
+ "name": "umbrella",
+ "isthing": 1
+ },
+ {
+ "id": 3,
+ "name": "tree-merged",
+ "isthing": 0
+ },
+ {
+ "id": 4,
+ "name": "bird",
+ "isthing": 1
+ },
+ {
+ "id": 5,
+ "name": "sky",
+ "isthing": 0
+ },
+ {
+ "id": 6,
+ "name": "cat",
+ "isthing": 1
+ }
+ ]
+}
diff --git a/models/research/deeplab/evaluation/testdata/coco_gt/bird.png b/models/research/deeplab/evaluation/testdata/coco_gt/bird.png
new file mode 100644
index 0000000000000000000000000000000000000000..9ef4ad9504126213bf2e3f1f49cdb65b189e6b95
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/coco_gt/bird.png differ
diff --git a/models/research/deeplab/evaluation/testdata/coco_gt/cat.png b/models/research/deeplab/evaluation/testdata/coco_gt/cat.png
new file mode 100644
index 0000000000000000000000000000000000000000..cb02530f2f912ef0d8252e327c6324211152c760
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/coco_gt/cat.png differ
diff --git a/models/research/deeplab/evaluation/testdata/coco_gt/congress.png b/models/research/deeplab/evaluation/testdata/coco_gt/congress.png
new file mode 100644
index 0000000000000000000000000000000000000000..a56b98d336172288b2c68284f5cc1373f515c342
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/coco_gt/congress.png differ
diff --git a/models/research/deeplab/evaluation/testdata/coco_gt/team.png b/models/research/deeplab/evaluation/testdata/coco_gt/team.png
new file mode 100644
index 0000000000000000000000000000000000000000..bde358d151a576049e993a0fd9ebb9661c7060a9
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/coco_gt/team.png differ
diff --git a/models/research/deeplab/evaluation/testdata/coco_pred.json b/models/research/deeplab/evaluation/testdata/coco_pred.json
new file mode 100644
index 0000000000000000000000000000000000000000..4aead17a65d7d9203eabdb9f46c334d9e5aa402c
--- /dev/null
+++ b/models/research/deeplab/evaluation/testdata/coco_pred.json
@@ -0,0 +1,208 @@
+{
+ "info": {
+ "description": "Test COCO-format dataset",
+ "url": "https://github.com/tensorflow/models/tree/master/research/deeplab",
+ "version": "1.0",
+ "year": 2019
+ },
+ "images": [
+ {
+ "id": 1,
+ "file_name": "bird.jpg",
+ "height": 159,
+ "width": 240,
+ "flickr_url": "https://www.flickr.com/photos/chivinskia/40619099560"
+ },
+ {
+ "id": 2,
+ "file_name": "cat.jpg",
+ "height": 330,
+ "width": 317,
+ "flickr_url": "https://www.flickr.com/photos/magdalena_b/4995858743"
+ },
+ {
+ "id": 3,
+ "file_name": "team.jpg",
+ "height": 231,
+ "width": 345
+ },
+ {
+ "id": 4,
+ "file_name": "congress.jpg",
+ "height": 267,
+ "width": 525
+ }
+ ],
+ "annotations": [
+ {
+ "image_id": 1,
+ "file_name": "bird.png",
+ "segments_info": [
+ {
+ "id": 55551,
+ "area": 3039,
+ "category_id": 4,
+ "iscrowd": 0
+ },
+ {
+ "id": 16216831,
+ "area": 33659,
+ "category_id": 5,
+ "iscrowd": 0
+ },
+ {
+ "id": 16760832,
+ "area": 1237,
+ "category_id": 3,
+ "iscrowd": 0
+ }
+ ]
+ },
+ {
+ "image_id": 2,
+ "file_name": "cat.png",
+ "segments_info": [
+ {
+ "id": 36493,
+ "area": 26910,
+ "category_id": 6,
+ "iscrowd": 0
+ }
+ ]
+ },
+ {
+ "image_id": 3,
+ "file_name": "team.png",
+ "segments_info": [
+ {
+ "id": 0,
+ "area": 22164,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 129,
+ "area": 3418,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 255,
+ "area": 12827,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 740608,
+ "area": 8606,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 2555695,
+ "area": 7636,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 2883541,
+ "area": 6844,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 14408667,
+ "area": 4766,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 16711820,
+ "area": 4767,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 16768768,
+ "area": 8667,
+ "category_id": 1,
+ "iscrowd": 0
+ }
+ ]
+ },
+ {
+ "image_id": 4,
+ "file_name": "congress.png",
+ "segments_info": [
+ {
+ "id": 255,
+ "area": 2599,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 37375,
+ "area": 386,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 62207,
+ "area": 384,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 5177088,
+ "area": 260,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 16711691,
+ "area": 1011,
+ "category_id": 1,
+ "iscrowd": 0
+ },
+ {
+ "id": 16774912,
+ "area": 803,
+ "category_id": 1,
+ "iscrowd": 0
+ }
+ ]
+ }
+ ],
+ "categories": [
+ {
+ "id": 1,
+ "name": "person",
+ "isthing": 1
+ },
+ {
+ "id": 2,
+ "name": "umbrella",
+ "isthing": 1
+ },
+ {
+ "id": 3,
+ "name": "tree-merged",
+ "isthing": 0
+ },
+ {
+ "id": 4,
+ "name": "bird",
+ "isthing": 1
+ },
+ {
+ "id": 5,
+ "name": "sky",
+ "isthing": 0
+ },
+ {
+ "id": 6,
+ "name": "cat",
+ "isthing": 1
+ }
+ ]
+}
diff --git a/models/research/deeplab/evaluation/testdata/coco_pred/bird.png b/models/research/deeplab/evaluation/testdata/coco_pred/bird.png
new file mode 100644
index 0000000000000000000000000000000000000000..c9b4cbcbf444a890e26ad9091f8496e2596c04ad
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/coco_pred/bird.png differ
diff --git a/models/research/deeplab/evaluation/testdata/coco_pred/cat.png b/models/research/deeplab/evaluation/testdata/coco_pred/cat.png
new file mode 100644
index 0000000000000000000000000000000000000000..324583271c4b11ef28e845e1cafb853383faf506
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/coco_pred/cat.png differ
diff --git a/models/research/deeplab/evaluation/testdata/coco_pred/congress.png b/models/research/deeplab/evaluation/testdata/coco_pred/congress.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc7bb06050ed40f5c022f3cd7f0060c7fa84751a
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/coco_pred/congress.png differ
diff --git a/models/research/deeplab/evaluation/testdata/coco_pred/team.png b/models/research/deeplab/evaluation/testdata/coco_pred/team.png
new file mode 100644
index 0000000000000000000000000000000000000000..7300bf41f03a8ba08a1cb3f99821b69cddb318c2
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/coco_pred/team.png differ
diff --git a/models/research/deeplab/evaluation/testdata/team_gt_instance.png b/models/research/deeplab/evaluation/testdata/team_gt_instance.png
new file mode 100644
index 0000000000000000000000000000000000000000..97abb55273ce409a5fbaa85cb999f0725d457dbf
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/team_gt_instance.png differ
diff --git a/models/research/deeplab/evaluation/testdata/team_pred_class.png b/models/research/deeplab/evaluation/testdata/team_pred_class.png
new file mode 100644
index 0000000000000000000000000000000000000000..2ed78de2cbd923e6530f08fc2c47bf8377cfaf69
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/team_pred_class.png differ
diff --git a/models/research/deeplab/evaluation/testdata/team_pred_instance.png b/models/research/deeplab/evaluation/testdata/team_pred_instance.png
new file mode 100644
index 0000000000000000000000000000000000000000..264606a4d8822108481132ff9e990d826c64a274
Binary files /dev/null and b/models/research/deeplab/evaluation/testdata/team_pred_instance.png differ
diff --git a/models/research/deeplab/export_model.py b/models/research/deeplab/export_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7307b5a212f4445f78b31a933443b2dcbd505e6
--- /dev/null
+++ b/models/research/deeplab/export_model.py
@@ -0,0 +1,201 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Exports trained model to TensorFlow frozen graph."""
+
+import os
+import tensorflow as tf
+
+from tensorflow.contrib import quantize as contrib_quantize
+from tensorflow.python.tools import freeze_graph
+from deeplab import common
+from deeplab import input_preprocess
+from deeplab import model
+
+slim = tf.contrib.slim
+flags = tf.app.flags
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path')
+
+flags.DEFINE_string('export_path', None,
+ 'Path to output Tensorflow frozen graph.')
+
+flags.DEFINE_integer('num_classes', 21, 'Number of classes.')
+
+flags.DEFINE_multi_integer('crop_size', [513, 513],
+ 'Crop size [height, width].')
+
+# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
+# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
+# one could use different atrous_rates/output_stride during training/evaluation.
+flags.DEFINE_multi_integer('atrous_rates', None,
+ 'Atrous rates for atrous spatial pyramid pooling.')
+
+flags.DEFINE_integer('output_stride', 8,
+ 'The ratio of input to output spatial resolution.')
+
+# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale inference.
+flags.DEFINE_multi_float('inference_scales', [1.0],
+ 'The scales to resize images for inference.')
+
+flags.DEFINE_bool('add_flipped_images', False,
+ 'Add flipped images during inference or not.')
+
+flags.DEFINE_integer(
+ 'quantize_delay_step', -1,
+ 'Steps to start quantized training. If < 0, will not quantize model.')
+
+flags.DEFINE_bool('save_inference_graph', False,
+ 'Save inference graph in text proto.')
+
+# Input name of the exported model.
+_INPUT_NAME = 'ImageTensor'
+
+# Output name of the exported predictions.
+_OUTPUT_NAME = 'SemanticPredictions'
+_RAW_OUTPUT_NAME = 'RawSemanticPredictions'
+
+# Output name of the exported probabilities.
+_OUTPUT_PROB_NAME = 'SemanticProbabilities'
+_RAW_OUTPUT_PROB_NAME = 'RawSemanticProbabilities'
+
+
+def _create_input_tensors():
+ """Creates and prepares input tensors for DeepLab model.
+
+ This method creates a 4-D uint8 image tensor 'ImageTensor' with shape
+ [1, None, None, 3]. The actual input tensor name to use during inference is
+ 'ImageTensor:0'.
+
+ Returns:
+ image: Preprocessed 4-D float32 tensor with shape [1, crop_height,
+ crop_width, 3].
+ original_image_size: Original image shape tensor [height, width].
+ resized_image_size: Resized image shape tensor [height, width].
+ """
+ # input_preprocess takes 4-D image tensor as input.
+ input_image = tf.placeholder(tf.uint8, [1, None, None, 3], name=_INPUT_NAME)
+ original_image_size = tf.shape(input_image)[1:3]
+
+ # Squeeze the dimension in axis=0 since `preprocess_image_and_label` assumes
+ # image to be 3-D.
+ image = tf.squeeze(input_image, axis=0)
+ resized_image, image, _ = input_preprocess.preprocess_image_and_label(
+ image,
+ label=None,
+ crop_height=FLAGS.crop_size[0],
+ crop_width=FLAGS.crop_size[1],
+ min_resize_value=FLAGS.min_resize_value,
+ max_resize_value=FLAGS.max_resize_value,
+ resize_factor=FLAGS.resize_factor,
+ is_training=False,
+ model_variant=FLAGS.model_variant)
+ resized_image_size = tf.shape(resized_image)[:2]
+
+ # Expand the dimension in axis=0, since the following operations assume the
+ # image to be 4-D.
+ image = tf.expand_dims(image, 0)
+
+ return image, original_image_size, resized_image_size
+
+
+def main(unused_argv):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.logging.info('Prepare to export model to: %s', FLAGS.export_path)
+
+ with tf.Graph().as_default():
+ image, image_size, resized_image_size = _create_input_tensors()
+
+ model_options = common.ModelOptions(
+ outputs_to_num_classes={common.OUTPUT_TYPE: FLAGS.num_classes},
+ crop_size=FLAGS.crop_size,
+ atrous_rates=FLAGS.atrous_rates,
+ output_stride=FLAGS.output_stride)
+
+ if tuple(FLAGS.inference_scales) == (1.0,):
+ tf.logging.info('Exported model performs single-scale inference.')
+ predictions = model.predict_labels(
+ image,
+ model_options=model_options,
+ image_pyramid=FLAGS.image_pyramid)
+ else:
+ tf.logging.info('Exported model performs multi-scale inference.')
+ if FLAGS.quantize_delay_step >= 0:
+ raise ValueError(
+ 'Quantize mode is not supported with multi-scale test.')
+ predictions = model.predict_labels_multi_scale(
+ image,
+ model_options=model_options,
+ eval_scales=FLAGS.inference_scales,
+ add_flipped_images=FLAGS.add_flipped_images)
+ raw_predictions = tf.identity(
+ tf.cast(predictions[common.OUTPUT_TYPE], tf.float32),
+ _RAW_OUTPUT_NAME)
+ raw_probabilities = tf.identity(
+ predictions[common.OUTPUT_TYPE + model.PROB_SUFFIX],
+ _RAW_OUTPUT_PROB_NAME)
+
+ # Crop the valid regions from the predictions.
+ semantic_predictions = raw_predictions[
+ :, :resized_image_size[0], :resized_image_size[1]]
+ semantic_probabilities = raw_probabilities[
+ :, :resized_image_size[0], :resized_image_size[1]]
+
+ # Resize back the prediction to the original image size.
+ def _resize_label(label, label_size):
+ # Expand dimension of label to [1, height, width, 1] for resize operation.
+ label = tf.expand_dims(label, 3)
+ resized_label = tf.image.resize_images(
+ label,
+ label_size,
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
+ align_corners=True)
+ return tf.cast(tf.squeeze(resized_label, 3), tf.int32)
+ semantic_predictions = _resize_label(semantic_predictions, image_size)
+ semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)
+
+ semantic_probabilities = tf.image.resize_bilinear(
+ semantic_probabilities, image_size, align_corners=True,
+ name=_OUTPUT_PROB_NAME)
+
+ if FLAGS.quantize_delay_step >= 0:
+ contrib_quantize.create_eval_graph()
+
+ saver = tf.train.Saver(tf.all_variables())
+
+ dirname = os.path.dirname(FLAGS.export_path)
+ tf.gfile.MakeDirs(dirname)
+ graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
+ freeze_graph.freeze_graph_with_def_protos(
+ graph_def,
+ saver.as_saver_def(),
+ FLAGS.checkpoint_path,
+ _OUTPUT_NAME + ',' + _OUTPUT_PROB_NAME,
+ restore_op_name=None,
+ filename_tensor_name=None,
+ output_graph=FLAGS.export_path,
+ clear_devices=True,
+ initializer_nodes=None)
+
+ if FLAGS.save_inference_graph:
+ tf.train.write_graph(graph_def, dirname, 'inference_graph.pbtxt')
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('checkpoint_path')
+ flags.mark_flag_as_required('export_path')
+ tf.app.run()
diff --git a/models/research/deeplab/g3doc/ade20k.md b/models/research/deeplab/g3doc/ade20k.md
new file mode 100644
index 0000000000000000000000000000000000000000..9505ab2cd99ef1b9a7eb8a53a7f909aa4a32977b
--- /dev/null
+++ b/models/research/deeplab/g3doc/ade20k.md
@@ -0,0 +1,107 @@
+# Running DeepLab on ADE20K Semantic Segmentation Dataset
+
+This page walks through the steps required to run DeepLab on ADE20K dataset on a
+local machine.
+
+## Download dataset and convert to TFRecord
+
+We have prepared the script (under the folder `datasets`) to download and
+convert ADE20K semantic segmentation dataset to TFRecord.
+
+```bash
+# From the tensorflow/models/research/deeplab/datasets directory.
+bash download_and_convert_ade20k.sh
+```
+
+The converted dataset will be saved at ./deeplab/datasets/ADE20K/tfrecord
+
+## Recommended Directory Structure for Training and Evaluation
+
+```
++ datasets
+ - build_data.py
+ - build_ade20k_data.py
+ - download_and_convert_ade20k.sh
+ + ADE20K
+ + tfrecord
+ + exp
+ + train_on_train_set
+ + train
+ + eval
+ + vis
+ + ADEChallengeData2016
+ + annotations
+ + training
+ + validation
+ + images
+ + training
+ + validation
+```
+
+where the folder `train_on_train_set` stores the train/eval/vis events and
+results (when training DeepLab on the ADE20K train set).
+
+## Running the train/eval/vis jobs
+
+A local training job using `xception_65` can be run with the following command:
+
+```bash
+# From tensorflow/models/research/
+python deeplab/train.py \
+ --logtostderr \
+ --training_number_of_steps=150000 \
+ --train_split="train" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --train_crop_size="513,513" \
+ --train_batch_size=4 \
+ --min_resize_value=513 \
+ --max_resize_value=513 \
+ --resize_factor=16 \
+ --dataset="ade20k" \
+ --tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \
+ --train_logdir=${PATH_TO_TRAIN_DIR}\
+ --dataset_dir=${PATH_TO_DATASET}
+```
+
+where ${PATH\_TO\_INITIAL\_CHECKPOINT} is the path to the initial checkpoint.
+${PATH\_TO\_TRAIN\_DIR} is the directory in which training checkpoints and
+events will be written to (it is recommended to set it to the
+`train_on_train_set/train` above), and ${PATH\_TO\_DATASET} is the directory in
+which the ADE20K dataset resides (the `tfrecord` above)
+
+**Note that for train.py:**
+
+1. In order to fine tune the BN layers, one needs to use large batch size (>
+ 12), and set fine_tune_batch_norm = True. Here, we simply use small batch
+ size during training for the purpose of demonstration. If the users have
+ limited GPU memory at hand, please fine-tune from our provided checkpoints
+ whose batch norm parameters have been trained, and use smaller learning rate
+ with fine_tune_batch_norm = False.
+
+2. User should fine tune the `min_resize_value` and `max_resize_value` to get
+ better result. Note that `resize_factor` has to be equal to `output_stride`.
+
+3. The users should change atrous_rates from [6, 12, 18] to [12, 24, 36] if
+ setting output_stride=8.
+
+4. The users could skip the flag, `decoder_output_stride`, if you do not want
+ to use the decoder structure.
+
+## Running Tensorboard
+
+Progress for training and evaluation jobs can be inspected using Tensorboard. If
+using the recommended directory structure, Tensorboard can be run using the
+following command:
+
+```bash
+tensorboard --logdir=${PATH_TO_LOG_DIRECTORY}
+```
+
+where `${PATH_TO_LOG_DIRECTORY}` points to the directory that contains the train
+directorie (e.g., the folder `train_on_train_set` in the above example). Please
+note it may take Tensorboard a couple minutes to populate with data.
diff --git a/models/research/deeplab/g3doc/cityscapes.md b/models/research/deeplab/g3doc/cityscapes.md
new file mode 100644
index 0000000000000000000000000000000000000000..af703088e61b49aa81bf62b536469b410f0fb352
--- /dev/null
+++ b/models/research/deeplab/g3doc/cityscapes.md
@@ -0,0 +1,159 @@
+# Running DeepLab on Cityscapes Semantic Segmentation Dataset
+
+This page walks through the steps required to run DeepLab on Cityscapes on a
+local machine.
+
+## Download dataset and convert to TFRecord
+
+We have prepared the script (under the folder `datasets`) to convert Cityscapes
+dataset to TFRecord. The users are required to download the dataset beforehand
+by registering the [website](https://www.cityscapes-dataset.com/).
+
+```bash
+# From the tensorflow/models/research/deeplab/datasets directory.
+sh convert_cityscapes.sh
+```
+
+The converted dataset will be saved at ./deeplab/datasets/cityscapes/tfrecord.
+
+## Recommended Directory Structure for Training and Evaluation
+
+```
++ datasets
+ + cityscapes
+ + leftImg8bit
+ + gtFine
+ + tfrecord
+ + exp
+ + train_on_train_set
+ + train
+ + eval
+ + vis
+```
+
+where the folder `train_on_train_set` stores the train/eval/vis events and
+results (when training DeepLab on the Cityscapes train set).
+
+## Running the train/eval/vis jobs
+
+A local training job using `xception_65` can be run with the following command:
+
+```bash
+# From tensorflow/models/research/
+python deeplab/train.py \
+ --logtostderr \
+ --training_number_of_steps=90000 \
+ --train_split="train" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --train_crop_size="769,769" \
+ --train_batch_size=1 \
+ --dataset="cityscapes" \
+ --tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \
+ --train_logdir=${PATH_TO_TRAIN_DIR} \
+ --dataset_dir=${PATH_TO_DATASET}
+```
+
+where ${PATH_TO_INITIAL_CHECKPOINT} is the path to the initial checkpoint
+(usually an ImageNet pretrained checkpoint), ${PATH_TO_TRAIN_DIR} is the
+directory in which training checkpoints and events will be written to, and
+${PATH_TO_DATASET} is the directory in which the Cityscapes dataset resides.
+
+**Note that for {train,eval,vis}.py**:
+
+1. In order to reproduce our results, one needs to use large batch size (> 8),
+ and set fine_tune_batch_norm = True. Here, we simply use small batch size
+ during training for the purpose of demonstration. If the users have limited
+ GPU memory at hand, please fine-tune from our provided checkpoints whose
+ batch norm parameters have been trained, and use smaller learning rate with
+ fine_tune_batch_norm = False.
+
+2. The users should change atrous_rates from [6, 12, 18] to [12, 24, 36] if
+ setting output_stride=8.
+
+3. The users could skip the flag, `decoder_output_stride`, if you do not want
+ to use the decoder structure.
+
+4. Change and add the following flags in order to use the provided dense
+ prediction cell. Note we need to set decoder_output_stride if you want to
+ use the provided checkpoints which include the decoder module.
+
+```bash
+--model_variant="xception_71"
+--dense_prediction_cell_json="deeplab/core/dense_prediction_cell_branch5_top1_cityscapes.json"
+--decoder_output_stride=4
+```
+
+A local evaluation job using `xception_65` can be run with the following
+command:
+
+```bash
+# From tensorflow/models/research/
+python deeplab/eval.py \
+ --logtostderr \
+ --eval_split="val" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --eval_crop_size="1025,2049" \
+ --dataset="cityscapes" \
+ --checkpoint_dir=${PATH_TO_CHECKPOINT} \
+ --eval_logdir=${PATH_TO_EVAL_DIR} \
+ --dataset_dir=${PATH_TO_DATASET}
+```
+
+where ${PATH_TO_CHECKPOINT} is the path to the trained checkpoint (i.e., the
+path to train_logdir), ${PATH_TO_EVAL_DIR} is the directory in which evaluation
+events will be written to, and ${PATH_TO_DATASET} is the directory in which the
+Cityscapes dataset resides.
+
+A local visualization job using `xception_65` can be run with the following
+command:
+
+```bash
+# From tensorflow/models/research/
+python deeplab/vis.py \
+ --logtostderr \
+ --vis_split="val" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --vis_crop_size="1025,2049" \
+ --dataset="cityscapes" \
+ --colormap_type="cityscapes" \
+ --checkpoint_dir=${PATH_TO_CHECKPOINT} \
+ --vis_logdir=${PATH_TO_VIS_DIR} \
+ --dataset_dir=${PATH_TO_DATASET}
+```
+
+where ${PATH_TO_CHECKPOINT} is the path to the trained checkpoint (i.e., the
+path to train_logdir), ${PATH_TO_VIS_DIR} is the directory in which evaluation
+events will be written to, and ${PATH_TO_DATASET} is the directory in which the
+Cityscapes dataset resides. Note that if the users would like to save the
+segmentation results for evaluation server, set also_save_raw_predictions =
+True.
+
+## Running Tensorboard
+
+Progress for training and evaluation jobs can be inspected using Tensorboard. If
+using the recommended directory structure, Tensorboard can be run using the
+following command:
+
+```bash
+tensorboard --logdir=${PATH_TO_LOG_DIRECTORY}
+```
+
+where `${PATH_TO_LOG_DIRECTORY}` points to the directory that contains the
+train, eval, and vis directories (e.g., the folder `train_on_train_set` in the
+above example). Please note it may take Tensorboard a couple minutes to populate
+with data.
diff --git a/models/research/deeplab/g3doc/export_model.md b/models/research/deeplab/g3doc/export_model.md
new file mode 100644
index 0000000000000000000000000000000000000000..c41649e609a39ccb2e7c7622e1d4e25f86d20cb7
--- /dev/null
+++ b/models/research/deeplab/g3doc/export_model.md
@@ -0,0 +1,23 @@
+# Export trained deeplab model to frozen inference graph
+
+After model training finishes, you could export it to a frozen TensorFlow
+inference graph proto. Your trained model checkpoint usually includes the
+following files:
+
+* model.ckpt-${CHECKPOINT_NUMBER}.data-00000-of-00001,
+* model.ckpt-${CHECKPOINT_NUMBER}.index
+* model.ckpt-${CHECKPOINT_NUMBER}.meta
+
+After you have identified a candidate checkpoint to export, you can run the
+following commandline to export to a frozen graph:
+
+```bash
+# From tensorflow/models/research/
+# Assume all checkpoint files share the same path prefix `${CHECKPOINT_PATH}`.
+python deeplab/export_model.py \
+ --checkpoint_path=${CHECKPOINT_PATH} \
+ --export_path=${OUTPUT_DIR}/frozen_inference_graph.pb
+```
+
+Please also add other model specific flags as you use for training, such as
+`model_variant`, `add_image_level_feature`, etc.
diff --git a/models/research/deeplab/g3doc/faq.md b/models/research/deeplab/g3doc/faq.md
new file mode 100644
index 0000000000000000000000000000000000000000..26ff4b3281cd624cb25292d89ef3fad55b8851f2
--- /dev/null
+++ b/models/research/deeplab/g3doc/faq.md
@@ -0,0 +1,87 @@
+# FAQ
+___
+Q1: What if I want to use other network backbones, such as ResNet [1], instead of only those provided ones (e.g., Xception)?
+
+A: The users could modify the provided core/feature_extractor.py to support more network backbones.
+___
+Q2: What if I want to train the model on other datasets?
+
+A: The users could modify the provided dataset/build_{cityscapes,voc2012}_data.py and dataset/segmentation_dataset.py to build their own dataset.
+___
+Q3: Where can I download the PASCAL VOC augmented training set?
+
+A: The PASCAL VOC augmented training set is provided by Bharath Hariharan et al. [2] Please refer to their [website](http://home.bharathh.info/pubs/codes/SBD/download.html) for details and consider citing their paper if using the dataset.
+___
+Q4: Why the implementation does not include DenseCRF [3]?
+
+A: We have not tried this. The interested users could take a look at Philipp Krähenbühl's [website](http://graphics.stanford.edu/projects/densecrf/) and [paper](https://arxiv.org/abs/1210.5644) for details.
+___
+Q5: What if I want to train the model and fine-tune the batch normalization parameters?
+
+A: If given the limited resource at hand, we would suggest you simply fine-tune
+from our provided checkpoint whose batch-norm parameters have been trained (i.e.,
+train with a smaller learning rate, set `fine_tune_batch_norm = false`, and
+employ longer training iterations since the learning rate is small). If
+you really would like to train by yourself, we would suggest
+
+1. Set `output_stride = 16` or maybe even `32` (remember to change the flag
+`atrous_rates` accordingly, e.g., `atrous_rates = [3, 6, 9]` for
+`output_stride = 32`).
+
+2. Use as many GPUs as possible (change the flag `num_clones` in train.py) and
+set `train_batch_size` as large as possible.
+
+3. Adjust the `train_crop_size` in train.py. Maybe set it to be smaller, e.g.,
+513x513 (or even 321x321), so that you could use a larger batch size.
+
+4. Use a smaller network backbone, such as MobileNet-v2.
+
+___
+Q6: How can I train the model asynchronously?
+
+A: In the train.py, the users could set `num_replicas` (number of machines for training) and `num_ps_tasks` (we usually set `num_ps_tasks` = `num_replicas` / 2). See slim.deployment.model_deploy for more details.
+___
+Q7: I could not reproduce the performance even with the provided checkpoints.
+
+A: Please try running
+
+```bash
+# Run the simple test with Xception_65 as network backbone.
+sh local_test.sh
+```
+
+or
+
+```bash
+# Run the simple test with MobileNet-v2 as network backbone.
+sh local_test_mobilenetv2.sh
+```
+
+First, make sure you could reproduce the results with our provided setting.
+After that, you could start to make a new change one at a time to help debug.
+___
+Q8: What value of `eval_crop_size` should I use?
+
+A: Our model uses whole-image inference, meaning that we need to set `eval_crop_size` equal to `output_stride` * k + 1, where k is an integer and set k so that the resulting `eval_crop_size` is slightly larger the largest
+image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose
+image dimension is all equal to 1024x2048.
+___
+Q9: Why multi-gpu training is slow?
+
+A: Please try to use more threads to pre-process the inputs. For, example change [num_readers = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/train.py#L457).
+___
+
+
+## References
+
+1. **Deep Residual Learning for Image Recognition**
+ Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ [[link]](https://arxiv.org/abs/1512.03385), In CVPR, 2016.
+
+2. **Semantic Contours from Inverse Detectors**
+ Bharath Hariharan, Pablo Arbelaez, Lubomir Bourdev, Subhransu Maji, Jitendra Malik
+ [[link]](http://home.bharathh.info/pubs/codes/SBD/download.html), In ICCV, 2011.
+
+3. **Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials**
+ Philipp Krähenbühl, Vladlen Koltun
+ [[link]](http://graphics.stanford.edu/projects/densecrf/), In NIPS, 2011.
diff --git a/models/research/deeplab/g3doc/img/image1.jpg b/models/research/deeplab/g3doc/img/image1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..939b6f9cef3da337e1279246090f20bd78920bc8
Binary files /dev/null and b/models/research/deeplab/g3doc/img/image1.jpg differ
diff --git a/models/research/deeplab/g3doc/img/image2.jpg b/models/research/deeplab/g3doc/img/image2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5ec1b8ac278906921bd3b6efec8fbe2e9d8c429e
Binary files /dev/null and b/models/research/deeplab/g3doc/img/image2.jpg differ
diff --git a/models/research/deeplab/g3doc/img/image3.jpg b/models/research/deeplab/g3doc/img/image3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d788e3dc68d684ca6e282bdff66a32abc767214a
Binary files /dev/null and b/models/research/deeplab/g3doc/img/image3.jpg differ
diff --git a/models/research/deeplab/g3doc/img/image_info.txt b/models/research/deeplab/g3doc/img/image_info.txt
new file mode 100644
index 0000000000000000000000000000000000000000..583d113e7ebb4d81ca1cdc51c3317243600809ee
--- /dev/null
+++ b/models/research/deeplab/g3doc/img/image_info.txt
@@ -0,0 +1,13 @@
+Image provenance:
+
+image1.jpg: Philippe Put,
+ https://www.flickr.com/photos/34547181@N00/14499172124
+
+image2.jpg: Peretz Partensky
+ https://www.flickr.com/photos/ifl/3926001309
+
+image3.jpg: Peter Harrison
+ https://www.flickr.com/photos/devcentre/392585679
+
+
+vis[1-3].png: Showing original image together with DeepLab segmentation map.
diff --git a/models/research/deeplab/g3doc/img/vis1.png b/models/research/deeplab/g3doc/img/vis1.png
new file mode 100644
index 0000000000000000000000000000000000000000..41b8ecd89590dcf6b635e32c3af4d4b18fbafede
Binary files /dev/null and b/models/research/deeplab/g3doc/img/vis1.png differ
diff --git a/models/research/deeplab/g3doc/img/vis2.png b/models/research/deeplab/g3doc/img/vis2.png
new file mode 100644
index 0000000000000000000000000000000000000000..7fa7a4cacc4807f2ab1a9c802757d76e932a41c1
Binary files /dev/null and b/models/research/deeplab/g3doc/img/vis2.png differ
diff --git a/models/research/deeplab/g3doc/img/vis3.png b/models/research/deeplab/g3doc/img/vis3.png
new file mode 100644
index 0000000000000000000000000000000000000000..813b6340a61f63e3b838a91562bc0b914191ba47
Binary files /dev/null and b/models/research/deeplab/g3doc/img/vis3.png differ
diff --git a/models/research/deeplab/g3doc/installation.md b/models/research/deeplab/g3doc/installation.md
new file mode 100644
index 0000000000000000000000000000000000000000..8629aba42207fc6e35c907024485c0e7f29f5e10
--- /dev/null
+++ b/models/research/deeplab/g3doc/installation.md
@@ -0,0 +1,73 @@
+# Installation
+
+## Dependencies
+
+DeepLab depends on the following libraries:
+
+* Numpy
+* Pillow 1.0
+* tf Slim (which is included in the "tensorflow/models/research/" checkout)
+* Jupyter notebook
+* Matplotlib
+* Tensorflow
+
+For detailed steps to install Tensorflow, follow the [Tensorflow installation
+instructions](https://www.tensorflow.org/install/). A typical user can install
+Tensorflow using one of the following commands:
+
+```bash
+# For CPU
+pip install tensorflow
+# For GPU
+pip install tensorflow-gpu
+```
+
+The remaining libraries can be installed on Ubuntu 14.04 using via apt-get:
+
+```bash
+sudo apt-get install python-pil python-numpy
+pip install --user jupyter
+pip install --user matplotlib
+pip install --user PrettyTable
+```
+
+## Add Libraries to PYTHONPATH
+
+When running locally, the tensorflow/models/research/ directory should be
+appended to PYTHONPATH. This can be done by running the following from
+tensorflow/models/research/:
+
+```bash
+# From tensorflow/models/research/
+export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
+
+# [Optional] for panoptic evaluation, you might need panopticapi:
+# https://github.com/cocodataset/panopticapi
+# Please clone it to a local directory ${PANOPTICAPI_DIR}
+touch ${PANOPTICAPI_DIR}/panopticapi/__init__.py
+export PYTHONPATH=$PYTHONPATH:${PANOPTICAPI_DIR}/panopticapi
+```
+
+Note: This command needs to run from every new terminal you start. If you wish
+to avoid running this manually, you can add it as a new line to the end of your
+~/.bashrc file.
+
+# Testing the Installation
+
+You can test if you have successfully installed the Tensorflow DeepLab by
+running the following commands:
+
+Quick test by running model_test.py:
+
+```bash
+# From tensorflow/models/research/
+python deeplab/model_test.py
+```
+
+Quick running the whole code on the PASCAL VOC 2012 dataset:
+
+```bash
+# From tensorflow/models/research/deeplab
+sh local_test.sh
+```
+
diff --git a/models/research/deeplab/g3doc/model_zoo.md b/models/research/deeplab/g3doc/model_zoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..76972dc796e77838004a6f36bef73ca5bb66aff5
--- /dev/null
+++ b/models/research/deeplab/g3doc/model_zoo.md
@@ -0,0 +1,254 @@
+# TensorFlow DeepLab Model Zoo
+
+We provide deeplab models pretrained several datasets, including (1) PASCAL VOC
+2012, (2) Cityscapes, and (3) ADE20K for reproducing our results, as well as
+some checkpoints that are only pretrained on ImageNet for training your own
+models.
+
+## DeepLab models trained on PASCAL VOC 2012
+
+Un-tar'ed directory includes:
+
+* a frozen inference graph (`frozen_inference_graph.pb`). All frozen inference
+ graphs by default use output stride of 8, a single eval scale of 1.0 and
+ no left-right flips, unless otherwise specified. MobileNet-v2 based models
+ do not include the decoder module.
+
+* a checkpoint (`model.ckpt.data-00000-of-00001`, `model.ckpt.index`)
+
+### Model details
+
+We provide several checkpoints that have been pretrained on VOC 2012 train_aug
+set or train_aug + trainval set. In the former case, one could train their model
+with smaller batch size and freeze batch normalization when limited GPU memory
+is available, since we have already fine-tuned the batch normalization for you.
+In the latter case, one could directly evaluate the checkpoints on VOC 2012 test
+set or use this checkpoint for demo. Note *MobileNet-v2* based models do not
+employ ASPP and decoder modules for fast computation.
+
+Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder
+--------------------------- | :--------------: | :-----------------: | :---: | :-----:
+mobilenetv2_dm05_coco_voc_trainaug | MobileNet-v2 Depth-Multiplier = 0.5 | ImageNet MS-COCO VOC 2012 train_aug set| N/A | N/A
+mobilenetv2_dm05_coco_voc_trainval | MobileNet-v2 Depth-Multiplier = 0.5 | ImageNet MS-COCO VOC 2012 train_aug + trainval sets | N/A | N/A
+mobilenetv2_coco_voc_trainaug | MobileNet-v2 | ImageNet MS-COCO VOC 2012 train_aug set| N/A | N/A
+mobilenetv2_coco_voc_trainval | MobileNet-v2 | ImageNet MS-COCO VOC 2012 train_aug + trainval sets | N/A | N/A
+xception65_coco_voc_trainaug | Xception_65 | ImageNet MS-COCO VOC 2012 train_aug set| [6,12,18] for OS=16 [12,24,36] for OS=8 | OS = 4
+xception65_coco_voc_trainval | Xception_65 | ImageNet MS-COCO VOC 2012 train_aug + trainval sets | [6,12,18] for OS=16 [12,24,36] for OS=8 | OS = 4
+
+In the table, **OS** denotes output stride.
+
+Checkpoint name | Eval OS | Eval scales | Left-right Flip | Multiply-Adds | Runtime (sec) | PASCAL mIOU | File Size
+------------------------------------------------------------------------------------------------------------------------ | :-------: | :------------------------: | :-------------: | :------------------: | :------------: | :----------------------------: | :-------:
+[mobilenetv2_dm05_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_trainaug_2018_10_01.tar.gz) | 16 | [1.0] | No | 0.88B | - | 70.19% (val) | 7.6MB
+[mobilenetv2_dm05_coco_voc_trainval](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_trainval_2018_10_01.tar.gz) | 8 | [1.0] | No | 2.84B | - | 71.83% (test) | 7.6MB
+[mobilenetv2_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz) | 16 8 | [1.0] [0.5:0.25:1.75] | No Yes | 2.75B 152.59B | 0.1 26.9 | 75.32% (val) 77.33 (val) | 23MB
+[mobilenetv2_coco_voc_trainval](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 152.59B | 26.9 | 80.25% (**test**) | 23MB
+[xception65_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz) | 16 8 | [1.0] [0.5:0.25:1.75] | No Yes | 54.17B 3055.35B | 0.7 223.2 | 82.20% (val) 83.58% (val) | 439MB
+[xception65_coco_voc_trainval](http://download.tensorflow.org/models/deeplabv3_pascal_trainval_2018_01_04.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 3055.35B | 223.2 | 87.80% (**test**) | 439MB
+
+In the table, we report both computation complexity (in terms of Multiply-Adds
+and CPU Runtime) and segmentation performance (in terms of mIOU) on the PASCAL
+VOC val or test set. The reported runtime is calculated by tfprof on a
+workstation with CPU E5-1650 v3 @ 3.50GHz and 32GB memory. Note that applying
+multi-scale inputs and left-right flips increases the segmentation performance
+but also significantly increases the computation and thus may not be suitable
+for real-time applications.
+
+## DeepLab models trained on Cityscapes
+
+### Model details
+
+We provide several checkpoints that have been pretrained on Cityscapes
+train_fine set. Note *MobileNet-v2* based model has been pretrained on MS-COCO
+dataset and does not employ ASPP and decoder modules for fast computation.
+
+Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder
+------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----:
+mobilenetv2_coco_cityscapes_trainfine | MobileNet-v2 | ImageNet MS-COCO Cityscapes train_fine set | N/A | N/A
+mobilenetv3_large_cityscapes_trainfine | MobileNet-v3 Large | Cityscapes train_fine set (No ImageNet) | N/A | OS = 8
+mobilenetv3_small_cityscapes_trainfine | MobileNet-v3 Small | Cityscapes train_fine set (No ImageNet) | N/A | OS = 8
+xception65_cityscapes_trainfine | Xception_65 | ImageNet Cityscapes train_fine set | [6, 12, 18] for OS=16 [12, 24, 36] for OS=8 | OS = 4
+xception71_dpc_cityscapes_trainfine | Xception_71 | ImageNet MS-COCO Cityscapes train_fine set | Dense Prediction Cell | OS = 4
+xception71_dpc_cityscapes_trainval | Xception_71 | ImageNet MS-COCO Cityscapes trainval_fine and coarse set | Dense Prediction Cell | OS = 4
+
+In the table, **OS** denotes output stride.
+
+Note for mobilenet v3 models, we use additional commandline flags as follows:
+
+```
+--model_variant={ mobilenet_v3_large_seg | mobilenet_v3_small_seg }
+--image_pooling_crop_size=769,769
+--image_pooling_stride=4,5
+--add_image_level_feature=1
+--aspp_convs_filters=128
+--aspp_with_concat_projection=0
+--aspp_with_squeeze_and_excitation=1
+--decoder_use_sum_merge=1
+--decoder_filters=19
+--decoder_output_is_logits=1
+--image_se_uses_qsigmoid=1
+--decoder_output_stride=8
+--output_stride=32
+```
+
+Checkpoint name | Eval OS | Eval scales | Left-right Flip | Multiply-Adds | Runtime (sec) | Cityscapes mIOU | File Size
+-------------------------------------------------------------------------------------------------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :------------: | :----------------------------: | :-------:
+[mobilenetv2_coco_cityscapes_trainfine](http://download.tensorflow.org/models/deeplabv3_mnv2_cityscapes_train_2018_02_05.tar.gz) | 16 8 | [1.0] [0.75:0.25:1.25] | No Yes | 21.27B 433.24B | 0.8 51.12 | 70.71% (val) 73.57% (val) | 23MB
+[mobilenetv3_large_cityscapes_trainfine](http://download.tensorflow.org/models/deeplab_mnv3_large_cityscapes_trainfine_2019_11_15.tar.gz) | 32 | [1.0] | No | 15.95B | 0.6 | 72.41% (val) | 17MB
+[mobilenetv3_small_cityscapes_trainfine](http://download.tensorflow.org/models/deeplab_mnv3_small_cityscapes_trainfine_2019_11_15.tar.gz) | 32 | [1.0] | No | 4.63B | 0.4 | 68.99% (val) | 5MB
+[xception65_cityscapes_trainfine](http://download.tensorflow.org/models/deeplabv3_cityscapes_train_2018_02_06.tar.gz) | 16 8 | [1.0] [0.75:0.25:1.25] | No Yes | 418.64B 8677.92B | 5.0 422.8 | 78.79% (val) 80.42% (val) | 439MB
+[xception71_dpc_cityscapes_trainfine](http://download.tensorflow.org/models/deeplab_cityscapes_xception71_trainfine_2018_09_08.tar.gz) | 16 | [1.0] | No | 502.07B | - | 80.31% (val) | 445MB
+[xception71_dpc_cityscapes_trainval](http://download.tensorflow.org/models/deeplab_cityscapes_xception71_trainvalfine_2018_09_08.tar.gz) | 8 | [0.75:0.25:2] | Yes | - | - | 82.66% (**test**) | 446MB
+
+### EdgeTPU-DeepLab models on Cityscapes
+
+EdgeTPU is Google's machine learning accelerator architecture for edge devices
+(exists in Coral devices and Pixel4's Neural Core). Leveraging nerual
+architecture search (NAS, also named as Auto-ML) algorithms,
+[EdgeTPU-Mobilenet](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet)
+has been released which yields higher hardware utilization, lower latency, as
+well as better accuracy over Mobilenet-v2/v3. We use EdgeTPU-Mobilenet as the
+backbone and provide checkpoints that have been pretrained on Cityscapes
+train_fine set. We named them as EdgeTPU-DeepLab models.
+
+Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder
+-------------------- | :----------------: | :----------------: | :--: | :-----:
+EdgeTPU-DeepLab | EdgeMobilenet-1.0 | ImageNet | N/A | N/A
+EdgeTPU-DeepLab-slim | EdgeMobilenet-0.75 | ImageNet | N/A | N/A
+
+For EdgeTPU-DeepLab-slim, the backbone feature extractor has depth multiplier =
+0.75 and aspp_convs_filters = 128. We do not employ ASPP nor decoder modules to
+further reduce the latency. We employ the same train/eval flags used for
+MobileNet-v2 DeepLab model. Flags changed for EdgeTPU-DeepLab model are listed
+here.
+
+```
+--decoder_output_stride=''
+--aspp_convs_filters=256
+--model_variant=mobilenet_edgetpu
+```
+
+For EdgeTPU-DeepLab-slim, also include the following flags.
+
+```
+--depth_multiplier=0.75
+--aspp_convs_filters=128
+```
+
+Checkpoint name | Eval OS | Eval scales | Cityscapes mIOU | Multiply-Adds | Simulator latency on Pixel 4 EdgeTPU
+---------------------------------------------------------------------------------------------------- | :--------: | :---------: | :--------------------------: | :------------: | :----------------------------------:
+[EdgeTPU-DeepLab](http://download.tensorflow.org/models/edgetpu-deeplab_2020_03_09.tar.gz) | 32 16 | [1.0] | 70.6% (val) 74.1% (val) | 5.6B 7.1B | 13.8 ms 17.5 ms
+[EdgeTPU-DeepLab-slim](http://download.tensorflow.org/models/edgetpu-deeplab-slim_2020_03_09.tar.gz) | 32 16 | [1.0] | 70.0% (val) 73.2% (val) | 3.5B 4.3B | 9.9 ms 13.2 ms
+
+## DeepLab models trained on ADE20K
+
+### Model details
+
+We provide some checkpoints that have been pretrained on ADE20K training set.
+Note that the model has only been pretrained on ImageNet, following the
+dataset rule.
+
+Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder | Input size
+------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----: | :-----:
+mobilenetv2_ade20k_train | MobileNet-v2 | ImageNet ADE20K training set | N/A | OS = 4 | 257x257
+xception65_ade20k_train | Xception_65 | ImageNet ADE20K training set | [6, 12, 18] for OS=16 [12, 24, 36] for OS=8 | OS = 4 | 513x513
+
+The input dimensions of ADE20K have a huge amount of variation. We resize inputs so that the longest size is 257 for MobileNet-v2 (faster inference) and 513 for Xception_65 (better performation). Note that we also include the decoder module in the MobileNet-v2 checkpoint.
+
+Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size
+------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :-------------------: | :-------:
+[mobilenetv2_ade20k_train](http://download.tensorflow.org/models/deeplabv3_mnv2_ade20k_train_2018_12_03.tar.gz) | 16 | [1.0] | No | 32.04% (val) | 75.41% (val) | 24.8MB
+[xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB
+
+
+## Checkpoints pretrained on ImageNet
+
+Un-tar'ed directory includes:
+
+* model checkpoint (`model.ckpt.data-00000-of-00001`, `model.ckpt.index`).
+
+### Model details
+
+We also provide some checkpoints that are pretrained on ImageNet and/or COCO (as
+post-fixed in the model name) so that one could use this for training your own
+models.
+
+* mobilenet_v2: We refer the interested users to the TensorFlow open source
+ [MobileNet-V2](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet)
+ for details.
+
+* xception_{41,65,71}: We adapt the original Xception model to the task of
+ semantic segmentation with the following changes: (1) more layers, (2) all
+ max pooling operations are replaced by strided (atrous) separable
+ convolutions, and (3) extra batch-norm and ReLU after each 3x3 depthwise
+ convolution are added. We provide three Xception model variants with
+ different network depths.
+
+* resnet_v1_{50,101}_beta: We modify the original ResNet-101 [10], similar to
+ PSPNet [11] by replacing the first 7x7 convolution with three 3x3
+ convolutions. See resnet_v1_beta.py for more details.
+
+Model name | File Size
+-------------------------------------------------------------------------------------- | :-------:
+[xception_41_imagenet](http://download.tensorflow.org/models/xception_41_2018_05_09.tar.gz ) | 288MB
+[xception_65_imagenet](http://download.tensorflow.org/models/deeplabv3_xception_2018_01_04.tar.gz) | 447MB
+[xception_65_imagenet_coco](http://download.tensorflow.org/models/xception_65_coco_pretrained_2018_10_02.tar.gz) | 292MB
+[xception_71_imagenet](http://download.tensorflow.org/models/xception_71_2018_05_09.tar.gz ) | 474MB
+[resnet_v1_50_beta_imagenet](http://download.tensorflow.org/models/resnet_v1_50_2018_05_04.tar.gz) | 274MB
+[resnet_v1_101_beta_imagenet](http://download.tensorflow.org/models/resnet_v1_101_2018_05_04.tar.gz) | 477MB
+
+## References
+
+1. **Mobilenets: Efficient convolutional neural networks for mobile vision applications**
+ Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
+ [[link]](https://arxiv.org/abs/1704.04861). arXiv:1704.04861, 2017.
+
+2. **Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation**
+ Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
+ [[link]](https://arxiv.org/abs/1801.04381). arXiv:1801.04381, 2018.
+
+3. **Xception: Deep Learning with Depthwise Separable Convolutions**
+ François Chollet
+ [[link]](https://arxiv.org/abs/1610.02357). In the Proc. of CVPR, 2017.
+
+4. **Deformable Convolutional Networks -- COCO Detection and Segmentation Challenge 2017 Entry**
+ Haozhi Qi, Zheng Zhang, Bin Xiao, Han Hu, Bowen Cheng, Yichen Wei, Jifeng Dai
+ [[link]](http://presentations.cocodataset.org/COCO17-Detect-MSRA.pdf). ICCV COCO Challenge
+ Workshop, 2017.
+
+5. **The Pascal Visual Object Classes Challenge: A Retrospective**
+ Mark Everingham, S. M. Ali Eslami, Luc Van Gool, Christopher K. I. Williams, John M. Winn, Andrew Zisserman
+ [[link]](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/). IJCV, 2014.
+
+6. **Semantic Contours from Inverse Detectors**
+ Bharath Hariharan, Pablo Arbelaez, Lubomir Bourdev, Subhransu Maji, Jitendra Malik
+ [[link]](http://home.bharathh.info/pubs/codes/SBD/download.html). In the Proc. of ICCV, 2011.
+
+7. **The Cityscapes Dataset for Semantic Urban Scene Understanding**
+ Cordts, Marius, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele.
+ [[link]](https://www.cityscapes-dataset.com/). In the Proc. of CVPR, 2016.
+
+8. **Microsoft COCO: Common Objects in Context**
+ Tsung-Yi Lin, Michael Maire, Serge Belongie, Lubomir Bourdev, Ross Girshick, James Hays, Pietro Perona, Deva Ramanan, C. Lawrence Zitnick, Piotr Dollar
+ [[link]](http://cocodataset.org/). In the Proc. of ECCV, 2014.
+
+9. **ImageNet Large Scale Visual Recognition Challenge**
+ Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, Li Fei-Fei
+ [[link]](http://www.image-net.org/). IJCV, 2015.
+
+10. **Deep Residual Learning for Image Recognition**
+ Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ [[link]](https://arxiv.org/abs/1512.03385). CVPR, 2016.
+
+11. **Pyramid Scene Parsing Network**
+ Hengshuang Zhao, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, Jiaya Jia
+ [[link]](https://arxiv.org/abs/1612.01105). In CVPR, 2017.
+
+12. **Scene Parsing through ADE20K Dataset**
+ Bolei Zhou, Hang Zhao, Xavier Puig, Sanja Fidler, Adela Barriuso, Antonio Torralba
+ [[link]](http://groups.csail.mit.edu/vision/datasets/ADE20K/). In CVPR,
+ 2017.
+
+13. **Searching for MobileNetV3**
+ Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam
+ [[link]](https://arxiv.org/abs/1905.02244). In ICCV, 2019.
diff --git a/models/research/deeplab/g3doc/pascal.md b/models/research/deeplab/g3doc/pascal.md
new file mode 100644
index 0000000000000000000000000000000000000000..f4bc84eabb83e90ff192077749784fb2560057dd
--- /dev/null
+++ b/models/research/deeplab/g3doc/pascal.md
@@ -0,0 +1,161 @@
+# Running DeepLab on PASCAL VOC 2012 Semantic Segmentation Dataset
+
+This page walks through the steps required to run DeepLab on PASCAL VOC 2012 on
+a local machine.
+
+## Download dataset and convert to TFRecord
+
+We have prepared the script (under the folder `datasets`) to download and
+convert PASCAL VOC 2012 semantic segmentation dataset to TFRecord.
+
+```bash
+# From the tensorflow/models/research/deeplab/datasets directory.
+sh download_and_convert_voc2012.sh
+```
+
+The converted dataset will be saved at
+./deeplab/datasets/pascal_voc_seg/tfrecord
+
+## Recommended Directory Structure for Training and Evaluation
+
+```
++ datasets
+ + pascal_voc_seg
+ + VOCdevkit
+ + VOC2012
+ + JPEGImages
+ + SegmentationClass
+ + tfrecord
+ + exp
+ + train_on_train_set
+ + train
+ + eval
+ + vis
+```
+
+where the folder `train_on_train_set` stores the train/eval/vis events and
+results (when training DeepLab on the PASCAL VOC 2012 train set).
+
+## Running the train/eval/vis jobs
+
+A local training job using `xception_65` can be run with the following command:
+
+```bash
+# From tensorflow/models/research/
+python deeplab/train.py \
+ --logtostderr \
+ --training_number_of_steps=30000 \
+ --train_split="train" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --train_crop_size="513,513" \
+ --train_batch_size=1 \
+ --dataset="pascal_voc_seg" \
+ --tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \
+ --train_logdir=${PATH_TO_TRAIN_DIR} \
+ --dataset_dir=${PATH_TO_DATASET}
+```
+
+where ${PATH_TO_INITIAL_CHECKPOINT} is the path to the initial checkpoint
+(usually an ImageNet pretrained checkpoint), ${PATH_TO_TRAIN_DIR} is the
+directory in which training checkpoints and events will be written to, and
+${PATH_TO_DATASET} is the directory in which the PASCAL VOC 2012 dataset
+resides.
+
+**Note that for {train,eval,vis}.py:**
+
+1. In order to reproduce our results, one needs to use large batch size (> 12),
+ and set fine_tune_batch_norm = True. Here, we simply use small batch size
+ during training for the purpose of demonstration. If the users have limited
+ GPU memory at hand, please fine-tune from our provided checkpoints whose
+ batch norm parameters have been trained, and use smaller learning rate with
+ fine_tune_batch_norm = False.
+
+2. The users should change atrous_rates from [6, 12, 18] to [12, 24, 36] if
+ setting output_stride=8.
+
+3. The users could skip the flag, `decoder_output_stride`, if you do not want
+ to use the decoder structure.
+
+A local evaluation job using `xception_65` can be run with the following
+command:
+
+```bash
+# From tensorflow/models/research/
+python deeplab/eval.py \
+ --logtostderr \
+ --eval_split="val" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --eval_crop_size="513,513" \
+ --dataset="pascal_voc_seg" \
+ --checkpoint_dir=${PATH_TO_CHECKPOINT} \
+ --eval_logdir=${PATH_TO_EVAL_DIR} \
+ --dataset_dir=${PATH_TO_DATASET}
+```
+
+where ${PATH_TO_CHECKPOINT} is the path to the trained checkpoint (i.e., the
+path to train_logdir), ${PATH_TO_EVAL_DIR} is the directory in which evaluation
+events will be written to, and ${PATH_TO_DATASET} is the directory in which the
+PASCAL VOC 2012 dataset resides.
+
+A local visualization job using `xception_65` can be run with the following
+command:
+
+```bash
+# From tensorflow/models/research/
+python deeplab/vis.py \
+ --logtostderr \
+ --vis_split="val" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --vis_crop_size="513,513" \
+ --dataset="pascal_voc_seg" \
+ --checkpoint_dir=${PATH_TO_CHECKPOINT} \
+ --vis_logdir=${PATH_TO_VIS_DIR} \
+ --dataset_dir=${PATH_TO_DATASET}
+```
+
+where ${PATH_TO_CHECKPOINT} is the path to the trained checkpoint (i.e., the
+path to train_logdir), ${PATH_TO_VIS_DIR} is the directory in which evaluation
+events will be written to, and ${PATH_TO_DATASET} is the directory in which the
+PASCAL VOC 2012 dataset resides. Note that if the users would like to save the
+segmentation results for evaluation server, set also_save_raw_predictions =
+True.
+
+## Running Tensorboard
+
+Progress for training and evaluation jobs can be inspected using Tensorboard. If
+using the recommended directory structure, Tensorboard can be run using the
+following command:
+
+```bash
+tensorboard --logdir=${PATH_TO_LOG_DIRECTORY}
+```
+
+where `${PATH_TO_LOG_DIRECTORY}` points to the directory that contains the
+train, eval, and vis directories (e.g., the folder `train_on_train_set` in the
+above example). Please note it may take Tensorboard a couple minutes to populate
+with data.
+
+## Example
+
+We provide a script to run the {train,eval,vis,export_model}.py on the PASCAL VOC
+2012 dataset as an example. See the code in local_test.sh for details.
+
+```bash
+# From tensorflow/models/research/deeplab
+sh local_test.sh
+```
diff --git a/models/research/deeplab/g3doc/quantize.md b/models/research/deeplab/g3doc/quantize.md
new file mode 100644
index 0000000000000000000000000000000000000000..d88a2e9a8acbac4a0de6e3ea2bed65cb44535665
--- /dev/null
+++ b/models/research/deeplab/g3doc/quantize.md
@@ -0,0 +1,110 @@
+# Quantize DeepLab model for faster on-device inference
+
+This page describes the steps required to quantize DeepLab model and convert it
+to TFLite for on-device inference. The main steps include:
+
+1. Quantization-aware training
+1. Exporting model
+1. Converting to TFLite FlatBuffer
+
+We provide details for each step below.
+
+## Quantization-aware training
+
+DeepLab supports two approaches to quantize your model.
+
+1. **[Recommended]** Training a non-quantized model until convergence. Then
+ fine-tune the trained float model with quantization using a small learning
+ rate (on PASCAL we use the value of 3e-5) . This fine-tuning step usually
+ takes 2k to 5k steps to converge.
+
+1. Training a deeplab float model with delayed quantization. Usually we delay
+ quantization until the last a few thousand steps in training.
+
+In the current implementation, quantization is only supported with 1)
+`num_clones=1` for training and 2) single scale inference for evaluation,
+visualization and model export. To get the best performance for the quantized
+model, we strongly recommend to train the float model with larger `num_clones`
+and then fine-tune the model with a single clone.
+
+Here shows the commandline to quantize deeplab model trained on PASCAL VOC
+dataset using fine-tuning:
+
+```
+# From tensorflow/models/research/
+python deeplab/train.py \
+ --logtostderr \
+ --training_number_of_steps=3000 \
+ --train_split="train" \
+ --model_variant="mobilenet_v2" \
+ --output_stride=16 \
+ --train_crop_size="513,513" \
+ --train_batch_size=8 \
+ --base_learning_rate=3e-5 \
+ --dataset="pascal_voc_seg" \
+ --initialize_last_layer \
+ --quantize_delay_step=0 \
+ --tf_initial_checkpoint=${PATH_TO_TRAINED_FLOAT_MODEL} \
+ --train_logdir=${PATH_TO_TRAIN_DIR} \
+ --dataset_dir=${PATH_TO_DATASET}
+```
+
+## Converting to TFLite FlatBuffer
+
+First use the following commandline to export your trained model.
+
+```
+# From tensorflow/models/research/
+python deeplab/export_model.py \
+ --checkpoint_path=${CHECKPOINT_PATH} \
+ --quantize_delay_step=0 \
+ --export_path=${OUTPUT_DIR}/frozen_inference_graph.pb
+
+```
+
+Commandline below shows how to convert exported graphdef to TFlite model.
+
+```
+tflite_convert \
+ --graph_def_file=${OUTPUT_DIR}/frozen_inference_graph.pb \
+ --output_file=${OUTPUT_DIR}/frozen_inference_graph.tflite \
+ --output_format=TFLITE \
+ --input_shape=1,513,513,3 \
+ --input_arrays="MobilenetV2/MobilenetV2/input" \
+ --inference_type=QUANTIZED_UINT8 \
+ --inference_input_type=QUANTIZED_UINT8 \
+ --std_dev_values=128 \
+ --mean_values=128 \
+ --change_concat_input_ranges=true \
+ --output_arrays="ArgMax"
+```
+
+**[Important]** Note that converted model expects 513x513 RGB input and doesn't
+include preprocessing (resize and pad input image) and post processing (crop
+padded region and resize to original input size). These steps can be implemented
+outside of TFlite model.
+
+## Quantized model on PASCAL VOC
+
+We provide float and quantized checkpoints that have been pretrained on VOC 2012
+train_aug set, using MobileNet-v2 backbone with different depth multipliers.
+Quantized model usually have 1% decay in mIoU.
+
+For quantized (8bit) model, un-tar'ed directory includes:
+
+* a frozen inference graph (frozen_inference_graph.pb)
+
+* a checkpoint (model.ckpt.data*, model.ckpt.index)
+
+* a converted TFlite FlatBuffer file (frozen_inference_graph.tflite)
+
+Checkpoint name | Eval OS | Eval scales | Left-right Flip | Multiply-Adds | Quantize | PASCAL mIOU | Folder Size | TFLite File Size
+-------------------------------------------------------------------------------------------------------------------------------------------- | :-----: | :---------: | :-------------: | :-----------: | :------: | :----------: | :-------: | :-------:
+[mobilenetv2_dm05_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_trainaug_2018_10_01.tar.gz) | 16 | [1.0] | No | 0.88B | No | 70.19% (val) | 7.6MB | N/A
+[mobilenetv2_dm05_coco_voc_trainaug_8bit](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_train_aug_8bit_2019_04_26.tar.gz) | 16 | [1.0] | No | 0.88B | Yes | 69.65% (val) | 8.2MB | 751.1KB
+[mobilenetv2_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz) | 16 | [1.0] | No | 2.75B | No | 75.32% (val) | 23MB | N/A
+[mobilenetv2_coco_voc_trainaug_8bit](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_8bit_2019_04_26.tar.gz) | 16 | [1.0] | No | 2.75B | Yes | 74.26% (val) | 24MB | 2.2MB
+
+Note that you might need the nightly build of TensorFlow (see
+[here](https://www.tensorflow.org/install) for install instructions) to convert
+above quantized model to TFLite.
diff --git a/models/research/deeplab/input_preprocess.py b/models/research/deeplab/input_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ca8bce4eb9104b22469419c4e6af4beaba9406a
--- /dev/null
+++ b/models/research/deeplab/input_preprocess.py
@@ -0,0 +1,139 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Prepares the data used for DeepLab training/evaluation."""
+import tensorflow as tf
+from deeplab.core import feature_extractor
+from deeplab.core import preprocess_utils
+
+
+# The probability of flipping the images and labels
+# left-right during training
+_PROB_OF_FLIP = 0.5
+
+
+def preprocess_image_and_label(image,
+ label,
+ crop_height,
+ crop_width,
+ min_resize_value=None,
+ max_resize_value=None,
+ resize_factor=None,
+ min_scale_factor=1.,
+ max_scale_factor=1.,
+ scale_factor_step_size=0,
+ ignore_label=255,
+ is_training=True,
+ model_variant=None):
+ """Preprocesses the image and label.
+
+ Args:
+ image: Input image.
+ label: Ground truth annotation label.
+ crop_height: The height value used to crop the image and label.
+ crop_width: The width value used to crop the image and label.
+ min_resize_value: Desired size of the smaller image side.
+ max_resize_value: Maximum allowed size of the larger image side.
+ resize_factor: Resized dimensions are multiple of factor plus one.
+ min_scale_factor: Minimum scale factor value.
+ max_scale_factor: Maximum scale factor value.
+ scale_factor_step_size: The step size from min scale factor to max scale
+ factor. The input is randomly scaled based on the value of
+ (min_scale_factor, max_scale_factor, scale_factor_step_size).
+ ignore_label: The label value which will be ignored for training and
+ evaluation.
+ is_training: If the preprocessing is used for training or not.
+ model_variant: Model variant (string) for choosing how to mean-subtract the
+ images. See feature_extractor.network_map for supported model variants.
+
+ Returns:
+ original_image: Original image (could be resized).
+ processed_image: Preprocessed image.
+ label: Preprocessed ground truth segmentation label.
+
+ Raises:
+ ValueError: Ground truth label not provided during training.
+ """
+ if is_training and label is None:
+ raise ValueError('During training, label must be provided.')
+ if model_variant is None:
+ tf.logging.warning('Default mean-subtraction is performed. Please specify '
+ 'a model_variant. See feature_extractor.network_map for '
+ 'supported model variants.')
+
+ # Keep reference to original image.
+ original_image = image
+
+ processed_image = tf.cast(image, tf.float32)
+
+ if label is not None:
+ label = tf.cast(label, tf.int32)
+
+ # Resize image and label to the desired range.
+ if min_resize_value or max_resize_value:
+ [processed_image, label] = (
+ preprocess_utils.resize_to_range(
+ image=processed_image,
+ label=label,
+ min_size=min_resize_value,
+ max_size=max_resize_value,
+ factor=resize_factor,
+ align_corners=True))
+ # The `original_image` becomes the resized image.
+ original_image = tf.identity(processed_image)
+
+ # Data augmentation by randomly scaling the inputs.
+ if is_training:
+ scale = preprocess_utils.get_random_scale(
+ min_scale_factor, max_scale_factor, scale_factor_step_size)
+ processed_image, label = preprocess_utils.randomly_scale_image_and_label(
+ processed_image, label, scale)
+ processed_image.set_shape([None, None, 3])
+
+ # Pad image and label to have dimensions >= [crop_height, crop_width]
+ image_shape = tf.shape(processed_image)
+ image_height = image_shape[0]
+ image_width = image_shape[1]
+
+ target_height = image_height + tf.maximum(crop_height - image_height, 0)
+ target_width = image_width + tf.maximum(crop_width - image_width, 0)
+
+ # Pad image with mean pixel value.
+ mean_pixel = tf.reshape(
+ feature_extractor.mean_pixel(model_variant), [1, 1, 3])
+ processed_image = preprocess_utils.pad_to_bounding_box(
+ processed_image, 0, 0, target_height, target_width, mean_pixel)
+
+ if label is not None:
+ label = preprocess_utils.pad_to_bounding_box(
+ label, 0, 0, target_height, target_width, ignore_label)
+
+ # Randomly crop the image and label.
+ if is_training and label is not None:
+ processed_image, label = preprocess_utils.random_crop(
+ [processed_image, label], crop_height, crop_width)
+
+ processed_image.set_shape([crop_height, crop_width, 3])
+
+ if label is not None:
+ label.set_shape([crop_height, crop_width, 1])
+
+ if is_training:
+ # Randomly left-right flip the image and label.
+ processed_image, label, _ = preprocess_utils.flip_dim(
+ [processed_image, label], _PROB_OF_FLIP, dim=1)
+
+ return original_image, processed_image, label
diff --git a/models/research/deeplab/local_test.sh b/models/research/deeplab/local_test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d5e4a5f42bb4241d4b6dd1b9d8a2619c4ca9dc8b
--- /dev/null
+++ b/models/research/deeplab/local_test.sh
@@ -0,0 +1,147 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to run local test on PASCAL VOC 2012. Users could also
+# modify from this script for their use case.
+#
+# Usage:
+# # From the tensorflow/models/research/deeplab directory.
+# sh ./local_test.sh
+#
+#
+
+# Exit immediately if a command exits with a non-zero status.
+set -e
+
+# Move one-level up to tensorflow/models/research directory.
+cd ..
+
+# Update PYTHONPATH.
+export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
+
+# Set up the working environment.
+CURRENT_DIR=$(pwd)
+WORK_DIR="${CURRENT_DIR}/deeplab"
+
+# Run model_test first to make sure the PYTHONPATH is correctly set.
+python "${WORK_DIR}"/model_test.py
+
+# Go to datasets folder and download PASCAL VOC 2012 segmentation dataset.
+DATASET_DIR="datasets"
+cd "${WORK_DIR}/${DATASET_DIR}"
+sh download_and_convert_voc2012.sh
+
+# Go back to original directory.
+cd "${CURRENT_DIR}"
+
+# Set up the working directories.
+PASCAL_FOLDER="pascal_voc_seg"
+EXP_FOLDER="exp/train_on_trainval_set"
+INIT_FOLDER="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/init_models"
+TRAIN_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/train"
+EVAL_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/eval"
+VIS_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/vis"
+EXPORT_DIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/export"
+mkdir -p "${INIT_FOLDER}"
+mkdir -p "${TRAIN_LOGDIR}"
+mkdir -p "${EVAL_LOGDIR}"
+mkdir -p "${VIS_LOGDIR}"
+mkdir -p "${EXPORT_DIR}"
+
+# Copy locally the trained checkpoint as the initial checkpoint.
+TF_INIT_ROOT="http://download.tensorflow.org/models"
+TF_INIT_CKPT="deeplabv3_pascal_train_aug_2018_01_04.tar.gz"
+cd "${INIT_FOLDER}"
+wget -nd -c "${TF_INIT_ROOT}/${TF_INIT_CKPT}"
+tar -xf "${TF_INIT_CKPT}"
+cd "${CURRENT_DIR}"
+
+PASCAL_DATASET="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/tfrecord"
+
+# Train 10 iterations.
+NUM_ITERATIONS=10
+python "${WORK_DIR}"/train.py \
+ --logtostderr \
+ --train_split="trainval" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --train_crop_size="513,513" \
+ --train_batch_size=4 \
+ --training_number_of_steps="${NUM_ITERATIONS}" \
+ --fine_tune_batch_norm=true \
+ --tf_initial_checkpoint="${INIT_FOLDER}/deeplabv3_pascal_train_aug/model.ckpt" \
+ --train_logdir="${TRAIN_LOGDIR}" \
+ --dataset_dir="${PASCAL_DATASET}"
+
+# Run evaluation. This performs eval over the full val split (1449 images) and
+# will take a while.
+# Using the provided checkpoint, one should expect mIOU=82.20%.
+python "${WORK_DIR}"/eval.py \
+ --logtostderr \
+ --eval_split="val" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --eval_crop_size="513,513" \
+ --checkpoint_dir="${TRAIN_LOGDIR}" \
+ --eval_logdir="${EVAL_LOGDIR}" \
+ --dataset_dir="${PASCAL_DATASET}" \
+ --max_number_of_evaluations=1
+
+# Visualize the results.
+python "${WORK_DIR}"/vis.py \
+ --logtostderr \
+ --vis_split="val" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --vis_crop_size="513,513" \
+ --checkpoint_dir="${TRAIN_LOGDIR}" \
+ --vis_logdir="${VIS_LOGDIR}" \
+ --dataset_dir="${PASCAL_DATASET}" \
+ --max_number_of_iterations=1
+
+# Export the trained checkpoint.
+CKPT_PATH="${TRAIN_LOGDIR}/model.ckpt-${NUM_ITERATIONS}"
+EXPORT_PATH="${EXPORT_DIR}/frozen_inference_graph.pb"
+
+python "${WORK_DIR}"/export_model.py \
+ --logtostderr \
+ --checkpoint_path="${CKPT_PATH}" \
+ --export_path="${EXPORT_PATH}" \
+ --model_variant="xception_65" \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --num_classes=21 \
+ --crop_size=513 \
+ --crop_size=513 \
+ --inference_scales=1.0
+
+# Run inference with the exported checkpoint.
+# Please refer to the provided deeplab_demo.ipynb for an example.
diff --git a/models/research/deeplab/local_test_mobilenetv2.sh b/models/research/deeplab/local_test_mobilenetv2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c38646fdf6caa3934b7c8db66e53ffbd4f9fd8c6
--- /dev/null
+++ b/models/research/deeplab/local_test_mobilenetv2.sh
@@ -0,0 +1,129 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to run local test on PASCAL VOC 2012 using MobileNet-v2.
+# Users could also modify from this script for their use case.
+#
+# Usage:
+# # From the tensorflow/models/research/deeplab directory.
+# sh ./local_test_mobilenetv2.sh
+#
+#
+
+# Exit immediately if a command exits with a non-zero status.
+set -e
+
+# Move one-level up to tensorflow/models/research directory.
+cd ..
+
+# Update PYTHONPATH.
+export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
+
+# Set up the working environment.
+CURRENT_DIR=$(pwd)
+WORK_DIR="${CURRENT_DIR}/deeplab"
+
+# Run model_test first to make sure the PYTHONPATH is correctly set.
+python "${WORK_DIR}"/model_test.py -v
+
+# Go to datasets folder and download PASCAL VOC 2012 segmentation dataset.
+DATASET_DIR="datasets"
+cd "${WORK_DIR}/${DATASET_DIR}"
+sh download_and_convert_voc2012.sh
+
+# Go back to original directory.
+cd "${CURRENT_DIR}"
+
+# Set up the working directories.
+PASCAL_FOLDER="pascal_voc_seg"
+EXP_FOLDER="exp/train_on_trainval_set_mobilenetv2"
+INIT_FOLDER="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/init_models"
+TRAIN_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/train"
+EVAL_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/eval"
+VIS_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/vis"
+EXPORT_DIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/export"
+mkdir -p "${INIT_FOLDER}"
+mkdir -p "${TRAIN_LOGDIR}"
+mkdir -p "${EVAL_LOGDIR}"
+mkdir -p "${VIS_LOGDIR}"
+mkdir -p "${EXPORT_DIR}"
+
+# Copy locally the trained checkpoint as the initial checkpoint.
+TF_INIT_ROOT="http://download.tensorflow.org/models"
+CKPT_NAME="deeplabv3_mnv2_pascal_train_aug"
+TF_INIT_CKPT="${CKPT_NAME}_2018_01_29.tar.gz"
+cd "${INIT_FOLDER}"
+wget -nd -c "${TF_INIT_ROOT}/${TF_INIT_CKPT}"
+tar -xf "${TF_INIT_CKPT}"
+cd "${CURRENT_DIR}"
+
+PASCAL_DATASET="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/tfrecord"
+
+# Train 10 iterations.
+NUM_ITERATIONS=10
+python "${WORK_DIR}"/train.py \
+ --logtostderr \
+ --train_split="trainval" \
+ --model_variant="mobilenet_v2" \
+ --output_stride=16 \
+ --train_crop_size="513,513" \
+ --train_batch_size=4 \
+ --training_number_of_steps="${NUM_ITERATIONS}" \
+ --fine_tune_batch_norm=true \
+ --tf_initial_checkpoint="${INIT_FOLDER}/${CKPT_NAME}/model.ckpt-30000" \
+ --train_logdir="${TRAIN_LOGDIR}" \
+ --dataset_dir="${PASCAL_DATASET}"
+
+# Run evaluation. This performs eval over the full val split (1449 images) and
+# will take a while.
+# Using the provided checkpoint, one should expect mIOU=75.34%.
+python "${WORK_DIR}"/eval.py \
+ --logtostderr \
+ --eval_split="val" \
+ --model_variant="mobilenet_v2" \
+ --eval_crop_size="513,513" \
+ --checkpoint_dir="${TRAIN_LOGDIR}" \
+ --eval_logdir="${EVAL_LOGDIR}" \
+ --dataset_dir="${PASCAL_DATASET}" \
+ --max_number_of_evaluations=1
+
+# Visualize the results.
+python "${WORK_DIR}"/vis.py \
+ --logtostderr \
+ --vis_split="val" \
+ --model_variant="mobilenet_v2" \
+ --vis_crop_size="513,513" \
+ --checkpoint_dir="${TRAIN_LOGDIR}" \
+ --vis_logdir="${VIS_LOGDIR}" \
+ --dataset_dir="${PASCAL_DATASET}" \
+ --max_number_of_iterations=1
+
+# Export the trained checkpoint.
+CKPT_PATH="${TRAIN_LOGDIR}/model.ckpt-${NUM_ITERATIONS}"
+EXPORT_PATH="${EXPORT_DIR}/frozen_inference_graph.pb"
+
+python "${WORK_DIR}"/export_model.py \
+ --logtostderr \
+ --checkpoint_path="${CKPT_PATH}" \
+ --export_path="${EXPORT_PATH}" \
+ --model_variant="mobilenet_v2" \
+ --num_classes=21 \
+ --crop_size=513 \
+ --crop_size=513 \
+ --inference_scales=1.0
+
+# Run inference with the exported checkpoint.
+# Please refer to the provided deeplab_demo.ipynb for an example.
diff --git a/models/research/deeplab/model.py b/models/research/deeplab/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..311aaa1acb13cb445053ac12fa09e354423e56df
--- /dev/null
+++ b/models/research/deeplab/model.py
@@ -0,0 +1,911 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Provides DeepLab model definition and helper functions.
+
+DeepLab is a deep learning system for semantic image segmentation with
+the following features:
+
+(1) Atrous convolution to explicitly control the resolution at which
+feature responses are computed within Deep Convolutional Neural Networks.
+
+(2) Atrous spatial pyramid pooling (ASPP) to robustly segment objects at
+multiple scales with filters at multiple sampling rates and effective
+fields-of-views.
+
+(3) ASPP module augmented with image-level feature and batch normalization.
+
+(4) A simple yet effective decoder module to recover the object boundaries.
+
+See the following papers for more details:
+
+"Encoder-Decoder with Atrous Separable Convolution for Semantic Image
+Segmentation"
+Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam.
+(https://arxiv.org/abs/1802.02611)
+
+"Rethinking Atrous Convolution for Semantic Image Segmentation,"
+Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam
+(https://arxiv.org/abs/1706.05587)
+
+"DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,
+Atrous Convolution, and Fully Connected CRFs",
+Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy,
+Alan L Yuille (* equal contribution)
+(https://arxiv.org/abs/1606.00915)
+
+"Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected
+CRFs"
+Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy,
+Alan L. Yuille (* equal contribution)
+(https://arxiv.org/abs/1412.7062)
+"""
+import tensorflow as tf
+from tensorflow.contrib import slim as contrib_slim
+from deeplab.core import dense_prediction_cell
+from deeplab.core import feature_extractor
+from deeplab.core import utils
+
+slim = contrib_slim
+
+LOGITS_SCOPE_NAME = 'logits'
+MERGED_LOGITS_SCOPE = 'merged_logits'
+IMAGE_POOLING_SCOPE = 'image_pooling'
+ASPP_SCOPE = 'aspp'
+CONCAT_PROJECTION_SCOPE = 'concat_projection'
+DECODER_SCOPE = 'decoder'
+META_ARCHITECTURE_SCOPE = 'meta_architecture'
+
+PROB_SUFFIX = '_prob'
+
+_resize_bilinear = utils.resize_bilinear
+scale_dimension = utils.scale_dimension
+split_separable_conv2d = utils.split_separable_conv2d
+
+
+def get_extra_layer_scopes(last_layers_contain_logits_only=False):
+ """Gets the scopes for extra layers.
+
+ Args:
+ last_layers_contain_logits_only: Boolean, True if only consider logits as
+ the last layer (i.e., exclude ASPP module, decoder module and so on)
+
+ Returns:
+ A list of scopes for extra layers.
+ """
+ if last_layers_contain_logits_only:
+ return [LOGITS_SCOPE_NAME]
+ else:
+ return [
+ LOGITS_SCOPE_NAME,
+ IMAGE_POOLING_SCOPE,
+ ASPP_SCOPE,
+ CONCAT_PROJECTION_SCOPE,
+ DECODER_SCOPE,
+ META_ARCHITECTURE_SCOPE,
+ ]
+
+
+def predict_labels_multi_scale(images,
+ model_options,
+ eval_scales=(1.0,),
+ add_flipped_images=False):
+ """Predicts segmentation labels.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ model_options: A ModelOptions instance to configure models.
+ eval_scales: The scales to resize images for evaluation.
+ add_flipped_images: Add flipped images for evaluation or not.
+
+ Returns:
+ A dictionary with keys specifying the output_type (e.g., semantic
+ prediction) and values storing Tensors representing predictions (argmax
+ over channels). Each prediction has size [batch, height, width].
+ """
+ outputs_to_predictions = {
+ output: []
+ for output in model_options.outputs_to_num_classes
+ }
+
+ for i, image_scale in enumerate(eval_scales):
+ with tf.variable_scope(tf.get_variable_scope(), reuse=True if i else None):
+ outputs_to_scales_to_logits = multi_scale_logits(
+ images,
+ model_options=model_options,
+ image_pyramid=[image_scale],
+ is_training=False,
+ fine_tune_batch_norm=False)
+
+ if add_flipped_images:
+ with tf.variable_scope(tf.get_variable_scope(), reuse=True):
+ outputs_to_scales_to_logits_reversed = multi_scale_logits(
+ tf.reverse_v2(images, [2]),
+ model_options=model_options,
+ image_pyramid=[image_scale],
+ is_training=False,
+ fine_tune_batch_norm=False)
+
+ for output in sorted(outputs_to_scales_to_logits):
+ scales_to_logits = outputs_to_scales_to_logits[output]
+ logits = _resize_bilinear(
+ scales_to_logits[MERGED_LOGITS_SCOPE],
+ tf.shape(images)[1:3],
+ scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
+ outputs_to_predictions[output].append(
+ tf.expand_dims(tf.nn.softmax(logits), 4))
+
+ if add_flipped_images:
+ scales_to_logits_reversed = (
+ outputs_to_scales_to_logits_reversed[output])
+ logits_reversed = _resize_bilinear(
+ tf.reverse_v2(scales_to_logits_reversed[MERGED_LOGITS_SCOPE], [2]),
+ tf.shape(images)[1:3],
+ scales_to_logits_reversed[MERGED_LOGITS_SCOPE].dtype)
+ outputs_to_predictions[output].append(
+ tf.expand_dims(tf.nn.softmax(logits_reversed), 4))
+
+ for output in sorted(outputs_to_predictions):
+ predictions = outputs_to_predictions[output]
+ # Compute average prediction across different scales and flipped images.
+ predictions = tf.reduce_mean(tf.concat(predictions, 4), axis=4)
+ outputs_to_predictions[output] = tf.argmax(predictions, 3)
+ outputs_to_predictions[output + PROB_SUFFIX] = tf.nn.softmax(predictions)
+
+ return outputs_to_predictions
+
+
+def predict_labels(images, model_options, image_pyramid=None):
+ """Predicts segmentation labels.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ model_options: A ModelOptions instance to configure models.
+ image_pyramid: Input image scales for multi-scale feature extraction.
+
+ Returns:
+ A dictionary with keys specifying the output_type (e.g., semantic
+ prediction) and values storing Tensors representing predictions (argmax
+ over channels). Each prediction has size [batch, height, width].
+ """
+ outputs_to_scales_to_logits = multi_scale_logits(
+ images,
+ model_options=model_options,
+ image_pyramid=image_pyramid,
+ is_training=False,
+ fine_tune_batch_norm=False)
+
+ predictions = {}
+ for output in sorted(outputs_to_scales_to_logits):
+ scales_to_logits = outputs_to_scales_to_logits[output]
+ logits = scales_to_logits[MERGED_LOGITS_SCOPE]
+ # There are two ways to obtain the final prediction results: (1) bilinear
+ # upsampling the logits followed by argmax, or (2) argmax followed by
+ # nearest neighbor upsampling. The second option may introduce the "blocking
+ # effect" but is computationally efficient.
+ if model_options.prediction_with_upsampled_logits:
+ logits = _resize_bilinear(logits,
+ tf.shape(images)[1:3],
+ scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
+ predictions[output] = tf.argmax(logits, 3)
+ predictions[output + PROB_SUFFIX] = tf.nn.softmax(logits)
+ else:
+ argmax_results = tf.argmax(logits, 3)
+ argmax_results = tf.image.resize_nearest_neighbor(
+ tf.expand_dims(argmax_results, 3),
+ tf.shape(images)[1:3],
+ align_corners=True,
+ name='resize_prediction')
+ predictions[output] = tf.squeeze(argmax_results, 3)
+ predictions[output + PROB_SUFFIX] = tf.image.resize_bilinear(
+ tf.nn.softmax(logits),
+ tf.shape(images)[1:3],
+ align_corners=True,
+ name='resize_prob')
+ return predictions
+
+
+def multi_scale_logits(images,
+ model_options,
+ image_pyramid,
+ weight_decay=0.0001,
+ is_training=False,
+ fine_tune_batch_norm=False,
+ nas_training_hyper_parameters=None):
+ """Gets the logits for multi-scale inputs.
+
+ The returned logits are all downsampled (due to max-pooling layers)
+ for both training and evaluation.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ model_options: A ModelOptions instance to configure models.
+ image_pyramid: Input image scales for multi-scale feature extraction.
+ weight_decay: The weight decay for model variables.
+ is_training: Is training or not.
+ fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
+ nas_training_hyper_parameters: A dictionary storing hyper-parameters for
+ training nas models. Its keys are:
+ - `drop_path_keep_prob`: Probability to keep each path in the cell when
+ training.
+ - `total_training_steps`: Total training steps to help drop path
+ probability calculation.
+
+ Returns:
+ outputs_to_scales_to_logits: A map of maps from output_type (e.g.,
+ semantic prediction) to a dictionary of multi-scale logits names to
+ logits. For each output_type, the dictionary has keys which
+ correspond to the scales and values which correspond to the logits.
+ For example, if `scales` equals [1.0, 1.5], then the keys would
+ include 'merged_logits', 'logits_1.00' and 'logits_1.50'.
+
+ Raises:
+ ValueError: If model_options doesn't specify crop_size and its
+ add_image_level_feature = True, since add_image_level_feature requires
+ crop_size information.
+ """
+ # Setup default values.
+ if not image_pyramid:
+ image_pyramid = [1.0]
+ crop_height = (
+ model_options.crop_size[0]
+ if model_options.crop_size else tf.shape(images)[1])
+ crop_width = (
+ model_options.crop_size[1]
+ if model_options.crop_size else tf.shape(images)[2])
+ if model_options.image_pooling_crop_size:
+ image_pooling_crop_height = model_options.image_pooling_crop_size[0]
+ image_pooling_crop_width = model_options.image_pooling_crop_size[1]
+
+ # Compute the height, width for the output logits.
+ if model_options.decoder_output_stride:
+ logits_output_stride = min(model_options.decoder_output_stride)
+ else:
+ logits_output_stride = model_options.output_stride
+
+ logits_height = scale_dimension(
+ crop_height,
+ max(1.0, max(image_pyramid)) / logits_output_stride)
+ logits_width = scale_dimension(
+ crop_width,
+ max(1.0, max(image_pyramid)) / logits_output_stride)
+
+ # Compute the logits for each scale in the image pyramid.
+ outputs_to_scales_to_logits = {
+ k: {}
+ for k in model_options.outputs_to_num_classes
+ }
+
+ num_channels = images.get_shape().as_list()[-1]
+
+ for image_scale in image_pyramid:
+ if image_scale != 1.0:
+ scaled_height = scale_dimension(crop_height, image_scale)
+ scaled_width = scale_dimension(crop_width, image_scale)
+ scaled_crop_size = [scaled_height, scaled_width]
+ scaled_images = _resize_bilinear(images, scaled_crop_size, images.dtype)
+ if model_options.crop_size:
+ scaled_images.set_shape(
+ [None, scaled_height, scaled_width, num_channels])
+ # Adjust image_pooling_crop_size accordingly.
+ scaled_image_pooling_crop_size = None
+ if model_options.image_pooling_crop_size:
+ scaled_image_pooling_crop_size = [
+ scale_dimension(image_pooling_crop_height, image_scale),
+ scale_dimension(image_pooling_crop_width, image_scale)]
+ else:
+ scaled_crop_size = model_options.crop_size
+ scaled_images = images
+ scaled_image_pooling_crop_size = model_options.image_pooling_crop_size
+
+ updated_options = model_options._replace(
+ crop_size=scaled_crop_size,
+ image_pooling_crop_size=scaled_image_pooling_crop_size)
+ outputs_to_logits = _get_logits(
+ scaled_images,
+ updated_options,
+ weight_decay=weight_decay,
+ reuse=tf.AUTO_REUSE,
+ is_training=is_training,
+ fine_tune_batch_norm=fine_tune_batch_norm,
+ nas_training_hyper_parameters=nas_training_hyper_parameters)
+
+ # Resize the logits to have the same dimension before merging.
+ for output in sorted(outputs_to_logits):
+ outputs_to_logits[output] = _resize_bilinear(
+ outputs_to_logits[output], [logits_height, logits_width],
+ outputs_to_logits[output].dtype)
+
+ # Return when only one input scale.
+ if len(image_pyramid) == 1:
+ for output in sorted(model_options.outputs_to_num_classes):
+ outputs_to_scales_to_logits[output][
+ MERGED_LOGITS_SCOPE] = outputs_to_logits[output]
+ return outputs_to_scales_to_logits
+
+ # Save logits to the output map.
+ for output in sorted(model_options.outputs_to_num_classes):
+ outputs_to_scales_to_logits[output][
+ 'logits_%.2f' % image_scale] = outputs_to_logits[output]
+
+ # Merge the logits from all the multi-scale inputs.
+ for output in sorted(model_options.outputs_to_num_classes):
+ # Concatenate the multi-scale logits for each output type.
+ all_logits = [
+ tf.expand_dims(logits, axis=4)
+ for logits in outputs_to_scales_to_logits[output].values()
+ ]
+ all_logits = tf.concat(all_logits, 4)
+ merge_fn = (
+ tf.reduce_max
+ if model_options.merge_method == 'max' else tf.reduce_mean)
+ outputs_to_scales_to_logits[output][MERGED_LOGITS_SCOPE] = merge_fn(
+ all_logits, axis=4)
+
+ return outputs_to_scales_to_logits
+
+
+def extract_features(images,
+ model_options,
+ weight_decay=0.0001,
+ reuse=None,
+ is_training=False,
+ fine_tune_batch_norm=False,
+ nas_training_hyper_parameters=None):
+ """Extracts features by the particular model_variant.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ model_options: A ModelOptions instance to configure models.
+ weight_decay: The weight decay for model variables.
+ reuse: Reuse the model variables or not.
+ is_training: Is training or not.
+ fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
+ nas_training_hyper_parameters: A dictionary storing hyper-parameters for
+ training nas models. Its keys are:
+ - `drop_path_keep_prob`: Probability to keep each path in the cell when
+ training.
+ - `total_training_steps`: Total training steps to help drop path
+ probability calculation.
+
+ Returns:
+ concat_logits: A tensor of size [batch, feature_height, feature_width,
+ feature_channels], where feature_height/feature_width are determined by
+ the images height/width and output_stride.
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+ """
+ features, end_points = feature_extractor.extract_features(
+ images,
+ output_stride=model_options.output_stride,
+ multi_grid=model_options.multi_grid,
+ model_variant=model_options.model_variant,
+ depth_multiplier=model_options.depth_multiplier,
+ divisible_by=model_options.divisible_by,
+ weight_decay=weight_decay,
+ reuse=reuse,
+ is_training=is_training,
+ preprocessed_images_dtype=model_options.preprocessed_images_dtype,
+ fine_tune_batch_norm=fine_tune_batch_norm,
+ nas_architecture_options=model_options.nas_architecture_options,
+ nas_training_hyper_parameters=nas_training_hyper_parameters,
+ use_bounded_activation=model_options.use_bounded_activation)
+
+ if not model_options.aspp_with_batch_norm:
+ return features, end_points
+ else:
+ if model_options.dense_prediction_cell_config is not None:
+ tf.logging.info('Using dense prediction cell config.')
+ dense_prediction_layer = dense_prediction_cell.DensePredictionCell(
+ config=model_options.dense_prediction_cell_config,
+ hparams={
+ 'conv_rate_multiplier': 16 // model_options.output_stride,
+ })
+ concat_logits = dense_prediction_layer.build_cell(
+ features,
+ output_stride=model_options.output_stride,
+ crop_size=model_options.crop_size,
+ image_pooling_crop_size=model_options.image_pooling_crop_size,
+ weight_decay=weight_decay,
+ reuse=reuse,
+ is_training=is_training,
+ fine_tune_batch_norm=fine_tune_batch_norm)
+ return concat_logits, end_points
+ else:
+ # The following codes employ the DeepLabv3 ASPP module. Note that we
+ # could express the ASPP module as one particular dense prediction
+ # cell architecture. We do not do so but leave the following codes
+ # for backward compatibility.
+ batch_norm_params = utils.get_batch_norm_params(
+ decay=0.9997,
+ epsilon=1e-5,
+ scale=True,
+ is_training=(is_training and fine_tune_batch_norm),
+ sync_batch_norm_method=model_options.sync_batch_norm_method)
+ batch_norm = utils.get_batch_norm_fn(
+ model_options.sync_batch_norm_method)
+ activation_fn = (
+ tf.nn.relu6 if model_options.use_bounded_activation else tf.nn.relu)
+ with slim.arg_scope(
+ [slim.conv2d, slim.separable_conv2d],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=activation_fn,
+ normalizer_fn=batch_norm,
+ padding='SAME',
+ stride=1,
+ reuse=reuse):
+ with slim.arg_scope([batch_norm], **batch_norm_params):
+ depth = model_options.aspp_convs_filters
+ branch_logits = []
+
+ if model_options.add_image_level_feature:
+ if model_options.crop_size is not None:
+ image_pooling_crop_size = model_options.image_pooling_crop_size
+ # If image_pooling_crop_size is not specified, use crop_size.
+ if image_pooling_crop_size is None:
+ image_pooling_crop_size = model_options.crop_size
+ pool_height = scale_dimension(
+ image_pooling_crop_size[0],
+ 1. / model_options.output_stride)
+ pool_width = scale_dimension(
+ image_pooling_crop_size[1],
+ 1. / model_options.output_stride)
+ image_feature = slim.avg_pool2d(
+ features, [pool_height, pool_width],
+ model_options.image_pooling_stride, padding='VALID')
+ resize_height = scale_dimension(
+ model_options.crop_size[0],
+ 1. / model_options.output_stride)
+ resize_width = scale_dimension(
+ model_options.crop_size[1],
+ 1. / model_options.output_stride)
+ else:
+ # If crop_size is None, we simply do global pooling.
+ pool_height = tf.shape(features)[1]
+ pool_width = tf.shape(features)[2]
+ image_feature = tf.reduce_mean(
+ features, axis=[1, 2], keepdims=True)
+ resize_height = pool_height
+ resize_width = pool_width
+ image_feature_activation_fn = tf.nn.relu
+ image_feature_normalizer_fn = batch_norm
+ if model_options.aspp_with_squeeze_and_excitation:
+ image_feature_activation_fn = tf.nn.sigmoid
+ if model_options.image_se_uses_qsigmoid:
+ image_feature_activation_fn = utils.q_sigmoid
+ image_feature_normalizer_fn = None
+ image_feature = slim.conv2d(
+ image_feature, depth, 1,
+ activation_fn=image_feature_activation_fn,
+ normalizer_fn=image_feature_normalizer_fn,
+ scope=IMAGE_POOLING_SCOPE)
+ image_feature = _resize_bilinear(
+ image_feature,
+ [resize_height, resize_width],
+ image_feature.dtype)
+ # Set shape for resize_height/resize_width if they are not Tensor.
+ if isinstance(resize_height, tf.Tensor):
+ resize_height = None
+ if isinstance(resize_width, tf.Tensor):
+ resize_width = None
+ image_feature.set_shape([None, resize_height, resize_width, depth])
+ if not model_options.aspp_with_squeeze_and_excitation:
+ branch_logits.append(image_feature)
+
+ # Employ a 1x1 convolution.
+ branch_logits.append(slim.conv2d(features, depth, 1,
+ scope=ASPP_SCOPE + str(0)))
+
+ if model_options.atrous_rates:
+ # Employ 3x3 convolutions with different atrous rates.
+ for i, rate in enumerate(model_options.atrous_rates, 1):
+ scope = ASPP_SCOPE + str(i)
+ if model_options.aspp_with_separable_conv:
+ aspp_features = split_separable_conv2d(
+ features,
+ filters=depth,
+ rate=rate,
+ weight_decay=weight_decay,
+ scope=scope)
+ else:
+ aspp_features = slim.conv2d(
+ features, depth, 3, rate=rate, scope=scope)
+ branch_logits.append(aspp_features)
+
+ # Merge branch logits.
+ concat_logits = tf.concat(branch_logits, 3)
+ if model_options.aspp_with_concat_projection:
+ concat_logits = slim.conv2d(
+ concat_logits, depth, 1, scope=CONCAT_PROJECTION_SCOPE)
+ concat_logits = slim.dropout(
+ concat_logits,
+ keep_prob=0.9,
+ is_training=is_training,
+ scope=CONCAT_PROJECTION_SCOPE + '_dropout')
+ if (model_options.add_image_level_feature and
+ model_options.aspp_with_squeeze_and_excitation):
+ concat_logits *= image_feature
+
+ return concat_logits, end_points
+
+
+def _get_logits(images,
+ model_options,
+ weight_decay=0.0001,
+ reuse=None,
+ is_training=False,
+ fine_tune_batch_norm=False,
+ nas_training_hyper_parameters=None):
+ """Gets the logits by atrous/image spatial pyramid pooling.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ model_options: A ModelOptions instance to configure models.
+ weight_decay: The weight decay for model variables.
+ reuse: Reuse the model variables or not.
+ is_training: Is training or not.
+ fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
+ nas_training_hyper_parameters: A dictionary storing hyper-parameters for
+ training nas models. Its keys are:
+ - `drop_path_keep_prob`: Probability to keep each path in the cell when
+ training.
+ - `total_training_steps`: Total training steps to help drop path
+ probability calculation.
+
+ Returns:
+ outputs_to_logits: A map from output_type to logits.
+ """
+ features, end_points = extract_features(
+ images,
+ model_options,
+ weight_decay=weight_decay,
+ reuse=reuse,
+ is_training=is_training,
+ fine_tune_batch_norm=fine_tune_batch_norm,
+ nas_training_hyper_parameters=nas_training_hyper_parameters)
+
+ if model_options.decoder_output_stride:
+ crop_size = model_options.crop_size
+ if crop_size is None:
+ crop_size = [tf.shape(images)[1], tf.shape(images)[2]]
+ features = refine_by_decoder(
+ features,
+ end_points,
+ crop_size=crop_size,
+ decoder_output_stride=model_options.decoder_output_stride,
+ decoder_use_separable_conv=model_options.decoder_use_separable_conv,
+ decoder_use_sum_merge=model_options.decoder_use_sum_merge,
+ decoder_filters=model_options.decoder_filters,
+ decoder_output_is_logits=model_options.decoder_output_is_logits,
+ model_variant=model_options.model_variant,
+ weight_decay=weight_decay,
+ reuse=reuse,
+ is_training=is_training,
+ fine_tune_batch_norm=fine_tune_batch_norm,
+ use_bounded_activation=model_options.use_bounded_activation)
+
+ outputs_to_logits = {}
+ for output in sorted(model_options.outputs_to_num_classes):
+ if model_options.decoder_output_is_logits:
+ outputs_to_logits[output] = tf.identity(features,
+ name=output)
+ else:
+ outputs_to_logits[output] = get_branch_logits(
+ features,
+ model_options.outputs_to_num_classes[output],
+ model_options.atrous_rates,
+ aspp_with_batch_norm=model_options.aspp_with_batch_norm,
+ kernel_size=model_options.logits_kernel_size,
+ weight_decay=weight_decay,
+ reuse=reuse,
+ scope_suffix=output)
+
+ return outputs_to_logits
+
+
+def refine_by_decoder(features,
+ end_points,
+ crop_size=None,
+ decoder_output_stride=None,
+ decoder_use_separable_conv=False,
+ decoder_use_sum_merge=False,
+ decoder_filters=256,
+ decoder_output_is_logits=False,
+ model_variant=None,
+ weight_decay=0.0001,
+ reuse=None,
+ is_training=False,
+ fine_tune_batch_norm=False,
+ use_bounded_activation=False,
+ sync_batch_norm_method='None'):
+ """Adds the decoder to obtain sharper segmentation results.
+
+ Args:
+ features: A tensor of size [batch, features_height, features_width,
+ features_channels].
+ end_points: A dictionary from components of the network to the corresponding
+ activation.
+ crop_size: A tuple [crop_height, crop_width] specifying whole patch crop
+ size.
+ decoder_output_stride: A list of integers specifying the output stride of
+ low-level features used in the decoder module.
+ decoder_use_separable_conv: Employ separable convolution for decoder or not.
+ decoder_use_sum_merge: Boolean, decoder uses simple sum merge or not.
+ decoder_filters: Integer, decoder filter size.
+ decoder_output_is_logits: Boolean, using decoder output as logits or not.
+ model_variant: Model variant for feature extraction.
+ weight_decay: The weight decay for model variables.
+ reuse: Reuse the model variables or not.
+ is_training: Is training or not.
+ fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
+ use_bounded_activation: Whether or not to use bounded activations. Bounded
+ activations better lend themselves to quantized inference.
+ sync_batch_norm_method: String, method used to sync batch norm. Currently
+ only support `None` (no sync batch norm) and `tpu` (use tpu code to
+ sync batch norm).
+
+ Returns:
+ Decoder output with size [batch, decoder_height, decoder_width,
+ decoder_channels].
+
+ Raises:
+ ValueError: If crop_size is None.
+ """
+ if crop_size is None:
+ raise ValueError('crop_size must be provided when using decoder.')
+ batch_norm_params = utils.get_batch_norm_params(
+ decay=0.9997,
+ epsilon=1e-5,
+ scale=True,
+ is_training=(is_training and fine_tune_batch_norm),
+ sync_batch_norm_method=sync_batch_norm_method)
+ batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
+ decoder_depth = decoder_filters
+ projected_filters = 48
+ if decoder_use_sum_merge:
+ # When using sum merge, the projected filters must be equal to decoder
+ # filters.
+ projected_filters = decoder_filters
+ if decoder_output_is_logits:
+ # Overwrite the setting when decoder output is logits.
+ activation_fn = None
+ normalizer_fn = None
+ conv2d_kernel = 1
+ # Use original conv instead of separable conv.
+ decoder_use_separable_conv = False
+ else:
+ # Default setting when decoder output is not logits.
+ activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
+ normalizer_fn = batch_norm
+ conv2d_kernel = 3
+ with slim.arg_scope(
+ [slim.conv2d, slim.separable_conv2d],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=activation_fn,
+ normalizer_fn=normalizer_fn,
+ padding='SAME',
+ stride=1,
+ reuse=reuse):
+ with slim.arg_scope([batch_norm], **batch_norm_params):
+ with tf.variable_scope(DECODER_SCOPE, DECODER_SCOPE, [features]):
+ decoder_features = features
+ decoder_stage = 0
+ scope_suffix = ''
+ for output_stride in decoder_output_stride:
+ feature_list = feature_extractor.networks_to_feature_maps[
+ model_variant][
+ feature_extractor.DECODER_END_POINTS][output_stride]
+ # If only one decoder stage, we do not change the scope name in
+ # order for backward compactibility.
+ if decoder_stage:
+ scope_suffix = '_{}'.format(decoder_stage)
+ for i, name in enumerate(feature_list):
+ decoder_features_list = [decoder_features]
+ # MobileNet and NAS variants use different naming convention.
+ if ('mobilenet' in model_variant or
+ model_variant.startswith('mnas') or
+ model_variant.startswith('nas')):
+ feature_name = name
+ else:
+ feature_name = '{}/{}'.format(
+ feature_extractor.name_scope[model_variant], name)
+ decoder_features_list.append(
+ slim.conv2d(
+ end_points[feature_name],
+ projected_filters,
+ 1,
+ scope='feature_projection' + str(i) + scope_suffix))
+ # Determine the output size.
+ decoder_height = scale_dimension(crop_size[0], 1.0 / output_stride)
+ decoder_width = scale_dimension(crop_size[1], 1.0 / output_stride)
+ # Resize to decoder_height/decoder_width.
+ for j, feature in enumerate(decoder_features_list):
+ decoder_features_list[j] = _resize_bilinear(
+ feature, [decoder_height, decoder_width], feature.dtype)
+ h = (None if isinstance(decoder_height, tf.Tensor)
+ else decoder_height)
+ w = (None if isinstance(decoder_width, tf.Tensor)
+ else decoder_width)
+ decoder_features_list[j].set_shape([None, h, w, None])
+ if decoder_use_sum_merge:
+ decoder_features = _decoder_with_sum_merge(
+ decoder_features_list,
+ decoder_depth,
+ conv2d_kernel=conv2d_kernel,
+ decoder_use_separable_conv=decoder_use_separable_conv,
+ weight_decay=weight_decay,
+ scope_suffix=scope_suffix)
+ else:
+ if not decoder_use_separable_conv:
+ scope_suffix = str(i) + scope_suffix
+ decoder_features = _decoder_with_concat_merge(
+ decoder_features_list,
+ decoder_depth,
+ decoder_use_separable_conv=decoder_use_separable_conv,
+ weight_decay=weight_decay,
+ scope_suffix=scope_suffix)
+ decoder_stage += 1
+ return decoder_features
+
+
+def _decoder_with_sum_merge(decoder_features_list,
+ decoder_depth,
+ conv2d_kernel=3,
+ decoder_use_separable_conv=True,
+ weight_decay=0.0001,
+ scope_suffix=''):
+ """Decoder with sum to merge features.
+
+ Args:
+ decoder_features_list: A list of decoder features.
+ decoder_depth: Integer, the filters used in the convolution.
+ conv2d_kernel: Integer, the convolution kernel size.
+ decoder_use_separable_conv: Boolean, use separable conv or not.
+ weight_decay: Weight decay for the model variables.
+ scope_suffix: String, used in the scope suffix.
+
+ Returns:
+ decoder features merged with sum.
+
+ Raises:
+ RuntimeError: If decoder_features_list have length not equal to 2.
+ """
+ if len(decoder_features_list) != 2:
+ raise RuntimeError('Expect decoder_features has length 2.')
+ # Only apply one convolution when decoder use sum merge.
+ if decoder_use_separable_conv:
+ decoder_features = split_separable_conv2d(
+ decoder_features_list[0],
+ filters=decoder_depth,
+ rate=1,
+ weight_decay=weight_decay,
+ scope='decoder_split_sep_conv0'+scope_suffix) + decoder_features_list[1]
+ else:
+ decoder_features = slim.conv2d(
+ decoder_features_list[0],
+ decoder_depth,
+ conv2d_kernel,
+ scope='decoder_conv0'+scope_suffix) + decoder_features_list[1]
+ return decoder_features
+
+
+def _decoder_with_concat_merge(decoder_features_list,
+ decoder_depth,
+ decoder_use_separable_conv=True,
+ weight_decay=0.0001,
+ scope_suffix=''):
+ """Decoder with concatenation to merge features.
+
+ This decoder method applies two convolutions to smooth the features obtained
+ by concatenating the input decoder_features_list.
+
+ This decoder module is proposed in the DeepLabv3+ paper.
+
+ Args:
+ decoder_features_list: A list of decoder features.
+ decoder_depth: Integer, the filters used in the convolution.
+ decoder_use_separable_conv: Boolean, use separable conv or not.
+ weight_decay: Weight decay for the model variables.
+ scope_suffix: String, used in the scope suffix.
+
+ Returns:
+ decoder features merged with concatenation.
+ """
+ if decoder_use_separable_conv:
+ decoder_features = split_separable_conv2d(
+ tf.concat(decoder_features_list, 3),
+ filters=decoder_depth,
+ rate=1,
+ weight_decay=weight_decay,
+ scope='decoder_conv0'+scope_suffix)
+ decoder_features = split_separable_conv2d(
+ decoder_features,
+ filters=decoder_depth,
+ rate=1,
+ weight_decay=weight_decay,
+ scope='decoder_conv1'+scope_suffix)
+ else:
+ num_convs = 2
+ decoder_features = slim.repeat(
+ tf.concat(decoder_features_list, 3),
+ num_convs,
+ slim.conv2d,
+ decoder_depth,
+ 3,
+ scope='decoder_conv'+scope_suffix)
+ return decoder_features
+
+
+def get_branch_logits(features,
+ num_classes,
+ atrous_rates=None,
+ aspp_with_batch_norm=False,
+ kernel_size=1,
+ weight_decay=0.0001,
+ reuse=None,
+ scope_suffix=''):
+ """Gets the logits from each model's branch.
+
+ The underlying model is branched out in the last layer when atrous
+ spatial pyramid pooling is employed, and all branches are sum-merged
+ to form the final logits.
+
+ Args:
+ features: A float tensor of shape [batch, height, width, channels].
+ num_classes: Number of classes to predict.
+ atrous_rates: A list of atrous convolution rates for last layer.
+ aspp_with_batch_norm: Use batch normalization layers for ASPP.
+ kernel_size: Kernel size for convolution.
+ weight_decay: Weight decay for the model variables.
+ reuse: Reuse model variables or not.
+ scope_suffix: Scope suffix for the model variables.
+
+ Returns:
+ Merged logits with shape [batch, height, width, num_classes].
+
+ Raises:
+ ValueError: Upon invalid input kernel_size value.
+ """
+ # When using batch normalization with ASPP, ASPP has been applied before
+ # in extract_features, and thus we simply apply 1x1 convolution here.
+ if aspp_with_batch_norm or atrous_rates is None:
+ if kernel_size != 1:
+ raise ValueError('Kernel size must be 1 when atrous_rates is None or '
+ 'using aspp_with_batch_norm. Gets %d.' % kernel_size)
+ atrous_rates = [1]
+
+ with slim.arg_scope(
+ [slim.conv2d],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
+ reuse=reuse):
+ with tf.variable_scope(LOGITS_SCOPE_NAME, LOGITS_SCOPE_NAME, [features]):
+ branch_logits = []
+ for i, rate in enumerate(atrous_rates):
+ scope = scope_suffix
+ if i:
+ scope += '_%d' % i
+
+ branch_logits.append(
+ slim.conv2d(
+ features,
+ num_classes,
+ kernel_size=kernel_size,
+ rate=rate,
+ activation_fn=None,
+ normalizer_fn=None,
+ scope=scope))
+
+ return tf.add_n(branch_logits)
diff --git a/models/research/deeplab/model_test.py b/models/research/deeplab/model_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8413d7395d022adb4f43223eb06a4bdc1aa53db
--- /dev/null
+++ b/models/research/deeplab/model_test.py
@@ -0,0 +1,148 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for DeepLab model and some helper functions."""
+
+import tensorflow as tf
+
+from deeplab import common
+from deeplab import model
+
+
+class DeeplabModelTest(tf.test.TestCase):
+
+ def testWrongDeepLabVariant(self):
+ model_options = common.ModelOptions([])._replace(
+ model_variant='no_such_variant')
+ with self.assertRaises(ValueError):
+ model._get_logits(images=[], model_options=model_options)
+
+ def testBuildDeepLabv2(self):
+ batch_size = 2
+ crop_size = [41, 41]
+
+ # Test with two image_pyramids.
+ image_pyramids = [[1], [0.5, 1]]
+
+ # Test two model variants.
+ model_variants = ['xception_65', 'mobilenet_v2']
+
+ # Test with two output_types.
+ outputs_to_num_classes = {'semantic': 3,
+ 'direction': 2}
+
+ expected_endpoints = [['merged_logits'],
+ ['merged_logits',
+ 'logits_0.50',
+ 'logits_1.00']]
+ expected_num_logits = [1, 3]
+
+ for model_variant in model_variants:
+ model_options = common.ModelOptions(outputs_to_num_classes)._replace(
+ add_image_level_feature=False,
+ aspp_with_batch_norm=False,
+ aspp_with_separable_conv=False,
+ model_variant=model_variant)
+
+ for i, image_pyramid in enumerate(image_pyramids):
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g):
+ inputs = tf.random_uniform(
+ (batch_size, crop_size[0], crop_size[1], 3))
+ outputs_to_scales_to_logits = model.multi_scale_logits(
+ inputs, model_options, image_pyramid=image_pyramid)
+
+ # Check computed results for each output type.
+ for output in outputs_to_num_classes:
+ scales_to_logits = outputs_to_scales_to_logits[output]
+ self.assertListEqual(sorted(scales_to_logits.keys()),
+ sorted(expected_endpoints[i]))
+
+ # Expected number of logits = len(image_pyramid) + 1, since the
+ # last logits is merged from all the scales.
+ self.assertEqual(len(scales_to_logits), expected_num_logits[i])
+
+ def testForwardpassDeepLabv3plus(self):
+ crop_size = [33, 33]
+ outputs_to_num_classes = {'semantic': 3}
+
+ model_options = common.ModelOptions(
+ outputs_to_num_classes,
+ crop_size,
+ output_stride=16
+ )._replace(
+ add_image_level_feature=True,
+ aspp_with_batch_norm=True,
+ logits_kernel_size=1,
+ decoder_output_stride=[4],
+ model_variant='mobilenet_v2') # Employ MobileNetv2 for fast test.
+
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g) as sess:
+ inputs = tf.random_uniform(
+ (1, crop_size[0], crop_size[1], 3))
+ outputs_to_scales_to_logits = model.multi_scale_logits(
+ inputs,
+ model_options,
+ image_pyramid=[1.0])
+
+ sess.run(tf.global_variables_initializer())
+ outputs_to_scales_to_logits = sess.run(outputs_to_scales_to_logits)
+
+ # Check computed results for each output type.
+ for output in outputs_to_num_classes:
+ scales_to_logits = outputs_to_scales_to_logits[output]
+ # Expect only one output.
+ self.assertEqual(len(scales_to_logits), 1)
+ for logits in scales_to_logits.values():
+ self.assertTrue(logits.any())
+
+ def testBuildDeepLabWithDensePredictionCell(self):
+ batch_size = 1
+ crop_size = [33, 33]
+ outputs_to_num_classes = {'semantic': 2}
+ expected_endpoints = ['merged_logits']
+ dense_prediction_cell_config = [
+ {'kernel': 3, 'rate': [1, 6], 'op': 'conv', 'input': -1},
+ {'kernel': 3, 'rate': [18, 15], 'op': 'conv', 'input': 0},
+ ]
+ model_options = common.ModelOptions(
+ outputs_to_num_classes,
+ crop_size,
+ output_stride=16)._replace(
+ aspp_with_batch_norm=True,
+ model_variant='mobilenet_v2',
+ dense_prediction_cell_config=dense_prediction_cell_config)
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g):
+ inputs = tf.random_uniform(
+ (batch_size, crop_size[0], crop_size[1], 3))
+ outputs_to_scales_to_model_results = model.multi_scale_logits(
+ inputs,
+ model_options,
+ image_pyramid=[1.0])
+ for output in outputs_to_num_classes:
+ scales_to_model_results = outputs_to_scales_to_model_results[output]
+ self.assertListEqual(
+ list(scales_to_model_results), expected_endpoints)
+ self.assertEqual(len(scales_to_model_results), 1)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/testing/info.md b/models/research/deeplab/testing/info.md
new file mode 100644
index 0000000000000000000000000000000000000000..b84d2adb1c5088ed2a6ec4799de7764d64f867b7
--- /dev/null
+++ b/models/research/deeplab/testing/info.md
@@ -0,0 +1,6 @@
+This directory contains testing data.
+
+# pascal_voc_seg
+This folder contains data specific to pascal_voc_seg dataset. val-00000-of-00001.tfrecord contains
+three randomly generated images with format defined in
+tensorflow/models/research/deeplab/datasets/build_voc2012_data.py.
diff --git a/models/research/deeplab/testing/pascal_voc_seg/val-00000-of-00001.tfrecord b/models/research/deeplab/testing/pascal_voc_seg/val-00000-of-00001.tfrecord
new file mode 100644
index 0000000000000000000000000000000000000000..de9dee50f7973c52305b2692f00a5d6f396f9fbe
--- /dev/null
+++ b/models/research/deeplab/testing/pascal_voc_seg/val-00000-of-00001.tfrecord
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:88314133dce131cdb6a93f37ec2e96c3efdb2f9a111defae284d1530fee3207a
+size 1137674
diff --git a/models/research/deeplab/train.py b/models/research/deeplab/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbe060dccd41793e3e843f4fcbe155576e42eb14
--- /dev/null
+++ b/models/research/deeplab/train.py
@@ -0,0 +1,464 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Training script for the DeepLab model.
+
+See model.py for more details and usage.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import six
+import tensorflow as tf
+from tensorflow.contrib import quantize as contrib_quantize
+from tensorflow.contrib import tfprof as contrib_tfprof
+from deeplab import common
+from deeplab import model
+from deeplab.datasets import data_generator
+from deeplab.utils import train_utils
+from deployment import model_deploy
+
+slim = tf.contrib.slim
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+# Settings for multi-GPUs/multi-replicas training.
+
+flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.')
+
+flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')
+
+flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')
+
+flags.DEFINE_integer('startup_delay_steps', 15,
+ 'Number of training steps between replicas startup.')
+
+flags.DEFINE_integer(
+ 'num_ps_tasks', 0,
+ 'The number of parameter servers. If the value is 0, then '
+ 'the parameters are handled locally by the worker.')
+
+flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
+
+flags.DEFINE_integer('task', 0, 'The task ID.')
+
+# Settings for logging.
+
+flags.DEFINE_string('train_logdir', None,
+ 'Where the checkpoint and logs are stored.')
+
+flags.DEFINE_integer('log_steps', 10,
+ 'Display logging information at every log_steps.')
+
+flags.DEFINE_integer('save_interval_secs', 1200,
+ 'How often, in seconds, we save the model to disk.')
+
+flags.DEFINE_integer('save_summaries_secs', 600,
+ 'How often, in seconds, we compute the summaries.')
+
+flags.DEFINE_boolean(
+ 'save_summaries_images', False,
+ 'Save sample inputs, labels, and semantic predictions as '
+ 'images to summary.')
+
+# Settings for profiling.
+
+flags.DEFINE_string('profile_logdir', None,
+ 'Where the profile files are stored.')
+
+# Settings for training strategy.
+
+flags.DEFINE_enum('optimizer', 'momentum', ['momentum', 'adam'],
+ 'Which optimizer to use.')
+
+
+# Momentum optimizer flags
+
+flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
+ 'Learning rate policy for training.')
+
+# Use 0.007 when training on PASCAL augmented training set, train_aug. When
+# fine-tuning on PASCAL trainval set, use learning rate=0.0001.
+flags.DEFINE_float('base_learning_rate', .0001,
+ 'The base learning rate for model training.')
+
+flags.DEFINE_float('decay_steps', 0.0,
+ 'Decay steps for polynomial learning rate schedule.')
+
+flags.DEFINE_float('end_learning_rate', 0.0,
+ 'End learning rate for polynomial learning rate schedule.')
+
+flags.DEFINE_float('learning_rate_decay_factor', 0.1,
+ 'The rate to decay the base learning rate.')
+
+flags.DEFINE_integer('learning_rate_decay_step', 2000,
+ 'Decay the base learning rate at a fixed step.')
+
+flags.DEFINE_float('learning_power', 0.9,
+ 'The power value used in the poly learning policy.')
+
+flags.DEFINE_integer('training_number_of_steps', 30000,
+ 'The number of steps used for training')
+
+flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')
+
+# Adam optimizer flags
+flags.DEFINE_float('adam_learning_rate', 0.001,
+ 'Learning rate for the adam optimizer.')
+flags.DEFINE_float('adam_epsilon', 1e-08, 'Adam optimizer epsilon.')
+
+# When fine_tune_batch_norm=True, use at least batch size larger than 12
+# (batch size more than 16 is better). Otherwise, one could use smaller batch
+# size and set fine_tune_batch_norm=False.
+flags.DEFINE_integer('train_batch_size', 8,
+ 'The number of images in each batch during training.')
+
+# For weight_decay, use 0.00004 for MobileNet-V2 or Xcpetion model variants.
+# Use 0.0001 for ResNet model variants.
+flags.DEFINE_float('weight_decay', 0.00004,
+ 'The value of the weight decay for training.')
+
+flags.DEFINE_list('train_crop_size', '513,513',
+ 'Image crop size [height, width] during training.')
+
+flags.DEFINE_float(
+ 'last_layer_gradient_multiplier', 1.0,
+ 'The gradient multiplier for last layers, which is used to '
+ 'boost the gradient of last layers if the value > 1.')
+
+flags.DEFINE_boolean('upsample_logits', True,
+ 'Upsample logits during training.')
+
+# Hyper-parameters for NAS training strategy.
+
+flags.DEFINE_float(
+ 'drop_path_keep_prob', 1.0,
+ 'Probability to keep each path in the NAS cell when training.')
+
+# Settings for fine-tuning the network.
+
+flags.DEFINE_string('tf_initial_checkpoint', None,
+ 'The initial checkpoint in tensorflow format.')
+
+# Set to False if one does not want to re-use the trained classifier weights.
+flags.DEFINE_boolean('initialize_last_layer', True,
+ 'Initialize the last layer.')
+
+flags.DEFINE_boolean('last_layers_contain_logits_only', False,
+ 'Only consider logits as last layers or not.')
+
+flags.DEFINE_integer('slow_start_step', 0,
+ 'Training model with small learning rate for few steps.')
+
+flags.DEFINE_float('slow_start_learning_rate', 1e-4,
+ 'Learning rate employed during slow start.')
+
+# Set to True if one wants to fine-tune the batch norm parameters in DeepLabv3.
+# Set to False and use small batch size to save GPU memory.
+flags.DEFINE_boolean('fine_tune_batch_norm', True,
+ 'Fine tune the batch norm parameters or not.')
+
+flags.DEFINE_float('min_scale_factor', 0.5,
+ 'Mininum scale factor for data augmentation.')
+
+flags.DEFINE_float('max_scale_factor', 2.,
+ 'Maximum scale factor for data augmentation.')
+
+flags.DEFINE_float('scale_factor_step_size', 0.25,
+ 'Scale factor step size for data augmentation.')
+
+# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
+# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
+# one could use different atrous_rates/output_stride during training/evaluation.
+flags.DEFINE_multi_integer('atrous_rates', None,
+ 'Atrous rates for atrous spatial pyramid pooling.')
+
+flags.DEFINE_integer('output_stride', 16,
+ 'The ratio of input to output spatial resolution.')
+
+# Hard example mining related flags.
+flags.DEFINE_integer(
+ 'hard_example_mining_step', 0,
+ 'The training step in which exact hard example mining kicks off. Note we '
+ 'gradually reduce the mining percent to the specified '
+ 'top_k_percent_pixels. For example, if hard_example_mining_step=100K and '
+ 'top_k_percent_pixels=0.25, then mining percent will gradually reduce from '
+ '100% to 25% until 100K steps after which we only mine top 25% pixels.')
+
+flags.DEFINE_float(
+ 'top_k_percent_pixels', 1.0,
+ 'The top k percent pixels (in terms of the loss values) used to compute '
+ 'loss during training. This is useful for hard pixel mining.')
+
+# Quantization setting.
+flags.DEFINE_integer(
+ 'quantize_delay_step', -1,
+ 'Steps to start quantized training. If < 0, will not quantize model.')
+
+# Dataset settings.
+flags.DEFINE_string('dataset', 'pascal_voc_seg',
+ 'Name of the segmentation dataset.')
+
+flags.DEFINE_string('train_split', 'train',
+ 'Which split of the dataset to be used for training')
+
+flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
+
+
+def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):
+ """Builds a clone of DeepLab.
+
+ Args:
+ iterator: An iterator of type tf.data.Iterator for images and labels.
+ outputs_to_num_classes: A map from output type to the number of classes. For
+ example, for the task of semantic segmentation with 21 semantic classes,
+ we would have outputs_to_num_classes['semantic'] = 21.
+ ignore_label: Ignore label.
+ """
+ samples = iterator.get_next()
+
+ # Add name to input and label nodes so we can add to summary.
+ samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
+ samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)
+
+ model_options = common.ModelOptions(
+ outputs_to_num_classes=outputs_to_num_classes,
+ crop_size=[int(sz) for sz in FLAGS.train_crop_size],
+ atrous_rates=FLAGS.atrous_rates,
+ output_stride=FLAGS.output_stride)
+
+ outputs_to_scales_to_logits = model.multi_scale_logits(
+ samples[common.IMAGE],
+ model_options=model_options,
+ image_pyramid=FLAGS.image_pyramid,
+ weight_decay=FLAGS.weight_decay,
+ is_training=True,
+ fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
+ nas_training_hyper_parameters={
+ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
+ 'total_training_steps': FLAGS.training_number_of_steps,
+ })
+
+ # Add name to graph node so we can add to summary.
+ output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
+ output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
+ output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE)
+
+ for output, num_classes in six.iteritems(outputs_to_num_classes):
+ train_utils.add_softmax_cross_entropy_loss_for_each_scale(
+ outputs_to_scales_to_logits[output],
+ samples[common.LABEL],
+ num_classes,
+ ignore_label,
+ loss_weight=model_options.label_weights,
+ upsample_logits=FLAGS.upsample_logits,
+ hard_example_mining_step=FLAGS.hard_example_mining_step,
+ top_k_percent_pixels=FLAGS.top_k_percent_pixels,
+ scope=output)
+
+
+def main(unused_argv):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
+ config = model_deploy.DeploymentConfig(
+ num_clones=FLAGS.num_clones,
+ clone_on_cpu=FLAGS.clone_on_cpu,
+ replica_id=FLAGS.task,
+ num_replicas=FLAGS.num_replicas,
+ num_ps_tasks=FLAGS.num_ps_tasks)
+
+ # Split the batch across GPUs.
+ assert FLAGS.train_batch_size % config.num_clones == 0, (
+ 'Training batch size not divisble by number of clones (GPUs).')
+
+ clone_batch_size = FLAGS.train_batch_size // config.num_clones
+
+ tf.gfile.MakeDirs(FLAGS.train_logdir)
+ tf.logging.info('Training on %s set', FLAGS.train_split)
+
+ with tf.Graph().as_default() as graph:
+ with tf.device(config.inputs_device()):
+ dataset = data_generator.Dataset(
+ dataset_name=FLAGS.dataset,
+ split_name=FLAGS.train_split,
+ dataset_dir=FLAGS.dataset_dir,
+ batch_size=clone_batch_size,
+ crop_size=[int(sz) for sz in FLAGS.train_crop_size],
+ min_resize_value=FLAGS.min_resize_value,
+ max_resize_value=FLAGS.max_resize_value,
+ resize_factor=FLAGS.resize_factor,
+ min_scale_factor=FLAGS.min_scale_factor,
+ max_scale_factor=FLAGS.max_scale_factor,
+ scale_factor_step_size=FLAGS.scale_factor_step_size,
+ model_variant=FLAGS.model_variant,
+ num_readers=4,
+ is_training=True,
+ should_shuffle=True,
+ should_repeat=True)
+
+ # Create the global step on the device storing the variables.
+ with tf.device(config.variables_device()):
+ global_step = tf.train.get_or_create_global_step()
+
+ # Define the model and create clones.
+ model_fn = _build_deeplab
+ model_args = (dataset.get_one_shot_iterator(), {
+ common.OUTPUT_TYPE: dataset.num_of_classes
+ }, dataset.ignore_label)
+ clones = model_deploy.create_clones(config, model_fn, args=model_args)
+
+ # Gather update_ops from the first clone. These contain, for example,
+ # the updates for the batch_norm variables created by model_fn.
+ first_clone_scope = config.clone_scope(0)
+ update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
+
+ # Gather initial summaries.
+ summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
+
+ # Add summaries for model variables.
+ for model_var in tf.model_variables():
+ summaries.add(tf.summary.histogram(model_var.op.name, model_var))
+
+ # Add summaries for images, labels, semantic predictions
+ if FLAGS.save_summaries_images:
+ summary_image = graph.get_tensor_by_name(
+ ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
+ summaries.add(
+ tf.summary.image('samples/%s' % common.IMAGE, summary_image))
+
+ first_clone_label = graph.get_tensor_by_name(
+ ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
+ # Scale up summary image pixel values for better visualization.
+ pixel_scaling = max(1, 255 // dataset.num_of_classes)
+ summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
+ summaries.add(
+ tf.summary.image('samples/%s' % common.LABEL, summary_label))
+
+ first_clone_output = graph.get_tensor_by_name(
+ ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
+ predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)
+
+ summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
+ summaries.add(
+ tf.summary.image(
+ 'samples/%s' % common.OUTPUT_TYPE, summary_predictions))
+
+ # Add summaries for losses.
+ for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
+ summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
+
+ # Build the optimizer based on the device specification.
+ with tf.device(config.optimizer_device()):
+ learning_rate = train_utils.get_model_learning_rate(
+ FLAGS.learning_policy,
+ FLAGS.base_learning_rate,
+ FLAGS.learning_rate_decay_step,
+ FLAGS.learning_rate_decay_factor,
+ FLAGS.training_number_of_steps,
+ FLAGS.learning_power,
+ FLAGS.slow_start_step,
+ FLAGS.slow_start_learning_rate,
+ decay_steps=FLAGS.decay_steps,
+ end_learning_rate=FLAGS.end_learning_rate)
+
+ summaries.add(tf.summary.scalar('learning_rate', learning_rate))
+
+ if FLAGS.optimizer == 'momentum':
+ optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
+ elif FLAGS.optimizer == 'adam':
+ optimizer = tf.train.AdamOptimizer(
+ learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
+ else:
+ raise ValueError('Unknown optimizer')
+
+ if FLAGS.quantize_delay_step >= 0:
+ if FLAGS.num_clones > 1:
+ raise ValueError('Quantization doesn\'t support multi-clone yet.')
+ contrib_quantize.create_training_graph(
+ quant_delay=FLAGS.quantize_delay_step)
+
+ startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
+
+ with tf.device(config.variables_device()):
+ total_loss, grads_and_vars = model_deploy.optimize_clones(
+ clones, optimizer)
+ total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
+ summaries.add(tf.summary.scalar('total_loss', total_loss))
+
+ # Modify the gradients for biases and last layer variables.
+ last_layers = model.get_extra_layer_scopes(
+ FLAGS.last_layers_contain_logits_only)
+ grad_mult = train_utils.get_model_gradient_multipliers(
+ last_layers, FLAGS.last_layer_gradient_multiplier)
+ if grad_mult:
+ grads_and_vars = slim.learning.multiply_gradients(
+ grads_and_vars, grad_mult)
+
+ # Create gradient update op.
+ grad_updates = optimizer.apply_gradients(
+ grads_and_vars, global_step=global_step)
+ update_ops.append(grad_updates)
+ update_op = tf.group(*update_ops)
+ with tf.control_dependencies([update_op]):
+ train_tensor = tf.identity(total_loss, name='train_op')
+
+ # Add the summaries from the first clone. These contain the summaries
+ # created by model_fn and either optimize_clones() or _gather_clone_loss().
+ summaries |= set(
+ tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
+
+ # Merge all summaries together.
+ summary_op = tf.summary.merge(list(summaries))
+
+ # Soft placement allows placing on CPU ops without GPU implementation.
+ session_config = tf.ConfigProto(
+ allow_soft_placement=True, log_device_placement=False)
+
+ # Start the training.
+ profile_dir = FLAGS.profile_logdir
+ if profile_dir is not None:
+ tf.gfile.MakeDirs(profile_dir)
+
+ with contrib_tfprof.ProfileContext(
+ enabled=profile_dir is not None, profile_dir=profile_dir):
+ init_fn = None
+ if FLAGS.tf_initial_checkpoint:
+ init_fn = train_utils.get_model_init_fn(
+ FLAGS.train_logdir,
+ FLAGS.tf_initial_checkpoint,
+ FLAGS.initialize_last_layer,
+ last_layers,
+ ignore_missing_vars=True)
+
+ slim.learning.train(
+ train_tensor,
+ logdir=FLAGS.train_logdir,
+ log_every_n_steps=FLAGS.log_steps,
+ master=FLAGS.master,
+ number_of_steps=FLAGS.training_number_of_steps,
+ is_chief=(FLAGS.task == 0),
+ session_config=session_config,
+ startup_delay_steps=startup_delay_steps,
+ init_fn=init_fn,
+ summary_op=summary_op,
+ save_summaries_secs=FLAGS.save_summaries_secs,
+ save_interval_secs=FLAGS.save_interval_secs)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('train_logdir')
+ flags.mark_flag_as_required('dataset_dir')
+ tf.app.run()
diff --git a/models/research/deeplab/utils/__init__.py b/models/research/deeplab/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/deeplab/utils/get_dataset_colormap.py b/models/research/deeplab/utils/get_dataset_colormap.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0502e3b3cdd4ee065701e5ee8d94d7f3701c576
--- /dev/null
+++ b/models/research/deeplab/utils/get_dataset_colormap.py
@@ -0,0 +1,416 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Visualizes the segmentation results via specified color map.
+
+Visualizes the semantic segmentation results by the color map
+defined by the different datasets. Supported colormaps are:
+
+* ADE20K (http://groups.csail.mit.edu/vision/datasets/ADE20K/).
+
+* Cityscapes dataset (https://www.cityscapes-dataset.com).
+
+* Mapillary Vistas (https://research.mapillary.com).
+
+* PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/).
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+from six.moves import range
+
+# Dataset names.
+_ADE20K = 'ade20k'
+_CITYSCAPES = 'cityscapes'
+_MAPILLARY_VISTAS = 'mapillary_vistas'
+_PASCAL = 'pascal'
+
+# Max number of entries in the colormap for each dataset.
+_DATASET_MAX_ENTRIES = {
+ _ADE20K: 151,
+ _CITYSCAPES: 256,
+ _MAPILLARY_VISTAS: 66,
+ _PASCAL: 512,
+}
+
+
+def create_ade20k_label_colormap():
+ """Creates a label colormap used in ADE20K segmentation benchmark.
+
+ Returns:
+ A colormap for visualizing segmentation results.
+ """
+ return np.asarray([
+ [0, 0, 0],
+ [120, 120, 120],
+ [180, 120, 120],
+ [6, 230, 230],
+ [80, 50, 50],
+ [4, 200, 3],
+ [120, 120, 80],
+ [140, 140, 140],
+ [204, 5, 255],
+ [230, 230, 230],
+ [4, 250, 7],
+ [224, 5, 255],
+ [235, 255, 7],
+ [150, 5, 61],
+ [120, 120, 70],
+ [8, 255, 51],
+ [255, 6, 82],
+ [143, 255, 140],
+ [204, 255, 4],
+ [255, 51, 7],
+ [204, 70, 3],
+ [0, 102, 200],
+ [61, 230, 250],
+ [255, 6, 51],
+ [11, 102, 255],
+ [255, 7, 71],
+ [255, 9, 224],
+ [9, 7, 230],
+ [220, 220, 220],
+ [255, 9, 92],
+ [112, 9, 255],
+ [8, 255, 214],
+ [7, 255, 224],
+ [255, 184, 6],
+ [10, 255, 71],
+ [255, 41, 10],
+ [7, 255, 255],
+ [224, 255, 8],
+ [102, 8, 255],
+ [255, 61, 6],
+ [255, 194, 7],
+ [255, 122, 8],
+ [0, 255, 20],
+ [255, 8, 41],
+ [255, 5, 153],
+ [6, 51, 255],
+ [235, 12, 255],
+ [160, 150, 20],
+ [0, 163, 255],
+ [140, 140, 140],
+ [250, 10, 15],
+ [20, 255, 0],
+ [31, 255, 0],
+ [255, 31, 0],
+ [255, 224, 0],
+ [153, 255, 0],
+ [0, 0, 255],
+ [255, 71, 0],
+ [0, 235, 255],
+ [0, 173, 255],
+ [31, 0, 255],
+ [11, 200, 200],
+ [255, 82, 0],
+ [0, 255, 245],
+ [0, 61, 255],
+ [0, 255, 112],
+ [0, 255, 133],
+ [255, 0, 0],
+ [255, 163, 0],
+ [255, 102, 0],
+ [194, 255, 0],
+ [0, 143, 255],
+ [51, 255, 0],
+ [0, 82, 255],
+ [0, 255, 41],
+ [0, 255, 173],
+ [10, 0, 255],
+ [173, 255, 0],
+ [0, 255, 153],
+ [255, 92, 0],
+ [255, 0, 255],
+ [255, 0, 245],
+ [255, 0, 102],
+ [255, 173, 0],
+ [255, 0, 20],
+ [255, 184, 184],
+ [0, 31, 255],
+ [0, 255, 61],
+ [0, 71, 255],
+ [255, 0, 204],
+ [0, 255, 194],
+ [0, 255, 82],
+ [0, 10, 255],
+ [0, 112, 255],
+ [51, 0, 255],
+ [0, 194, 255],
+ [0, 122, 255],
+ [0, 255, 163],
+ [255, 153, 0],
+ [0, 255, 10],
+ [255, 112, 0],
+ [143, 255, 0],
+ [82, 0, 255],
+ [163, 255, 0],
+ [255, 235, 0],
+ [8, 184, 170],
+ [133, 0, 255],
+ [0, 255, 92],
+ [184, 0, 255],
+ [255, 0, 31],
+ [0, 184, 255],
+ [0, 214, 255],
+ [255, 0, 112],
+ [92, 255, 0],
+ [0, 224, 255],
+ [112, 224, 255],
+ [70, 184, 160],
+ [163, 0, 255],
+ [153, 0, 255],
+ [71, 255, 0],
+ [255, 0, 163],
+ [255, 204, 0],
+ [255, 0, 143],
+ [0, 255, 235],
+ [133, 255, 0],
+ [255, 0, 235],
+ [245, 0, 255],
+ [255, 0, 122],
+ [255, 245, 0],
+ [10, 190, 212],
+ [214, 255, 0],
+ [0, 204, 255],
+ [20, 0, 255],
+ [255, 255, 0],
+ [0, 153, 255],
+ [0, 41, 255],
+ [0, 255, 204],
+ [41, 0, 255],
+ [41, 255, 0],
+ [173, 0, 255],
+ [0, 245, 255],
+ [71, 0, 255],
+ [122, 0, 255],
+ [0, 255, 184],
+ [0, 92, 255],
+ [184, 255, 0],
+ [0, 133, 255],
+ [255, 214, 0],
+ [25, 194, 194],
+ [102, 255, 0],
+ [92, 0, 255],
+ ])
+
+
+def create_cityscapes_label_colormap():
+ """Creates a label colormap used in CITYSCAPES segmentation benchmark.
+
+ Returns:
+ A colormap for visualizing segmentation results.
+ """
+ colormap = np.zeros((256, 3), dtype=np.uint8)
+ colormap[0] = [128, 64, 128]
+ colormap[1] = [244, 35, 232]
+ colormap[2] = [70, 70, 70]
+ colormap[3] = [102, 102, 156]
+ colormap[4] = [190, 153, 153]
+ colormap[5] = [153, 153, 153]
+ colormap[6] = [250, 170, 30]
+ colormap[7] = [220, 220, 0]
+ colormap[8] = [107, 142, 35]
+ colormap[9] = [152, 251, 152]
+ colormap[10] = [70, 130, 180]
+ colormap[11] = [220, 20, 60]
+ colormap[12] = [255, 0, 0]
+ colormap[13] = [0, 0, 142]
+ colormap[14] = [0, 0, 70]
+ colormap[15] = [0, 60, 100]
+ colormap[16] = [0, 80, 100]
+ colormap[17] = [0, 0, 230]
+ colormap[18] = [119, 11, 32]
+ return colormap
+
+
+def create_mapillary_vistas_label_colormap():
+ """Creates a label colormap used in Mapillary Vistas segmentation benchmark.
+
+ Returns:
+ A colormap for visualizing segmentation results.
+ """
+ return np.asarray([
+ [165, 42, 42],
+ [0, 192, 0],
+ [196, 196, 196],
+ [190, 153, 153],
+ [180, 165, 180],
+ [102, 102, 156],
+ [102, 102, 156],
+ [128, 64, 255],
+ [140, 140, 200],
+ [170, 170, 170],
+ [250, 170, 160],
+ [96, 96, 96],
+ [230, 150, 140],
+ [128, 64, 128],
+ [110, 110, 110],
+ [244, 35, 232],
+ [150, 100, 100],
+ [70, 70, 70],
+ [150, 120, 90],
+ [220, 20, 60],
+ [255, 0, 0],
+ [255, 0, 0],
+ [255, 0, 0],
+ [200, 128, 128],
+ [255, 255, 255],
+ [64, 170, 64],
+ [128, 64, 64],
+ [70, 130, 180],
+ [255, 255, 255],
+ [152, 251, 152],
+ [107, 142, 35],
+ [0, 170, 30],
+ [255, 255, 128],
+ [250, 0, 30],
+ [0, 0, 0],
+ [220, 220, 220],
+ [170, 170, 170],
+ [222, 40, 40],
+ [100, 170, 30],
+ [40, 40, 40],
+ [33, 33, 33],
+ [170, 170, 170],
+ [0, 0, 142],
+ [170, 170, 170],
+ [210, 170, 100],
+ [153, 153, 153],
+ [128, 128, 128],
+ [0, 0, 142],
+ [250, 170, 30],
+ [192, 192, 192],
+ [220, 220, 0],
+ [180, 165, 180],
+ [119, 11, 32],
+ [0, 0, 142],
+ [0, 60, 100],
+ [0, 0, 142],
+ [0, 0, 90],
+ [0, 0, 230],
+ [0, 80, 100],
+ [128, 64, 64],
+ [0, 0, 110],
+ [0, 0, 70],
+ [0, 0, 192],
+ [32, 32, 32],
+ [0, 0, 0],
+ [0, 0, 0],
+ ])
+
+
+def create_pascal_label_colormap():
+ """Creates a label colormap used in PASCAL VOC segmentation benchmark.
+
+ Returns:
+ A colormap for visualizing segmentation results.
+ """
+ colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int)
+ ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int)
+
+ for shift in reversed(list(range(8))):
+ for channel in range(3):
+ colormap[:, channel] |= bit_get(ind, channel) << shift
+ ind >>= 3
+
+ return colormap
+
+
+def get_ade20k_name():
+ return _ADE20K
+
+
+def get_cityscapes_name():
+ return _CITYSCAPES
+
+
+def get_mapillary_vistas_name():
+ return _MAPILLARY_VISTAS
+
+
+def get_pascal_name():
+ return _PASCAL
+
+
+def bit_get(val, idx):
+ """Gets the bit value.
+
+ Args:
+ val: Input value, int or numpy int array.
+ idx: Which bit of the input val.
+
+ Returns:
+ The "idx"-th bit of input val.
+ """
+ return (val >> idx) & 1
+
+
+def create_label_colormap(dataset=_PASCAL):
+ """Creates a label colormap for the specified dataset.
+
+ Args:
+ dataset: The colormap used in the dataset.
+
+ Returns:
+ A numpy array of the dataset colormap.
+
+ Raises:
+ ValueError: If the dataset is not supported.
+ """
+ if dataset == _ADE20K:
+ return create_ade20k_label_colormap()
+ elif dataset == _CITYSCAPES:
+ return create_cityscapes_label_colormap()
+ elif dataset == _MAPILLARY_VISTAS:
+ return create_mapillary_vistas_label_colormap()
+ elif dataset == _PASCAL:
+ return create_pascal_label_colormap()
+ else:
+ raise ValueError('Unsupported dataset.')
+
+
+def label_to_color_image(label, dataset=_PASCAL):
+ """Adds color defined by the dataset colormap to the label.
+
+ Args:
+ label: A 2D array with integer type, storing the segmentation label.
+ dataset: The colormap used in the dataset.
+
+ Returns:
+ result: A 2D array with floating type. The element of the array
+ is the color indexed by the corresponding element in the input label
+ to the dataset color map.
+
+ Raises:
+ ValueError: If label is not of rank 2 or its value is larger than color
+ map maximum entry.
+ """
+ if label.ndim != 2:
+ raise ValueError('Expect 2-D input label. Got {}'.format(label.shape))
+
+ if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
+ raise ValueError(
+ 'label value too large: {} >= {}.'.format(
+ np.max(label), _DATASET_MAX_ENTRIES[dataset]))
+
+ colormap = create_label_colormap(dataset)
+ return colormap[label]
+
+
+def get_dataset_colormap_max_entries(dataset):
+ return _DATASET_MAX_ENTRIES[dataset]
diff --git a/models/research/deeplab/utils/get_dataset_colormap_test.py b/models/research/deeplab/utils/get_dataset_colormap_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..89adb2c7391ce087100558fcf256acb1ca45638b
--- /dev/null
+++ b/models/research/deeplab/utils/get_dataset_colormap_test.py
@@ -0,0 +1,97 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for get_dataset_colormap.py."""
+
+import numpy as np
+import tensorflow as tf
+
+from deeplab.utils import get_dataset_colormap
+
+
+class VisualizationUtilTest(tf.test.TestCase):
+
+ def testBitGet(self):
+ """Test that if the returned bit value is correct."""
+ self.assertEqual(1, get_dataset_colormap.bit_get(9, 0))
+ self.assertEqual(0, get_dataset_colormap.bit_get(9, 1))
+ self.assertEqual(0, get_dataset_colormap.bit_get(9, 2))
+ self.assertEqual(1, get_dataset_colormap.bit_get(9, 3))
+
+ def testPASCALLabelColorMapValue(self):
+ """Test the getd color map value."""
+ colormap = get_dataset_colormap.create_pascal_label_colormap()
+
+ # Only test a few sampled entries in the color map.
+ self.assertTrue(np.array_equal([128., 0., 128.], colormap[5, :]))
+ self.assertTrue(np.array_equal([128., 192., 128.], colormap[23, :]))
+ self.assertTrue(np.array_equal([128., 0., 192.], colormap[37, :]))
+ self.assertTrue(np.array_equal([224., 192., 192.], colormap[127, :]))
+ self.assertTrue(np.array_equal([192., 160., 192.], colormap[175, :]))
+
+ def testLabelToPASCALColorImage(self):
+ """Test the value of the converted label value."""
+ label = np.array([[0, 16, 16], [52, 7, 52]])
+ expected_result = np.array([
+ [[0, 0, 0], [0, 64, 0], [0, 64, 0]],
+ [[0, 64, 192], [128, 128, 128], [0, 64, 192]]
+ ])
+ colored_label = get_dataset_colormap.label_to_color_image(
+ label, get_dataset_colormap.get_pascal_name())
+ self.assertTrue(np.array_equal(expected_result, colored_label))
+
+ def testUnExpectedLabelValueForLabelToPASCALColorImage(self):
+ """Raise ValueError when input value exceeds range."""
+ label = np.array([[120], [600]])
+ with self.assertRaises(ValueError):
+ get_dataset_colormap.label_to_color_image(
+ label, get_dataset_colormap.get_pascal_name())
+
+ def testUnExpectedLabelDimensionForLabelToPASCALColorImage(self):
+ """Raise ValueError if input dimension is not correct."""
+ label = np.array([120])
+ with self.assertRaises(ValueError):
+ get_dataset_colormap.label_to_color_image(
+ label, get_dataset_colormap.get_pascal_name())
+
+ def testGetColormapForUnsupportedDataset(self):
+ with self.assertRaises(ValueError):
+ get_dataset_colormap.create_label_colormap('unsupported_dataset')
+
+ def testUnExpectedLabelDimensionForLabelToADE20KColorImage(self):
+ label = np.array([250])
+ with self.assertRaises(ValueError):
+ get_dataset_colormap.label_to_color_image(
+ label, get_dataset_colormap.get_ade20k_name())
+
+ def testFirstColorInADE20KColorMap(self):
+ label = np.array([[1, 3], [10, 20]])
+ expected_result = np.array([
+ [[120, 120, 120], [6, 230, 230]],
+ [[4, 250, 7], [204, 70, 3]]
+ ])
+ colored_label = get_dataset_colormap.label_to_color_image(
+ label, get_dataset_colormap.get_ade20k_name())
+ self.assertTrue(np.array_equal(colored_label, expected_result))
+
+ def testMapillaryVistasColorMapValue(self):
+ colormap = get_dataset_colormap.create_mapillary_vistas_label_colormap()
+ self.assertTrue(np.array_equal([190, 153, 153], colormap[3, :]))
+ self.assertTrue(np.array_equal([102, 102, 156], colormap[6, :]))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/deeplab/utils/save_annotation.py b/models/research/deeplab/utils/save_annotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2444df79532d6ef999f470ab8eef5ab333491660
--- /dev/null
+++ b/models/research/deeplab/utils/save_annotation.py
@@ -0,0 +1,66 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Saves an annotation as one png image.
+
+This script saves an annotation as one png image, and has the option to add
+colormap to the png image for better visualization.
+"""
+
+import numpy as np
+import PIL.Image as img
+import tensorflow as tf
+
+from deeplab.utils import get_dataset_colormap
+
+
+def save_annotation(label,
+ save_dir,
+ filename,
+ add_colormap=True,
+ normalize_to_unit_values=False,
+ scale_values=False,
+ colormap_type=get_dataset_colormap.get_pascal_name()):
+ """Saves the given label to image on disk.
+
+ Args:
+ label: The numpy array to be saved. The data will be converted
+ to uint8 and saved as png image.
+ save_dir: String, the directory to which the results will be saved.
+ filename: String, the image filename.
+ add_colormap: Boolean, add color map to the label or not.
+ normalize_to_unit_values: Boolean, normalize the input values to [0, 1].
+ scale_values: Boolean, scale the input values to [0, 255] for visualization.
+ colormap_type: String, colormap type for visualization.
+ """
+ # Add colormap for visualizing the prediction.
+ if add_colormap:
+ colored_label = get_dataset_colormap.label_to_color_image(
+ label, colormap_type)
+ else:
+ colored_label = label
+ if normalize_to_unit_values:
+ min_value = np.amin(colored_label)
+ max_value = np.amax(colored_label)
+ range_value = max_value - min_value
+ if range_value != 0:
+ colored_label = (colored_label - min_value) / range_value
+
+ if scale_values:
+ colored_label = 255. * colored_label
+
+ pil_image = img.fromarray(colored_label.astype(dtype=np.uint8))
+ with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f:
+ pil_image.save(f, 'PNG')
diff --git a/models/research/deeplab/utils/train_utils.py b/models/research/deeplab/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..14bbd6ee7e55533d94195fb4e7327e63e53a2800
--- /dev/null
+++ b/models/research/deeplab/utils/train_utils.py
@@ -0,0 +1,372 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for training."""
+
+import six
+import tensorflow as tf
+from tensorflow.contrib import framework as contrib_framework
+
+from deeplab.core import preprocess_utils
+from deeplab.core import utils
+
+
+def _div_maybe_zero(total_loss, num_present):
+ """Normalizes the total loss with the number of present pixels."""
+ return tf.to_float(num_present > 0) * tf.math.divide(
+ total_loss,
+ tf.maximum(1e-5, num_present))
+
+
+def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
+ labels,
+ num_classes,
+ ignore_label,
+ loss_weight=1.0,
+ upsample_logits=True,
+ hard_example_mining_step=0,
+ top_k_percent_pixels=1.0,
+ gt_is_matting_map=False,
+ scope=None):
+ """Adds softmax cross entropy loss for logits of each scale.
+
+ Args:
+ scales_to_logits: A map from logits names for different scales to logits.
+ The logits have shape [batch, logits_height, logits_width, num_classes].
+ labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
+ num_classes: Integer, number of target classes.
+ ignore_label: Integer, label to ignore.
+ loss_weight: A float or a list of loss weights. If it is a float, it means
+ all the labels have the same weight. If it is a list of weights, then each
+ element in the list represents the weight for the label of its index, for
+ example, loss_weight = [0.1, 0.5] means the weight for label 0 is 0.1 and
+ the weight for label 1 is 0.5.
+ upsample_logits: Boolean, upsample logits or not.
+ hard_example_mining_step: An integer, the training step in which the hard
+ exampling mining kicks off. Note that we gradually reduce the mining
+ percent to the top_k_percent_pixels. For example, if
+ hard_example_mining_step = 100K and top_k_percent_pixels = 0.25, then
+ mining percent will gradually reduce from 100% to 25% until 100K steps
+ after which we only mine top 25% pixels.
+ top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its value
+ < 1.0, only compute the loss for the top k percent pixels (e.g., the top
+ 20% pixels). This is useful for hard pixel mining.
+ gt_is_matting_map: If true, the groundtruth is a matting map of confidence
+ score. If false, the groundtruth is an integer valued class mask.
+ scope: String, the scope for the loss.
+
+ Raises:
+ ValueError: Label or logits is None, or groundtruth is matting map while
+ label is not floating value.
+ """
+ if labels is None:
+ raise ValueError('No label for softmax cross entropy loss.')
+
+ # If input groundtruth is a matting map of confidence, check if the input
+ # labels are floating point values.
+ if gt_is_matting_map and not labels.dtype.is_floating:
+ raise ValueError('Labels must be floats if groundtruth is a matting map.')
+
+ for scale, logits in six.iteritems(scales_to_logits):
+ loss_scope = None
+ if scope:
+ loss_scope = '%s_%s' % (scope, scale)
+
+ if upsample_logits:
+ # Label is not downsampled, and instead we upsample logits.
+ logits = tf.image.resize_bilinear(
+ logits,
+ preprocess_utils.resolve_shape(labels, 4)[1:3],
+ align_corners=True)
+ scaled_labels = labels
+ else:
+ # Label is downsampled to the same size as logits.
+ # When gt_is_matting_map = true, label downsampling with nearest neighbor
+ # method may introduce artifacts. However, to avoid ignore_label from
+ # being interpolated with other labels, we still perform nearest neighbor
+ # interpolation.
+ # TODO(huizhongc): Change to bilinear interpolation by processing padded
+ # and non-padded label separately.
+ if gt_is_matting_map:
+ tf.logging.warning(
+ 'Label downsampling with nearest neighbor may introduce artifacts.')
+
+ scaled_labels = tf.image.resize_nearest_neighbor(
+ labels,
+ preprocess_utils.resolve_shape(logits, 4)[1:3],
+ align_corners=True)
+
+ scaled_labels = tf.reshape(scaled_labels, shape=[-1])
+ weights = utils.get_label_weight_mask(
+ scaled_labels, ignore_label, num_classes, label_weights=loss_weight)
+ # Dimension of keep_mask is equal to the total number of pixels.
+ keep_mask = tf.cast(
+ tf.not_equal(scaled_labels, ignore_label), dtype=tf.float32)
+
+ train_labels = None
+ logits = tf.reshape(logits, shape=[-1, num_classes])
+
+ if gt_is_matting_map:
+ # When the groundtruth is integer label mask, we can assign class
+ # dependent label weights to the loss. When the groundtruth is image
+ # matting confidence, we do not apply class-dependent label weight (i.e.,
+ # label_weight = 1.0).
+ if loss_weight != 1.0:
+ raise ValueError(
+ 'loss_weight must equal to 1 if groundtruth is matting map.')
+
+ # Assign label value 0 to ignore pixels. The exact label value of ignore
+ # pixel does not matter, because those ignore_value pixel losses will be
+ # multiplied to 0 weight.
+ train_labels = scaled_labels * keep_mask
+
+ train_labels = tf.expand_dims(train_labels, 1)
+ train_labels = tf.concat([1 - train_labels, train_labels], axis=1)
+ else:
+ train_labels = tf.one_hot(
+ scaled_labels, num_classes, on_value=1.0, off_value=0.0)
+
+ default_loss_scope = ('softmax_all_pixel_loss'
+ if top_k_percent_pixels == 1.0 else
+ 'softmax_hard_example_mining')
+ with tf.name_scope(loss_scope, default_loss_scope,
+ [logits, train_labels, weights]):
+ # Compute the loss for all pixels.
+ pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
+ labels=tf.stop_gradient(
+ train_labels, name='train_labels_stop_gradient'),
+ logits=logits,
+ name='pixel_losses')
+ weighted_pixel_losses = tf.multiply(pixel_losses, weights)
+
+ if top_k_percent_pixels == 1.0:
+ total_loss = tf.reduce_sum(weighted_pixel_losses)
+ num_present = tf.reduce_sum(keep_mask)
+ loss = _div_maybe_zero(total_loss, num_present)
+ tf.losses.add_loss(loss)
+ else:
+ num_pixels = tf.to_float(tf.shape(logits)[0])
+ # Compute the top_k_percent pixels based on current training step.
+ if hard_example_mining_step == 0:
+ # Directly focus on the top_k pixels.
+ top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
+ else:
+ # Gradually reduce the mining percent to top_k_percent_pixels.
+ global_step = tf.to_float(tf.train.get_or_create_global_step())
+ ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
+ top_k_pixels = tf.to_int32(
+ (ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
+ top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
+ k=top_k_pixels,
+ sorted=True,
+ name='top_k_percent_pixels')
+ total_loss = tf.reduce_sum(top_k_losses)
+ num_present = tf.reduce_sum(
+ tf.to_float(tf.not_equal(top_k_losses, 0.0)))
+ loss = _div_maybe_zero(total_loss, num_present)
+ tf.losses.add_loss(loss)
+
+
+def get_model_init_fn(train_logdir,
+ tf_initial_checkpoint,
+ initialize_last_layer,
+ last_layers,
+ ignore_missing_vars=False):
+ """Gets the function initializing model variables from a checkpoint.
+
+ Args:
+ train_logdir: Log directory for training.
+ tf_initial_checkpoint: TensorFlow checkpoint for initialization.
+ initialize_last_layer: Initialize last layer or not.
+ last_layers: Last layers of the model.
+ ignore_missing_vars: Ignore missing variables in the checkpoint.
+
+ Returns:
+ Initialization function.
+ """
+ if tf_initial_checkpoint is None:
+ tf.logging.info('Not initializing the model from a checkpoint.')
+ return None
+
+ if tf.train.latest_checkpoint(train_logdir):
+ tf.logging.info('Ignoring initialization; other checkpoint exists')
+ return None
+
+ tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
+
+ # Variables that will not be restored.
+ exclude_list = ['global_step']
+ if not initialize_last_layer:
+ exclude_list.extend(last_layers)
+
+ variables_to_restore = contrib_framework.get_variables_to_restore(
+ exclude=exclude_list)
+
+ if variables_to_restore:
+ init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
+ tf_initial_checkpoint,
+ variables_to_restore,
+ ignore_missing_vars=ignore_missing_vars)
+ global_step = tf.train.get_or_create_global_step()
+
+ def restore_fn(sess):
+ sess.run(init_op, init_feed_dict)
+ sess.run([global_step])
+
+ return restore_fn
+
+ return None
+
+
+def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
+ """Gets the gradient multipliers.
+
+ The gradient multipliers will adjust the learning rates for model
+ variables. For the task of semantic segmentation, the models are
+ usually fine-tuned from the models trained on the task of image
+ classification. To fine-tune the models, we usually set larger (e.g.,
+ 10 times larger) learning rate for the parameters of last layer.
+
+ Args:
+ last_layers: Scopes of last layers.
+ last_layer_gradient_multiplier: The gradient multiplier for last layers.
+
+ Returns:
+ The gradient multiplier map with variables as key, and multipliers as value.
+ """
+ gradient_multipliers = {}
+
+ for var in tf.model_variables():
+ # Double the learning rate for biases.
+ if 'biases' in var.op.name:
+ gradient_multipliers[var.op.name] = 2.
+
+ # Use larger learning rate for last layer variables.
+ for layer in last_layers:
+ if layer in var.op.name and 'biases' in var.op.name:
+ gradient_multipliers[var.op.name] = 2 * last_layer_gradient_multiplier
+ break
+ elif layer in var.op.name:
+ gradient_multipliers[var.op.name] = last_layer_gradient_multiplier
+ break
+
+ return gradient_multipliers
+
+
+def get_model_learning_rate(learning_policy,
+ base_learning_rate,
+ learning_rate_decay_step,
+ learning_rate_decay_factor,
+ training_number_of_steps,
+ learning_power,
+ slow_start_step,
+ slow_start_learning_rate,
+ slow_start_burnin_type='none',
+ decay_steps=0.0,
+ end_learning_rate=0.0,
+ boundaries=None,
+ boundary_learning_rates=None):
+ """Gets model's learning rate.
+
+ Computes the model's learning rate for different learning policy.
+ Right now, only "step" and "poly" are supported.
+ (1) The learning policy for "step" is computed as follows:
+ current_learning_rate = base_learning_rate *
+ learning_rate_decay_factor ^ (global_step / learning_rate_decay_step)
+ See tf.train.exponential_decay for details.
+ (2) The learning policy for "poly" is computed as follows:
+ current_learning_rate = base_learning_rate *
+ (1 - global_step / training_number_of_steps) ^ learning_power
+
+ Args:
+ learning_policy: Learning rate policy for training.
+ base_learning_rate: The base learning rate for model training.
+ learning_rate_decay_step: Decay the base learning rate at a fixed step.
+ learning_rate_decay_factor: The rate to decay the base learning rate.
+ training_number_of_steps: Number of steps for training.
+ learning_power: Power used for 'poly' learning policy.
+ slow_start_step: Training model with small learning rate for the first
+ few steps.
+ slow_start_learning_rate: The learning rate employed during slow start.
+ slow_start_burnin_type: The burnin type for the slow start stage. Can be
+ `none` which means no burnin or `linear` which means the learning rate
+ increases linearly from slow_start_learning_rate and reaches
+ base_learning_rate after slow_start_steps.
+ decay_steps: Float, `decay_steps` for polynomial learning rate.
+ end_learning_rate: Float, `end_learning_rate` for polynomial learning rate.
+ boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
+ increasing entries.
+ boundary_learning_rates: A list of `Tensor`s or `float`s or `int`s that
+ specifies the values for the intervals defined by `boundaries`. It should
+ have one more element than `boundaries`, and all elements should have the
+ same type.
+
+ Returns:
+ Learning rate for the specified learning policy.
+
+ Raises:
+ ValueError: If learning policy or slow start burnin type is not recognized.
+ ValueError: If `boundaries` and `boundary_learning_rates` are not set for
+ multi_steps learning rate decay.
+ """
+ global_step = tf.train.get_or_create_global_step()
+ adjusted_global_step = tf.maximum(global_step - slow_start_step, 0)
+ if decay_steps == 0.0:
+ tf.logging.info('Setting decay_steps to total training steps.')
+ decay_steps = training_number_of_steps - slow_start_step
+ if learning_policy == 'step':
+ learning_rate = tf.train.exponential_decay(
+ base_learning_rate,
+ adjusted_global_step,
+ learning_rate_decay_step,
+ learning_rate_decay_factor,
+ staircase=True)
+ elif learning_policy == 'poly':
+ learning_rate = tf.train.polynomial_decay(
+ base_learning_rate,
+ adjusted_global_step,
+ decay_steps=decay_steps,
+ end_learning_rate=end_learning_rate,
+ power=learning_power)
+ elif learning_policy == 'cosine':
+ learning_rate = tf.train.cosine_decay(
+ base_learning_rate,
+ adjusted_global_step,
+ training_number_of_steps - slow_start_step)
+ elif learning_policy == 'multi_steps':
+ if boundaries is None or boundary_learning_rates is None:
+ raise ValueError('Must set `boundaries` and `boundary_learning_rates` '
+ 'for multi_steps learning rate decay.')
+ learning_rate = tf.train.piecewise_constant_decay(
+ adjusted_global_step,
+ boundaries,
+ boundary_learning_rates)
+ else:
+ raise ValueError('Unknown learning policy.')
+
+ adjusted_slow_start_learning_rate = slow_start_learning_rate
+ if slow_start_burnin_type == 'linear':
+ # Do linear burnin. Increase linearly from slow_start_learning_rate and
+ # reach base_learning_rate after (global_step >= slow_start_steps).
+ adjusted_slow_start_learning_rate = (
+ slow_start_learning_rate +
+ (base_learning_rate - slow_start_learning_rate) *
+ tf.to_float(global_step) / slow_start_step)
+ elif slow_start_burnin_type != 'none':
+ raise ValueError('Unknown burnin type.')
+
+ # Employ small learning rate at the first few steps for warm start.
+ return tf.where(global_step < slow_start_step,
+ adjusted_slow_start_learning_rate, learning_rate)
diff --git a/models/research/deeplab/vis.py b/models/research/deeplab/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..20808d37bf2f45f196a04391548c6745fcc6603b
--- /dev/null
+++ b/models/research/deeplab/vis.py
@@ -0,0 +1,327 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Segmentation results visualization on a given set of images.
+
+See model.py for more details and usage.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os.path
+import time
+import numpy as np
+from six.moves import range
+import tensorflow as tf
+from tensorflow.contrib import quantize as contrib_quantize
+from tensorflow.contrib import training as contrib_training
+from deeplab import common
+from deeplab import model
+from deeplab.datasets import data_generator
+from deeplab.utils import save_annotation
+
+flags = tf.app.flags
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
+
+# Settings for log directories.
+
+flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.')
+
+flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
+
+# Settings for visualizing the model.
+
+flags.DEFINE_integer('vis_batch_size', 1,
+ 'The number of images in each batch during evaluation.')
+
+flags.DEFINE_list('vis_crop_size', '513,513',
+ 'Crop size [height, width] for visualization.')
+
+flags.DEFINE_integer('eval_interval_secs', 60 * 5,
+ 'How often (in seconds) to run evaluation.')
+
+# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
+# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
+# one could use different atrous_rates/output_stride during training/evaluation.
+flags.DEFINE_multi_integer('atrous_rates', None,
+ 'Atrous rates for atrous spatial pyramid pooling.')
+
+flags.DEFINE_integer('output_stride', 16,
+ 'The ratio of input to output spatial resolution.')
+
+# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test.
+flags.DEFINE_multi_float('eval_scales', [1.0],
+ 'The scales to resize images for evaluation.')
+
+# Change to True for adding flipped images during test.
+flags.DEFINE_bool('add_flipped_images', False,
+ 'Add flipped images for evaluation or not.')
+
+flags.DEFINE_integer(
+ 'quantize_delay_step', -1,
+ 'Steps to start quantized training. If < 0, will not quantize model.')
+
+# Dataset settings.
+
+flags.DEFINE_string('dataset', 'pascal_voc_seg',
+ 'Name of the segmentation dataset.')
+
+flags.DEFINE_string('vis_split', 'val',
+ 'Which split of the dataset used for visualizing results')
+
+flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
+
+flags.DEFINE_enum('colormap_type', 'pascal', ['pascal', 'cityscapes', 'ade20k'],
+ 'Visualization colormap type.')
+
+flags.DEFINE_boolean('also_save_raw_predictions', False,
+ 'Also save raw predictions.')
+
+flags.DEFINE_integer('max_number_of_iterations', 0,
+ 'Maximum number of visualization iterations. Will loop '
+ 'indefinitely upon nonpositive values.')
+
+# The folder where semantic segmentation predictions are saved.
+_SEMANTIC_PREDICTION_SAVE_FOLDER = 'segmentation_results'
+
+# The folder where raw semantic segmentation predictions are saved.
+_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER = 'raw_segmentation_results'
+
+# The format to save image.
+_IMAGE_FORMAT = '%06d_image'
+
+# The format to save prediction
+_PREDICTION_FORMAT = '%06d_prediction'
+
+# To evaluate Cityscapes results on the evaluation server, the labels used
+# during training should be mapped to the labels for evaluation.
+_CITYSCAPES_TRAIN_ID_TO_EVAL_ID = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
+ 23, 24, 25, 26, 27, 28, 31, 32, 33]
+
+
+def _convert_train_id_to_eval_id(prediction, train_id_to_eval_id):
+ """Converts the predicted label for evaluation.
+
+ There are cases where the training labels are not equal to the evaluation
+ labels. This function is used to perform the conversion so that we could
+ evaluate the results on the evaluation server.
+
+ Args:
+ prediction: Semantic segmentation prediction.
+ train_id_to_eval_id: A list mapping from train id to evaluation id.
+
+ Returns:
+ Semantic segmentation prediction whose labels have been changed.
+ """
+ converted_prediction = prediction.copy()
+ for train_id, eval_id in enumerate(train_id_to_eval_id):
+ converted_prediction[prediction == train_id] = eval_id
+
+ return converted_prediction
+
+
+def _process_batch(sess, original_images, semantic_predictions, image_names,
+ image_heights, image_widths, image_id_offset, save_dir,
+ raw_save_dir, train_id_to_eval_id=None):
+ """Evaluates one single batch qualitatively.
+
+ Args:
+ sess: TensorFlow session.
+ original_images: One batch of original images.
+ semantic_predictions: One batch of semantic segmentation predictions.
+ image_names: Image names.
+ image_heights: Image heights.
+ image_widths: Image widths.
+ image_id_offset: Image id offset for indexing images.
+ save_dir: The directory where the predictions will be saved.
+ raw_save_dir: The directory where the raw predictions will be saved.
+ train_id_to_eval_id: A list mapping from train id to eval id.
+ """
+ (original_images,
+ semantic_predictions,
+ image_names,
+ image_heights,
+ image_widths) = sess.run([original_images, semantic_predictions,
+ image_names, image_heights, image_widths])
+
+ num_image = semantic_predictions.shape[0]
+ for i in range(num_image):
+ image_height = np.squeeze(image_heights[i])
+ image_width = np.squeeze(image_widths[i])
+ original_image = np.squeeze(original_images[i])
+ semantic_prediction = np.squeeze(semantic_predictions[i])
+ crop_semantic_prediction = semantic_prediction[:image_height, :image_width]
+
+ # Save image.
+ save_annotation.save_annotation(
+ original_image, save_dir, _IMAGE_FORMAT % (image_id_offset + i),
+ add_colormap=False)
+
+ # Save prediction.
+ save_annotation.save_annotation(
+ crop_semantic_prediction, save_dir,
+ _PREDICTION_FORMAT % (image_id_offset + i), add_colormap=True,
+ colormap_type=FLAGS.colormap_type)
+
+ if FLAGS.also_save_raw_predictions:
+ image_filename = os.path.basename(image_names[i])
+
+ if train_id_to_eval_id is not None:
+ crop_semantic_prediction = _convert_train_id_to_eval_id(
+ crop_semantic_prediction,
+ train_id_to_eval_id)
+ save_annotation.save_annotation(
+ crop_semantic_prediction, raw_save_dir, image_filename,
+ add_colormap=False)
+
+
+def main(unused_argv):
+ tf.logging.set_verbosity(tf.logging.INFO)
+
+ # Get dataset-dependent information.
+ dataset = data_generator.Dataset(
+ dataset_name=FLAGS.dataset,
+ split_name=FLAGS.vis_split,
+ dataset_dir=FLAGS.dataset_dir,
+ batch_size=FLAGS.vis_batch_size,
+ crop_size=[int(sz) for sz in FLAGS.vis_crop_size],
+ min_resize_value=FLAGS.min_resize_value,
+ max_resize_value=FLAGS.max_resize_value,
+ resize_factor=FLAGS.resize_factor,
+ model_variant=FLAGS.model_variant,
+ is_training=False,
+ should_shuffle=False,
+ should_repeat=False)
+
+ train_id_to_eval_id = None
+ if dataset.dataset_name == data_generator.get_cityscapes_dataset_name():
+ tf.logging.info('Cityscapes requires converting train_id to eval_id.')
+ train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID
+
+ # Prepare for visualization.
+ tf.gfile.MakeDirs(FLAGS.vis_logdir)
+ save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
+ tf.gfile.MakeDirs(save_dir)
+ raw_save_dir = os.path.join(
+ FLAGS.vis_logdir, _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
+ tf.gfile.MakeDirs(raw_save_dir)
+
+ tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
+
+ with tf.Graph().as_default():
+ samples = dataset.get_one_shot_iterator().get_next()
+
+ model_options = common.ModelOptions(
+ outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
+ crop_size=[int(sz) for sz in FLAGS.vis_crop_size],
+ atrous_rates=FLAGS.atrous_rates,
+ output_stride=FLAGS.output_stride)
+
+ if tuple(FLAGS.eval_scales) == (1.0,):
+ tf.logging.info('Performing single-scale test.')
+ predictions = model.predict_labels(
+ samples[common.IMAGE],
+ model_options=model_options,
+ image_pyramid=FLAGS.image_pyramid)
+ else:
+ tf.logging.info('Performing multi-scale test.')
+ if FLAGS.quantize_delay_step >= 0:
+ raise ValueError(
+ 'Quantize mode is not supported with multi-scale test.')
+ predictions = model.predict_labels_multi_scale(
+ samples[common.IMAGE],
+ model_options=model_options,
+ eval_scales=FLAGS.eval_scales,
+ add_flipped_images=FLAGS.add_flipped_images)
+ predictions = predictions[common.OUTPUT_TYPE]
+
+ if FLAGS.min_resize_value and FLAGS.max_resize_value:
+ # Only support batch_size = 1, since we assume the dimensions of original
+ # image after tf.squeeze is [height, width, 3].
+ assert FLAGS.vis_batch_size == 1
+
+ # Reverse the resizing and padding operations performed in preprocessing.
+ # First, we slice the valid regions (i.e., remove padded region) and then
+ # we resize the predictions back.
+ original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
+ original_image_shape = tf.shape(original_image)
+ predictions = tf.slice(
+ predictions,
+ [0, 0, 0],
+ [1, original_image_shape[0], original_image_shape[1]])
+ resized_shape = tf.to_int32([tf.squeeze(samples[common.HEIGHT]),
+ tf.squeeze(samples[common.WIDTH])])
+ predictions = tf.squeeze(
+ tf.image.resize_images(tf.expand_dims(predictions, 3),
+ resized_shape,
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
+ align_corners=True), 3)
+
+ tf.train.get_or_create_global_step()
+ if FLAGS.quantize_delay_step >= 0:
+ contrib_quantize.create_eval_graph()
+
+ num_iteration = 0
+ max_num_iteration = FLAGS.max_number_of_iterations
+
+ checkpoints_iterator = contrib_training.checkpoints_iterator(
+ FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)
+ for checkpoint_path in checkpoints_iterator:
+ num_iteration += 1
+ tf.logging.info(
+ 'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
+ time.gmtime()))
+ tf.logging.info('Visualizing with model %s', checkpoint_path)
+
+ scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer())
+ session_creator = tf.train.ChiefSessionCreator(
+ scaffold=scaffold,
+ master=FLAGS.master,
+ checkpoint_filename_with_path=checkpoint_path)
+ with tf.train.MonitoredSession(
+ session_creator=session_creator, hooks=None) as sess:
+ batch = 0
+ image_id_offset = 0
+
+ while not sess.should_stop():
+ tf.logging.info('Visualizing batch %d', batch + 1)
+ _process_batch(sess=sess,
+ original_images=samples[common.ORIGINAL_IMAGE],
+ semantic_predictions=predictions,
+ image_names=samples[common.IMAGE_NAME],
+ image_heights=samples[common.HEIGHT],
+ image_widths=samples[common.WIDTH],
+ image_id_offset=image_id_offset,
+ save_dir=save_dir,
+ raw_save_dir=raw_save_dir,
+ train_id_to_eval_id=train_id_to_eval_id)
+ image_id_offset += FLAGS.vis_batch_size
+ batch += 1
+
+ tf.logging.info(
+ 'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
+ time.gmtime()))
+ if max_num_iteration > 0 and num_iteration >= max_num_iteration:
+ break
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('checkpoint_dir')
+ flags.mark_flag_as_required('vis_logdir')
+ flags.mark_flag_as_required('dataset_dir')
+ tf.app.run()
diff --git a/models/research/delf/.gitignore b/models/research/delf/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b61ddd100012ab30e0bf438f3a5ab01ea3f44281
--- /dev/null
+++ b/models/research/delf/.gitignore
@@ -0,0 +1,4 @@
+*pyc
+*~
+*pb2.py
+*pb2.pyc
diff --git a/models/research/delf/DETECTION.md b/models/research/delf/DETECTION.md
new file mode 100644
index 0000000000000000000000000000000000000000..7fa7570f74dc58622151bee37f9a2a5697b896de
--- /dev/null
+++ b/models/research/delf/DETECTION.md
@@ -0,0 +1,69 @@
+## Quick start: landmark detection
+
+[](https://arxiv.org/abs/1812.01584)
+
+### Install DELF library
+
+To be able to use this code, please follow
+[these instructions](INSTALL_INSTRUCTIONS.md) to properly install the DELF
+library.
+
+### Download Oxford buildings dataset
+
+To illustrate detector usage, please download the Oxford buildings dataset, by
+following the instructions
+[here](EXTRACTION_MATCHING.md#download-oxford-buildings-dataset). Then, create
+the file `list_images_detector.txt` as follows:
+
+```bash
+# From tensorflow/models/research/delf/delf/python/examples/
+echo data/oxford5k_images/all_souls_000002.jpg >> list_images_detector.txt
+echo data/oxford5k_images/all_souls_000035.jpg >> list_images_detector.txt
+```
+
+### Download detector model
+
+Also, you will need to download the pre-trained detector model:
+
+```bash
+# From tensorflow/models/research/delf/delf/python/examples/
+mkdir parameters && cd parameters
+wget http://storage.googleapis.com/delf/d2r_frcnn_20190411.tar.gz
+tar -xvzf d2r_frcnn_20190411.tar.gz
+```
+
+**Note**: this is the Faster-RCNN based model. We also release a MobileNet-SSD
+model, see the [README](README.md#pre-trained-models) for download link. The
+instructions should work seamlessly for both models.
+
+### Detecting landmarks
+
+Now that you have everything in place, running this command should detect boxes
+for the images `all_souls_000002.jpg` and `all_souls_000035.jpg`, with a
+threshold of 0.8, and produce visualizations.
+
+```bash
+# From tensorflow/models/research/delf/delf/python/examples/
+python3 extract_boxes.py \
+ --detector_path parameters/d2r_frcnn_20190411 \
+ --detector_thresh 0.8 \
+ --list_images_path list_images_detector.txt \
+ --output_dir data/oxford5k_boxes \
+ --output_viz_dir data/oxford5k_boxes_viz
+```
+
+Two images are generated in the `data/oxford5k_boxes_viz` directory, they should
+look similar to these ones:
+
+
+
+
+### Troubleshooting
+
+#### `matplotlib`
+
+`matplotlib` may complain with a message such as `no display name and no
+$DISPLAY environment variable`. To fix this, one option is add the line
+`backend : Agg` to the file `.config/matplotlib/matplotlibrc`. On this problem,
+see the discussion
+[here](https://stackoverflow.com/questions/37604289/tkinter-tclerror-no-display-name-and-no-display-environment-variable).
diff --git a/models/research/delf/EXTRACTION_MATCHING.md b/models/research/delf/EXTRACTION_MATCHING.md
new file mode 100644
index 0000000000000000000000000000000000000000..53159638587282658129aa13ac165fbd7d3803ea
--- /dev/null
+++ b/models/research/delf/EXTRACTION_MATCHING.md
@@ -0,0 +1,87 @@
+## Quick start: DELF extraction and matching
+
+[](https://arxiv.org/abs/1612.06321)
+
+### Install DELF library
+
+To be able to use this code, please follow
+[these instructions](INSTALL_INSTRUCTIONS.md) to properly install the DELF
+library.
+
+### Download Oxford buildings dataset
+
+To illustrate DELF usage, please download the Oxford buildings dataset. To
+follow these instructions closely, please download the dataset to the
+`tensorflow/models/research/delf/delf/python/examples` directory, as in the
+following commands:
+
+```bash
+# From tensorflow/models/research/delf/delf/python/examples/
+mkdir data && cd data
+wget http://www.robots.ox.ac.uk/~vgg/data/oxbuildings/oxbuild_images.tgz
+mkdir oxford5k_images oxford5k_features
+tar -xvzf oxbuild_images.tgz -C oxford5k_images/
+cd ../
+echo data/oxford5k_images/hertford_000056.jpg >> list_images.txt
+echo data/oxford5k_images/oxford_000317.jpg >> list_images.txt
+```
+
+### Download pre-trained DELF model
+
+Also, you will need to download the trained DELF model:
+
+```bash
+# From tensorflow/models/research/delf/delf/python/examples/
+mkdir parameters && cd parameters
+wget http://storage.googleapis.com/delf/delf_gld_20190411.tar.gz
+tar -xvzf delf_gld_20190411.tar.gz
+```
+
+### DELF feature extraction
+
+Now that you have everything in place, running this command should extract DELF
+features for the images `hertford_000056.jpg` and `oxford_000317.jpg`:
+
+```bash
+# From tensorflow/models/research/delf/delf/python/examples/
+python3 extract_features.py \
+ --config_path delf_config_example.pbtxt \
+ --list_images_path list_images.txt \
+ --output_dir data/oxford5k_features
+```
+
+### Image matching using DELF features
+
+After feature extraction, run this command to perform feature matching between
+the images `hertford_000056.jpg` and `oxford_000317.jpg`:
+
+```bash
+python3 match_images.py \
+ --image_1_path data/oxford5k_images/hertford_000056.jpg \
+ --image_2_path data/oxford5k_images/oxford_000317.jpg \
+ --features_1_path data/oxford5k_features/hertford_000056.delf \
+ --features_2_path data/oxford5k_features/oxford_000317.delf \
+ --output_image matched_images.png
+```
+
+The image `matched_images.png` is generated and should look similar to this one:
+
+
+
+### Troubleshooting
+
+#### `matplotlib`
+
+`matplotlib` may complain with a message such as `no display name and no
+$DISPLAY environment variable`. To fix this, one option is add the line
+`backend : Agg` to the file `.config/matplotlib/matplotlibrc`. On this problem,
+see the discussion
+[here](https://stackoverflow.com/questions/37604289/tkinter-tclerror-no-display-name-and-no-display-environment-variable).
+
+#### 'skimage'
+
+By default, skimage 0.13.XX or 0.14.1 is installed if you followed the
+instructions. According to
+[https://github.com/scikit-image/scikit-image/issues/3649#issuecomment-455273659]
+If you have scikit-image related issues, upgrading to a version above 0.14.1
+with `pip3 install -U scikit-image` should fix the issue
diff --git a/models/research/delf/INSTALL_INSTRUCTIONS.md b/models/research/delf/INSTALL_INSTRUCTIONS.md
new file mode 100644
index 0000000000000000000000000000000000000000..4f66e9389fdd126dd769a8282a69482b989c9c9e
--- /dev/null
+++ b/models/research/delf/INSTALL_INSTRUCTIONS.md
@@ -0,0 +1,122 @@
+## DELF installation
+
+### Tensorflow
+
+[](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
+[](https://www.python.org/downloads/release/python-360/)
+
+For detailed steps to install Tensorflow, follow the
+[Tensorflow installation instructions](https://www.tensorflow.org/install/). A
+typical user can install Tensorflow using one of the following commands:
+
+```bash
+# For CPU:
+pip3 install 'tensorflow'
+# For GPU:
+pip3 install 'tensorflow-gpu'
+```
+
+### TF-Slim
+
+Note: currently, we need to install the latest version from source, to avoid
+using previous versions which relied on tf.contrib (which is now deprecated).
+
+```bash
+git clone git@github.com:google-research/tf-slim.git
+cd tf-slim
+pip3 install .
+```
+
+Note that these commands assume you are cloning using SSH. If you are using
+HTTPS instead, use `git clone https://github.com/google-research/tf-slim.git`
+instead. See
+[this link](https://help.github.com/en/github/using-git/which-remote-url-should-i-use)
+for more information.
+
+### Protobuf
+
+The DELF library uses [protobuf](https://github.com/google/protobuf) (the python
+version) to configure feature extraction and its format. You will need the
+`protoc` compiler, version >= 3.3. The easiest way to get it is to download
+directly. For Linux, this can be done as (see
+[here](https://github.com/google/protobuf/releases) for other platforms):
+
+```bash
+wget https://github.com/google/protobuf/releases/download/v3.3.0/protoc-3.3.0-linux-x86_64.zip
+unzip protoc-3.3.0-linux-x86_64.zip
+PATH_TO_PROTOC=`pwd`
+```
+
+### Python dependencies
+
+Install python library dependencies:
+
+```bash
+pip3 install matplotlib numpy scikit-image scipy
+sudo apt-get install python3-tk
+```
+
+### `tensorflow/models`
+
+Now, clone `tensorflow/models`, and install required libraries: (note that the
+`object_detection` library requires you to add `tensorflow/models/research/` to
+your `PYTHONPATH`, as instructed
+[here](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md))
+
+```bash
+git clone git@github.com:tensorflow/models.git
+
+# Setup the object_detection module by editing PYTHONPATH.
+cd ..
+# From tensorflow/models/research/
+export PYTHONPATH=$PYTHONPATH:`pwd`
+```
+
+Note that these commands assume you are cloning using SSH. If you are using
+HTTPS instead, use `git clone https://github.com/tensorflow/models.git` instead.
+See
+[this link](https://help.github.com/en/github/using-git/which-remote-url-should-i-use)
+for more information.
+
+Then, compile DELF's protobufs. Use `PATH_TO_PROTOC` as the directory where you
+downloaded the `protoc` compiler.
+
+```bash
+# From tensorflow/models/research/delf/
+${PATH_TO_PROTOC?}/bin/protoc delf/protos/*.proto --python_out=.
+```
+
+Finally, install the DELF package. This may also install some other dependencies
+under the hood.
+
+```bash
+# From tensorflow/models/research/delf/
+pip3 install -e . # Install "delf" package.
+```
+
+At this point, running
+
+```bash
+python3 -c 'import delf'
+```
+
+should just return without complaints. This indicates that the DELF package is
+loaded successfully.
+
+### Troubleshooting
+
+#### `pip3 install`
+
+Issues might be observed if using `pip3 install` with `-e` option (editable
+mode). You may try out to simply remove the `-e` from the commands above. Also,
+depending on your machine setup, you might need to run the `sudo pip3 install`
+command, that is with a `sudo` at the beginning.
+
+#### Cloning github repositories
+
+The default commands above assume you are cloning using SSH. If you are using
+HTTPS instead, use for example `git clone
+https://github.com/tensorflow/models.git` instead of `git clone
+git@github.com:tensorflow/models.git`. See
+[this link](https://help.github.com/en/github/using-git/which-remote-url-should-i-use)
+for more information.
diff --git a/models/research/delf/README.md b/models/research/delf/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f10852759c3455ae2990475ea917a4e45ee96264
--- /dev/null
+++ b/models/research/delf/README.md
@@ -0,0 +1,324 @@
+# Deep Local and Global Image Features
+
+[](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
+[](https://www.python.org/downloads/release/python-360/)
+
+This project presents code for deep local and global image feature methods,
+which are particularly useful for the computer vision tasks of instance-level
+recognition and retrieval. These were introduced in the
+[DELF](https://arxiv.org/abs/1612.06321),
+[Detect-to-Retrieve](https://arxiv.org/abs/1812.01584),
+[DELG](https://arxiv.org/abs/2001.05027) and
+[Google Landmarks Dataset v2](https://arxiv.org/abs/2004.01804) papers.
+
+We provide Tensorflow code for building and training models, and python code for
+image retrieval and local feature matching. Pre-trained models for the landmark
+recognition domain are also provided.
+
+If you make use of this codebase, please consider citing the following papers:
+
+DELF:
+[](https://arxiv.org/abs/1612.06321)
+
+```
+"Large-Scale Image Retrieval with Attentive Deep Local Features",
+H. Noh, A. Araujo, J. Sim, T. Weyand and B. Han,
+Proc. ICCV'17
+```
+
+Detect-to-Retrieve:
+[](https://arxiv.org/abs/1812.01584)
+
+```
+"Detect-to-Retrieve: Efficient Regional Aggregation for Image Search",
+M. Teichmann*, A. Araujo*, M. Zhu and J. Sim,
+Proc. CVPR'19
+```
+
+DELG:
+[](https://arxiv.org/abs/2001.05027)
+
+```
+"Unifying Deep Local and Global Features for Image Search",
+B. Cao*, A. Araujo* and J. Sim,
+arxiv:2001.05027
+```
+
+GLDv2:
+[](https://arxiv.org/abs/2004.01804)
+
+```
+"Google Landmarks Dataset v2 - A Large-Scale Benchmark for Instance-Level Recognition and Retrieval",
+T. Weyand*, A. Araujo*, B. Cao and J. Sim,
+Proc. CVPR'20
+```
+
+## News
+
+- [Apr'20] Check out our CVPR'20 paper: ["Google Landmarks Dataset v2 - A
+ Large-Scale Benchmark for Instance-Level Recognition and
+ Retrieval"](https://arxiv.org/abs/2004.01804)
+- [Jan'20] Check out our new paper:
+ ["Unifying Deep Local and Global Features for Image Search"](https://arxiv.org/abs/2001.05027)
+- [Jun'19] DELF achieved 2nd place in
+ [CVPR Visual Localization challenge (Local Features track)](https://sites.google.com/corp/view/ltvl2019).
+ See our slides
+ [here](https://docs.google.com/presentation/d/e/2PACX-1vTswzoXelqFqI_pCEIVl2uazeyGr7aKNklWHQCX-CbQ7MB17gaycqIaDTguuUCRm6_lXHwCdrkP7n1x/pub?start=false&loop=false&delayms=3000).
+- [Apr'19] Check out our CVPR'19 paper:
+ ["Detect-to-Retrieve: Efficient Regional Aggregation for Image Search"](https://arxiv.org/abs/1812.01584)
+- [Jun'18] DELF achieved state-of-the-art results in a CVPR'18 image retrieval
+ paper: [Radenovic et al., "Revisiting Oxford and Paris: Large-Scale Image
+ Retrieval Benchmarking"](https://arxiv.org/abs/1803.11285).
+- [Apr'18] DELF was featured in
+ [ModelDepot](https://modeldepot.io/mikeshi/delf/overview)
+- [Mar'18] DELF is now available in
+ [TF-Hub](https://www.tensorflow.org/hub/modules/google/delf/1)
+
+## Datasets
+
+We have two Google-Landmarks dataset versions:
+
+- Initial version (v1) can be found
+ [here](https://www.kaggle.com/google/google-landmarks-dataset). In includes
+ the Google Landmark Boxes which were described in the Detect-to-Retrieve
+ paper.
+- Second version (v2) has been released as part of two Kaggle challenges:
+ [Landmark Recognition](https://www.kaggle.com/c/landmark-recognition-2019)
+ and [Landmark Retrieval](https://www.kaggle.com/c/landmark-retrieval-2019).
+ It can be downloaded from CVDF
+ [here](https://github.com/cvdfoundation/google-landmark). See also
+ [the CVPR'20 paper](https://arxiv.org/abs/2004.01804) on this new dataset
+ version.
+
+If you make use of these datasets in your research, please consider citing the
+papers mentioned above.
+
+## Installation
+
+To be able to use this code, please follow
+[these instructions](INSTALL_INSTRUCTIONS.md) to properly install the DELF
+library.
+
+## Quick start
+
+### Pre-trained models
+
+We release several pre-trained models. See instructions in the following
+sections for examples on how to use the models.
+
+**DELF pre-trained on the Google-Landmarks dataset v1**
+([link](http://storage.googleapis.com/delf/delf_gld_20190411.tar.gz)). Presented
+in the [Detect-to-Retrieve paper](https://arxiv.org/abs/1812.01584). Boosts
+performance by ~4% mAP compared to ICCV'17 DELF model.
+
+**DELG pre-trained on the Google-Landmarks dataset v1**
+([link](http://storage.googleapis.com/delf/delg_gld_20200520.tar.gz)). Presented
+in the [DELG paper](https://arxiv.org/abs/2001.05027).
+
+**RN101-ArcFace pre-trained on the Google-Landmarks dataset v2 (train-clean)**
+([link](https://storage.googleapis.com/delf/rn101_af_gldv2clean_20200521.tar.gz)).
+Presented in the [GLDv2 paper](https://arxiv.org/abs/2004.01804).
+
+**DELF pre-trained on Landmarks-Clean/Landmarks-Full dataset**
+([link](http://storage.googleapis.com/delf/delf_v1_20171026.tar.gz)). Presented
+in the [DELF paper](https://arxiv.org/abs/1612.06321), model was trained on the
+dataset released by the [DIR paper](https://arxiv.org/abs/1604.01325).
+
+**Faster-RCNN detector pre-trained on Google Landmark Boxes**
+([link](http://storage.googleapis.com/delf/d2r_frcnn_20190411.tar.gz)).
+Presented in the [Detect-to-Retrieve paper](https://arxiv.org/abs/1812.01584).
+
+**MobileNet-SSD detector pre-trained on Google Landmark Boxes**
+([link](http://storage.googleapis.com/delf/d2r_mnetssd_20190411.tar.gz)).
+Presented in the [Detect-to-Retrieve paper](https://arxiv.org/abs/1812.01584).
+
+Besides these, we also release pre-trained codebooks for local feature
+aggregation. See the
+[Detect-to-Retrieve instructions](delf/python/detect_to_retrieve/DETECT_TO_RETRIEVE_INSTRUCTIONS.md)
+for details.
+
+### DELF extraction and matching
+
+Please follow [these instructions](EXTRACTION_MATCHING.md). At the end, you
+should obtain a nice figure showing local feature matches, as:
+
+
+
+### DELF training
+
+Please follow [these instructions](delf/python/training/README.md).
+
+### DELG
+
+Please follow [these instructions](delf/python/delg/DELG_INSTRUCTIONS.md). At
+the end, you should obtain image retrieval results on the Revisited Oxford/Paris
+datasets.
+
+### GLDv2 baseline
+
+Please follow
+[these instructions](delf/python/google_landmarks_dataset/README.md). At the
+end, you should obtain image retrieval results on the Revisited Oxford/Paris
+datasets.
+
+### Landmark detection
+
+Please follow [these instructions](DETECTION.md). At the end, you should obtain
+a nice figure showing a detection, as:
+
+
+
+### Detect-to-Retrieve
+
+Please follow
+[these instructions](delf/python/detect_to_retrieve/DETECT_TO_RETRIEVE_INSTRUCTIONS.md).
+At the end, you should obtain image retrieval results on the Revisited
+Oxford/Paris datasets.
+
+## Code overview
+
+DELF/D2R/DELG/GLD code is located under the `delf` directory. There are two
+directories therein, `protos` and `python`.
+
+### `delf/protos`
+
+This directory contains protobufs:
+
+- `aggregation_config.proto`: protobuf for configuring local feature
+ aggregation.
+- `box.proto`: protobuf for serializing detected boxes.
+- `datum.proto`: general-purpose protobuf for serializing float tensors.
+- `delf_config.proto`: protobuf for configuring DELF/DELG extraction.
+- `feature.proto`: protobuf for serializing DELF features.
+
+### `delf/python`
+
+This directory contains files for several different purposes:
+
+- `box_io.py`, `datum_io.py`, `feature_io.py` are helper files for reading and
+ writing tensors and features.
+- `delf_v1.py` contains code to create DELF models.
+- `feature_aggregation_extractor.py` contains a module to perform local
+ feature aggregation.
+- `feature_aggregation_similarity.py` contains a module to perform similarity
+ computation for aggregated local features.
+- `feature_extractor.py` contains the code to extract features using DELF.
+ This is particularly useful for extracting features over multiple scales,
+ with keypoint selection based on attention scores, and PCA/whitening
+ post-processing.
+
+The subdirectory `delf/python/examples` contains sample scripts to run DELF
+feature extraction/matching, and object detection:
+
+- `delf_config_example.pbtxt` shows an example instantiation of the DelfConfig
+ proto, used for DELF feature extraction.
+- `detector.py` is a module to construct an object detector function.
+- `extract_boxes.py` enables object detection from a list of images.
+- `extract_features.py` enables DELF extraction from a list of images.
+- `extractor.py` is a module to construct a DELF/DELG local feature extraction
+ function.
+- `match_images.py` supports image matching using DELF features extracted
+ using `extract_features.py`.
+
+The subdirectory `delf/python/delg` contains sample scripts/configs related to
+the DELG paper:
+
+- `delg_gld_config.pbtxt` gives the DelfConfig used in DELG paper.
+- `extract_features.py` for local+global feature extraction on Revisited
+ datasets.
+- `perform_retrieval.py` for performing retrieval/evaluating methods on
+ Revisited datasets.
+
+The subdirectory `delf/python/detect_to_retrieve` contains sample
+scripts/configs related to the Detect-to-Retrieve paper:
+
+- `aggregation_extraction.py` is a library to extract/save feature
+ aggregation.
+- `boxes_and_features_extraction.py` is a library to extract/save boxes and
+ DELF features.
+- `cluster_delf_features.py` for local feature clustering.
+- `dataset.py` for parsing/evaluating results on Revisited Oxford/Paris
+ datasets.
+- `delf_gld_config.pbtxt` gives the DelfConfig used in Detect-to-Retrieve
+ paper.
+- `extract_aggregation.py` for aggregated local feature extraction.
+- `extract_index_boxes_and_features.py` for index image local feature
+ extraction / bounding box detection on Revisited datasets.
+- `extract_query_features.py` for query image local feature extraction on
+ Revisited datasets.
+- `image_reranking.py` is a module to re-rank images with geometric
+ verification.
+- `perform_retrieval.py` for performing retrieval/evaluating methods using
+ aggregated local features on Revisited datasets.
+- `index_aggregation_config.pbtxt`, `query_aggregation_config.pbtxt` give
+ AggregationConfig's for Detect-to-Retrieve experiments.
+
+The subdirectory `delf/python/google_landmarks_dataset` contains sample
+scripts/modules for computing GLD metrics / reproducing results from the GLDv2
+paper:
+
+- `compute_recognition_metrics.py` performs recognition metric computation
+ given input predictions and solution files.
+- `compute_retrieval_metrics.py` performs retrieval metric computation given
+ input predictions and solution files.
+- `dataset_file_io.py` is a module for dataset-related file IO.
+- `metrics.py` is a module for GLD metric computation.
+- `rn101_af_gldv2clean_config.pbtxt` gives the DelfConfig used in the
+ ResNet101-ArcFace (trained on GLDv2-train-clean) baseline used in the GLDv2
+ paper.
+
+The subdirectory `delf/python/training` contains sample scripts/modules for
+performing DELF training:
+
+- `datasets/googlelandmarks.py` is the dataset module used for training.
+- `model/delf_model.py` is the model module used for training.
+- `model/export_model.py` is a script for exporting trained models in the
+ format used by the inference code.
+- `model/export_model_utils.py` is a module with utilities for model
+ exporting.
+- `model/resnet50.py` is a module with a backbone RN50 implementation.
+- `build_image_dataset.py` converts downloaded dataset into TFRecords format
+ for training.
+- `train.py` is the main training script.
+
+Besides these, other files in the different subdirectories contain tests for the
+various modules.
+
+## Maintainers
+
+André Araujo (@andrefaraujo)
+
+## Release history
+
+### May, 2020
+
+- Codebase is now Python3-first
+- DELG model/code released
+- GLDv2 baseline model released
+
+**Thanks to contributors**: Barbara Fusinska and André Araujo.
+
+### April, 2020 (version 2.0)
+
+- Initial DELF training code released.
+- Codebase is now fully compatible with TF 2.1.
+
+**Thanks to contributors**: Arun Mukundan, Yuewei Na and André Araujo.
+
+### April, 2019
+
+Detect-to-Retrieve code released.
+
+Includes pre-trained models to detect landmark boxes, and DELF model pre-trained
+on Google Landmarks v1 dataset.
+
+**Thanks to contributors**: André Araujo, Marvin Teichmann, Menglong Zhu,
+Jack Sim.
+
+### October, 2017
+
+Initial release containing DELF-v1 code, including feature extraction and
+matching examples. Pre-trained DELF model from ICCV'17 paper is released.
+
+**Thanks to contributors**: André Araujo, Hyeonwoo Noh, Youlong Cheng,
+Jack Sim.
diff --git a/models/research/delf/delf/__init__.py b/models/research/delf/delf/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a52df3c4546414e61f479357d06b65d4c132c753
--- /dev/null
+++ b/models/research/delf/delf/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module to extract deep local features."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from delf.protos import aggregation_config_pb2
+from delf.protos import box_pb2
+from delf.protos import datum_pb2
+from delf.protos import delf_config_pb2
+from delf.protos import feature_pb2
+from delf.python import box_io
+from delf.python import datum_io
+from delf.python import feature_aggregation_extractor
+from delf.python import feature_aggregation_similarity
+from delf.python import feature_extractor
+from delf.python import feature_io
+from delf.python import utils
+from delf.python.examples import detector
+from delf.python.examples import extractor
+from delf.python import detect_to_retrieve
+from delf.python import training
+from delf.python.training import model
+from delf.python.training import datasets
+# pylint: enable=unused-import
diff --git a/models/research/delf/delf/protos/__init__.py b/models/research/delf/delf/protos/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/delf/delf/protos/aggregation_config.proto b/models/research/delf/delf/protos/aggregation_config.proto
new file mode 100644
index 0000000000000000000000000000000000000000..b1d5953d43ffc84f435c57be7145c7c17ae01186
--- /dev/null
+++ b/models/research/delf/delf/protos/aggregation_config.proto
@@ -0,0 +1,63 @@
+// Protocol buffer for feature aggregation configuration.
+//
+// Used for both extraction and comparison of aggregated representations. Note
+// that some options are only relevant for the former or the latter.
+//
+// For more details, please refer to the paper:
+// "Detect-to-Retrieve: Efficient Regional Aggregation for Image Search",
+// Proc. CVPR'19 (https://arxiv.org/abs/1812.01584).
+
+syntax = "proto2";
+
+package delf.protos;
+
+message AggregationConfig {
+ // Number of codewords (ie, visual words) in the codebook.
+ optional int32 codebook_size = 1 [default = 65536];
+
+ // Dimensionality of local features (eg, 128 for DELF used in
+ // Detect-to-Retrieve paper).
+ optional int32 feature_dimensionality = 2 [default = 128];
+
+ // Type of aggregation to use.
+ // For example, to use R-ASMK*, `aggregation_type` should be set to ASMK_STAR
+ // and `use_regional_aggregation` should be set to true.
+ enum AggregationType {
+ INVALID = 0;
+ VLAD = 1;
+ ASMK = 2;
+ ASMK_STAR = 3;
+ }
+ optional AggregationType aggregation_type = 3 [default = ASMK_STAR];
+
+ // L2 normalization option.
+ // - For vanilla aggregated kernels (eg, VLAD/ASMK/ASMK*), this should be
+ // set to true.
+ // - For regional aggregated kernels (ie, if `use_regional_aggregation` is
+ // true, leading to R-VLAD/R-ASMK/R-ASMK*), this should be set to false.
+ // Note that it is used differently depending on the `aggregation_type`:
+ // - For VLAD, this option is only used for extraction.
+ // - For ASMK/ASMK*, this option is only used for comparisons.
+ optional bool use_l2_normalization = 4 [default = true];
+
+ // Additional options used only for extraction.
+ // - Path to codebook checkpoint for aggregation.
+ optional string codebook_path = 5;
+ // - Number of visual words to assign each feature.
+ optional int32 num_assignments = 6 [default = 1];
+ // - Whether to use regional aggregation.
+ optional bool use_regional_aggregation = 7 [default = false];
+ // - Batch size to use for local features when computing aggregated
+ // representations. Particularly useful if `codebook_size` and
+ // `feature_dimensionality` are large, to avoid OOM. A value of zero or
+ // lower indicates that no batching is used.
+ optional int32 feature_batch_size = 10 [default = 100];
+
+ // Additional options used only for comparison.
+ // Only relevant if `aggregation_type` is ASMK or ASMK_STAR.
+ // - Power-law exponent for similarity of visual word descriptors.
+ optional float alpha = 8 [default = 3.0];
+ // - Threshold above which similarity of visual word descriptors are
+ // considered; below this, similarity is set to zero.
+ optional float tau = 9 [default = 0.0];
+}
diff --git a/models/research/delf/delf/protos/box.proto b/models/research/delf/delf/protos/box.proto
new file mode 100644
index 0000000000000000000000000000000000000000..28da7fb71410262f9be98e206e756c82ba3beb38
--- /dev/null
+++ b/models/research/delf/delf/protos/box.proto
@@ -0,0 +1,24 @@
+// Protocol buffer for serializing detected bounding boxes.
+
+syntax = "proto2";
+
+package delf.protos;
+
+message Box {
+ // Coordinates: [ymin, xmin, ymax, xmax] corresponds to
+ // [top, left, bottom, right].
+ optional float ymin = 1;
+ optional float xmin = 2;
+ optional float ymax = 3;
+ optional float xmax = 4;
+
+ // Detection score. Usually, the higher the more confident.
+ optional float score = 5;
+
+ // Indicates which class the box corresponds to.
+ optional int32 class_index = 6;
+}
+
+message Boxes {
+ repeated Box box = 1;
+}
diff --git a/models/research/delf/delf/protos/datum.proto b/models/research/delf/delf/protos/datum.proto
new file mode 100644
index 0000000000000000000000000000000000000000..6806e56b25e912bfb6a87280432c8566dba0c41a
--- /dev/null
+++ b/models/research/delf/delf/protos/datum.proto
@@ -0,0 +1,66 @@
+// Protocol buffer for serializing arbitrary float tensors.
+// Note: Currently only floating point feature is supported.
+
+syntax = "proto2";
+
+package delf.protos;
+
+// A DatumProto is a data structure used to serialize tensor with arbitrary
+// shape. DatumProto contains an array of floating point values and its shape
+// is represented as a sequence of integer values. Values are contained in
+// row major order.
+//
+// Example:
+// 3 x 2 array
+//
+// [1.1, 2.2]
+// [3.3, 4.4]
+// [5.5, 6.6]
+//
+// can be represented with the following DatumProto:
+//
+// DatumProto {
+// shape {
+// dim: 3
+// dim: 2
+// }
+// float_list {
+// value: 1.1
+// value: 2.2
+// value: 3.3
+// value: 4.4
+// value: 5.5
+// value: 6.6
+// }
+// }
+
+// DatumShape is array of dimension of the tensor.
+message DatumShape {
+ repeated int64 dim = 1 [packed = true];
+}
+
+// FloatList is a container of tensor values, which are saved as a list of
+// floating point values.
+message FloatList {
+ repeated float value = 1 [packed = true];
+}
+
+// Uint32List is a container of tensor values, which are saved as a list of
+// uint32 values.
+message Uint32List {
+ repeated uint32 value = 1 [packed = true];
+}
+
+message DatumProto {
+ optional DatumShape shape = 1;
+ oneof kind_oneof {
+ FloatList float_list = 2;
+ Uint32List uint32_list = 3;
+ }
+}
+
+// Groups two DatumProto's.
+message DatumPairProto {
+ optional DatumProto first = 1;
+ optional DatumProto second = 2;
+}
diff --git a/models/research/delf/delf/protos/delf_config.proto b/models/research/delf/delf/protos/delf_config.proto
new file mode 100644
index 0000000000000000000000000000000000000000..10ae0a614cbdd483f08f1f9f806a9d3adbe6b46d
--- /dev/null
+++ b/models/research/delf/delf/protos/delf_config.proto
@@ -0,0 +1,121 @@
+// Protocol buffer for configuring DELF feature extraction.
+
+syntax = "proto2";
+
+package delf.protos;
+
+message DelfPcaParameters {
+ // Path to PCA mean file.
+ optional string mean_path = 1; // Required.
+
+ // Path to PCA matrix file.
+ optional string projection_matrix_path = 2; // Required.
+
+ // Dimensionality of feature after PCA.
+ optional int32 pca_dim = 3; // Required.
+
+ // If whitening is to be used, this must be set to true.
+ optional bool use_whitening = 4 [default = false];
+
+ // Path to PCA variances file, used for whitening. This is used only if
+ // use_whitening is set to true.
+ optional string pca_variances_path = 5;
+}
+
+message DelfLocalFeatureConfig {
+ // If PCA is to be used, this must be set to true.
+ optional bool use_pca = 1 [default = true];
+
+ // Target layer name for DELF model. This is used to obtain receptive field
+ // parameters used for localizing features with respect to the input image.
+ optional string layer_name = 2 [default = ""];
+
+ // Intersection over union threshold for the non-max suppression (NMS)
+ // operation. If two features overlap by at most this amount, both are kept.
+ // Otherwise, the one with largest attention score is kept. This should be a
+ // number between 0.0 (no region is selected) and 1.0 (all regions are
+ // selected and NMS is not performed).
+ optional float iou_threshold = 3 [default = 1.0];
+
+ // Maximum number of features that will be selected. The features with largest
+ // scores (eg, largest attention score if score_type is "Att") are the
+ // selected ones.
+ optional int32 max_feature_num = 4 [default = 1000];
+
+ // Threshold to be used for feature selection: no feature with score lower
+ // than this number will be selected).
+ optional float score_threshold = 5 [default = 100.0];
+
+ // PCA parameters for DELF local feature. This is used only if use_pca is
+ // true.
+ optional DelfPcaParameters pca_parameters = 6;
+}
+
+message DelfGlobalFeatureConfig {
+ // If PCA is to be used, this must be set to true.
+ optional bool use_pca = 1 [default = true];
+
+ // PCA parameters for DELF global feature. This is used only if use_pca is
+ // true.
+ optional DelfPcaParameters pca_parameters = 2;
+
+ // Denotes indices of DelfConfig's scales that will be used for global
+ // descriptor extraction. For example, if DelfConfig's image_scales are
+ // [0.25, 0.5, 1.0] and image_scales_ind is [0, 2], global descriptor
+ // extraction will use solely scales [0.25, 1.0]. Note that local feature
+ // extraction will still use [0.25, 0.5, 1.0] in this case. If empty (default)
+ // , all scales are used.
+ repeated int32 image_scales_ind = 3;
+}
+
+message DelfConfig {
+ // Whether to extract local features when using the model.
+ // At least one of {use_local_features, use_global_features} must be true.
+ optional bool use_local_features = 7 [default = true];
+ // Configuration used for local features. Note: this is used only if
+ // use_local_features is true.
+ optional DelfLocalFeatureConfig delf_local_config = 3;
+
+ // Whether to extract global features when using the model.
+ // At least one of {use_local_features, use_global_features} must be true.
+ optional bool use_global_features = 8 [default = false];
+ // Configuration used for global features. Note: this is used only if
+ // use_global_features is true.
+ optional DelfGlobalFeatureConfig delf_global_config = 9;
+
+ // Path to DELF model.
+ optional string model_path = 1; // Required.
+
+ // Image scales to be used.
+ repeated float image_scales = 2;
+
+ // Image resizing options.
+ // - The maximum/minimum image size (in terms of height or width) to be used
+ // when extracting DELF features. If set to -1 (default), no upper/lower
+ // bound for image size. If use_square_images option is false (default):
+ // * If the height *OR* width is larger than max_image_size, it will be
+ // resized to max_image_size, and the other dimension will be resized by
+ // preserving the aspect ratio.
+ // * If both height *AND* width are smaller than min_image_size, the larger
+ // side is set to min_image_size.
+ // - If use_square_images option is true, it needs to be resized to square
+ // resolution. To be more specific:
+ // * If the height *OR* width is larger than max_image_size, it is resized
+ // to square resolution of max_image_size.
+ // * If both height *AND* width are smaller than min_image_size, it is
+ // resized to square resolution of min_image_size.
+ // * Else, if the input image's resolution is not square, it is resized to
+ // square resolution of the larger side.
+ // Image resizing is useful when we want to ensure that the input to the image
+ // pyramid has a reasonable number of pixels, which could have large impact in
+ // terms of image matching performance.
+ // When using local features, note that the feature locations and scales will
+ // be consistent with the original image input size.
+ // Note that when both max_image_size and min_image_size are specified
+ // (which is a valid and legit use case), as long as max_image_size >=
+ // min_image_size, there's no conflicting scenario (i.e. never triggers both
+ // enlarging / shrinking). Bilinear interpolation is used.
+ optional int32 max_image_size = 4 [default = -1];
+ optional int32 min_image_size = 5 [default = -1];
+ optional bool use_square_images = 6 [default = false];
+}
diff --git a/models/research/delf/delf/protos/feature.proto b/models/research/delf/delf/protos/feature.proto
new file mode 100644
index 0000000000000000000000000000000000000000..64c342fe2c36b9170de10628b8ddc83ee3cfb2c6
--- /dev/null
+++ b/models/research/delf/delf/protos/feature.proto
@@ -0,0 +1,22 @@
+// Protocol buffer for serializing the DELF feature information.
+
+syntax = "proto2";
+
+package delf.protos;
+
+import "delf/protos/datum.proto";
+
+// FloatList is the container of tensor values. The tensor values are saved as
+// a list of floating point values.
+message DelfFeature {
+ optional DatumProto descriptor = 1;
+ optional float x = 2;
+ optional float y = 3;
+ optional float scale = 4;
+ optional float orientation = 5;
+ optional float strength = 6;
+}
+
+message DelfFeatures {
+ repeated DelfFeature feature = 1;
+}
diff --git a/models/research/delf/delf/python/__init__.py b/models/research/delf/delf/python/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/delf/delf/python/box_io.py b/models/research/delf/delf/python/box_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b0f0d2c973d5b83f9110f651f5c5541fad049b7
--- /dev/null
+++ b/models/research/delf/delf/python/box_io.py
@@ -0,0 +1,151 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python interface for Boxes proto.
+
+Support read and write of Boxes from/to numpy arrays and file.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from delf import box_pb2
+
+
+def ArraysToBoxes(boxes, scores, class_indices):
+ """Converts `boxes` to Boxes proto.
+
+ Args:
+ boxes: [N, 4] float array denoting bounding box coordinates, in format [top,
+ left, bottom, right].
+ scores: [N] float array with detection scores.
+ class_indices: [N] int array with class indices.
+
+ Returns:
+ boxes_proto: Boxes object.
+ """
+ num_boxes = len(scores)
+ assert num_boxes == boxes.shape[0]
+ assert num_boxes == len(class_indices)
+
+ boxes_proto = box_pb2.Boxes()
+ for i in range(num_boxes):
+ boxes_proto.box.add(
+ ymin=boxes[i, 0],
+ xmin=boxes[i, 1],
+ ymax=boxes[i, 2],
+ xmax=boxes[i, 3],
+ score=scores[i],
+ class_index=class_indices[i])
+
+ return boxes_proto
+
+
+def BoxesToArrays(boxes_proto):
+ """Converts data saved in Boxes proto to numpy arrays.
+
+ If there are no boxes, the function returns three empty arrays.
+
+ Args:
+ boxes_proto: Boxes proto object.
+
+ Returns:
+ boxes: [N, 4] float array denoting bounding box coordinates, in format [top,
+ left, bottom, right].
+ scores: [N] float array with detection scores.
+ class_indices: [N] int array with class indices.
+ """
+ num_boxes = len(boxes_proto.box)
+ if num_boxes == 0:
+ return np.array([]), np.array([]), np.array([])
+
+ boxes = np.zeros([num_boxes, 4])
+ scores = np.zeros([num_boxes])
+ class_indices = np.zeros([num_boxes])
+
+ for i in range(num_boxes):
+ box_proto = boxes_proto.box[i]
+ boxes[i] = [box_proto.ymin, box_proto.xmin, box_proto.ymax, box_proto.xmax]
+ scores[i] = box_proto.score
+ class_indices[i] = box_proto.class_index
+
+ return boxes, scores, class_indices
+
+
+def SerializeToString(boxes, scores, class_indices):
+ """Converts numpy arrays to serialized Boxes.
+
+ Args:
+ boxes: [N, 4] float array denoting bounding box coordinates, in format [top,
+ left, bottom, right].
+ scores: [N] float array with detection scores.
+ class_indices: [N] int array with class indices.
+
+ Returns:
+ Serialized Boxes string.
+ """
+ boxes_proto = ArraysToBoxes(boxes, scores, class_indices)
+ return boxes_proto.SerializeToString()
+
+
+def ParseFromString(string):
+ """Converts serialized Boxes proto string to numpy arrays.
+
+ Args:
+ string: Serialized Boxes string.
+
+ Returns:
+ boxes: [N, 4] float array denoting bounding box coordinates, in format [top,
+ left, bottom, right].
+ scores: [N] float array with detection scores.
+ class_indices: [N] int array with class indices.
+ """
+ boxes_proto = box_pb2.Boxes()
+ boxes_proto.ParseFromString(string)
+ return BoxesToArrays(boxes_proto)
+
+
+def ReadFromFile(file_path):
+ """Helper function to load data from a Boxes proto format in a file.
+
+ Args:
+ file_path: Path to file containing data.
+
+ Returns:
+ boxes: [N, 4] float array denoting bounding box coordinates, in format [top,
+ left, bottom, right].
+ scores: [N] float array with detection scores.
+ class_indices: [N] int array with class indices.
+ """
+ with tf.io.gfile.GFile(file_path, 'rb') as f:
+ return ParseFromString(f.read())
+
+
+def WriteToFile(file_path, boxes, scores, class_indices):
+ """Helper function to write data to a file in Boxes proto format.
+
+ Args:
+ file_path: Path to file that will be written.
+ boxes: [N, 4] float array denoting bounding box coordinates, in format [top,
+ left, bottom, right].
+ scores: [N] float array with detection scores.
+ class_indices: [N] int array with class indices.
+ """
+ serialized_data = SerializeToString(boxes, scores, class_indices)
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ f.write(serialized_data)
diff --git a/models/research/delf/delf/python/box_io_test.py b/models/research/delf/delf/python/box_io_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c659185daeec7efc9097ff72637d0d8f7c38664b
--- /dev/null
+++ b/models/research/delf/delf/python/box_io_test.py
@@ -0,0 +1,82 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for box_io, the python interface of Boxes proto."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+import numpy as np
+import tensorflow as tf
+
+from delf import box_io
+
+FLAGS = flags.FLAGS
+
+
+class BoxesIoTest(tf.test.TestCase):
+
+ def _create_data(self):
+ """Creates data to be used in tests.
+
+ Returns:
+ boxes: [N, 4] float array denoting bounding box coordinates, in format
+ [top,
+ left, bottom, right].
+ scores: [N] float array with detection scores.
+ class_indices: [N] int array with class indices.
+ """
+ boxes = np.arange(24, dtype=np.float32).reshape(6, 4)
+ scores = np.arange(6, dtype=np.float32)
+ class_indices = np.arange(6, dtype=np.int32)
+
+ return boxes, scores, class_indices
+
+ def testConversionAndBack(self):
+ boxes, scores, class_indices = self._create_data()
+
+ serialized = box_io.SerializeToString(boxes, scores, class_indices)
+ parsed_data = box_io.ParseFromString(serialized)
+
+ self.assertAllEqual(boxes, parsed_data[0])
+ self.assertAllEqual(scores, parsed_data[1])
+ self.assertAllEqual(class_indices, parsed_data[2])
+
+ def testWriteAndReadToFile(self):
+ boxes, scores, class_indices = self._create_data()
+
+ filename = os.path.join(FLAGS.test_tmpdir, 'test.boxes')
+ box_io.WriteToFile(filename, boxes, scores, class_indices)
+ data_read = box_io.ReadFromFile(filename)
+
+ self.assertAllEqual(boxes, data_read[0])
+ self.assertAllEqual(scores, data_read[1])
+ self.assertAllEqual(class_indices, data_read[2])
+
+ def testWriteAndReadToFileEmptyFile(self):
+ filename = os.path.join(FLAGS.test_tmpdir, 'test.box')
+ box_io.WriteToFile(filename, np.array([]), np.array([]), np.array([]))
+ data_read = box_io.ReadFromFile(filename)
+
+ self.assertAllEqual(np.array([]), data_read[0])
+ self.assertAllEqual(np.array([]), data_read[1])
+ self.assertAllEqual(np.array([]), data_read[2])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/datum_io.py b/models/research/delf/delf/python/datum_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0d4cbfd11a140c6805c1fa017b7328cf3d04e38
--- /dev/null
+++ b/models/research/delf/delf/python/datum_io.py
@@ -0,0 +1,221 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python interface for DatumProto.
+
+DatumProto is protocol buffer used to serialize tensor with arbitrary shape.
+Please refer to datum.proto for details.
+
+Support read and write of DatumProto from/to NumPy array and file.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from delf import datum_pb2
+
+
+def ArrayToDatum(arr):
+ """Converts NumPy array to DatumProto.
+
+ Supports arrays of types:
+ - float16 (it is converted into a float32 in DatumProto)
+ - float32
+ - float64 (it is converted into a float32 in DatumProto)
+ - uint8 (it is converted into a uint32 in DatumProto)
+ - uint16 (it is converted into a uint32 in DatumProto)
+ - uint32
+ - uint64 (it is converted into a uint32 in DatumProto)
+
+ Args:
+ arr: NumPy array of arbitrary shape.
+
+ Returns:
+ datum: DatumProto object.
+
+ Raises:
+ ValueError: If array type is unsupported.
+ """
+ datum = datum_pb2.DatumProto()
+ if arr.dtype in ('float16', 'float32', 'float64'):
+ datum.float_list.value.extend(arr.astype('float32').flat)
+ elif arr.dtype in ('uint8', 'uint16', 'uint32', 'uint64'):
+ datum.uint32_list.value.extend(arr.astype('uint32').flat)
+ else:
+ raise ValueError('Unsupported array type: %s' % arr.dtype)
+
+ datum.shape.dim.extend(arr.shape)
+ return datum
+
+
+def ArraysToDatumPair(arr_1, arr_2):
+ """Converts numpy arrays to DatumPairProto.
+
+ Supports same formats as `ArrayToDatum`, see documentation therein.
+
+ Args:
+ arr_1: NumPy array of arbitrary shape.
+ arr_2: NumPy array of arbitrary shape.
+
+ Returns:
+ datum_pair: DatumPairProto object.
+ """
+ datum_pair = datum_pb2.DatumPairProto()
+ datum_pair.first.CopyFrom(ArrayToDatum(arr_1))
+ datum_pair.second.CopyFrom(ArrayToDatum(arr_2))
+
+ return datum_pair
+
+
+def DatumToArray(datum):
+ """Converts data saved in DatumProto to NumPy array.
+
+ Args:
+ datum: DatumProto object.
+
+ Returns:
+ NumPy array of arbitrary shape.
+ """
+ if datum.HasField('float_list'):
+ return np.array(datum.float_list.value).astype('float32').reshape(
+ datum.shape.dim)
+ elif datum.HasField('uint32_list'):
+ return np.array(datum.uint32_list.value).astype('uint32').reshape(
+ datum.shape.dim)
+ else:
+ raise ValueError('Input DatumProto does not have float_list or uint32_list')
+
+
+def DatumPairToArrays(datum_pair):
+ """Converts data saved in DatumPairProto to NumPy arrays.
+
+ Args:
+ datum_pair: DatumPairProto object.
+
+ Returns:
+ Two NumPy arrays of arbitrary shape.
+ """
+ first_datum = DatumToArray(datum_pair.first)
+ second_datum = DatumToArray(datum_pair.second)
+ return first_datum, second_datum
+
+
+def SerializeToString(arr):
+ """Converts NumPy array to serialized DatumProto.
+
+ Args:
+ arr: NumPy array of arbitrary shape.
+
+ Returns:
+ Serialized DatumProto string.
+ """
+ datum = ArrayToDatum(arr)
+ return datum.SerializeToString()
+
+
+def SerializePairToString(arr_1, arr_2):
+ """Converts pair of NumPy arrays to serialized DatumPairProto.
+
+ Args:
+ arr_1: NumPy array of arbitrary shape.
+ arr_2: NumPy array of arbitrary shape.
+
+ Returns:
+ Serialized DatumPairProto string.
+ """
+ datum_pair = ArraysToDatumPair(arr_1, arr_2)
+ return datum_pair.SerializeToString()
+
+
+def ParseFromString(string):
+ """Converts serialized DatumProto string to NumPy array.
+
+ Args:
+ string: Serialized DatumProto string.
+
+ Returns:
+ NumPy array.
+ """
+ datum = datum_pb2.DatumProto()
+ datum.ParseFromString(string)
+ return DatumToArray(datum)
+
+
+def ParsePairFromString(string):
+ """Converts serialized DatumPairProto string to NumPy arrays.
+
+ Args:
+ string: Serialized DatumProto string.
+
+ Returns:
+ Two NumPy arrays.
+ """
+ datum_pair = datum_pb2.DatumPairProto()
+ datum_pair.ParseFromString(string)
+ return DatumPairToArrays(datum_pair)
+
+
+def ReadFromFile(file_path):
+ """Helper function to load data from a DatumProto format in a file.
+
+ Args:
+ file_path: Path to file containing data.
+
+ Returns:
+ data: NumPy array.
+ """
+ with tf.io.gfile.GFile(file_path, 'rb') as f:
+ return ParseFromString(f.read())
+
+
+def ReadPairFromFile(file_path):
+ """Helper function to load data from a DatumPairProto format in a file.
+
+ Args:
+ file_path: Path to file containing data.
+
+ Returns:
+ Two NumPy arrays.
+ """
+ with tf.io.gfile.GFile(file_path, 'rb') as f:
+ return ParsePairFromString(f.read())
+
+
+def WriteToFile(data, file_path):
+ """Helper function to write data to a file in DatumProto format.
+
+ Args:
+ data: NumPy array.
+ file_path: Path to file that will be written.
+ """
+ serialized_data = SerializeToString(data)
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ f.write(serialized_data)
+
+
+def WritePairToFile(arr_1, arr_2, file_path):
+ """Helper function to write pair of arrays to a file in DatumPairProto format.
+
+ Args:
+ arr_1: NumPy array of arbitrary shape.
+ arr_2: NumPy array of arbitrary shape.
+ file_path: Path to file that will be written.
+ """
+ serialized_data = SerializePairToString(arr_1, arr_2)
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ f.write(serialized_data)
diff --git a/models/research/delf/delf/python/datum_io_test.py b/models/research/delf/delf/python/datum_io_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3587a10017f93af49715a9becc1ed72da8ebe69
--- /dev/null
+++ b/models/research/delf/delf/python/datum_io_test.py
@@ -0,0 +1,97 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for datum_io, the python interface of DatumProto."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+import numpy as np
+import tensorflow as tf
+
+from delf import datum_io
+
+FLAGS = flags.FLAGS
+
+
+class DatumIoTest(tf.test.TestCase):
+
+ def Conversion2dTestWithType(self, dtype):
+ original_data = np.arange(9).reshape(3, 3).astype(dtype)
+ serialized = datum_io.SerializeToString(original_data)
+ retrieved_data = datum_io.ParseFromString(serialized)
+ self.assertTrue(np.array_equal(original_data, retrieved_data))
+
+ def Conversion3dTestWithType(self, dtype):
+ original_data = np.arange(24).reshape(2, 3, 4).astype(dtype)
+ serialized = datum_io.SerializeToString(original_data)
+ retrieved_data = datum_io.ParseFromString(serialized)
+ self.assertTrue(np.array_equal(original_data, retrieved_data))
+
+ # This test covers the following functions: ArrayToDatum, SerializeToString,
+ # ParseFromString, DatumToArray.
+ def testConversion2dWithType(self):
+ self.Conversion2dTestWithType(np.uint16)
+ self.Conversion2dTestWithType(np.uint32)
+ self.Conversion2dTestWithType(np.uint64)
+ self.Conversion2dTestWithType(np.float16)
+ self.Conversion2dTestWithType(np.float32)
+ self.Conversion2dTestWithType(np.float64)
+
+ # This test covers the following functions: ArrayToDatum, SerializeToString,
+ # ParseFromString, DatumToArray.
+ def testConversion3dWithType(self):
+ self.Conversion3dTestWithType(np.uint16)
+ self.Conversion3dTestWithType(np.uint32)
+ self.Conversion3dTestWithType(np.uint64)
+ self.Conversion3dTestWithType(np.float16)
+ self.Conversion3dTestWithType(np.float32)
+ self.Conversion3dTestWithType(np.float64)
+
+ def testConversionWithUnsupportedType(self):
+ with self.assertRaisesRegex(ValueError, 'Unsupported array type'):
+ self.Conversion3dTestWithType(int)
+
+ # This test covers the following functions: ArrayToDatum, SerializeToString,
+ # WriteToFile, ReadFromFile, ParseFromString, DatumToArray.
+ def testWriteAndReadToFile(self):
+ data = np.array([[[-1.0, 125.0, -2.5], [14.5, 3.5, 0.0]],
+ [[20.0, 0.0, 30.0], [25.5, 36.0, 42.0]]])
+ filename = os.path.join(FLAGS.test_tmpdir, 'test.datum')
+ datum_io.WriteToFile(data, filename)
+ data_read = datum_io.ReadFromFile(filename)
+ self.assertAllEqual(data_read, data)
+
+ # This test covers the following functions: ArraysToDatumPair,
+ # SerializePairToString, WritePairToFile, ReadPairFromFile,
+ # ParsePairFromString, DatumPairToArrays.
+ def testWriteAndReadPairToFile(self):
+ data_1 = np.array([[[-1.0, 125.0, -2.5], [14.5, 3.5, 0.0]],
+ [[20.0, 0.0, 30.0], [25.5, 36.0, 42.0]]])
+ data_2 = np.array(
+ [[[255, 0, 5], [10, 300, 0]], [[20, 1, 100], [255, 360, 420]]],
+ dtype='uint32')
+ filename = os.path.join(FLAGS.test_tmpdir, 'test.datum_pair')
+ datum_io.WritePairToFile(data_1, data_2, filename)
+ data_read_1, data_read_2 = datum_io.ReadPairFromFile(filename)
+ self.assertAllEqual(data_read_1, data_1)
+ self.assertAllEqual(data_read_2, data_2)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/delg/DELG_INSTRUCTIONS.md b/models/research/delf/delf/python/delg/DELG_INSTRUCTIONS.md
new file mode 100644
index 0000000000000000000000000000000000000000..2b62ac29003e3d4b13a3ccc8fad5b43e236f96da
--- /dev/null
+++ b/models/research/delf/delf/python/delg/DELG_INSTRUCTIONS.md
@@ -0,0 +1,159 @@
+## DELG instructions
+
+[](https://arxiv.org/abs/2001.05027)
+
+These instructions can be used to reproduce the results from the
+[DELG paper](https://arxiv.org/abs/2001.05027) for the Revisited Oxford/Paris
+datasets.
+
+### Install DELF library
+
+To be able to use this code, please follow
+[these instructions](../../../INSTALL_INSTRUCTIONS.md) to properly install the
+DELF library.
+
+### Download datasets
+
+```bash
+mkdir -p ~/delg/data && cd ~/delg/data
+
+# Oxford dataset.
+wget http://www.robots.ox.ac.uk/~vgg/data/oxbuildings/oxbuild_images.tgz
+mkdir oxford5k_images
+tar -xvzf oxbuild_images.tgz -C oxford5k_images/
+
+# Paris dataset. Download and move all images to same directory.
+wget http://www.robots.ox.ac.uk/~vgg/data/parisbuildings/paris_1.tgz
+wget http://www.robots.ox.ac.uk/~vgg/data/parisbuildings/paris_2.tgz
+mkdir paris6k_images_tmp
+tar -xvzf paris_1.tgz -C paris6k_images_tmp/
+tar -xvzf paris_2.tgz -C paris6k_images_tmp/
+mkdir paris6k_images
+mv paris6k_images_tmp/paris/*/*.jpg paris6k_images/
+
+# Revisited annotations.
+wget http://cmp.felk.cvut.cz/revisitop/data/datasets/roxford5k/gnd_roxford5k.mat
+wget http://cmp.felk.cvut.cz/revisitop/data/datasets/rparis6k/gnd_rparis6k.mat
+```
+
+### Download model
+
+This is necessary to reproduce the main paper results:
+
+```bash
+# From models/research/delf/delf/python/delg
+mkdir parameters && cd parameters
+
+# DELG-GLD model.
+wget http://storage.googleapis.com/delf/delg_gld_20200520.tar.gz
+tar -xvzf delg_gld_20200520.tar.gz
+```
+
+### Feature extraction
+
+We present here commands for extraction on `roxford5k`. To extract on `rparis6k`
+instead, please edit the arguments accordingly (especially the
+`dataset_file_path` argument).
+
+#### Query feature extraction
+
+For query feature extraction, the cropped query image should be used to extract
+features, according to the Revisited Oxford/Paris experimental protocol. Note
+that this is done in the `extract_features` script, when setting
+`image_set=query`.
+
+Query feature extraction can be run as follows:
+
+```bash
+# From models/research/delf/delf/python/delg
+python3 extract_features.py \
+ --delf_config_path delg_gld_config.pbtxt \
+ --dataset_file_path ~/delg/data/gnd_roxford5k.mat \
+ --images_dir ~/delg/data/oxford5k_images \
+ --image_set query \
+ --output_features_dir ~/delg/data/oxford5k_features/query
+```
+
+#### Index feature extraction
+
+Run index feature extraction as follows:
+
+```bash
+# From models/research/delf/delf/python/delg
+python3 extract_features.py \
+ --delf_config_path delg_gld_config.pbtxt \
+ --dataset_file_path ~/delg/data/gnd_roxford5k.mat \
+ --images_dir ~/delg/data/oxford5k_images \
+ --image_set index \
+ --output_features_dir ~/delg/data/oxford5k_features/index
+```
+
+### Perform retrieval
+
+To run retrieval on `roxford5k`, the following command can be used:
+
+```bash
+# From models/research/delf/delf/python/delg
+python3 perform_retrieval.py \
+ --dataset_file_path ~/delg/data/gnd_roxford5k.mat \
+ --query_features_dir ~/delg/data/oxford5k_features/query \
+ --index_features_dir ~/delg/data/oxford5k_features/index \
+ --output_dir ~/delg/results/oxford5k
+```
+
+A file with named `metrics.txt` will be written to the path given in
+`output_dir`, with retrieval metrics for an experiment where geometric
+verification is not used. The contents should look approximately like:
+
+```
+hard
+ mAP=45.11
+ mP@k[ 1 5 10] [85.71 72.29 60.14]
+ mR@k[ 1 5 10] [19.15 29.72 36.32]
+medium
+ mAP=69.71
+ mP@k[ 1 5 10] [95.71 92. 86.86]
+ mR@k[ 1 5 10] [10.17 25.94 33.83]
+```
+
+which are the results presented in Table 3 of the paper.
+
+If you want to run retrieval with geometric verification, set
+`use_geometric_verification` to `True`. It's much slower since (1) in this code
+example the re-ranking is loading DELF local features from disk, and (2)
+re-ranking needs to be performed separately for each dataset protocol, since the
+junk images from each protocol should be removed when re-ranking. Here is an
+example command:
+
+```bash
+# From models/research/delf/delf/python/delg
+python3 perform_retrieval.py \
+ --dataset_file_path ~/delg/data/gnd_roxford5k.mat \
+ --query_features_dir ~/delg/data/oxford5k_features/query \
+ --index_features_dir ~/delg/data/oxford5k_features/index \
+ --use_geometric_verification \
+ --output_dir ~/delg/results/oxford5k_with_gv
+```
+
+The `metrics.txt` should now show:
+
+```
+hard
+ mAP=45.11
+ mP@k[ 1 5 10] [85.71 72.29 60.14]
+ mR@k[ 1 5 10] [19.15 29.72 36.32]
+hard_after_gv
+ mAP=53.72
+ mP@k[ 1 5 10] [91.43 83.81 74.38]
+ mR@k[ 1 5 10] [19.45 34.45 44.64]
+medium
+ mAP=69.71
+ mP@k[ 1 5 10] [95.71 92. 86.86]
+ mR@k[ 1 5 10] [10.17 25.94 33.83]
+medium_after_gv
+ mAP=75.42
+ mP@k[ 1 5 10] [97.14 95.24 93.81]
+ mR@k[ 1 5 10] [10.21 27.21 37.72]
+```
+
+which, again, are the results presented in Table 3 of the paper.
diff --git a/models/research/delf/delf/python/delg/delg_gld_config.pbtxt b/models/research/delf/delf/python/delg/delg_gld_config.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a659a0a3ee502c31f7d4b71fd634803f94d425b7
--- /dev/null
+++ b/models/research/delf/delf/python/delg/delg_gld_config.pbtxt
@@ -0,0 +1,22 @@
+use_local_features: true
+use_global_features: true
+model_path: "parameters/delg_gld_20200520"
+image_scales: 0.25
+image_scales: 0.35355338
+image_scales: 0.5
+image_scales: 0.70710677
+image_scales: 1.0
+image_scales: 1.4142135
+image_scales: 2.0
+delf_local_config {
+ use_pca: false
+ max_feature_num: 1000
+ score_threshold: 175.0
+}
+delf_global_config {
+ use_pca: false
+ image_scales_ind: 3
+ image_scales_ind: 4
+ image_scales_ind: 5
+}
+max_image_size: 1024
diff --git a/models/research/delf/delf/python/delg/extract_features.py b/models/research/delf/delf/python/delg/extract_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad65d66e69ddaa032d1201b34a2f10a04fe61eb5
--- /dev/null
+++ b/models/research/delf/delf/python/delg/extract_features.py
@@ -0,0 +1,162 @@
+# Copyright 2020 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extracts DELG features for images from Revisited Oxford/Paris datasets.
+
+Note that query images are cropped before feature extraction, as required by the
+evaluation protocols of these datasets.
+
+The types of extracted features (local and/or global) depend on the input
+DelfConfig.
+
+The program checks if features already exist, and skips computation for those.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+from absl import app
+from absl import flags
+import numpy as np
+import tensorflow as tf
+
+from google.protobuf import text_format
+from delf import delf_config_pb2
+from delf import datum_io
+from delf import feature_io
+from delf import utils
+from delf.python.detect_to_retrieve import dataset
+from delf import extractor
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string(
+ 'delf_config_path', '/tmp/delf_config_example.pbtxt',
+ 'Path to DelfConfig proto text file with configuration to be used for DELG '
+ 'extraction. Local features are extracted if use_local_features is True; '
+ 'global features are extracted if use_global_features is True.')
+flags.DEFINE_string(
+ 'dataset_file_path', '/tmp/gnd_roxford5k.mat',
+ 'Dataset file for Revisited Oxford or Paris dataset, in .mat format.')
+flags.DEFINE_string(
+ 'images_dir', '/tmp/images',
+ 'Directory where dataset images are located, all in .jpg format.')
+flags.DEFINE_enum('image_set', 'query', ['query', 'index'],
+ 'Whether to extract features from query or index images.')
+flags.DEFINE_string(
+ 'output_features_dir', '/tmp/features',
+ "Directory where DELG features will be written to. Each image's features "
+ 'will be written to files with same name but different extension: the '
+ 'global feature is written to a file with extension .delg_global and the '
+ 'local features are written to a file with extension .delg_local.')
+
+# Extensions.
+_DELG_GLOBAL_EXTENSION = '.delg_global'
+_DELG_LOCAL_EXTENSION = '.delg_local'
+_IMAGE_EXTENSION = '.jpg'
+
+# Pace to report extraction log.
+_STATUS_CHECK_ITERATIONS = 50
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Read list of images from dataset file.
+ print('Reading list of images from dataset file...')
+ query_list, index_list, ground_truth = dataset.ReadDatasetFile(
+ FLAGS.dataset_file_path)
+ if FLAGS.image_set == 'query':
+ image_list = query_list
+ else:
+ image_list = index_list
+ num_images = len(image_list)
+ print('done! Found %d images' % num_images)
+
+ # Parse DelfConfig proto.
+ config = delf_config_pb2.DelfConfig()
+ with tf.io.gfile.GFile(FLAGS.delf_config_path, 'r') as f:
+ text_format.Parse(f.read(), config)
+
+ # Create output directory if necessary.
+ if not tf.io.gfile.exists(FLAGS.output_features_dir):
+ tf.io.gfile.makedirs(FLAGS.output_features_dir)
+
+ extractor_fn = extractor.MakeExtractor(config)
+
+ start = time.time()
+ for i in range(num_images):
+ if i == 0:
+ print('Starting to extract features...')
+ elif i % _STATUS_CHECK_ITERATIONS == 0:
+ elapsed = (time.time() - start)
+ print('Processing image %d out of %d, last %d '
+ 'images took %f seconds' %
+ (i, num_images, _STATUS_CHECK_ITERATIONS, elapsed))
+ start = time.time()
+
+ image_name = image_list[i]
+ input_image_filename = os.path.join(FLAGS.images_dir,
+ image_name + _IMAGE_EXTENSION)
+
+ # Compose output file name and decide if image should be skipped.
+ should_skip_global = True
+ should_skip_local = True
+ if config.use_global_features:
+ output_global_feature_filename = os.path.join(
+ FLAGS.output_features_dir, image_name + _DELG_GLOBAL_EXTENSION)
+ if not tf.io.gfile.exists(output_global_feature_filename):
+ should_skip_global = False
+ if config.use_local_features:
+ output_local_feature_filename = os.path.join(
+ FLAGS.output_features_dir, image_name + _DELG_LOCAL_EXTENSION)
+ if not tf.io.gfile.exists(output_local_feature_filename):
+ should_skip_local = False
+ if should_skip_global and should_skip_local:
+ print('Skipping %s' % image_name)
+ continue
+
+ pil_im = utils.RgbLoader(input_image_filename)
+ resize_factor = 1.0
+ if FLAGS.image_set == 'query':
+ # Crop query image according to bounding box.
+ original_image_size = max(pil_im.size)
+ bbox = [int(round(b)) for b in ground_truth[i]['bbx']]
+ pil_im = pil_im.crop(bbox)
+ cropped_image_size = max(pil_im.size)
+ resize_factor = cropped_image_size / original_image_size
+
+ im = np.array(pil_im)
+
+ # Extract and save features.
+ extracted_features = extractor_fn(im, resize_factor)
+ if config.use_global_features:
+ global_descriptor = extracted_features['global_descriptor']
+ datum_io.WriteToFile(global_descriptor, output_global_feature_filename)
+ if config.use_local_features:
+ locations = extracted_features['local_features']['locations']
+ descriptors = extracted_features['local_features']['descriptors']
+ feature_scales = extracted_features['local_features']['scales']
+ attention = extracted_features['local_features']['attention']
+ feature_io.WriteToFile(output_local_feature_filename, locations,
+ feature_scales, descriptors, attention)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/models/research/delf/delf/python/delg/measure_latency.py b/models/research/delf/delf/python/delg/measure_latency.py
new file mode 100644
index 0000000000000000000000000000000000000000..21ffbda4179a191139ae35244c8ae34693594fd9
--- /dev/null
+++ b/models/research/delf/delf/python/delg/measure_latency.py
@@ -0,0 +1,108 @@
+# Copyright 2020 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Times DELF/G extraction."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl import app
+from absl import flags
+import numpy as np
+from six.moves import range
+import tensorflow as tf
+
+from google.protobuf import text_format
+from delf import delf_config_pb2
+from delf import utils
+from delf import extractor
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string(
+ 'delf_config_path', '/tmp/delf_config_example.pbtxt',
+ 'Path to DelfConfig proto text file with configuration to be used for DELG '
+ 'extraction. Local features are extracted if use_local_features is True; '
+ 'global features are extracted if use_global_features is True.')
+flags.DEFINE_string('list_images_path', '/tmp/list_images.txt',
+ 'Path to list of images whose features will be extracted.')
+flags.DEFINE_integer('repeat_per_image', 10,
+ 'Number of times to repeat extraction per image.')
+
+# Pace to report extraction log.
+_STATUS_CHECK_ITERATIONS = 100
+
+
+def _ReadImageList(list_path):
+ """Helper function to read image paths.
+
+ Args:
+ list_path: Path to list of images, one image path per line.
+
+ Returns:
+ image_paths: List of image paths.
+ """
+ with tf.io.gfile.GFile(list_path, 'r') as f:
+ image_paths = f.readlines()
+ image_paths = [entry.rstrip() for entry in image_paths]
+ return image_paths
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Read list of images.
+ print('Reading list of images...')
+ image_paths = _ReadImageList(FLAGS.list_images_path)
+ num_images = len(image_paths)
+ print(f'done! Found {num_images} images')
+
+ # Load images in memory.
+ print('Loading images, %d times per image...' % FLAGS.repeat_per_image)
+ im_array = []
+ for filename in image_paths:
+ im = np.array(utils.RgbLoader(filename))
+ for _ in range(FLAGS.repeat_per_image):
+ im_array.append(im)
+ np.random.shuffle(im_array)
+ print('done!')
+
+ # Parse DelfConfig proto.
+ config = delf_config_pb2.DelfConfig()
+ with tf.io.gfile.GFile(FLAGS.delf_config_path, 'r') as f:
+ text_format.Parse(f.read(), config)
+
+ extractor_fn = extractor.MakeExtractor(config)
+
+ start = time.time()
+ for i, im in enumerate(im_array):
+ if i == 0:
+ print('Starting to extract DELF features from images...')
+ elif i % _STATUS_CHECK_ITERATIONS == 0:
+ elapsed = (time.time() - start)
+ print(f'Processing image {i} out of {len(im_array)}, last '
+ f'{_STATUS_CHECK_ITERATIONS} images took {elapsed} seconds,'
+ f'ie {elapsed/_STATUS_CHECK_ITERATIONS} secs/image.')
+ start = time.time()
+
+ # Extract and save features.
+ extracted_features = extractor_fn(im)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/models/research/delf/delf/python/delg/perform_retrieval.py b/models/research/delf/delf/python/delg/perform_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb53abb1a9e15a5d5a040be42213f325ab345163
--- /dev/null
+++ b/models/research/delf/delf/python/delg/perform_retrieval.py
@@ -0,0 +1,215 @@
+# Copyright 2020 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Performs DELG-based image retrieval on Revisited Oxford/Paris datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+from absl import app
+from absl import flags
+import numpy as np
+import tensorflow as tf
+
+from delf import datum_io
+from delf.python.detect_to_retrieve import dataset
+from delf.python.detect_to_retrieve import image_reranking
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string(
+ 'dataset_file_path', '/tmp/gnd_roxford5k.mat',
+ 'Dataset file for Revisited Oxford or Paris dataset, in .mat format.')
+flags.DEFINE_string('query_features_dir', '/tmp/features/query',
+ 'Directory where query DELG features are located.')
+flags.DEFINE_string('index_features_dir', '/tmp/features/index',
+ 'Directory where index DELG features are located.')
+flags.DEFINE_boolean(
+ 'use_geometric_verification', False,
+ 'If True, performs re-ranking using local feature-based geometric '
+ 'verification.')
+flags.DEFINE_float(
+ 'local_feature_distance_threshold', 1.0,
+ 'Optional, only used if `use_geometric_verification` is True. '
+ 'Distance threshold below which a pair of local descriptors is considered '
+ 'a potential match, and will be fed into RANSAC.')
+flags.DEFINE_float(
+ 'ransac_residual_threshold', 20.0,
+ 'Optional, only used if `use_geometric_verification` is True. '
+ 'Residual error threshold for considering matches as inliers, used in '
+ 'RANSAC algorithm.')
+flags.DEFINE_string(
+ 'output_dir', '/tmp/retrieval',
+ 'Directory where retrieval output will be written to. A file containing '
+ "metrics for this run is saved therein, with file name 'metrics.txt'.")
+
+# Extensions.
+_DELG_GLOBAL_EXTENSION = '.delg_global'
+_DELG_LOCAL_EXTENSION = '.delg_local'
+
+# Precision-recall ranks to use in metric computation.
+_PR_RANKS = (1, 5, 10)
+
+# Pace to log.
+_STATUS_CHECK_LOAD_ITERATIONS = 50
+
+# Output file names.
+_METRICS_FILENAME = 'metrics.txt'
+
+
+def _ReadDelgGlobalDescriptors(input_dir, image_list):
+ """Reads DELG global features.
+
+ Args:
+ input_dir: Directory where features are located.
+ image_list: List of image names for which to load features.
+
+ Returns:
+ global_descriptors: NumPy array of shape (len(image_list), D), where D
+ corresponds to the global descriptor dimensionality.
+ """
+ num_images = len(image_list)
+ global_descriptors = []
+ print('Starting to collect global descriptors for %d images...' % num_images)
+ start = time.time()
+ for i in range(num_images):
+ if i > 0 and i % _STATUS_CHECK_LOAD_ITERATIONS == 0:
+ elapsed = (time.time() - start)
+ print('Reading global descriptors for image %d out of %d, last %d '
+ 'images took %f seconds' %
+ (i, num_images, _STATUS_CHECK_LOAD_ITERATIONS, elapsed))
+ start = time.time()
+
+ descriptor_filename = image_list[i] + _DELG_GLOBAL_EXTENSION
+ descriptor_fullpath = os.path.join(input_dir, descriptor_filename)
+ global_descriptors.append(datum_io.ReadFromFile(descriptor_fullpath))
+
+ return np.array(global_descriptors)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Parse dataset to obtain query/index images, and ground-truth.
+ print('Parsing dataset...')
+ query_list, index_list, ground_truth = dataset.ReadDatasetFile(
+ FLAGS.dataset_file_path)
+ num_query_images = len(query_list)
+ num_index_images = len(index_list)
+ (_, medium_ground_truth,
+ hard_ground_truth) = dataset.ParseEasyMediumHardGroundTruth(ground_truth)
+ print('done! Found %d queries and %d index images' %
+ (num_query_images, num_index_images))
+
+ # Read global features.
+ query_global_features = _ReadDelgGlobalDescriptors(FLAGS.query_features_dir,
+ query_list)
+ index_global_features = _ReadDelgGlobalDescriptors(FLAGS.index_features_dir,
+ index_list)
+
+ # Compute similarity between query and index images, potentially re-ranking
+ # with geometric verification.
+ ranks_before_gv = np.zeros([num_query_images, num_index_images],
+ dtype='int32')
+ if FLAGS.use_geometric_verification:
+ medium_ranks_after_gv = np.zeros([num_query_images, num_index_images],
+ dtype='int32')
+ hard_ranks_after_gv = np.zeros([num_query_images, num_index_images],
+ dtype='int32')
+ for i in range(num_query_images):
+ print('Performing retrieval with query %d (%s)...' % (i, query_list[i]))
+ start = time.time()
+
+ # Compute similarity between global descriptors.
+ similarities = np.dot(index_global_features, query_global_features[i])
+ ranks_before_gv[i] = np.argsort(-similarities)
+
+ # Re-rank using geometric verification.
+ if FLAGS.use_geometric_verification:
+ medium_ranks_after_gv[i] = image_reranking.RerankByGeometricVerification(
+ input_ranks=ranks_before_gv[i],
+ initial_scores=similarities,
+ query_name=query_list[i],
+ index_names=index_list,
+ query_features_dir=FLAGS.query_features_dir,
+ index_features_dir=FLAGS.index_features_dir,
+ junk_ids=set(medium_ground_truth[i]['junk']),
+ local_feature_extension=_DELG_LOCAL_EXTENSION,
+ ransac_seed=0,
+ feature_distance_threshold=FLAGS.local_feature_distance_threshold,
+ ransac_residual_threshold=FLAGS.ransac_residual_threshold)
+ hard_ranks_after_gv[i] = image_reranking.RerankByGeometricVerification(
+ input_ranks=ranks_before_gv[i],
+ initial_scores=similarities,
+ query_name=query_list[i],
+ index_names=index_list,
+ query_features_dir=FLAGS.query_features_dir,
+ index_features_dir=FLAGS.index_features_dir,
+ junk_ids=set(hard_ground_truth[i]['junk']),
+ local_feature_extension=_DELG_LOCAL_EXTENSION,
+ ransac_seed=0,
+ feature_distance_threshold=FLAGS.local_feature_distance_threshold,
+ ransac_residual_threshold=FLAGS.ransac_residual_threshold)
+
+ elapsed = (time.time() - start)
+ print('done! Retrieval for query %d took %f seconds' % (i, elapsed))
+
+ # Create output directory if necessary.
+ if not tf.io.gfile.exists(FLAGS.output_dir):
+ tf.io.gfile.makedirs(FLAGS.output_dir)
+
+ # Compute metrics.
+ medium_metrics = dataset.ComputeMetrics(ranks_before_gv, medium_ground_truth,
+ _PR_RANKS)
+ hard_metrics = dataset.ComputeMetrics(ranks_before_gv, hard_ground_truth,
+ _PR_RANKS)
+ if FLAGS.use_geometric_verification:
+ medium_metrics_after_gv = dataset.ComputeMetrics(medium_ranks_after_gv,
+ medium_ground_truth,
+ _PR_RANKS)
+ hard_metrics_after_gv = dataset.ComputeMetrics(hard_ranks_after_gv,
+ hard_ground_truth, _PR_RANKS)
+
+ # Write metrics to file.
+ mean_average_precision_dict = {
+ 'medium': medium_metrics[0],
+ 'hard': hard_metrics[0]
+ }
+ mean_precisions_dict = {'medium': medium_metrics[1], 'hard': hard_metrics[1]}
+ mean_recalls_dict = {'medium': medium_metrics[2], 'hard': hard_metrics[2]}
+ if FLAGS.use_geometric_verification:
+ mean_average_precision_dict.update({
+ 'medium_after_gv': medium_metrics_after_gv[0],
+ 'hard_after_gv': hard_metrics_after_gv[0]
+ })
+ mean_precisions_dict.update({
+ 'medium_after_gv': medium_metrics_after_gv[1],
+ 'hard_after_gv': hard_metrics_after_gv[1]
+ })
+ mean_recalls_dict.update({
+ 'medium_after_gv': medium_metrics_after_gv[2],
+ 'hard_after_gv': hard_metrics_after_gv[2]
+ })
+ dataset.SaveMetricsFile(mean_average_precision_dict, mean_precisions_dict,
+ mean_recalls_dict, _PR_RANKS,
+ os.path.join(FLAGS.output_dir, _METRICS_FILENAME))
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/models/research/delf/delf/python/detect_to_retrieve/DETECT_TO_RETRIEVE_INSTRUCTIONS.md b/models/research/delf/delf/python/detect_to_retrieve/DETECT_TO_RETRIEVE_INSTRUCTIONS.md
new file mode 100644
index 0000000000000000000000000000000000000000..2d18a328997ace5ee01f60a4b2c95714a18eb7d9
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/DETECT_TO_RETRIEVE_INSTRUCTIONS.md
@@ -0,0 +1,231 @@
+## Detect-to-Retrieve instructions
+
+[](https://arxiv.org/abs/1812.01584)
+
+These instructions can be used to reproduce the results from the
+[Detect-to-Retrieve paper](https://arxiv.org/abs/1812.01584) for the Revisited
+Oxford/Paris datasets.
+
+### Install DELF library
+
+To be able to use this code, please follow
+[these instructions](../../../INSTALL_INSTRUCTIONS.md) to properly install the
+DELF library.
+
+### Download datasets
+
+```bash
+mkdir -p ~/detect_to_retrieve/data && cd ~/detect_to_retrieve/data
+
+# Oxford dataset.
+wget http://www.robots.ox.ac.uk/~vgg/data/oxbuildings/oxbuild_images.tgz
+mkdir oxford5k_images
+tar -xvzf oxbuild_images.tgz -C oxford5k_images/
+
+# Paris dataset. Download and move all images to same directory.
+wget http://www.robots.ox.ac.uk/~vgg/data/parisbuildings/paris_1.tgz
+wget http://www.robots.ox.ac.uk/~vgg/data/parisbuildings/paris_2.tgz
+mkdir paris6k_images_tmp
+tar -xvzf paris_1.tgz -C paris6k_images_tmp/
+tar -xvzf paris_2.tgz -C paris6k_images_tmp/
+mkdir paris6k_images
+mv paris6k_images_tmp/paris/*/*.jpg paris6k_images/
+
+# Revisited annotations.
+wget http://cmp.felk.cvut.cz/revisitop/data/datasets/roxford5k/gnd_roxford5k.mat
+wget http://cmp.felk.cvut.cz/revisitop/data/datasets/rparis6k/gnd_rparis6k.mat
+```
+
+### Download models
+
+These are necessary to reproduce the main paper results:
+
+```bash
+# From models/research/delf/delf/python/detect_to_retrieve
+mkdir parameters && cd parameters
+
+# DELF-GLD model.
+wget http://storage.googleapis.com/delf/delf_gld_20190411.tar.gz
+tar -xvzf delf_gld_20190411.tar.gz
+
+# Faster-RCNN detector model.
+wget http://storage.googleapis.com/delf/d2r_frcnn_20190411.tar.gz
+tar -xvzf d2r_frcnn_20190411.tar.gz
+
+# Codebooks.
+# Note: you should use codebook trained on rparis6k for roxford5k retrieval
+# experiments, and vice-versa.
+wget http://storage.googleapis.com/delf/rparis6k_codebook_65536.tar.gz
+mkdir rparis6k_codebook_65536
+tar -xvzf rparis6k_codebook_65536.tar.gz -C rparis6k_codebook_65536/
+wget http://storage.googleapis.com/delf/roxford5k_codebook_65536.tar.gz
+mkdir roxford5k_codebook_65536
+tar -xvzf roxford5k_codebook_65536.tar.gz -C roxford5k_codebook_65536/
+```
+
+We also make available other models/parameters that can be used to reproduce
+more results from the paper:
+
+- [MobileNet-SSD trained detector](http://storage.googleapis.com/delf/d2r_mnetssd_20190411.tar.gz).
+- Codebooks with 1024 centroids:
+ [rparis6k](http://storage.googleapis.com/delf/rparis6k_codebook_1024.tar.gz),
+ [roxford5k](http://storage.googleapis.com/delf/roxford5k_codebook_1024.tar.gz).
+
+### Feature extraction
+
+We present here commands for extraction on `roxford5k`. To extract on `rparis6k`
+instead, please edit the arguments accordingly (especially the
+`dataset_file_path` argument).
+
+#### Query feature extraction
+
+For query feature extraction, the cropped query image should be used to extract
+features, according to the Revisited Oxford/Paris experimental protocol. Note
+that this is done in the `extract_query_features` script.
+
+Query feature extraction can be run as follows:
+
+```bash
+# From models/research/delf/delf/python/detect_to_retrieve
+python3 extract_query_features.py \
+ --delf_config_path delf_gld_config.pbtxt \
+ --dataset_file_path ~/detect_to_retrieve/data/gnd_roxford5k.mat \
+ --images_dir ~/detect_to_retrieve/data/oxford5k_images \
+ --output_features_dir ~/detect_to_retrieve/data/oxford5k_features/query
+```
+
+#### Index feature extraction and box detection
+
+Index feature extraction / box detection can be run as follows:
+
+```bash
+# From models/research/delf/delf/python/detect_to_retrieve
+python3 extract_index_boxes_and_features.py \
+ --delf_config_path delf_gld_config.pbtxt \
+ --detector_model_dir parameters/d2r_frcnn_20190411 \
+ --detector_thresh 0.1 \
+ --dataset_file_path ~/detect_to_retrieve/data/gnd_roxford5k.mat \
+ --images_dir ~/detect_to_retrieve/data/oxford5k_images \
+ --output_boxes_dir ~/detect_to_retrieve/data/oxford5k_boxes/index \
+ --output_features_dir ~/detect_to_retrieve/data/oxford5k_features/index_0.1 \
+ --output_index_mapping ~/detect_to_retrieve/data/oxford5k_features/index_mapping_0.1.csv
+```
+
+### R-ASMK* aggregation extraction
+
+We present here commands for aggregation extraction on `roxford5k`. To extract
+on `rparis6k` instead, please edit the arguments accordingly. In particular,
+note that feature aggregation on `roxford5k` should use a codebook trained on
+`rparis6k`, and vice-versa (this can be edited in the
+`query_aggregation_config.pbtxt` and `index_aggregation_config.pbtxt` files.
+
+#### Query
+
+Run query feature aggregation as follows:
+
+```bash
+# From models/research/delf/delf/python/detect_to_retrieve
+python3 extract_aggregation.py \
+ --use_query_images True \
+ --aggregation_config_path query_aggregation_config.pbtxt \
+ --dataset_file_path ~/detect_to_retrieve/data/gnd_roxford5k.mat \
+ --features_dir ~/detect_to_retrieve/data/oxford5k_features/query \
+ --output_aggregation_dir ~/detect_to_retrieve/data/oxford5k_aggregation/query
+```
+
+#### Index
+
+Run index feature aggregation as follows:
+
+```bash
+# From models/research/delf/delf/python/detect_to_retrieve
+python3 extract_aggregation.py \
+ --aggregation_config_path index_aggregation_config.pbtxt \
+ --dataset_file_path ~/detect_to_retrieve/data/gnd_roxford5k.mat \
+ --features_dir ~/detect_to_retrieve/data/oxford5k_features/index_0.1 \
+ --index_mapping_path ~/detect_to_retrieve/data/oxford5k_features/index_mapping_0.1.csv \
+ --output_aggregation_dir ~/detect_to_retrieve/data/oxford5k_aggregation/index_0.1
+```
+
+### Perform retrieval
+
+Currently, we support retrieval via brute-force comparison of aggregated
+features.
+
+To run retrieval on `roxford5k`, the following command can be used:
+
+```bash
+# From models/research/delf/delf/python/detect_to_retrieve
+python3 perform_retrieval.py \
+ --index_aggregation_config_path index_aggregation_config.pbtxt \
+ --query_aggregation_config_path query_aggregation_config.pbtxt \
+ --dataset_file_path ~/detect_to_retrieve/data/gnd_roxford5k.mat \
+ --index_aggregation_dir ~/detect_to_retrieve/data/oxford5k_aggregation/index_0.1 \
+ --query_aggregation_dir ~/detect_to_retrieve/data/oxford5k_aggregation/query \
+ --output_dir ~/detect_to_retrieve/results/oxford5k
+```
+
+A file with named `metrics.txt` will be written to the path given in
+`output_dir`, with retrieval metrics for an experiment where geometric
+verification is not used. The contents should look approximately like:
+
+```
+hard
+mAP=47.61
+mP@k[ 1 5 10] [84.29 73.71 64.43]
+mR@k[ 1 5 10] [18.84 29.44 36.82]
+medium
+mAP=73.3
+mP@k[ 1 5 10] [97.14 94.57 90.14]
+mR@k[ 1 5 10] [10.14 26.2 34.75]
+```
+
+which are the results presented in Table 2 of the paper (with small numerical
+precision differences).
+
+If you want to run retrieval with geometric verification, set
+`use_geometric_verification` to `True` and the arguments
+`index_features_dir`/`query_features_dir`. It's much slower since (1) in this
+code example the re-ranking is loading DELF local features from disk, and (2)
+re-ranking needs to be performed separately for each dataset protocol, since the
+junk images from each protocol should be removed when re-ranking. Here is an
+example command:
+
+```bash
+# From models/research/delf/delf/python/detect_to_retrieve
+python3 perform_retrieval.py \
+ --index_aggregation_config_path index_aggregation_config.pbtxt \
+ --query_aggregation_config_path query_aggregation_config.pbtxt \
+ --dataset_file_path ~/detect_to_retrieve/data/gnd_roxford5k.mat \
+ --index_aggregation_dir ~/detect_to_retrieve/data/oxford5k_aggregation/index_0.1 \
+ --query_aggregation_dir ~/detect_to_retrieve/data/oxford5k_aggregation/query \
+ --use_geometric_verification True \
+ --index_features_dir ~/detect_to_retrieve/data/oxford5k_features/index_0.1 \
+ --query_features_dir ~/detect_to_retrieve/data/oxford5k_features/query \
+ --output_dir ~/detect_to_retrieve/results/oxford5k_with_gv
+```
+
+### Clustering
+
+In the code example above, we used a pre-trained DELF codebook. We also provide
+code for re-training the codebook if desired.
+
+Note that for the time being this can only run on CPU, since the main ops in
+K-means are not registered for GPU usage in Tensorflow.
+
+```bash
+# From models/research/delf/delf/python/detect_to_retrieve
+python3 cluster_delf_features.py \
+ --dataset_file_path ~/detect_to_retrieve/data/gnd_rparis6k.mat \
+ --features_dir ~/detect_to_retrieve/data/paris6k_features/index_0.1 \
+ --num_clusters 1024 \
+ --num_iterations 50 \
+ --output_cluster_dir ~/detect_to_retrieve/data/paris6k_clusters_1024
+```
+
+### Next steps
+
+To make retrieval more scalable and handle larger datasets more smoothly, we are
+considering to provide code for inverted index building and retrieval. Please
+reach out if you would like to help doing that -- feel free submit a pull
+request.
diff --git a/models/research/delf/delf/python/detect_to_retrieve/__init__.py b/models/research/delf/delf/python/detect_to_retrieve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..06972a7d06738da1dc50e832c4e8443b0e6fb5b6
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module for Detect-to-Retrieve technique."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from delf.python.detect_to_retrieve import aggregation_extraction
+from delf.python.detect_to_retrieve import boxes_and_features_extraction
+from delf.python.detect_to_retrieve import dataset
+# pylint: enable=unused-import
diff --git a/models/research/delf/delf/python/detect_to_retrieve/aggregation_extraction.py b/models/research/delf/delf/python/detect_to_retrieve/aggregation_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ddab944b8a3365209b8e92af38241d297974122
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/aggregation_extraction.py
@@ -0,0 +1,193 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library to extract/save feature aggregation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+import os
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from google.protobuf import text_format
+from delf import aggregation_config_pb2
+from delf import datum_io
+from delf import feature_aggregation_extractor
+from delf import feature_io
+
+# Aliases for aggregation types.
+_VLAD = aggregation_config_pb2.AggregationConfig.VLAD
+_ASMK = aggregation_config_pb2.AggregationConfig.ASMK
+_ASMK_STAR = aggregation_config_pb2.AggregationConfig.ASMK_STAR
+
+# Extensions.
+_DELF_EXTENSION = '.delf'
+_VLAD_EXTENSION_SUFFIX = 'vlad'
+_ASMK_EXTENSION_SUFFIX = 'asmk'
+_ASMK_STAR_EXTENSION_SUFFIX = 'asmk_star'
+
+# Pace to report extraction log.
+_STATUS_CHECK_ITERATIONS = 50
+
+
+def _ReadMappingBasenameToBoxNames(input_path, index_image_names):
+ """Reads mapping from image name to DELF file names for each box.
+
+ Args:
+ input_path: Path to CSV file containing mapping.
+ index_image_names: List containing index image names, in order, for the
+ dataset under consideration.
+
+ Returns:
+ images_to_box_feature_files: Dict. key=string (image name); value=list of
+ strings (file names containing DELF features for boxes).
+ """
+ images_to_box_feature_files = {}
+ with tf.io.gfile.GFile(input_path, 'r') as f:
+ reader = csv.DictReader(f)
+ for row in reader:
+ index_image_name = index_image_names[int(row['index_image_id'])]
+ if index_image_name not in images_to_box_feature_files:
+ images_to_box_feature_files[index_image_name] = []
+
+ images_to_box_feature_files[index_image_name].append(row['name'])
+
+ return images_to_box_feature_files
+
+
+def ExtractAggregatedRepresentationsToFiles(image_names, features_dir,
+ aggregation_config_path,
+ mapping_path,
+ output_aggregation_dir):
+ """Extracts aggregated feature representations, saving them to files.
+
+ It checks if the aggregated representation for an image already exists,
+ and skips computation for those.
+
+ Args:
+ image_names: List of image names. These are used to compose input file names
+ for the feature files, and the output file names for aggregated
+ representations.
+ features_dir: Directory where DELF features are located.
+ aggregation_config_path: Path to AggregationConfig proto text file with
+ configuration to be used for extraction.
+ mapping_path: Optional CSV file which maps each .delf file name to the index
+ image ID and detected box ID. If regional aggregation is performed, this
+ should be set. Otherwise, this is ignored.
+ output_aggregation_dir: Directory where aggregation output will be written
+ to.
+
+ Raises:
+ ValueError: If AggregationConfig is malformed, or `mapping_path` is
+ missing.
+ """
+ num_images = len(image_names)
+
+ # Parse AggregationConfig proto, and select output extension.
+ config = aggregation_config_pb2.AggregationConfig()
+ with tf.io.gfile.GFile(aggregation_config_path, 'r') as f:
+ text_format.Merge(f.read(), config)
+ output_extension = '.'
+ if config.use_regional_aggregation:
+ output_extension += 'r'
+ if config.aggregation_type == _VLAD:
+ output_extension += _VLAD_EXTENSION_SUFFIX
+ elif config.aggregation_type == _ASMK:
+ output_extension += _ASMK_EXTENSION_SUFFIX
+ elif config.aggregation_type == _ASMK_STAR:
+ output_extension += _ASMK_STAR_EXTENSION_SUFFIX
+ else:
+ raise ValueError('Invalid aggregation type: %d' % config.aggregation_type)
+
+ # Read index mapping path, if provided.
+ if mapping_path:
+ images_to_box_feature_files = _ReadMappingBasenameToBoxNames(
+ mapping_path, image_names)
+
+ # Create output directory if necessary.
+ if not tf.io.gfile.exists(output_aggregation_dir):
+ tf.io.gfile.makedirs(output_aggregation_dir)
+
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+
+ start = time.time()
+ for i in range(num_images):
+ if i == 0:
+ print('Starting to extract aggregation from images...')
+ elif i % _STATUS_CHECK_ITERATIONS == 0:
+ elapsed = (time.time() - start)
+ print('Processing image %d out of %d, last %d '
+ 'images took %f seconds' %
+ (i, num_images, _STATUS_CHECK_ITERATIONS, elapsed))
+ start = time.time()
+
+ image_name = image_names[i]
+
+ # Compose output file name, skip extraction for this image if it already
+ # exists.
+ output_aggregation_filename = os.path.join(output_aggregation_dir,
+ image_name + output_extension)
+ if tf.io.gfile.exists(output_aggregation_filename):
+ print('Skipping %s' % image_name)
+ continue
+
+ # Load DELF features.
+ if config.use_regional_aggregation:
+ if not mapping_path:
+ raise ValueError(
+ 'Requested regional aggregation, but mapping_path was not '
+ 'provided')
+ descriptors_list = []
+ num_features_per_box = []
+ for box_feature_file in images_to_box_feature_files[image_name]:
+ delf_filename = os.path.join(features_dir,
+ box_feature_file + _DELF_EXTENSION)
+ _, _, box_descriptors, _, _ = feature_io.ReadFromFile(delf_filename)
+ # If `box_descriptors` is empty, reshape it such that it can be
+ # concatenated with other descriptors.
+ if not box_descriptors.shape[0]:
+ box_descriptors = np.reshape(box_descriptors,
+ [0, config.feature_dimensionality])
+ descriptors_list.append(box_descriptors)
+ num_features_per_box.append(box_descriptors.shape[0])
+
+ descriptors = np.concatenate(descriptors_list)
+ else:
+ input_delf_filename = os.path.join(features_dir,
+ image_name + _DELF_EXTENSION)
+ _, _, descriptors, _, _ = feature_io.ReadFromFile(input_delf_filename)
+ # If `descriptors` is empty, reshape it to avoid extraction failure.
+ if not descriptors.shape[0]:
+ descriptors = np.reshape(descriptors,
+ [0, config.feature_dimensionality])
+ num_features_per_box = None
+
+ # Extract and save aggregation. If using VLAD, only
+ # `aggregated_descriptors` needs to be saved.
+ (aggregated_descriptors,
+ feature_visual_words) = extractor.Extract(descriptors,
+ num_features_per_box)
+ if config.aggregation_type == _VLAD:
+ datum_io.WriteToFile(aggregated_descriptors,
+ output_aggregation_filename)
+ else:
+ datum_io.WritePairToFile(aggregated_descriptors,
+ feature_visual_words.astype('uint32'),
+ output_aggregation_filename)
diff --git a/models/research/delf/delf/python/detect_to_retrieve/boxes_and_features_extraction.py b/models/research/delf/delf/python/detect_to_retrieve/boxes_and_features_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..1faef983b2e0413e2f2746c5d56b5e62045e5a39
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/boxes_and_features_extraction.py
@@ -0,0 +1,202 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library to extract/save boxes and DELF features."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+import math
+import os
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from google.protobuf import text_format
+from delf import delf_config_pb2
+from delf import box_io
+from delf import feature_io
+from delf import utils
+from delf import detector
+from delf import extractor
+
+# Extension of feature files.
+_BOX_EXTENSION = '.boxes'
+_DELF_EXTENSION = '.delf'
+
+# Pace to report extraction log.
+_STATUS_CHECK_ITERATIONS = 100
+
+
+def _WriteMappingBasenameToIds(index_names_ids_and_boxes, output_path):
+ """Helper function to write CSV mapping from DELF file name to IDs.
+
+ Args:
+ index_names_ids_and_boxes: List containing 3-element lists with name, image
+ ID and box ID.
+ output_path: Output CSV path.
+ """
+ with tf.io.gfile.GFile(output_path, 'w') as f:
+ csv_writer = csv.DictWriter(
+ f, fieldnames=['name', 'index_image_id', 'box_id'])
+ csv_writer.writeheader()
+ for name_imid_boxid in index_names_ids_and_boxes:
+ csv_writer.writerow({
+ 'name': name_imid_boxid[0],
+ 'index_image_id': name_imid_boxid[1],
+ 'box_id': name_imid_boxid[2],
+ })
+
+
+def ExtractBoxesAndFeaturesToFiles(image_names, image_paths, delf_config_path,
+ detector_model_dir, detector_thresh,
+ output_features_dir, output_boxes_dir,
+ output_mapping):
+ """Extracts boxes and features, saving them to files.
+
+ Boxes are saved to .boxes files. DELF features are extracted for
+ the entire image and saved into .delf files. In addition, DELF
+ features are extracted for each high-confidence bounding box in the image, and
+ saved into files named _0.delf, _1.delf, etc.
+
+ It checks if descriptors/boxes already exist, and skips computation for those.
+
+ Args:
+ image_names: List of image names. These are used to compose output file
+ names for boxes and features.
+ image_paths: List of image paths. image_paths[i] is the path for the image
+ named by image_names[i]. `image_names` and `image_paths` must have the
+ same number of elements.
+ delf_config_path: Path to DelfConfig proto text file.
+ detector_model_dir: Directory where detector SavedModel is located.
+ detector_thresh: Threshold used to decide if an image's detected box
+ undergoes feature extraction.
+ output_features_dir: Directory where DELF features will be written to.
+ output_boxes_dir: Directory where detected boxes will be written to.
+ output_mapping: CSV file which maps each .delf file name to the image ID and
+ detected box ID.
+
+ Raises:
+ ValueError: If len(image_names) and len(image_paths) are different.
+ """
+ num_images = len(image_names)
+ if len(image_paths) != num_images:
+ raise ValueError(
+ 'image_names and image_paths have different number of items')
+
+ # Parse DelfConfig proto.
+ config = delf_config_pb2.DelfConfig()
+ with tf.io.gfile.GFile(delf_config_path, 'r') as f:
+ text_format.Merge(f.read(), config)
+
+ # Create output directories if necessary.
+ if not tf.io.gfile.exists(output_features_dir):
+ tf.io.gfile.makedirs(output_features_dir)
+ if not tf.io.gfile.exists(output_boxes_dir):
+ tf.io.gfile.makedirs(output_boxes_dir)
+ if not tf.io.gfile.exists(os.path.dirname(output_mapping)):
+ tf.io.gfile.makedirs(os.path.dirname(output_mapping))
+
+ names_ids_and_boxes = []
+ detector_fn = detector.MakeDetector(detector_model_dir)
+ delf_extractor_fn = extractor.MakeExtractor(config)
+
+ start = time.time()
+ for i in range(num_images):
+ if i == 0:
+ print('Starting to extract features/boxes...')
+ elif i % _STATUS_CHECK_ITERATIONS == 0:
+ elapsed = (time.time() - start)
+ print('Processing image %d out of %d, last %d '
+ 'images took %f seconds' %
+ (i, num_images, _STATUS_CHECK_ITERATIONS, elapsed))
+ start = time.time()
+
+ image_name = image_names[i]
+ output_feature_filename_whole_image = os.path.join(
+ output_features_dir, image_name + _DELF_EXTENSION)
+ output_box_filename = os.path.join(output_boxes_dir,
+ image_name + _BOX_EXTENSION)
+
+ pil_im = utils.RgbLoader(image_paths[i])
+ width, height = pil_im.size
+
+ # Extract and save boxes.
+ if tf.io.gfile.exists(output_box_filename):
+ print('Skipping box computation for %s' % image_name)
+ (boxes_out, scores_out,
+ class_indices_out) = box_io.ReadFromFile(output_box_filename)
+ else:
+ (boxes_out, scores_out,
+ class_indices_out) = detector_fn(np.expand_dims(pil_im, 0))
+ # Using only one image per batch.
+ boxes_out = boxes_out[0]
+ scores_out = scores_out[0]
+ class_indices_out = class_indices_out[0]
+ box_io.WriteToFile(output_box_filename, boxes_out, scores_out,
+ class_indices_out)
+
+ # Select boxes with scores greater than threshold. Those will be the
+ # ones with extracted DELF features (besides the whole image, whose DELF
+ # features are extracted in all cases).
+ num_delf_files = 1
+ selected_boxes = []
+ for box_ind, box in enumerate(boxes_out):
+ if scores_out[box_ind] >= detector_thresh:
+ selected_boxes.append(box)
+ num_delf_files += len(selected_boxes)
+
+ # Extract and save DELF features.
+ for delf_file_ind in range(num_delf_files):
+ if delf_file_ind == 0:
+ box_name = image_name
+ output_feature_filename = output_feature_filename_whole_image
+ else:
+ box_name = image_name + '_' + str(delf_file_ind - 1)
+ output_feature_filename = os.path.join(output_features_dir,
+ box_name + _DELF_EXTENSION)
+
+ names_ids_and_boxes.append([box_name, i, delf_file_ind - 1])
+
+ if tf.io.gfile.exists(output_feature_filename):
+ print('Skipping DELF computation for %s' % box_name)
+ continue
+
+ if delf_file_ind >= 1:
+ bbox_for_cropping = selected_boxes[delf_file_ind - 1]
+ bbox_for_cropping_pil_convention = [
+ int(math.floor(bbox_for_cropping[1] * width)),
+ int(math.floor(bbox_for_cropping[0] * height)),
+ int(math.ceil(bbox_for_cropping[3] * width)),
+ int(math.ceil(bbox_for_cropping[2] * height))
+ ]
+ pil_cropped_im = pil_im.crop(bbox_for_cropping_pil_convention)
+ im = np.array(pil_cropped_im)
+ else:
+ im = np.array(pil_im)
+
+ extracted_features = delf_extractor_fn(im)
+ locations_out = extracted_features['local_features']['locations']
+ descriptors_out = extracted_features['local_features']['descriptors']
+ feature_scales_out = extracted_features['local_features']['scales']
+ attention_out = extracted_features['local_features']['attention']
+
+ feature_io.WriteToFile(output_feature_filename, locations_out,
+ feature_scales_out, descriptors_out, attention_out)
+
+ # Save mapping from output DELF name to image id and box id.
+ _WriteMappingBasenameToIds(names_ids_and_boxes, output_mapping)
diff --git a/models/research/delf/delf/python/detect_to_retrieve/cluster_delf_features.py b/models/research/delf/delf/python/detect_to_retrieve/cluster_delf_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ddda8e4d0cae7950e76383950aab976249f3461
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/cluster_delf_features.py
@@ -0,0 +1,213 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Clusters DELF features using the K-means algorithm.
+
+All DELF local feature descriptors for a given dataset's index images are loaded
+as the input.
+
+Note that:
+- we only use features extracted from whole images (no features from boxes are
+ used).
+- the codebook should be trained on Paris images for Oxford retrieval
+ experiments, and vice-versa.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.platform import app
+from delf import feature_io
+from delf.python.detect_to_retrieve import dataset
+
+cmd_args = None
+
+# Extensions.
+_DELF_EXTENSION = '.delf'
+
+# Default DELF dimensionality.
+_DELF_DIM = 128
+
+# Pace to report log when collecting features.
+_STATUS_CHECK_ITERATIONS = 100
+
+
+class _IteratorInitHook(tf.estimator.SessionRunHook):
+ """Hook to initialize data iterator after session is created."""
+
+ def __init__(self):
+ super(_IteratorInitHook, self).__init__()
+ self.iterator_initializer_fn = None
+
+ def after_create_session(self, session, coord):
+ """Initialize the iterator after the session has been created."""
+ del coord
+ self.iterator_initializer_fn(session)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Process output directory.
+ if tf.io.gfile.exists(cmd_args.output_cluster_dir):
+ raise RuntimeError(
+ 'output_cluster_dir = %s already exists. This may indicate that a '
+ 'previous run already wrote checkpoints in this directory, which would '
+ 'lead to incorrect training. Please re-run this script by specifying an'
+ ' inexisting directory.' % cmd_args.output_cluster_dir)
+ else:
+ tf.io.gfile.makedirs(cmd_args.output_cluster_dir)
+
+ # Read list of index images from dataset file.
+ print('Reading list of index images from dataset file...')
+ _, index_list, _ = dataset.ReadDatasetFile(cmd_args.dataset_file_path)
+ num_images = len(index_list)
+ print('done! Found %d images' % num_images)
+
+ # Loop over list of index images and collect DELF features.
+ features_for_clustering = []
+ start = time.clock()
+ print('Starting to collect features from index images...')
+ for i in range(num_images):
+ if i > 0 and i % _STATUS_CHECK_ITERATIONS == 0:
+ elapsed = (time.clock() - start)
+ print('Processing index image %d out of %d, last %d '
+ 'images took %f seconds' %
+ (i, num_images, _STATUS_CHECK_ITERATIONS, elapsed))
+ start = time.clock()
+
+ features_filename = index_list[i] + _DELF_EXTENSION
+ features_fullpath = os.path.join(cmd_args.features_dir, features_filename)
+ _, _, features, _, _ = feature_io.ReadFromFile(features_fullpath)
+ if features.size != 0:
+ assert features.shape[1] == _DELF_DIM
+ for feature in features:
+ features_for_clustering.append(feature)
+
+ features_for_clustering = np.array(features_for_clustering, dtype=np.float32)
+ print('All features were loaded! There are %d features, each with %d '
+ 'dimensions' %
+ (features_for_clustering.shape[0], features_for_clustering.shape[1]))
+
+ # Run K-means clustering.
+ def _get_input_fn():
+ """Helper function to create input function and hook for training.
+
+ Returns:
+ input_fn: Input function for k-means Estimator training.
+ init_hook: Hook used to load data during training.
+ """
+ init_hook = _IteratorInitHook()
+
+ def _input_fn():
+ """Produces tf.data.Dataset object for k-means training.
+
+ Returns:
+ Tensor with the data for training.
+ """
+ features_placeholder = tf.compat.v1.placeholder(
+ tf.float32, features_for_clustering.shape)
+ delf_dataset = tf.data.Dataset.from_tensor_slices((features_placeholder))
+ delf_dataset = delf_dataset.shuffle(1000).batch(
+ features_for_clustering.shape[0])
+ iterator = delf_dataset.make_initializable_iterator()
+
+ def _initializer_fn(sess):
+ """Initialize dataset iterator, feed in the data."""
+ sess.run(
+ iterator.initializer,
+ feed_dict={features_placeholder: features_for_clustering})
+
+ init_hook.iterator_initializer_fn = _initializer_fn
+ return iterator.get_next()
+
+ return _input_fn, init_hook
+
+ input_fn, init_hook = _get_input_fn()
+
+ kmeans = tf.compat.v1.estimator.experimental.KMeans(
+ num_clusters=cmd_args.num_clusters,
+ model_dir=cmd_args.output_cluster_dir,
+ use_mini_batch=False,
+ )
+
+ print('Starting K-means clustering...')
+ start = time.clock()
+ for i in range(cmd_args.num_iterations):
+ kmeans.train(input_fn, hooks=[init_hook])
+ average_sum_squared_error = kmeans.evaluate(
+ input_fn, hooks=[init_hook])['score'] / features_for_clustering.shape[0]
+ elapsed = (time.clock() - start)
+ print('K-means iteration %d (out of %d) took %f seconds, '
+ 'average-sum-of-squares: %f' %
+ (i, cmd_args.num_iterations, elapsed, average_sum_squared_error))
+ start = time.clock()
+
+ print('K-means clustering finished!')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--dataset_file_path',
+ type=str,
+ default='/tmp/gnd_roxford5k.mat',
+ help="""
+ Dataset file for Revisited Oxford or Paris dataset, in .mat format. The
+ list of index images loaded from this file is used to collect local
+ features, which are assumed to be in .delf file format.
+ """)
+ parser.add_argument(
+ '--features_dir',
+ type=str,
+ default='/tmp/features',
+ help="""
+ Directory where DELF feature files are to be found.
+ """)
+ parser.add_argument(
+ '--num_clusters',
+ type=int,
+ default=1024,
+ help="""
+ Number of clusters to use.
+ """)
+ parser.add_argument(
+ '--num_iterations',
+ type=int,
+ default=50,
+ help="""
+ Number of iterations to use.
+ """)
+ parser.add_argument(
+ '--output_cluster_dir',
+ type=str,
+ default='/tmp/cluster',
+ help="""
+ Directory where clustering outputs are written to. This directory should
+ not exist before running this script; it will be created during
+ clustering.
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/detect_to_retrieve/dataset.py b/models/research/delf/delf/python/detect_to_retrieve/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a1e6b247895aa7bd8022d3a2fb87b878bbb3b38
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/dataset.py
@@ -0,0 +1,469 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python library to parse ground-truth/evaluate on Revisited datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy.io import matlab
+import tensorflow as tf
+
+_GROUND_TRUTH_KEYS = ['easy', 'hard', 'junk']
+
+
+def ReadDatasetFile(dataset_file_path):
+ """Reads dataset file in Revisited Oxford/Paris ".mat" format.
+
+ Args:
+ dataset_file_path: Path to dataset file, in .mat format.
+
+ Returns:
+ query_list: List of query image names.
+ index_list: List of index image names.
+ ground_truth: List containing ground-truth information for dataset. Each
+ entry is a dict corresponding to the ground-truth information for a query.
+ The dict may have keys 'easy', 'hard', or 'junk', mapping to a NumPy
+ array of integers; additionally, it has a key 'bbx' mapping to a NumPy
+ array of floats with bounding box coordinates.
+ """
+ with tf.io.gfile.GFile(dataset_file_path, 'rb') as f:
+ cfg = matlab.loadmat(f)
+
+ # Parse outputs according to the specificities of the dataset file.
+ query_list = [str(im_array[0]) for im_array in np.squeeze(cfg['qimlist'])]
+ index_list = [str(im_array[0]) for im_array in np.squeeze(cfg['imlist'])]
+ ground_truth_raw = np.squeeze(cfg['gnd'])
+ ground_truth = []
+ for query_ground_truth_raw in ground_truth_raw:
+ query_ground_truth = {}
+ for ground_truth_key in _GROUND_TRUTH_KEYS:
+ if ground_truth_key in query_ground_truth_raw.dtype.names:
+ adjusted_labels = query_ground_truth_raw[ground_truth_key] - 1
+ query_ground_truth[ground_truth_key] = adjusted_labels.flatten()
+
+ query_ground_truth['bbx'] = np.squeeze(query_ground_truth_raw['bbx'])
+ ground_truth.append(query_ground_truth)
+
+ return query_list, index_list, ground_truth
+
+
+def _ParseGroundTruth(ok_list, junk_list):
+ """Constructs dictionary of ok/junk indices for a data subset and query.
+
+ Args:
+ ok_list: List of NumPy arrays containing true positive indices for query.
+ junk_list: List of NumPy arrays containing ignored indices for query.
+
+ Returns:
+ ok_junk_dict: Dict mapping 'ok' and 'junk' strings to NumPy array of
+ indices.
+ """
+ ok_junk_dict = {}
+ ok_junk_dict['ok'] = np.concatenate(ok_list)
+ ok_junk_dict['junk'] = np.concatenate(junk_list)
+ return ok_junk_dict
+
+
+def ParseEasyMediumHardGroundTruth(ground_truth):
+ """Parses easy/medium/hard ground-truth from Revisited datasets.
+
+ Args:
+ ground_truth: Usually the output from ReadDatasetFile(). List containing
+ ground-truth information for dataset. Each entry is a dict corresponding
+ to the ground-truth information for a query. The dict must have keys
+ 'easy', 'hard', and 'junk', mapping to a NumPy array of integers.
+
+ Returns:
+ easy_ground_truth: List containing ground-truth information for easy subset
+ of dataset. Each entry is a dict corresponding to the ground-truth
+ information for a query. The dict has keys 'ok' and 'junk', mapping to a
+ NumPy array of integers.
+ medium_ground_truth: Same as `easy_ground_truth`, but for the medium subset.
+ hard_ground_truth: Same as `easy_ground_truth`, but for the hard subset.
+ """
+ num_queries = len(ground_truth)
+
+ easy_ground_truth = []
+ medium_ground_truth = []
+ hard_ground_truth = []
+ for i in range(num_queries):
+ easy_ground_truth.append(
+ _ParseGroundTruth([ground_truth[i]['easy']],
+ [ground_truth[i]['junk'], ground_truth[i]['hard']]))
+ medium_ground_truth.append(
+ _ParseGroundTruth([ground_truth[i]['easy'], ground_truth[i]['hard']],
+ [ground_truth[i]['junk']]))
+ hard_ground_truth.append(
+ _ParseGroundTruth([ground_truth[i]['hard']],
+ [ground_truth[i]['junk'], ground_truth[i]['easy']]))
+
+ return easy_ground_truth, medium_ground_truth, hard_ground_truth
+
+
+def AdjustPositiveRanks(positive_ranks, junk_ranks):
+ """Adjusts positive ranks based on junk ranks.
+
+ Args:
+ positive_ranks: Sorted 1D NumPy integer array.
+ junk_ranks: Sorted 1D NumPy integer array.
+
+ Returns:
+ adjusted_positive_ranks: Sorted 1D NumPy array.
+ """
+ if not junk_ranks.size:
+ return positive_ranks
+
+ adjusted_positive_ranks = positive_ranks
+ j = 0
+ for i, positive_index in enumerate(positive_ranks):
+ while (j < len(junk_ranks) and positive_index > junk_ranks[j]):
+ j += 1
+
+ adjusted_positive_ranks[i] -= j
+
+ return adjusted_positive_ranks
+
+
+def ComputeAveragePrecision(positive_ranks):
+ """Computes average precision according to dataset convention.
+
+ It assumes that `positive_ranks` contains the ranks for all expected positive
+ index images to be retrieved. If `positive_ranks` is empty, returns
+ `average_precision` = 0.
+
+ Note that average precision computation here does NOT use the finite sum
+ method (see
+ https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision)
+ which is common in information retrieval literature. Instead, the method
+ implemented here integrates over the precision-recall curve by averaging two
+ adjacent precision points, then multiplying by the recall step. This is the
+ convention for the Revisited Oxford/Paris datasets.
+
+ Args:
+ positive_ranks: Sorted 1D NumPy integer array, zero-indexed.
+
+ Returns:
+ average_precision: Float.
+ """
+ average_precision = 0.0
+
+ num_expected_positives = len(positive_ranks)
+ if not num_expected_positives:
+ return average_precision
+
+ recall_step = 1.0 / num_expected_positives
+ for i, rank in enumerate(positive_ranks):
+ if not rank:
+ left_precision = 1.0
+ else:
+ left_precision = i / rank
+
+ right_precision = (i + 1) / (rank + 1)
+ average_precision += (left_precision + right_precision) * recall_step / 2
+
+ return average_precision
+
+
+def ComputePRAtRanks(positive_ranks, desired_pr_ranks):
+ """Computes precision/recall at desired ranks.
+
+ It assumes that `positive_ranks` contains the ranks for all expected positive
+ index images to be retrieved. If `positive_ranks` is empty, return all-zeros
+ `precisions`/`recalls`.
+
+ If a desired rank is larger than the last positive rank, its precision is
+ computed based on the last positive rank. For example, if `desired_pr_ranks`
+ is [10] and `positive_ranks` = [0, 7] --> `precisions` = [0.25], `recalls` =
+ [1.0].
+
+ Args:
+ positive_ranks: 1D NumPy integer array, zero-indexed.
+ desired_pr_ranks: List of integers containing the desired precision/recall
+ ranks to be reported. Eg, if precision@1/recall@1 and
+ precision@10/recall@10 are desired, this should be set to [1, 10].
+
+ Returns:
+ precisions: Precision @ `desired_pr_ranks` (NumPy array of
+ floats, with shape [len(desired_pr_ranks)]).
+ recalls: Recall @ `desired_pr_ranks` (NumPy array of floats, with
+ shape [len(desired_pr_ranks)]).
+ """
+ num_desired_pr_ranks = len(desired_pr_ranks)
+ precisions = np.zeros([num_desired_pr_ranks])
+ recalls = np.zeros([num_desired_pr_ranks])
+
+ num_expected_positives = len(positive_ranks)
+ if not num_expected_positives:
+ return precisions, recalls
+
+ positive_ranks_one_indexed = positive_ranks + 1
+ for i, desired_pr_rank in enumerate(desired_pr_ranks):
+ recalls[i] = np.sum(
+ positive_ranks_one_indexed <= desired_pr_rank) / num_expected_positives
+
+ # If `desired_pr_rank` is larger than last positive's rank, only compute
+ # precision with respect to last positive's position.
+ precision_rank = min(max(positive_ranks_one_indexed), desired_pr_rank)
+ precisions[i] = np.sum(
+ positive_ranks_one_indexed <= precision_rank) / precision_rank
+
+ return precisions, recalls
+
+
+def ComputeMetrics(sorted_index_ids, ground_truth, desired_pr_ranks):
+ """Computes metrics for retrieval results on the Revisited datasets.
+
+ If there are no valid ground-truth index images for a given query, the metric
+ results for the given query (`average_precisions`, `precisions` and `recalls`)
+ are set to NaN, and they are not taken into account when computing the
+ aggregated metrics (`mean_average_precision`, `mean_precisions` and
+ `mean_recalls`) over all queries.
+
+ Args:
+ sorted_index_ids: Integer NumPy array of shape [#queries, #index_images].
+ For each query, contains an array denoting the most relevant index images,
+ sorted from most to least relevant.
+ ground_truth: List containing ground-truth information for dataset. Each
+ entry is a dict corresponding to the ground-truth information for a query.
+ The dict has keys 'ok' and 'junk', mapping to a NumPy array of integers.
+ desired_pr_ranks: List of integers containing the desired precision/recall
+ ranks to be reported. Eg, if precision@1/recall@1 and
+ precision@10/recall@10 are desired, this should be set to [1, 10]. The
+ largest item should be <= #index_images.
+
+ Returns:
+ mean_average_precision: Mean average precision (float).
+ mean_precisions: Mean precision @ `desired_pr_ranks` (NumPy array of
+ floats, with shape [len(desired_pr_ranks)]).
+ mean_recalls: Mean recall @ `desired_pr_ranks` (NumPy array of floats, with
+ shape [len(desired_pr_ranks)]).
+ average_precisions: Average precision for each query (NumPy array of floats,
+ with shape [#queries]).
+ precisions: Precision @ `desired_pr_ranks`, for each query (NumPy array of
+ floats, with shape [#queries, len(desired_pr_ranks)]).
+ recalls: Recall @ `desired_pr_ranks`, for each query (NumPy array of
+ floats, with shape [#queries, len(desired_pr_ranks)]).
+
+ Raises:
+ ValueError: If largest desired PR rank in `desired_pr_ranks` >
+ #index_images.
+ """
+ num_queries, num_index_images = sorted_index_ids.shape
+ num_desired_pr_ranks = len(desired_pr_ranks)
+
+ sorted_desired_pr_ranks = sorted(desired_pr_ranks)
+
+ if sorted_desired_pr_ranks[-1] > num_index_images:
+ raise ValueError(
+ 'Requested PR ranks up to %d, however there are only %d images' %
+ (sorted_desired_pr_ranks[-1], num_index_images))
+
+ # Instantiate all outputs, then loop over each query and gather metrics.
+ mean_average_precision = 0.0
+ mean_precisions = np.zeros([num_desired_pr_ranks])
+ mean_recalls = np.zeros([num_desired_pr_ranks])
+ average_precisions = np.zeros([num_queries])
+ precisions = np.zeros([num_queries, num_desired_pr_ranks])
+ recalls = np.zeros([num_queries, num_desired_pr_ranks])
+ num_empty_gt_queries = 0
+ for i in range(num_queries):
+ ok_index_images = ground_truth[i]['ok']
+ junk_index_images = ground_truth[i]['junk']
+
+ if not ok_index_images.size:
+ average_precisions[i] = float('nan')
+ precisions[i, :] = float('nan')
+ recalls[i, :] = float('nan')
+ num_empty_gt_queries += 1
+ continue
+
+ positive_ranks = np.arange(num_index_images)[np.in1d(
+ sorted_index_ids[i], ok_index_images)]
+ junk_ranks = np.arange(num_index_images)[np.in1d(sorted_index_ids[i],
+ junk_index_images)]
+
+ adjusted_positive_ranks = AdjustPositiveRanks(positive_ranks, junk_ranks)
+
+ average_precisions[i] = ComputeAveragePrecision(adjusted_positive_ranks)
+ precisions[i, :], recalls[i, :] = ComputePRAtRanks(adjusted_positive_ranks,
+ desired_pr_ranks)
+
+ mean_average_precision += average_precisions[i]
+ mean_precisions += precisions[i, :]
+ mean_recalls += recalls[i, :]
+
+ # Normalize aggregated metrics by number of queries.
+ num_valid_queries = num_queries - num_empty_gt_queries
+ mean_average_precision /= num_valid_queries
+ mean_precisions /= num_valid_queries
+ mean_recalls /= num_valid_queries
+
+ return (mean_average_precision, mean_precisions, mean_recalls,
+ average_precisions, precisions, recalls)
+
+
+def SaveMetricsFile(mean_average_precision, mean_precisions, mean_recalls,
+ pr_ranks, output_path):
+ """Saves aggregated retrieval metrics to text file.
+
+ Args:
+ mean_average_precision: Dict mapping each dataset protocol to a float.
+ mean_precisions: Dict mapping each dataset protocol to a NumPy array of
+ floats with shape [len(pr_ranks)].
+ mean_recalls: Dict mapping each dataset protocol to a NumPy array of floats
+ with shape [len(pr_ranks)].
+ pr_ranks: List of integers.
+ output_path: Full file path.
+ """
+ with tf.io.gfile.GFile(output_path, 'w') as f:
+ for k in sorted(mean_average_precision.keys()):
+ f.write('{}\n mAP={}\n mP@k{} {}\n mR@k{} {}\n'.format(
+ k, np.around(mean_average_precision[k] * 100, decimals=2),
+ np.array(pr_ranks), np.around(mean_precisions[k] * 100, decimals=2),
+ np.array(pr_ranks), np.around(mean_recalls[k] * 100, decimals=2)))
+
+
+def _ParseSpaceSeparatedStringsInBrackets(line, prefixes, ind):
+ """Parses line containing space-separated strings in brackets.
+
+ Args:
+ line: String, containing line in metrics file with mP@k or mR@k figures.
+ prefixes: Tuple/list of strings, containing valid prefixes.
+ ind: Integer indicating which field within brackets is parsed.
+
+ Yields:
+ entry: String format entry.
+
+ Raises:
+ ValueError: If input line does not contain a valid prefix.
+ """
+ for prefix in prefixes:
+ if line.startswith(prefix):
+ line = line[len(prefix):]
+ break
+ else:
+ raise ValueError('Line %s is malformed, cannot find valid prefixes' % line)
+
+ for entry in line.split('[')[ind].split(']')[0].split():
+ yield entry
+
+
+def _ParsePrRanks(line):
+ """Parses PR ranks from mP@k line in metrics file.
+
+ Args:
+ line: String, containing line in metrics file with mP@k figures.
+
+ Returns:
+ pr_ranks: List of integers, containing used ranks.
+
+ Raises:
+ ValueError: If input line is malformed.
+ """
+ return [
+ int(pr_rank) for pr_rank in _ParseSpaceSeparatedStringsInBrackets(
+ line, [' mP@k['], 0) if pr_rank
+ ]
+
+
+def _ParsePrScores(line, num_pr_ranks):
+ """Parses PR scores from line in metrics file.
+
+ Args:
+ line: String, containing line in metrics file with mP@k or mR@k figures.
+ num_pr_ranks: Integer, number of scores that should be in output list.
+
+ Returns:
+ pr_scores: List of floats, containing scores.
+
+ Raises:
+ ValueError: If input line is malformed.
+ """
+ pr_scores = [
+ float(pr_score) for pr_score in _ParseSpaceSeparatedStringsInBrackets(
+ line, (' mP@k[', ' mR@k['), 1) if pr_score
+ ]
+
+ if len(pr_scores) != num_pr_ranks:
+ raise ValueError('Line %s is malformed, expected %d scores but found %d' %
+ (line, num_pr_ranks, len(pr_scores)))
+
+ return pr_scores
+
+
+def ReadMetricsFile(metrics_path):
+ """Reads aggregated retrieval metrics from text file.
+
+ Args:
+ metrics_path: Full file path, containing aggregated retrieval metrics.
+
+ Returns:
+ mean_average_precision: Dict mapping each dataset protocol to a float.
+ pr_ranks: List of integer ranks used in aggregated recall/precision metrics.
+ mean_precisions: Dict mapping each dataset protocol to a NumPy array of
+ floats with shape [len(`pr_ranks`)].
+ mean_recalls: Dict mapping each dataset protocol to a NumPy array of floats
+ with shape [len(`pr_ranks`)].
+
+ Raises:
+ ValueError: If input file is malformed.
+ """
+ with tf.io.gfile.GFile(metrics_path, 'r') as f:
+ file_contents_stripped = [l.rstrip() for l in f]
+
+ if len(file_contents_stripped) % 4:
+ raise ValueError(
+ 'Malformed input %s: number of lines must be a multiple of 4, '
+ 'but it is %d' % (metrics_path, len(file_contents_stripped)))
+
+ mean_average_precision = {}
+ pr_ranks = []
+ mean_precisions = {}
+ mean_recalls = {}
+ protocols = set()
+ for i in range(0, len(file_contents_stripped), 4):
+ protocol = file_contents_stripped[i]
+ if protocol in protocols:
+ raise ValueError(
+ 'Malformed input %s: protocol %s is found a second time' %
+ (metrics_path, protocol))
+ protocols.add(protocol)
+
+ # Parse mAP.
+ mean_average_precision[protocol] = float(
+ file_contents_stripped[i + 1].split('=')[1]) / 100.0
+
+ # Parse (or check consistency of) pr_ranks.
+ parsed_pr_ranks = _ParsePrRanks(file_contents_stripped[i + 2])
+ if not pr_ranks:
+ pr_ranks = parsed_pr_ranks
+ else:
+ if parsed_pr_ranks != pr_ranks:
+ raise ValueError('Malformed input %s: inconsistent PR ranks' %
+ metrics_path)
+
+ # Parse mean precisions.
+ mean_precisions[protocol] = np.array(
+ _ParsePrScores(file_contents_stripped[i + 2], len(pr_ranks)),
+ dtype=float) / 100.0
+
+ # Parse mean recalls.
+ mean_recalls[protocol] = np.array(
+ _ParsePrScores(file_contents_stripped[i + 3], len(pr_ranks)),
+ dtype=float) / 100.0
+
+ return mean_average_precision, pr_ranks, mean_precisions, mean_recalls
diff --git a/models/research/delf/delf/python/detect_to_retrieve/dataset_test.py b/models/research/delf/delf/python/detect_to_retrieve/dataset_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e742703b04210787ede0bfc945a9f305d59efc7
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/dataset_test.py
@@ -0,0 +1,288 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the python library parsing Revisited Oxford/Paris datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+import numpy as np
+import tensorflow as tf
+
+from delf.python.detect_to_retrieve import dataset
+
+FLAGS = flags.FLAGS
+
+
+class DatasetTest(tf.test.TestCase):
+
+ def testParseEasyMediumHardGroundTruth(self):
+ # Define input.
+ ground_truth = [{
+ 'easy': np.array([10, 56, 100]),
+ 'hard': np.array([0]),
+ 'junk': np.array([6, 90])
+ }, {
+ 'easy': np.array([], dtype='int64'),
+ 'hard': [5],
+ 'junk': [99, 100]
+ }, {
+ 'easy': [33],
+ 'hard': [66, 99],
+ 'junk': np.array([], dtype='int64')
+ }]
+
+ # Run tested function.
+ (easy_ground_truth, medium_ground_truth,
+ hard_ground_truth) = dataset.ParseEasyMediumHardGroundTruth(ground_truth)
+
+ # Define expected outputs.
+ expected_easy_ground_truth = [{
+ 'ok': np.array([10, 56, 100]),
+ 'junk': np.array([6, 90, 0])
+ }, {
+ 'ok': np.array([], dtype='int64'),
+ 'junk': np.array([99, 100, 5])
+ }, {
+ 'ok': np.array([33]),
+ 'junk': np.array([66, 99])
+ }]
+ expected_medium_ground_truth = [{
+ 'ok': np.array([10, 56, 100, 0]),
+ 'junk': np.array([6, 90])
+ }, {
+ 'ok': np.array([5]),
+ 'junk': np.array([99, 100])
+ }, {
+ 'ok': np.array([33, 66, 99]),
+ 'junk': np.array([], dtype='int64')
+ }]
+ expected_hard_ground_truth = [{
+ 'ok': np.array([0]),
+ 'junk': np.array([6, 90, 10, 56, 100])
+ }, {
+ 'ok': np.array([5]),
+ 'junk': np.array([99, 100])
+ }, {
+ 'ok': np.array([66, 99]),
+ 'junk': np.array([33])
+ }]
+
+ # Compare actual versus expected.
+ def _AssertListOfDictsOfArraysAreEqual(ground_truth, expected_ground_truth):
+ """Helper function to compare ground-truth data.
+
+ Args:
+ ground_truth: List of dicts of arrays.
+ expected_ground_truth: List of dicts of arrays.
+ """
+ self.assertEqual(len(ground_truth), len(expected_ground_truth))
+
+ for i, ground_truth_entry in enumerate(ground_truth):
+ self.assertEqual(sorted(ground_truth_entry.keys()), ['junk', 'ok'])
+ self.assertAllEqual(ground_truth_entry['junk'],
+ expected_ground_truth[i]['junk'])
+ self.assertAllEqual(ground_truth_entry['ok'],
+ expected_ground_truth[i]['ok'])
+
+ _AssertListOfDictsOfArraysAreEqual(easy_ground_truth,
+ expected_easy_ground_truth)
+ _AssertListOfDictsOfArraysAreEqual(medium_ground_truth,
+ expected_medium_ground_truth)
+ _AssertListOfDictsOfArraysAreEqual(hard_ground_truth,
+ expected_hard_ground_truth)
+
+ def testAdjustPositiveRanksWorks(self):
+ # Define inputs.
+ positive_ranks = np.array([0, 2, 6, 10, 20])
+ junk_ranks = np.array([1, 8, 9, 30])
+
+ # Run tested function.
+ adjusted_positive_ranks = dataset.AdjustPositiveRanks(
+ positive_ranks, junk_ranks)
+
+ # Define expected output.
+ expected_adjusted_positive_ranks = [0, 1, 5, 7, 17]
+
+ # Compare actual versus expected.
+ self.assertAllEqual(adjusted_positive_ranks,
+ expected_adjusted_positive_ranks)
+
+ def testComputeAveragePrecisionWorks(self):
+ # Define input.
+ positive_ranks = [0, 2, 5]
+
+ # Run tested function.
+ average_precision = dataset.ComputeAveragePrecision(positive_ranks)
+
+ # Define expected output.
+ expected_average_precision = 0.677778
+
+ # Compare actual versus expected.
+ self.assertAllClose(average_precision, expected_average_precision)
+
+ def testComputePRAtRanksWorks(self):
+ # Define inputs.
+ positive_ranks = np.array([0, 2, 5])
+ desired_pr_ranks = np.array([1, 5, 10])
+
+ # Run tested function.
+ precisions, recalls = dataset.ComputePRAtRanks(positive_ranks,
+ desired_pr_ranks)
+
+ # Define expected outputs.
+ expected_precisions = [1.0, 0.4, 0.5]
+ expected_recalls = [0.333333, 0.666667, 1.0]
+
+ # Compare actual versus expected.
+ self.assertAllClose(precisions, expected_precisions)
+ self.assertAllClose(recalls, expected_recalls)
+
+ def testComputeMetricsWorks(self):
+ # Define inputs: 3 queries. For the last one, there are no expected images
+ # to be retrieved
+ sorted_index_ids = np.array([[4, 2, 0, 1, 3], [0, 2, 4, 1, 3],
+ [0, 1, 2, 3, 4]])
+ ground_truth = [{
+ 'ok': np.array([0, 1]),
+ 'junk': np.array([2])
+ }, {
+ 'ok': np.array([0, 4]),
+ 'junk': np.array([], dtype='int64')
+ }, {
+ 'ok': np.array([], dtype='int64'),
+ 'junk': np.array([], dtype='int64')
+ }]
+ desired_pr_ranks = [1, 2, 5]
+
+ # Run tested function.
+ (mean_average_precision, mean_precisions, mean_recalls, average_precisions,
+ precisions, recalls) = dataset.ComputeMetrics(sorted_index_ids,
+ ground_truth,
+ desired_pr_ranks)
+
+ # Define expected outputs.
+ expected_mean_average_precision = 0.604167
+ expected_mean_precisions = [0.5, 0.5, 0.666667]
+ expected_mean_recalls = [0.25, 0.5, 1.0]
+ expected_average_precisions = [0.416667, 0.791667, float('nan')]
+ expected_precisions = [[0.0, 0.5, 0.666667], [1.0, 0.5, 0.666667],
+ [float('nan'),
+ float('nan'),
+ float('nan')]]
+ expected_recalls = [[0.0, 0.5, 1.0], [0.5, 0.5, 1.0],
+ [float('nan'), float('nan'),
+ float('nan')]]
+
+ # Compare actual versus expected.
+ self.assertAllClose(mean_average_precision, expected_mean_average_precision)
+ self.assertAllClose(mean_precisions, expected_mean_precisions)
+ self.assertAllClose(mean_recalls, expected_mean_recalls)
+ self.assertAllClose(average_precisions, expected_average_precisions)
+ self.assertAllClose(precisions, expected_precisions)
+ self.assertAllClose(recalls, expected_recalls)
+
+ def testSaveMetricsFileWorks(self):
+ # Define inputs.
+ mean_average_precision = {'hard': 0.7, 'medium': 0.9}
+ mean_precisions = {
+ 'hard': np.array([1.0, 0.8]),
+ 'medium': np.array([1.0, 1.0])
+ }
+ mean_recalls = {
+ 'hard': np.array([0.5, 0.8]),
+ 'medium': np.array([0.5, 1.0])
+ }
+ pr_ranks = [1, 5]
+ output_path = os.path.join(FLAGS.test_tmpdir, 'metrics.txt')
+
+ # Run tested function.
+ dataset.SaveMetricsFile(mean_average_precision, mean_precisions,
+ mean_recalls, pr_ranks, output_path)
+
+ # Define expected results.
+ expected_metrics = ('hard\n'
+ ' mAP=70.0\n'
+ ' mP@k[1 5] [100. 80.]\n'
+ ' mR@k[1 5] [50. 80.]\n'
+ 'medium\n'
+ ' mAP=90.0\n'
+ ' mP@k[1 5] [100. 100.]\n'
+ ' mR@k[1 5] [ 50. 100.]\n')
+
+ # Parse actual results, and compare to expected.
+ with tf.io.gfile.GFile(output_path) as f:
+ metrics = f.read()
+
+ self.assertEqual(metrics, expected_metrics)
+
+ def testSaveAndReadMetricsWorks(self):
+ # Define inputs.
+ mean_average_precision = {'hard': 0.7, 'medium': 0.9}
+ mean_precisions = {
+ 'hard': np.array([1.0, 0.8]),
+ 'medium': np.array([1.0, 1.0])
+ }
+ mean_recalls = {
+ 'hard': np.array([0.5, 0.8]),
+ 'medium': np.array([0.5, 1.0])
+ }
+ pr_ranks = [1, 5]
+ output_path = os.path.join(FLAGS.test_tmpdir, 'metrics.txt')
+
+ # Run tested functions.
+ dataset.SaveMetricsFile(mean_average_precision, mean_precisions,
+ mean_recalls, pr_ranks, output_path)
+ (read_mean_average_precision, read_pr_ranks, read_mean_precisions,
+ read_mean_recalls) = dataset.ReadMetricsFile(output_path)
+
+ # Compares actual and expected metrics.
+ self.assertEqual(read_mean_average_precision, mean_average_precision)
+ self.assertEqual(read_pr_ranks, pr_ranks)
+ self.assertEqual(read_mean_precisions.keys(), mean_precisions.keys())
+ self.assertAllEqual(read_mean_precisions['hard'], mean_precisions['hard'])
+ self.assertAllEqual(read_mean_precisions['medium'],
+ mean_precisions['medium'])
+ self.assertEqual(read_mean_recalls.keys(), mean_recalls.keys())
+ self.assertAllEqual(read_mean_recalls['hard'], mean_recalls['hard'])
+ self.assertAllEqual(read_mean_recalls['medium'], mean_recalls['medium'])
+
+ def testReadMetricsWithRepeatedProtocolFails(self):
+ # Define inputs.
+ input_path = os.path.join(FLAGS.test_tmpdir, 'metrics.txt')
+ with tf.io.gfile.GFile(input_path, 'w') as f:
+ f.write('hard\n'
+ ' mAP=70.0\n'
+ ' mP@k[1 5] [ 100. 80.]\n'
+ ' mR@k[1 5] [ 50. 80.]\n'
+ 'medium\n'
+ ' mAP=90.0\n'
+ ' mP@k[1 5] [ 100. 100.]\n'
+ ' mR@k[1 5] [ 50. 100.]\n'
+ 'medium\n'
+ ' mAP=90.0\n'
+ ' mP@k[1 5] [ 100. 100.]\n'
+ ' mR@k[1 5] [ 50. 100.]\n')
+
+ # Run tested functions.
+ with self.assertRaisesRegex(ValueError, 'Malformed input'):
+ dataset.ReadMetricsFile(input_path)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/detect_to_retrieve/delf_gld_config.pbtxt b/models/research/delf/delf/python/detect_to_retrieve/delf_gld_config.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..046aed766ce8cee4b6309c8385d451cf20ad633a
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/delf_gld_config.pbtxt
@@ -0,0 +1,25 @@
+model_path: "parameters/delf_gld_20190411/model"
+image_scales: .25
+image_scales: .3536
+image_scales: .5
+image_scales: .7071
+image_scales: 1.0
+image_scales: 1.4142
+image_scales: 2.0
+delf_local_config {
+ use_pca: true
+ # Note that for the exported model provided as an example, layer_name and
+ # iou_threshold are hard-coded in the checkpoint. So, the layer_name and
+ # iou_threshold variables here have no effect on the provided
+ # extract_features.py script.
+ layer_name: "resnet_v1_50/block3"
+ iou_threshold: 1.0
+ max_feature_num: 1000
+ score_threshold: 100.0
+ pca_parameters {
+ mean_path: "parameters/delf_gld_20190411/pca/mean.datum"
+ projection_matrix_path: "parameters/delf_gld_20190411/pca/pca_proj_mat.datum"
+ pca_dim: 128
+ use_whitening: false
+ }
+}
diff --git a/models/research/delf/delf/python/detect_to_retrieve/extract_aggregation.py b/models/research/delf/delf/python/detect_to_retrieve/extract_aggregation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a0fb3e6c62c0adc583ad3b30b809f36742d586
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/extract_aggregation.py
@@ -0,0 +1,113 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extracts aggregation for images from Revisited Oxford/Paris datasets.
+
+The program checks if the aggregated representation for an image already exists,
+and skips computation for those.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.python.platform import app
+from delf.python.detect_to_retrieve import aggregation_extraction
+from delf.python.detect_to_retrieve import dataset
+
+cmd_args = None
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Read list of images from dataset file.
+ print('Reading list of images from dataset file...')
+ query_list, index_list, _ = dataset.ReadDatasetFile(
+ cmd_args.dataset_file_path)
+ if cmd_args.use_query_images:
+ image_list = query_list
+ else:
+ image_list = index_list
+ num_images = len(image_list)
+ print('done! Found %d images' % num_images)
+
+ aggregation_extraction.ExtractAggregatedRepresentationsToFiles(
+ image_names=image_list,
+ features_dir=cmd_args.features_dir,
+ aggregation_config_path=cmd_args.aggregation_config_path,
+ mapping_path=cmd_args.index_mapping_path,
+ output_aggregation_dir=cmd_args.output_aggregation_dir)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--aggregation_config_path',
+ type=str,
+ default='/tmp/aggregation_config.pbtxt',
+ help="""
+ Path to AggregationConfig proto text file with configuration to be used
+ for extraction.
+ """)
+ parser.add_argument(
+ '--dataset_file_path',
+ type=str,
+ default='/tmp/gnd_roxford5k.mat',
+ help="""
+ Dataset file for Revisited Oxford or Paris dataset, in .mat format.
+ """)
+ parser.add_argument(
+ '--use_query_images',
+ type=lambda x: (str(x).lower() == 'true'),
+ default=False,
+ help="""
+ If True, processes the query images of the dataset. If False, processes
+ the database (ie, index) images.
+ """)
+ parser.add_argument(
+ '--features_dir',
+ type=str,
+ default='/tmp/features',
+ help="""
+ Directory where image features are located, all in .delf format.
+ """)
+ parser.add_argument(
+ '--index_mapping_path',
+ type=str,
+ default='',
+ help="""
+ Optional CSV file which maps each .delf file name to the index image ID
+ and detected box ID. If regional aggregation is performed, this should be
+ set. Otherwise, this is ignored.
+ Usually this file is obtained as an output from the
+ `extract_index_boxes_and_features.py` script.
+ """)
+ parser.add_argument(
+ '--output_aggregation_dir',
+ type=str,
+ default='/tmp/aggregation',
+ help="""
+ Directory where aggregation output will be written to. Each image's
+ features will be written to a file with same name, and extension replaced
+ by one of
+ ['.vlad', '.asmk', '.asmk_star', '.rvlad', '.rasmk', '.rasmk_star'].
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/detect_to_retrieve/extract_index_boxes_and_features.py b/models/research/delf/delf/python/detect_to_retrieve/extract_index_boxes_and_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b891de4b0b093aa723c0dce547c2722ee475d7e
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/extract_index_boxes_and_features.py
@@ -0,0 +1,151 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extracts DELF and boxes from the Revisited Oxford/Paris index datasets.
+
+Boxes are saved to .boxes files. DELF features are extracted for the
+entire image and saved into .delf files. In addition, DELF features
+are extracted for each high-confidence bounding box in the image, and saved into
+files named _0.delf, _1.delf, etc.
+
+The program checks if descriptors/boxes already exist, and skips computation for
+those.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+
+from tensorflow.python.platform import app
+from delf.python.detect_to_retrieve import boxes_and_features_extraction
+from delf.python.detect_to_retrieve import dataset
+
+cmd_args = None
+
+_IMAGE_EXTENSION = '.jpg'
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Read list of index images from dataset file.
+ print('Reading list of index images from dataset file...')
+ _, index_list, _ = dataset.ReadDatasetFile(cmd_args.dataset_file_path)
+ num_images = len(index_list)
+ print('done! Found %d images' % num_images)
+
+ # Compose list of image paths.
+ image_paths = [
+ os.path.join(cmd_args.images_dir, index_image_name + _IMAGE_EXTENSION)
+ for index_image_name in index_list
+ ]
+
+ # Extract boxes/features and save them to files.
+ boxes_and_features_extraction.ExtractBoxesAndFeaturesToFiles(
+ image_names=index_list,
+ image_paths=image_paths,
+ delf_config_path=cmd_args.delf_config_path,
+ detector_model_dir=cmd_args.detector_model_dir,
+ detector_thresh=cmd_args.detector_thresh,
+ output_features_dir=cmd_args.output_features_dir,
+ output_boxes_dir=cmd_args.output_boxes_dir,
+ output_mapping=cmd_args.output_index_mapping)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--delf_config_path',
+ type=str,
+ default='/tmp/delf_config_example.pbtxt',
+ help="""
+ Path to DelfConfig proto text file with configuration to be used for DELF
+ extraction.
+ """)
+ parser.add_argument(
+ '--detector_model_dir',
+ type=str,
+ default='/tmp/detector_model',
+ help="""
+ Directory where detector SavedModel is located.
+ """)
+ parser.add_argument(
+ '--detector_thresh',
+ type=float,
+ default=0.1,
+ help="""
+ Threshold used to decide if an image's detected box undergoes feature
+ extraction. For all detected boxes with detection score larger than this,
+ a .delf file is saved containing the box features. Note that this
+ threshold is used only to select which boxes are used in feature
+ extraction; all detected boxes are actually saved in the .boxes file, even
+ those with score lower than detector_thresh.
+ """)
+ parser.add_argument(
+ '--dataset_file_path',
+ type=str,
+ default='/tmp/gnd_roxford5k.mat',
+ help="""
+ Dataset file for Revisited Oxford or Paris dataset, in .mat format.
+ """)
+ parser.add_argument(
+ '--images_dir',
+ type=str,
+ default='/tmp/images',
+ help="""
+ Directory where dataset images are located, all in .jpg format.
+ """)
+ parser.add_argument(
+ '--output_boxes_dir',
+ type=str,
+ default='/tmp/boxes',
+ help="""
+ Directory where detected boxes will be written to. Each image's boxes
+ will be written to a file with same name, and extension replaced by
+ .boxes.
+ """)
+ parser.add_argument(
+ '--output_features_dir',
+ type=str,
+ default='/tmp/features',
+ help="""
+ Directory where DELF features will be written to. Each image's features
+ will be written to a file with same name, and extension replaced by .delf,
+ eg: .delf. In addition, DELF features are extracted for each
+ high-confidence bounding box in the image, and saved into files named
+ _0.delf, _1.delf, etc.
+ """)
+ parser.add_argument(
+ '--output_index_mapping',
+ type=str,
+ default='/tmp/index_mapping.csv',
+ help="""
+ CSV file which maps each .delf file name to the index image ID and
+ detected box ID. The format is 'name,index_image_id,box_id', including a
+ header. The 'name' refers to the .delf file name without extension.
+
+ For example, a few lines may be like:
+ 'radcliffe_camera_000158,2,-1'
+ 'radcliffe_camera_000158_0,2,0'
+ 'radcliffe_camera_000158_1,2,1'
+ 'radcliffe_camera_000158_2,2,2'
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/detect_to_retrieve/extract_query_features.py b/models/research/delf/delf/python/detect_to_retrieve/extract_query_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0812b191265ec6e5350acf989432747d196a519
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/extract_query_features.py
@@ -0,0 +1,137 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extracts DELF features for query images from Revisited Oxford/Paris datasets.
+
+Note that query images are cropped before feature extraction, as required by the
+evaluation protocols of these datasets.
+
+The program checks if descriptors already exist, and skips computation for
+those.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from google.protobuf import text_format
+from tensorflow.python.platform import app
+from delf import delf_config_pb2
+from delf import feature_io
+from delf import utils
+from delf.python.detect_to_retrieve import dataset
+from delf import extractor
+
+cmd_args = None
+
+# Extensions.
+_DELF_EXTENSION = '.delf'
+_IMAGE_EXTENSION = '.jpg'
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Read list of query images from dataset file.
+ print('Reading list of query images and boxes from dataset file...')
+ query_list, _, ground_truth = dataset.ReadDatasetFile(
+ cmd_args.dataset_file_path)
+ num_images = len(query_list)
+ print(f'done! Found {num_images} images')
+
+ # Parse DelfConfig proto.
+ config = delf_config_pb2.DelfConfig()
+ with tf.io.gfile.GFile(cmd_args.delf_config_path, 'r') as f:
+ text_format.Merge(f.read(), config)
+
+ # Create output directory if necessary.
+ if not tf.io.gfile.exists(cmd_args.output_features_dir):
+ tf.io.gfile.makedirs(cmd_args.output_features_dir)
+
+ extractor_fn = extractor.MakeExtractor(config)
+
+ start = time.time()
+ for i in range(num_images):
+ query_image_name = query_list[i]
+ input_image_filename = os.path.join(cmd_args.images_dir,
+ query_image_name + _IMAGE_EXTENSION)
+ output_feature_filename = os.path.join(
+ cmd_args.output_features_dir, query_image_name + _DELF_EXTENSION)
+ if tf.io.gfile.exists(output_feature_filename):
+ print(f'Skipping {query_image_name}')
+ continue
+
+ # Crop query image according to bounding box.
+ bbox = [int(round(b)) for b in ground_truth[i]['bbx']]
+ im = np.array(utils.RgbLoader(input_image_filename).crop(bbox))
+
+ # Extract and save features.
+ extracted_features = extractor_fn(im)
+ locations_out = extracted_features['local_features']['locations']
+ descriptors_out = extracted_features['local_features']['descriptors']
+ feature_scales_out = extracted_features['local_features']['scales']
+ attention_out = extracted_features['local_features']['attention']
+
+ feature_io.WriteToFile(output_feature_filename, locations_out,
+ feature_scales_out, descriptors_out,
+ attention_out)
+
+ elapsed = (time.time() - start)
+ print('Processed %d query images in %f seconds' % (num_images, elapsed))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--delf_config_path',
+ type=str,
+ default='/tmp/delf_config_example.pbtxt',
+ help="""
+ Path to DelfConfig proto text file with configuration to be used for DELF
+ extraction.
+ """)
+ parser.add_argument(
+ '--dataset_file_path',
+ type=str,
+ default='/tmp/gnd_roxford5k.mat',
+ help="""
+ Dataset file for Revisited Oxford or Paris dataset, in .mat format.
+ """)
+ parser.add_argument(
+ '--images_dir',
+ type=str,
+ default='/tmp/images',
+ help="""
+ Directory where dataset images are located, all in .jpg format.
+ """)
+ parser.add_argument(
+ '--output_features_dir',
+ type=str,
+ default='/tmp/features',
+ help="""
+ Directory where DELF features will be written to. Each image's features
+ will be written to a file with same name, and extension replaced by .delf.
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/detect_to_retrieve/image_reranking.py b/models/research/delf/delf/python/detect_to_retrieve/image_reranking.py
new file mode 100644
index 0000000000000000000000000000000000000000..60c29cc18a4436815c721855da0ca4577b06e6c4
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/image_reranking.py
@@ -0,0 +1,279 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library to re-rank images based on geometric verification."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import io
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy import spatial
+from skimage import feature
+from skimage import measure
+from skimage import transform
+
+from delf import feature_io
+
+# Extensions.
+_DELF_EXTENSION = '.delf'
+
+# Pace to log.
+_STATUS_CHECK_GV_ITERATIONS = 10
+
+# Re-ranking / geometric verification parameters.
+_NUM_TO_RERANK = 100
+_NUM_RANSAC_TRIALS = 1000
+_MIN_RANSAC_SAMPLES = 3
+
+
+def MatchFeatures(query_locations,
+ query_descriptors,
+ index_image_locations,
+ index_image_descriptors,
+ ransac_seed=None,
+ feature_distance_threshold=0.9,
+ ransac_residual_threshold=10.0,
+ query_im_array=None,
+ index_im_array=None,
+ query_im_scale_factors=None,
+ index_im_scale_factors=None):
+ """Matches local features using geometric verification.
+
+ First, finds putative local feature matches by matching `query_descriptors`
+ against a KD-tree from the `index_image_descriptors`. Then, attempts to fit an
+ affine transformation between the putative feature corresponces using their
+ locations.
+
+ Args:
+ query_locations: Locations of local features for query image. NumPy array of
+ shape [#query_features, 2].
+ query_descriptors: Descriptors of local features for query image. NumPy
+ array of shape [#query_features, depth].
+ index_image_locations: Locations of local features for index image. NumPy
+ array of shape [#index_image_features, 2].
+ index_image_descriptors: Descriptors of local features for index image.
+ NumPy array of shape [#index_image_features, depth].
+ ransac_seed: Seed used by RANSAC. If None (default), no seed is provided.
+ feature_distance_threshold: Distance threshold below which a pair of
+ features is considered a potential match, and will be fed into RANSAC.
+ ransac_residual_threshold: Residual error threshold for considering matches
+ as inliers, used in RANSAC algorithm.
+ query_im_array: Optional. If not None, contains a NumPy array with the query
+ image, used to produce match visualization, if there is a match.
+ index_im_array: Optional. Same as `query_im_array`, but for index image.
+ query_im_scale_factors: Optional. If not None, contains a NumPy array with
+ the query image scales, used to produce match visualization, if there is a
+ match. If None and a visualization will be produced, [1.0, 1.0] is used
+ (ie, feature locations are not scaled).
+ index_im_scale_factors: Optional. Same as `query_im_scale_factors`, but for
+ index image.
+
+ Returns:
+ score: Number of inliers of match. If no match is found, returns 0.
+ match_viz_bytes: Encoded image bytes with visualization of the match, if
+ there is one, and if `query_im_array` and `index_im_array` are properly
+ set. Otherwise, it's an empty bytes string.
+
+ Raises:
+ ValueError: If local descriptors from query and index images have different
+ dimensionalities.
+ """
+ num_features_query = query_locations.shape[0]
+ num_features_index_image = index_image_locations.shape[0]
+ if not num_features_query or not num_features_index_image:
+ return 0, b''
+
+ local_feature_dim = query_descriptors.shape[1]
+ if index_image_descriptors.shape[1] != local_feature_dim:
+ raise ValueError(
+ 'Local feature dimensionality is not consistent for query and index '
+ 'images.')
+
+ # Find nearest-neighbor matches using a KD tree.
+ index_image_tree = spatial.cKDTree(index_image_descriptors)
+ _, indices = index_image_tree.query(
+ query_descriptors, distance_upper_bound=feature_distance_threshold)
+
+ # Select feature locations for putative matches.
+ query_locations_to_use = np.array([
+ query_locations[i,]
+ for i in range(num_features_query)
+ if indices[i] != num_features_index_image
+ ])
+ index_image_locations_to_use = np.array([
+ index_image_locations[indices[i],]
+ for i in range(num_features_query)
+ if indices[i] != num_features_index_image
+ ])
+
+ # If there are not enough putative matches, early return 0.
+ if query_locations_to_use.shape[0] <= _MIN_RANSAC_SAMPLES:
+ return 0, b''
+
+ # Perform geometric verification using RANSAC.
+ _, inliers = measure.ransac(
+ (index_image_locations_to_use, query_locations_to_use),
+ transform.AffineTransform,
+ min_samples=_MIN_RANSAC_SAMPLES,
+ residual_threshold=ransac_residual_threshold,
+ max_trials=_NUM_RANSAC_TRIALS,
+ random_state=ransac_seed)
+ match_viz_bytes = b''
+
+ if inliers is None:
+ inliers = []
+ elif query_im_array is not None and index_im_array is not None:
+ if query_im_scale_factors is None:
+ query_im_scale_factors = [1.0, 1.0]
+ if index_im_scale_factors is None:
+ index_im_scale_factors = [1.0, 1.0]
+ inlier_idxs = np.nonzero(inliers)[0]
+ _, ax = plt.subplots()
+ ax.axis('off')
+ ax.xaxis.set_major_locator(plt.NullLocator())
+ ax.yaxis.set_major_locator(plt.NullLocator())
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+ plt.margins(0, 0)
+ feature.plot_matches(
+ ax,
+ query_im_array,
+ index_im_array,
+ query_locations_to_use * query_im_scale_factors,
+ index_image_locations_to_use * index_im_scale_factors,
+ np.column_stack((inlier_idxs, inlier_idxs)),
+ only_matches=True)
+
+ match_viz_io = io.BytesIO()
+ plt.savefig(match_viz_io, format='jpeg', bbox_inches='tight', pad_inches=0)
+ match_viz_bytes = match_viz_io.getvalue()
+
+ return sum(inliers), match_viz_bytes
+
+
+def RerankByGeometricVerification(input_ranks,
+ initial_scores,
+ query_name,
+ index_names,
+ query_features_dir,
+ index_features_dir,
+ junk_ids,
+ local_feature_extension=_DELF_EXTENSION,
+ ransac_seed=None,
+ feature_distance_threshold=0.9,
+ ransac_residual_threshold=10.0):
+ """Re-ranks retrieval results using geometric verification.
+
+ Args:
+ input_ranks: 1D NumPy array with indices of top-ranked index images, sorted
+ from the most to the least similar.
+ initial_scores: 1D NumPy array with initial similarity scores between query
+ and index images. Entry i corresponds to score for image i.
+ query_name: Name for query image (string).
+ index_names: List of names for index images (strings).
+ query_features_dir: Directory where query local feature file is located
+ (string).
+ index_features_dir: Directory where index local feature files are located
+ (string).
+ junk_ids: Set with indices of junk images which should not be considered
+ during re-ranking.
+ local_feature_extension: String, extension to use for loading local feature
+ files.
+ ransac_seed: Seed used by RANSAC. If None (default), no seed is provided.
+ feature_distance_threshold: Distance threshold below which a pair of local
+ features is considered a potential match, and will be fed into RANSAC.
+ ransac_residual_threshold: Residual error threshold for considering matches
+ as inliers, used in RANSAC algorithm.
+
+ Returns:
+ output_ranks: 1D NumPy array with index image indices, sorted from the most
+ to the least similar according to the geometric verification and initial
+ scores.
+
+ Raises:
+ ValueError: If `input_ranks`, `initial_scores` and `index_names` do not have
+ the same number of entries.
+ """
+ num_index_images = len(index_names)
+ if len(input_ranks) != num_index_images:
+ raise ValueError('input_ranks and index_names have different number of '
+ 'elements: %d vs %d' %
+ (len(input_ranks), len(index_names)))
+ if len(initial_scores) != num_index_images:
+ raise ValueError('initial_scores and index_names have different number of '
+ 'elements: %d vs %d' %
+ (len(initial_scores), len(index_names)))
+
+ # Filter out junk images from list that will be re-ranked.
+ input_ranks_for_gv = []
+ for ind in input_ranks:
+ if ind not in junk_ids:
+ input_ranks_for_gv.append(ind)
+ num_to_rerank = min(_NUM_TO_RERANK, len(input_ranks_for_gv))
+
+ # Load query image features.
+ query_features_path = os.path.join(query_features_dir,
+ query_name + local_feature_extension)
+ query_locations, _, query_descriptors, _, _ = feature_io.ReadFromFile(
+ query_features_path)
+
+ # Initialize list containing number of inliers and initial similarity scores.
+ inliers_and_initial_scores = []
+ for i in range(num_index_images):
+ inliers_and_initial_scores.append([0, initial_scores[i]])
+
+ # Loop over top-ranked images and get results.
+ print('Starting to re-rank')
+ for i in range(num_to_rerank):
+ if i > 0 and i % _STATUS_CHECK_GV_ITERATIONS == 0:
+ print('Re-ranking: i = %d out of %d' % (i, num_to_rerank))
+
+ index_image_id = input_ranks_for_gv[i]
+
+ # Load index image features.
+ index_image_features_path = os.path.join(
+ index_features_dir,
+ index_names[index_image_id] + local_feature_extension)
+ (index_image_locations, _, index_image_descriptors, _,
+ _) = feature_io.ReadFromFile(index_image_features_path)
+
+ inliers_and_initial_scores[index_image_id][0], _ = MatchFeatures(
+ query_locations,
+ query_descriptors,
+ index_image_locations,
+ index_image_descriptors,
+ ransac_seed=ransac_seed,
+ feature_distance_threshold=feature_distance_threshold,
+ ransac_residual_threshold=ransac_residual_threshold)
+
+ # Sort based on (inliers_score, initial_score).
+ def _InliersInitialScoresSorting(k):
+ """Helper function to sort list based on two entries.
+
+ Args:
+ k: Index into `inliers_and_initial_scores`.
+
+ Returns:
+ Tuple containing inlier score and initial score.
+ """
+ return (inliers_and_initial_scores[k][0], inliers_and_initial_scores[k][1])
+
+ output_ranks = sorted(
+ range(num_index_images), key=_InliersInitialScoresSorting, reverse=True)
+
+ return output_ranks
diff --git a/models/research/delf/delf/python/detect_to_retrieve/index_aggregation_config.pbtxt b/models/research/delf/delf/python/detect_to_retrieve/index_aggregation_config.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ba7ba4e4956637152d952aff1cccd66da42800f4
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/index_aggregation_config.pbtxt
@@ -0,0 +1,10 @@
+codebook_size: 65536
+feature_dimensionality: 128
+aggregation_type: ASMK_STAR
+use_l2_normalization: false
+codebook_path: "parameters/rparis6k_codebook_65536/k65536_codebook_tfckpt/codebook"
+num_assignments: 1
+use_regional_aggregation: true
+feature_batch_size: 100
+alpha: 3.0
+tau: 0.0
diff --git a/models/research/delf/delf/python/detect_to_retrieve/perform_retrieval.py b/models/research/delf/delf/python/detect_to_retrieve/perform_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2034dfb285118f4ed8928f996e031365a3ffbbf
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/perform_retrieval.py
@@ -0,0 +1,301 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Performs image retrieval on Revisited Oxford/Paris datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from google.protobuf import text_format
+from tensorflow.python.platform import app
+from delf import aggregation_config_pb2
+from delf import datum_io
+from delf import feature_aggregation_similarity
+from delf.python.detect_to_retrieve import dataset
+from delf.python.detect_to_retrieve import image_reranking
+
+cmd_args = None
+
+# Aliases for aggregation types.
+_VLAD = aggregation_config_pb2.AggregationConfig.VLAD
+_ASMK = aggregation_config_pb2.AggregationConfig.ASMK
+_ASMK_STAR = aggregation_config_pb2.AggregationConfig.ASMK_STAR
+
+# Extensions.
+_VLAD_EXTENSION_SUFFIX = 'vlad'
+_ASMK_EXTENSION_SUFFIX = 'asmk'
+_ASMK_STAR_EXTENSION_SUFFIX = 'asmk_star'
+
+# Precision-recall ranks to use in metric computation.
+_PR_RANKS = (1, 5, 10)
+
+# Pace to log.
+_STATUS_CHECK_LOAD_ITERATIONS = 50
+
+# Output file names.
+_METRICS_FILENAME = 'metrics.txt'
+
+
+def _ReadAggregatedDescriptors(input_dir, image_list, config):
+ """Reads aggregated descriptors.
+
+ Args:
+ input_dir: Directory where aggregated descriptors are located.
+ image_list: List of image names for which to load descriptors.
+ config: AggregationConfig used for images.
+
+ Returns:
+ aggregated_descriptors: List containing #images items, each a 1D NumPy
+ array.
+ visual_words: If using VLAD aggregation, returns an empty list. Otherwise,
+ returns a list containing #images items, each a 1D NumPy array.
+ """
+ # Compose extension of aggregated descriptors.
+ extension = '.'
+ if config.use_regional_aggregation:
+ extension += 'r'
+ if config.aggregation_type == _VLAD:
+ extension += _VLAD_EXTENSION_SUFFIX
+ elif config.aggregation_type == _ASMK:
+ extension += _ASMK_EXTENSION_SUFFIX
+ elif config.aggregation_type == _ASMK_STAR:
+ extension += _ASMK_STAR_EXTENSION_SUFFIX
+ else:
+ raise ValueError('Invalid aggregation type: %d' % config.aggregation_type)
+
+ num_images = len(image_list)
+ aggregated_descriptors = []
+ visual_words = []
+ print('Starting to collect descriptors for %d images...' % num_images)
+ start = time.clock()
+ for i in range(num_images):
+ if i > 0 and i % _STATUS_CHECK_LOAD_ITERATIONS == 0:
+ elapsed = (time.clock() - start)
+ print('Reading descriptors for image %d out of %d, last %d '
+ 'images took %f seconds' %
+ (i, num_images, _STATUS_CHECK_LOAD_ITERATIONS, elapsed))
+ start = time.clock()
+
+ descriptors_filename = image_list[i] + extension
+ descriptors_fullpath = os.path.join(input_dir, descriptors_filename)
+ if config.aggregation_type == _VLAD:
+ aggregated_descriptors.append(datum_io.ReadFromFile(descriptors_fullpath))
+ else:
+ d, v = datum_io.ReadPairFromFile(descriptors_fullpath)
+ if config.aggregation_type == _ASMK_STAR:
+ d = d.astype('uint8')
+
+ aggregated_descriptors.append(d)
+ visual_words.append(v)
+
+ return aggregated_descriptors, visual_words
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Parse dataset to obtain query/index images, and ground-truth.
+ print('Parsing dataset...')
+ query_list, index_list, ground_truth = dataset.ReadDatasetFile(
+ cmd_args.dataset_file_path)
+ num_query_images = len(query_list)
+ num_index_images = len(index_list)
+ (_, medium_ground_truth,
+ hard_ground_truth) = dataset.ParseEasyMediumHardGroundTruth(ground_truth)
+ print('done! Found %d queries and %d index images' %
+ (num_query_images, num_index_images))
+
+ # Parse AggregationConfig protos.
+ query_config = aggregation_config_pb2.AggregationConfig()
+ with tf.io.gfile.GFile(cmd_args.query_aggregation_config_path, 'r') as f:
+ text_format.Merge(f.read(), query_config)
+ index_config = aggregation_config_pb2.AggregationConfig()
+ with tf.io.gfile.GFile(cmd_args.index_aggregation_config_path, 'r') as f:
+ text_format.Merge(f.read(), index_config)
+
+ # Read aggregated descriptors.
+ query_aggregated_descriptors, query_visual_words = _ReadAggregatedDescriptors(
+ cmd_args.query_aggregation_dir, query_list, query_config)
+ index_aggregated_descriptors, index_visual_words = _ReadAggregatedDescriptors(
+ cmd_args.index_aggregation_dir, index_list, index_config)
+
+ # Create similarity computer.
+ similarity_computer = (
+ feature_aggregation_similarity.SimilarityAggregatedRepresentation(
+ index_config))
+
+ # Compute similarity between query and index images, potentially re-ranking
+ # with geometric verification.
+ ranks_before_gv = np.zeros([num_query_images, num_index_images],
+ dtype='int32')
+ if cmd_args.use_geometric_verification:
+ medium_ranks_after_gv = np.zeros([num_query_images, num_index_images],
+ dtype='int32')
+ hard_ranks_after_gv = np.zeros([num_query_images, num_index_images],
+ dtype='int32')
+ for i in range(num_query_images):
+ print('Performing retrieval with query %d (%s)...' % (i, query_list[i]))
+ start = time.clock()
+
+ # Compute similarity between aggregated descriptors.
+ similarities = np.zeros([num_index_images])
+ for j in range(num_index_images):
+ similarities[j] = similarity_computer.ComputeSimilarity(
+ query_aggregated_descriptors[i], index_aggregated_descriptors[j],
+ query_visual_words[i], index_visual_words[j])
+
+ ranks_before_gv[i] = np.argsort(-similarities)
+
+ # Re-rank using geometric verification.
+ if cmd_args.use_geometric_verification:
+ medium_ranks_after_gv[i] = image_reranking.RerankByGeometricVerification(
+ ranks_before_gv[i], similarities, query_list[i], index_list,
+ cmd_args.query_features_dir, cmd_args.index_features_dir,
+ set(medium_ground_truth[i]['junk']))
+ hard_ranks_after_gv[i] = image_reranking.RerankByGeometricVerification(
+ ranks_before_gv[i], similarities, query_list[i], index_list,
+ cmd_args.query_features_dir, cmd_args.index_features_dir,
+ set(hard_ground_truth[i]['junk']))
+
+ elapsed = (time.clock() - start)
+ print('done! Retrieval for query %d took %f seconds' % (i, elapsed))
+
+ # Create output directory if necessary.
+ if not tf.io.gfile.exists(cmd_args.output_dir):
+ tf.io.gfile.makedirs(cmd_args.output_dir)
+
+ # Compute metrics.
+ medium_metrics = dataset.ComputeMetrics(ranks_before_gv, medium_ground_truth,
+ _PR_RANKS)
+ hard_metrics = dataset.ComputeMetrics(ranks_before_gv, hard_ground_truth,
+ _PR_RANKS)
+ if cmd_args.use_geometric_verification:
+ medium_metrics_after_gv = dataset.ComputeMetrics(medium_ranks_after_gv,
+ medium_ground_truth,
+ _PR_RANKS)
+ hard_metrics_after_gv = dataset.ComputeMetrics(hard_ranks_after_gv,
+ hard_ground_truth, _PR_RANKS)
+
+ # Write metrics to file.
+ mean_average_precision_dict = {
+ 'medium': medium_metrics[0],
+ 'hard': hard_metrics[0]
+ }
+ mean_precisions_dict = {'medium': medium_metrics[1], 'hard': hard_metrics[1]}
+ mean_recalls_dict = {'medium': medium_metrics[2], 'hard': hard_metrics[2]}
+ if cmd_args.use_geometric_verification:
+ mean_average_precision_dict.update({
+ 'medium_after_gv': medium_metrics_after_gv[0],
+ 'hard_after_gv': hard_metrics_after_gv[0]
+ })
+ mean_precisions_dict.update({
+ 'medium_after_gv': medium_metrics_after_gv[1],
+ 'hard_after_gv': hard_metrics_after_gv[1]
+ })
+ mean_recalls_dict.update({
+ 'medium_after_gv': medium_metrics_after_gv[2],
+ 'hard_after_gv': hard_metrics_after_gv[2]
+ })
+ dataset.SaveMetricsFile(mean_average_precision_dict, mean_precisions_dict,
+ mean_recalls_dict, _PR_RANKS,
+ os.path.join(cmd_args.output_dir, _METRICS_FILENAME))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--index_aggregation_config_path',
+ type=str,
+ default='/tmp/index_aggregation_config.pbtxt',
+ help="""
+ Path to index AggregationConfig proto text file. This is used to load the
+ aggregated descriptors from the index, and to define the parameters used
+ in computing similarity for aggregated descriptors.
+ """)
+ parser.add_argument(
+ '--query_aggregation_config_path',
+ type=str,
+ default='/tmp/query_aggregation_config.pbtxt',
+ help="""
+ Path to query AggregationConfig proto text file. This is only used to load
+ the aggregated descriptors for the queries.
+ """)
+ parser.add_argument(
+ '--dataset_file_path',
+ type=str,
+ default='/tmp/gnd_roxford5k.mat',
+ help="""
+ Dataset file for Revisited Oxford or Paris dataset, in .mat format.
+ """)
+ parser.add_argument(
+ '--index_aggregation_dir',
+ type=str,
+ default='/tmp/index_aggregation',
+ help="""
+ Directory where index aggregated descriptors are located.
+ """)
+ parser.add_argument(
+ '--query_aggregation_dir',
+ type=str,
+ default='/tmp/query_aggregation',
+ help="""
+ Directory where query aggregated descriptors are located.
+ """)
+ parser.add_argument(
+ '--use_geometric_verification',
+ type=lambda x: (str(x).lower() == 'true'),
+ default=False,
+ help="""
+ If True, performs re-ranking using local feature-based geometric
+ verification.
+ """)
+ parser.add_argument(
+ '--index_features_dir',
+ type=str,
+ default='/tmp/index_features',
+ help="""
+ Only used if `use_geometric_verification` is True.
+ Directory where index local image features are located, all in .delf
+ format.
+ """)
+ parser.add_argument(
+ '--query_features_dir',
+ type=str,
+ default='/tmp/query_features',
+ help="""
+ Only used if `use_geometric_verification` is True.
+ Directory where query local image features are located, all in .delf
+ format.
+ """)
+ parser.add_argument(
+ '--output_dir',
+ type=str,
+ default='/tmp/retrieval',
+ help="""
+ Directory where retrieval output will be written to. A file containing
+ metrics for this run is saved therein, with file name "metrics.txt".
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/detect_to_retrieve/query_aggregation_config.pbtxt b/models/research/delf/delf/python/detect_to_retrieve/query_aggregation_config.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..39a917eef4389baa9a7f08722aed0a2e0cb6dd1f
--- /dev/null
+++ b/models/research/delf/delf/python/detect_to_retrieve/query_aggregation_config.pbtxt
@@ -0,0 +1,7 @@
+codebook_size: 65536
+feature_dimensionality: 128
+aggregation_type: ASMK_STAR
+codebook_path: "parameters/rparis6k_codebook_65536/k65536_codebook_tfckpt/codebook"
+num_assignments: 1
+use_regional_aggregation: false
+feature_batch_size: 100
diff --git a/models/research/delf/delf/python/examples/__init__.py b/models/research/delf/delf/python/examples/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/delf/delf/python/examples/delf_config_example.pbtxt b/models/research/delf/delf/python/examples/delf_config_example.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ff2d9c0023c41accd2cff51c0117d4ff01def2a0
--- /dev/null
+++ b/models/research/delf/delf/python/examples/delf_config_example.pbtxt
@@ -0,0 +1,25 @@
+model_path: "parameters/delf_gld_20190411/model/"
+image_scales: .25
+image_scales: .3536
+image_scales: .5
+image_scales: .7071
+image_scales: 1.0
+image_scales: 1.4142
+image_scales: 2.0
+delf_local_config {
+ use_pca: true
+ # Note that for the exported model provided as an example, layer_name and
+ # iou_threshold are hard-coded in the checkpoint. So, the layer_name and
+ # iou_threshold variables here have no effect on the provided
+ # extract_features.py script.
+ layer_name: "resnet_v1_50/block3"
+ iou_threshold: 1.0
+ max_feature_num: 1000
+ score_threshold: 100.0
+ pca_parameters {
+ mean_path: "parameters/delf_gld_20190411/pca/mean.datum"
+ projection_matrix_path: "parameters/delf_gld_20190411/pca/pca_proj_mat.datum"
+ pca_dim: 40
+ use_whitening: false
+ }
+}
diff --git a/models/research/delf/delf/python/examples/detection_example_1.jpg b/models/research/delf/delf/python/examples/detection_example_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..afdb388f0dea74de9ac259cfd8a0f9b3b17779e8
Binary files /dev/null and b/models/research/delf/delf/python/examples/detection_example_1.jpg differ
diff --git a/models/research/delf/delf/python/examples/detection_example_2.jpg b/models/research/delf/delf/python/examples/detection_example_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5baf54a80888810f0a0daa279eb29936f15c8457
Binary files /dev/null and b/models/research/delf/delf/python/examples/detection_example_2.jpg differ
diff --git a/models/research/delf/delf/python/examples/detector.py b/models/research/delf/delf/python/examples/detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd8aef1cf7fef2aea7ae1e28e793f0ab172e915d
--- /dev/null
+++ b/models/research/delf/delf/python/examples/detector.py
@@ -0,0 +1,55 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module to construct object detector function."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def MakeDetector(model_dir):
+ """Creates a function to detect objects in an image.
+
+ Args:
+ model_dir: Directory where SavedModel is located.
+
+ Returns:
+ Function that receives an image and returns detection results.
+ """
+ model = tf.saved_model.load(model_dir)
+
+ # Input and output tensors.
+ feeds = ['input_images:0']
+ fetches = ['detection_boxes:0', 'detection_scores:0', 'detection_classes:0']
+
+ model = model.prune(feeds=feeds, fetches=fetches)
+
+ def DetectorFn(images):
+ """Receives an image and returns detected boxes.
+
+ Args:
+ images: Uint8 array with shape (batch, height, width 3) containing a batch
+ of RGB images.
+
+ Returns:
+ Tuple (boxes, scores, class_indices).
+ """
+ boxes, scores, class_indices = model(tf.convert_to_tensor(images))
+
+ return boxes.numpy(), scores.numpy(), class_indices.numpy()
+
+ return DetectorFn
diff --git a/models/research/delf/delf/python/examples/extract_boxes.py b/models/research/delf/delf/python/examples/extract_boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..8851c44fb9a051104adde50a4c869a28cfd513da
--- /dev/null
+++ b/models/research/delf/delf/python/examples/extract_boxes.py
@@ -0,0 +1,234 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extracts bounding boxes from a list of images, saving them to files.
+
+The images must be in JPG format. The program checks if boxes already
+exist, and skips computation for those.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+import time
+
+import matplotlib.patches as patches
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.platform import app
+from delf import box_io
+from delf import utils
+from delf import detector
+
+cmd_args = None
+
+# Extension/suffix of produced files.
+_BOX_EXT = '.boxes'
+_VIZ_SUFFIX = '_viz.jpg'
+
+# Used for plotting boxes.
+_BOX_EDGE_COLORS = ['r', 'y', 'b', 'm', 'k', 'g', 'c', 'w']
+
+# Pace to report extraction log.
+_STATUS_CHECK_ITERATIONS = 100
+
+
+def _ReadImageList(list_path):
+ """Helper function to read image paths.
+
+ Args:
+ list_path: Path to list of images, one image path per line.
+
+ Returns:
+ image_paths: List of image paths.
+ """
+ with tf.io.gfile.GFile(list_path, 'r') as f:
+ image_paths = f.readlines()
+ image_paths = [entry.rstrip() for entry in image_paths]
+ return image_paths
+
+
+def _FilterBoxesByScore(boxes, scores, class_indices, score_threshold):
+ """Filter boxes based on detection scores.
+
+ Boxes with detection score >= score_threshold are returned.
+
+ Args:
+ boxes: [N, 4] float array denoting bounding box coordinates, in format [top,
+ left, bottom, right].
+ scores: [N] float array with detection scores.
+ class_indices: [N] int array with class indices.
+ score_threshold: Float detection score threshold to use.
+
+ Returns:
+ selected_boxes: selected `boxes`.
+ selected_scores: selected `scores`.
+ selected_class_indices: selected `class_indices`.
+ """
+ selected_boxes = []
+ selected_scores = []
+ selected_class_indices = []
+ for i, box in enumerate(boxes):
+ if scores[i] >= score_threshold:
+ selected_boxes.append(box)
+ selected_scores.append(scores[i])
+ selected_class_indices.append(class_indices[i])
+
+ return np.array(selected_boxes), np.array(selected_scores), np.array(
+ selected_class_indices)
+
+
+def _PlotBoxesAndSaveImage(image, boxes, output_path):
+ """Plot boxes on image and save to output path.
+
+ Args:
+ image: Numpy array containing image.
+ boxes: [N, 4] float array denoting bounding box coordinates, in format [top,
+ left, bottom, right].
+ output_path: String containing output path.
+ """
+ height = image.shape[0]
+ width = image.shape[1]
+
+ fig, ax = plt.subplots(1)
+ ax.imshow(image)
+ for i, box in enumerate(boxes):
+ scaled_box = [
+ box[0] * height, box[1] * width, box[2] * height, box[3] * width
+ ]
+ rect = patches.Rectangle([scaled_box[1], scaled_box[0]],
+ scaled_box[3] - scaled_box[1],
+ scaled_box[2] - scaled_box[0],
+ linewidth=3,
+ edgecolor=_BOX_EDGE_COLORS[i %
+ len(_BOX_EDGE_COLORS)],
+ facecolor='none')
+ ax.add_patch(rect)
+
+ ax.axis('off')
+ plt.savefig(output_path, bbox_inches='tight')
+ plt.close(fig)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Read list of images.
+ print('Reading list of images...')
+ image_paths = _ReadImageList(cmd_args.list_images_path)
+ num_images = len(image_paths)
+ print(f'done! Found {num_images} images')
+
+ # Create output directories if necessary.
+ if not tf.io.gfile.exists(cmd_args.output_dir):
+ tf.io.gfile.makedirs(cmd_args.output_dir)
+ if cmd_args.output_viz_dir and not tf.io.gfile.exists(
+ cmd_args.output_viz_dir):
+ tf.io.gfile.makedirs(cmd_args.output_viz_dir)
+
+ detector_fn = detector.MakeDetector(cmd_args.detector_path)
+
+ start = time.time()
+ for i, image_path in enumerate(image_paths):
+ # Report progress once in a while.
+ if i == 0:
+ print('Starting to detect objects in images...')
+ elif i % _STATUS_CHECK_ITERATIONS == 0:
+ elapsed = (time.time() - start)
+ print(
+ f'Processing image {i} out of {num_images}, last '
+ f'{_STATUS_CHECK_ITERATIONS} images took {elapsed} seconds'
+ )
+ start = time.time()
+
+ # If descriptor already exists, skip its computation.
+ base_boxes_filename, _ = os.path.splitext(os.path.basename(image_path))
+ out_boxes_filename = base_boxes_filename + _BOX_EXT
+ out_boxes_fullpath = os.path.join(cmd_args.output_dir,
+ out_boxes_filename)
+ if tf.io.gfile.exists(out_boxes_fullpath):
+ print(f'Skipping {image_path}')
+ continue
+
+ im = np.expand_dims(np.array(utils.RgbLoader(image_paths[i])), 0)
+
+ # Extract and save boxes.
+ (boxes_out, scores_out, class_indices_out) = detector_fn(im)
+ (selected_boxes, selected_scores,
+ selected_class_indices) = _FilterBoxesByScore(boxes_out[0],
+ scores_out[0],
+ class_indices_out[0],
+ cmd_args.detector_thresh)
+
+ box_io.WriteToFile(out_boxes_fullpath, selected_boxes, selected_scores,
+ selected_class_indices)
+ if cmd_args.output_viz_dir:
+ out_viz_filename = base_boxes_filename + _VIZ_SUFFIX
+ out_viz_fullpath = os.path.join(cmd_args.output_viz_dir,
+ out_viz_filename)
+ _PlotBoxesAndSaveImage(im[0], selected_boxes, out_viz_fullpath)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--detector_path',
+ type=str,
+ default='/tmp/d2r_frcnn_20190411/',
+ help="""
+ Path to exported detector model.
+ """)
+ parser.add_argument(
+ '--detector_thresh',
+ type=float,
+ default=.0,
+ help="""
+ Detector threshold. Any box with confidence score lower than this is not
+ returned.
+ """)
+ parser.add_argument(
+ '--list_images_path',
+ type=str,
+ default='list_images.txt',
+ help="""
+ Path to list of images to undergo object detection.
+ """)
+ parser.add_argument(
+ '--output_dir',
+ type=str,
+ default='test_boxes',
+ help="""
+ Directory where bounding boxes will be written to. Each image's boxes
+ will be written to a file with same name, and extension replaced by
+ .boxes.
+ """)
+ parser.add_argument(
+ '--output_viz_dir',
+ type=str,
+ default='',
+ help="""
+ Optional. If set, a visualization of the detected boxes overlaid on the
+ image is produced, and saved to this directory. Each image is saved with
+ _viz.jpg suffix.
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/examples/extract_features.py b/models/research/delf/delf/python/examples/extract_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..05fd77316070d39722e133dbd544f5b53791f6d0
--- /dev/null
+++ b/models/research/delf/delf/python/examples/extract_features.py
@@ -0,0 +1,144 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extracts DELF features from a list of images, saving them to file.
+
+The images must be in JPG format. The program checks if descriptors already
+exist, and skips computation for those.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+import time
+
+import numpy as np
+from six.moves import range
+import tensorflow as tf
+
+from google.protobuf import text_format
+from tensorflow.python.platform import app
+from delf import delf_config_pb2
+from delf import feature_io
+from delf import utils
+from delf import extractor
+
+cmd_args = None
+
+# Extension of feature files.
+_DELF_EXT = '.delf'
+
+# Pace to report extraction log.
+_STATUS_CHECK_ITERATIONS = 100
+
+
+def _ReadImageList(list_path):
+ """Helper function to read image paths.
+
+ Args:
+ list_path: Path to list of images, one image path per line.
+
+ Returns:
+ image_paths: List of image paths.
+ """
+ with tf.io.gfile.GFile(list_path, 'r') as f:
+ image_paths = f.readlines()
+ image_paths = [entry.rstrip() for entry in image_paths]
+ return image_paths
+
+
+def main(unused_argv):
+ # Read list of images.
+ print('Reading list of images...')
+ image_paths = _ReadImageList(cmd_args.list_images_path)
+ num_images = len(image_paths)
+ print(f'done! Found {num_images} images')
+
+ # Parse DelfConfig proto.
+ config = delf_config_pb2.DelfConfig()
+ with tf.io.gfile.GFile(cmd_args.config_path, 'r') as f:
+ text_format.Merge(f.read(), config)
+
+ # Create output directory if necessary.
+ if not tf.io.gfile.exists(cmd_args.output_dir):
+ tf.io.gfile.makedirs(cmd_args.output_dir)
+
+ extractor_fn = extractor.MakeExtractor(config)
+
+ start = time.time()
+ for i in range(num_images):
+ # Report progress once in a while.
+ if i == 0:
+ print('Starting to extract DELF features from images...')
+ elif i % _STATUS_CHECK_ITERATIONS == 0:
+ elapsed = (time.time() - start)
+ print(
+ f'Processing image {i} out of {num_images}, last '
+ f'{_STATUS_CHECK_ITERATIONS} images took {elapsed} seconds'
+ )
+ start = time.time()
+
+ # If descriptor already exists, skip its computation.
+ out_desc_filename = os.path.splitext(os.path.basename(
+ image_paths[i]))[0] + _DELF_EXT
+ out_desc_fullpath = os.path.join(cmd_args.output_dir, out_desc_filename)
+ if tf.io.gfile.exists(out_desc_fullpath):
+ print(f'Skipping {image_paths[i]}')
+ continue
+
+ im = np.array(utils.RgbLoader(image_paths[i]))
+
+ # Extract and save features.
+ extracted_features = extractor_fn(im)
+ locations_out = extracted_features['local_features']['locations']
+ descriptors_out = extracted_features['local_features']['descriptors']
+ feature_scales_out = extracted_features['local_features']['scales']
+ attention_out = extracted_features['local_features']['attention']
+
+ feature_io.WriteToFile(out_desc_fullpath, locations_out, feature_scales_out,
+ descriptors_out, attention_out)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--config_path',
+ type=str,
+ default='delf_config_example.pbtxt',
+ help="""
+ Path to DelfConfig proto text file with configuration to be used for DELF
+ extraction.
+ """)
+ parser.add_argument(
+ '--list_images_path',
+ type=str,
+ default='list_images.txt',
+ help="""
+ Path to list of images whose DELF features will be extracted.
+ """)
+ parser.add_argument(
+ '--output_dir',
+ type=str,
+ default='test_features',
+ help="""
+ Directory where DELF features will be written to. Each image's features
+ will be written to a file with same name, and extension replaced by .delf.
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/examples/extractor.py b/models/research/delf/delf/python/examples/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd63ab38362a9c6f9ccc5d3bfca1fd007045d261
--- /dev/null
+++ b/models/research/delf/delf/python/examples/extractor.py
@@ -0,0 +1,277 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module to construct DELF feature extractor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from PIL import Image
+import tensorflow as tf
+
+from delf import datum_io
+from delf import feature_extractor
+
+# Minimum dimensions below which DELF features are not extracted (empty
+# features are returned). This applies after any resizing is performed.
+_MIN_HEIGHT = 10
+_MIN_WIDTH = 10
+
+
+def ResizeImage(image, config, resize_factor=1.0):
+ """Resizes image according to config.
+
+ Args:
+ image: Uint8 array with shape (height, width, 3).
+ config: DelfConfig proto containing the model configuration.
+ resize_factor: Optional float resize factor for the input image. If given,
+ the maximum and minimum allowed image sizes in `config` are scaled by this
+ factor. Must be non-negative.
+
+ Returns:
+ resized_image: Uint8 array with resized image.
+ scale_factors: 2D float array, with factors used for resizing along height
+ and width (If upscaling, larger than 1; if downscaling, smaller than 1).
+
+ Raises:
+ ValueError: If `image` has incorrect number of dimensions/channels.
+ """
+ if resize_factor < 0.0:
+ raise ValueError('negative resize_factor is not allowed: %f' %
+ resize_factor)
+ if image.ndim != 3:
+ raise ValueError('image has incorrect number of dimensions: %d' %
+ image.ndims)
+ height, width, channels = image.shape
+
+ # Take into account resize factor.
+ max_image_size = resize_factor * config.max_image_size
+ min_image_size = resize_factor * config.min_image_size
+
+ if channels != 3:
+ raise ValueError('image has incorrect number of channels: %d' % channels)
+
+ largest_side = max(width, height)
+
+ if max_image_size >= 0 and largest_side > max_image_size:
+ scale_factor = max_image_size / largest_side
+ elif min_image_size >= 0 and largest_side < min_image_size:
+ scale_factor = min_image_size / largest_side
+ elif config.use_square_images and (height != width):
+ scale_factor = 1.0
+ else:
+ # No resizing needed, early return.
+ return image, np.ones(2, dtype=float)
+
+ # Note that new_shape is in (width, height) format (PIL convention), while
+ # scale_factors are in (height, width) convention (NumPy convention).
+ if config.use_square_images:
+ new_shape = (int(round(largest_side * scale_factor)),
+ int(round(largest_side * scale_factor)))
+ else:
+ new_shape = (int(round(width * scale_factor)),
+ int(round(height * scale_factor)))
+
+ scale_factors = np.array([new_shape[1] / height, new_shape[0] / width],
+ dtype=float)
+
+ pil_image = Image.fromarray(image)
+ resized_image = np.array(pil_image.resize(new_shape, resample=Image.BILINEAR))
+
+ return resized_image, scale_factors
+
+
+def MakeExtractor(config):
+ """Creates a function to extract global and/or local features from an image.
+
+ Args:
+ config: DelfConfig proto containing the model configuration.
+
+ Returns:
+ Function that receives an image and returns features.
+ """
+ # Load model.
+ model = tf.saved_model.load(config.model_path)
+
+ # Input/output end-points/tensors.
+ feeds = ['input_image:0', 'input_scales:0']
+ fetches = []
+ image_scales_tensor = tf.convert_to_tensor(list(config.image_scales))
+
+ # Custom configuration needed when local features are used.
+ if config.use_local_features:
+ # Extra input/output end-points/tensors.
+ feeds.append('input_abs_thres:0')
+ feeds.append('input_max_feature_num:0')
+ fetches.append('boxes:0')
+ fetches.append('features:0')
+ fetches.append('scales:0')
+ fetches.append('scores:0')
+ score_threshold_tensor = tf.constant(
+ config.delf_local_config.score_threshold)
+ max_feature_num_tensor = tf.constant(
+ config.delf_local_config.max_feature_num)
+
+ # If using PCA, pre-load required parameters.
+ local_pca_parameters = {}
+ if config.delf_local_config.use_pca:
+ local_pca_parameters['mean'] = tf.constant(
+ datum_io.ReadFromFile(
+ config.delf_local_config.pca_parameters.mean_path),
+ dtype=tf.float32)
+ local_pca_parameters['matrix'] = tf.constant(
+ datum_io.ReadFromFile(
+ config.delf_local_config.pca_parameters.projection_matrix_path),
+ dtype=tf.float32)
+ local_pca_parameters[
+ 'dim'] = config.delf_local_config.pca_parameters.pca_dim
+ local_pca_parameters['use_whitening'] = (
+ config.delf_local_config.pca_parameters.use_whitening)
+ if config.delf_local_config.pca_parameters.use_whitening:
+ local_pca_parameters['variances'] = tf.squeeze(
+ tf.constant(
+ datum_io.ReadFromFile(
+ config.delf_local_config.pca_parameters.pca_variances_path),
+ dtype=tf.float32))
+ else:
+ local_pca_parameters['variances'] = None
+
+ # Custom configuration needed when global features are used.
+ if config.use_global_features:
+ # Extra output end-point.
+ fetches.append('global_descriptors:0')
+
+ # If using PCA, pre-load required parameters.
+ global_pca_parameters = {}
+ if config.delf_global_config.use_pca:
+ global_pca_parameters['mean'] = tf.constant(
+ datum_io.ReadFromFile(
+ config.delf_global_config.pca_parameters.mean_path),
+ dtype=tf.float32)
+ global_pca_parameters['matrix'] = tf.constant(
+ datum_io.ReadFromFile(
+ config.delf_global_config.pca_parameters.projection_matrix_path),
+ dtype=tf.float32)
+ global_pca_parameters[
+ 'dim'] = config.delf_global_config.pca_parameters.pca_dim
+ global_pca_parameters['use_whitening'] = (
+ config.delf_global_config.pca_parameters.use_whitening)
+ if config.delf_global_config.pca_parameters.use_whitening:
+ global_pca_parameters['variances'] = tf.squeeze(
+ tf.constant(
+ datum_io.ReadFromFile(config.delf_global_config.pca_parameters
+ .pca_variances_path),
+ dtype=tf.float32))
+ else:
+ global_pca_parameters['variances'] = None
+
+ model = model.prune(feeds=feeds, fetches=fetches)
+
+ def ExtractorFn(image, resize_factor=1.0):
+ """Receives an image and returns DELF global and/or local features.
+
+ If image is too small, returns empty features.
+
+ Args:
+ image: Uint8 array with shape (height, width, 3) containing the RGB image.
+ resize_factor: Optional float resize factor for the input image. If given,
+ the maximum and minimum allowed image sizes in the config are scaled by
+ this factor.
+
+ Returns:
+ extracted_features: A dict containing the extracted global descriptors
+ (key 'global_descriptor' mapping to a [D] float array), and/or local
+ features (key 'local_features' mapping to a dict with keys 'locations',
+ 'descriptors', 'scales', 'attention').
+ """
+
+ resized_image, scale_factors = ResizeImage(
+ image, config, resize_factor=resize_factor)
+
+ # If the image is too small, returns empty features.
+ if resized_image.shape[0] < _MIN_HEIGHT or resized_image.shape[
+ 1] < _MIN_WIDTH:
+ extracted_features = {'global_descriptor': np.array([])}
+ if config.use_local_features:
+ extracted_features.update({
+ 'local_features': {
+ 'locations': np.array([]),
+ 'descriptors': np.array([]),
+ 'scales': np.array([]),
+ 'attention': np.array([]),
+ }
+ })
+ return extracted_features
+
+ # Input tensors.
+ image_tensor = tf.convert_to_tensor(resized_image)
+
+ # Extracted features.
+ extracted_features = {}
+ output = None
+
+ if config.use_local_features:
+ output = model(image_tensor, image_scales_tensor, score_threshold_tensor,
+ max_feature_num_tensor)
+ else:
+ output = model(image_tensor, image_scales_tensor)
+
+ # Post-process extracted features: normalize, PCA (optional), pooling.
+ if config.use_global_features:
+ raw_global_descriptors = output[-1]
+ if config.delf_global_config.image_scales_ind:
+ raw_global_descriptors_selected_scales = tf.gather(
+ raw_global_descriptors,
+ list(config.delf_global_config.image_scales_ind))
+ else:
+ raw_global_descriptors_selected_scales = raw_global_descriptors
+ global_descriptors_per_scale = feature_extractor.PostProcessDescriptors(
+ raw_global_descriptors_selected_scales,
+ config.delf_global_config.use_pca, global_pca_parameters)
+ unnormalized_global_descriptor = tf.reduce_sum(
+ global_descriptors_per_scale, axis=0, name='sum_pooling')
+ global_descriptor = tf.nn.l2_normalize(
+ unnormalized_global_descriptor, axis=0, name='final_l2_normalization')
+ extracted_features.update({
+ 'global_descriptor': global_descriptor.numpy(),
+ })
+
+ if config.use_local_features:
+ boxes = output[0]
+ raw_local_descriptors = output[1]
+ feature_scales = output[2]
+ attention_with_extra_dim = output[3]
+
+ attention = tf.reshape(attention_with_extra_dim,
+ [tf.shape(attention_with_extra_dim)[0]])
+ locations, local_descriptors = (
+ feature_extractor.DelfFeaturePostProcessing(
+ boxes, raw_local_descriptors, config.delf_local_config.use_pca,
+ local_pca_parameters))
+ locations /= scale_factors
+
+ extracted_features.update({
+ 'local_features': {
+ 'locations': locations.numpy(),
+ 'descriptors': local_descriptors.numpy(),
+ 'scales': feature_scales.numpy(),
+ 'attention': attention.numpy(),
+ }
+ })
+
+ return extracted_features
+
+ return ExtractorFn
diff --git a/models/research/delf/delf/python/examples/extractor_test.py b/models/research/delf/delf/python/examples/extractor_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa560c75a5ca7f8a48247eb7636643e2369c0e5e
--- /dev/null
+++ b/models/research/delf/delf/python/examples/extractor_test.py
@@ -0,0 +1,103 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DELF feature extractor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from delf import delf_config_pb2
+from delf import extractor
+
+
+class ExtractorTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('Max-1Min-1', -1, -1, 1.0, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max-1Min-1Square', -1, -1, 1.0, True, [4, 4, 3], [1.0, 2.0]),
+ ('Max2Min-1', 2, -1, 1.0, False, [2, 1, 3], [0.5, 0.5]),
+ ('Max2Min-1Square', 2, -1, 1.0, True, [2, 2, 3], [0.5, 1.0]),
+ ('Max8Min-1', 8, -1, 1.0, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max8Min-1Square', 8, -1, 1.0, True, [4, 4, 3], [1.0, 2.0]),
+ ('Max-1Min1', -1, 1, 1.0, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max-1Min1Square', -1, 1, 1.0, True, [4, 4, 3], [1.0, 2.0]),
+ ('Max-1Min8', -1, 8, 1.0, False, [8, 4, 3], [2.0, 2.0]),
+ ('Max-1Min8Square', -1, 8, 1.0, True, [8, 8, 3], [2.0, 4.0]),
+ ('Max16Min8', 16, 8, 1.0, False, [8, 4, 3], [2.0, 2.0]),
+ ('Max16Min8Square', 16, 8, 1.0, True, [8, 8, 3], [2.0, 4.0]),
+ ('Max2Min2', 2, 2, 1.0, False, [2, 1, 3], [0.5, 0.5]),
+ ('Max2Min2Square', 2, 2, 1.0, True, [2, 2, 3], [0.5, 1.0]),
+ ('Max-1Min-1Factor0.5', -1, -1, 0.5, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max-1Min-1Factor0.5Square', -1, -1, 0.5, True, [4, 4, 3], [1.0, 2.0]),
+ ('Max2Min-1Factor2.0', 2, -1, 2.0, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max2Min-1Factor2.0Square', 2, -1, 2.0, True, [4, 4, 3], [1.0, 2.0]),
+ ('Max-1Min8Factor0.5', -1, 8, 0.5, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max-1Min8Factor0.5Square', -1, 8, 0.5, True, [4, 4, 3], [1.0, 2.0]),
+ ('Max-1Min8Factor0.25', -1, 8, 0.25, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max-1Min8Factor0.25Square', -1, 8, 0.25, True, [4, 4, 3], [1.0, 2.0]),
+ ('Max2Min2Factor2.0', 2, 2, 2.0, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max2Min2Factor2.0Square', 2, 2, 2.0, True, [4, 4, 3], [1.0, 2.0]),
+ ('Max16Min8Factor0.5', 16, 8, 0.5, False, [4, 2, 3], [1.0, 1.0]),
+ ('Max16Min8Factor0.5Square', 16, 8, 0.5, True, [4, 4, 3], [1.0, 2.0]),
+ )
+ def testResizeImageWorks(self, max_image_size, min_image_size, resize_factor,
+ square_output, expected_shape,
+ expected_scale_factors):
+ # Construct image of size 4x2x3.
+ image = np.array([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [3, 3, 3]],
+ [[4, 4, 4], [5, 5, 5]], [[6, 6, 6], [7, 7, 7]]],
+ dtype='uint8')
+
+ # Set up config.
+ config = delf_config_pb2.DelfConfig(
+ max_image_size=max_image_size,
+ min_image_size=min_image_size,
+ use_square_images=square_output)
+
+ resized_image, scale_factors = extractor.ResizeImage(
+ image, config, resize_factor)
+ self.assertAllEqual(resized_image.shape, expected_shape)
+ self.assertAllClose(scale_factors, expected_scale_factors)
+
+ @parameterized.named_parameters(
+ ('Max2Min2', 2, 2, 1.0, False, [2, 1, 3], [0.666666, 0.5]),
+ ('Max2Min2Square', 2, 2, 1.0, True, [2, 2, 3], [0.666666, 1.0]),
+ )
+ def testResizeImageRoundingWorks(self, max_image_size, min_image_size,
+ resize_factor, square_output, expected_shape,
+ expected_scale_factors):
+ # Construct image of size 3x2x3.
+ image = np.array([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [3, 3, 3]],
+ [[4, 4, 4], [5, 5, 5]]],
+ dtype='uint8')
+
+ # Set up config.
+ config = delf_config_pb2.DelfConfig(
+ max_image_size=max_image_size,
+ min_image_size=min_image_size,
+ use_square_images=square_output)
+
+ resized_image, scale_factors = extractor.ResizeImage(
+ image, config, resize_factor)
+ self.assertAllEqual(resized_image.shape, expected_shape)
+ self.assertAllClose(scale_factors, expected_scale_factors)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/examples/match_images.py b/models/research/delf/delf/python/examples/match_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb030739cb9067bf3be50f999368af622f083b54
--- /dev/null
+++ b/models/research/delf/delf/python/examples/match_images.py
@@ -0,0 +1,143 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Matches two images using their DELF features.
+
+The matching is done using feature-based nearest-neighbor search, followed by
+geometric verification using RANSAC.
+
+The DELF features can be extracted using the extract_features.py script.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import matplotlib
+# Needed before pyplot import for matplotlib to work properly.
+matplotlib.use('Agg')
+import matplotlib.image as mpimg # pylint: disable=g-import-not-at-top
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy import spatial
+from skimage import feature
+from skimage import measure
+from skimage import transform
+
+from tensorflow.python.platform import app
+from delf import feature_io
+
+cmd_args = None
+
+_DISTANCE_THRESHOLD = 0.8
+
+
+def main(unused_argv):
+ # Read features.
+ locations_1, _, descriptors_1, _, _ = feature_io.ReadFromFile(
+ cmd_args.features_1_path)
+ num_features_1 = locations_1.shape[0]
+ print(f"Loaded image 1's {num_features_1} features")
+ locations_2, _, descriptors_2, _, _ = feature_io.ReadFromFile(
+ cmd_args.features_2_path)
+ num_features_2 = locations_2.shape[0]
+ print(f"Loaded image 2's {num_features_2} features")
+
+ # Find nearest-neighbor matches using a KD tree.
+ d1_tree = spatial.cKDTree(descriptors_1)
+ _, indices = d1_tree.query(
+ descriptors_2, distance_upper_bound=_DISTANCE_THRESHOLD)
+
+ # Select feature locations for putative matches.
+ locations_2_to_use = np.array([
+ locations_2[i,]
+ for i in range(num_features_2)
+ if indices[i] != num_features_1
+ ])
+ locations_1_to_use = np.array([
+ locations_1[indices[i],]
+ for i in range(num_features_2)
+ if indices[i] != num_features_1
+ ])
+
+ # Perform geometric verification using RANSAC.
+ _, inliers = measure.ransac((locations_1_to_use, locations_2_to_use),
+ transform.AffineTransform,
+ min_samples=3,
+ residual_threshold=20,
+ max_trials=1000)
+
+ print(f'Found {sum(inliers)} inliers')
+
+ # Visualize correspondences, and save to file.
+ _, ax = plt.subplots()
+ img_1 = mpimg.imread(cmd_args.image_1_path)
+ img_2 = mpimg.imread(cmd_args.image_2_path)
+ inlier_idxs = np.nonzero(inliers)[0]
+ feature.plot_matches(
+ ax,
+ img_1,
+ img_2,
+ locations_1_to_use,
+ locations_2_to_use,
+ np.column_stack((inlier_idxs, inlier_idxs)),
+ matches_color='b')
+ ax.axis('off')
+ ax.set_title('DELF correspondences')
+ plt.savefig(cmd_args.output_image)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--image_1_path',
+ type=str,
+ default='test_images/image_1.jpg',
+ help="""
+ Path to test image 1.
+ """)
+ parser.add_argument(
+ '--image_2_path',
+ type=str,
+ default='test_images/image_2.jpg',
+ help="""
+ Path to test image 2.
+ """)
+ parser.add_argument(
+ '--features_1_path',
+ type=str,
+ default='test_features/image_1.delf',
+ help="""
+ Path to DELF features from image 1.
+ """)
+ parser.add_argument(
+ '--features_2_path',
+ type=str,
+ default='test_features/image_2.delf',
+ help="""
+ Path to DELF features from image 2.
+ """)
+ parser.add_argument(
+ '--output_image',
+ type=str,
+ default='test_match.png',
+ help="""
+ Path where an image showing the matches will be saved.
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/examples/matched_images_example.jpg b/models/research/delf/delf/python/examples/matched_images_example.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bbd0061ac02d460be3bc24a3f7f736d418b1da88
Binary files /dev/null and b/models/research/delf/delf/python/examples/matched_images_example.jpg differ
diff --git a/models/research/delf/delf/python/feature_aggregation_extractor.py b/models/research/delf/delf/python/feature_aggregation_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f230642ea950d5393005583334836630328198c9
--- /dev/null
+++ b/models/research/delf/delf/python/feature_aggregation_extractor.py
@@ -0,0 +1,475 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Local feature aggregation extraction.
+
+For more details, please refer to the paper:
+"Detect-to-Retrieve: Efficient Regional Aggregation for Image Search",
+Proc. CVPR'19 (https://arxiv.org/abs/1812.01584).
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from delf import aggregation_config_pb2
+
+_CLUSTER_CENTERS_VAR_NAME = "clusters"
+_NORM_SQUARED_TOLERANCE = 1e-12
+
+# Aliases for aggregation types.
+_VLAD = aggregation_config_pb2.AggregationConfig.VLAD
+_ASMK = aggregation_config_pb2.AggregationConfig.ASMK
+_ASMK_STAR = aggregation_config_pb2.AggregationConfig.ASMK_STAR
+
+
+class ExtractAggregatedRepresentation(object):
+ """Class for extraction of aggregated local feature representation.
+
+ Args:
+ aggregation_config: AggregationConfig object defining type of aggregation to
+ use.
+
+ Raises:
+ ValueError: If aggregation type is invalid.
+ """
+
+ def __init__(self, aggregation_config):
+ self._codebook_size = aggregation_config.codebook_size
+ self._feature_dimensionality = aggregation_config.feature_dimensionality
+ self._aggregation_type = aggregation_config.aggregation_type
+ self._feature_batch_size = aggregation_config.feature_batch_size
+ self._codebook_path = aggregation_config.codebook_path
+ self._use_regional_aggregation = aggregation_config.use_regional_aggregation
+ self._use_l2_normalization = aggregation_config.use_l2_normalization
+ self._num_assignments = aggregation_config.num_assignments
+
+ if self._aggregation_type not in [_VLAD, _ASMK, _ASMK_STAR]:
+ raise ValueError("Invalid aggregation type: %d" % self._aggregation_type)
+
+ # Load codebook
+ codebook = tf.Variable(
+ tf.zeros([self._codebook_size, self._feature_dimensionality],
+ dtype=tf.float32),
+ name=_CLUSTER_CENTERS_VAR_NAME)
+ ckpt = tf.train.Checkpoint(codebook=codebook)
+ ckpt.restore(self._codebook_path)
+
+ self._codebook = codebook
+
+ def Extract(self, features, num_features_per_region=None):
+ """Extracts aggregated representation.
+
+ Args:
+ features: [N, D] float numpy array with N local feature descriptors.
+ num_features_per_region: Required only if computing regional aggregated
+ representations, otherwise optional. List of number of features per
+ region, such that sum(num_features_per_region) = N. It indicates which
+ features correspond to each region.
+
+ Returns:
+ aggregated_descriptors: 1-D numpy array.
+ feature_visual_words: Used only for ASMK/ASMK* aggregation type. 1-D
+ numpy array denoting visual words corresponding to the
+ `aggregated_descriptors`.
+
+ Raises:
+ ValueError: If inputs are misconfigured.
+ """
+ features = tf.cast(features, dtype=tf.float32)
+
+ if num_features_per_region is None:
+ # Use dummy value since it is unused.
+ num_features_per_region = []
+ else:
+ num_features_per_region = tf.cast(num_features_per_region, dtype=tf.int32)
+ if len(num_features_per_region
+ ) and sum(num_features_per_region) != features.shape[0]:
+ raise ValueError(
+ "Incorrect arguments: sum(num_features_per_region) and "
+ "features.shape[0] are different: %d vs %d" %
+ (sum(num_features_per_region), features.shape[0]))
+
+ # Extract features based on desired options.
+ if self._aggregation_type == _VLAD:
+ # Feature visual words are unused in the case of VLAD, so just return
+ # dummy constant.
+ feature_visual_words = tf.constant(-1, dtype=tf.int32)
+ if self._use_regional_aggregation:
+ aggregated_descriptors = self._ComputeRvlad(
+ features,
+ num_features_per_region,
+ self._codebook,
+ use_l2_normalization=self._use_l2_normalization,
+ num_assignments=self._num_assignments)
+ else:
+ aggregated_descriptors = self._ComputeVlad(
+ features,
+ self._codebook,
+ use_l2_normalization=self._use_l2_normalization,
+ num_assignments=self._num_assignments)
+ elif (self._aggregation_type == _ASMK or
+ self._aggregation_type == _ASMK_STAR):
+ if self._use_regional_aggregation:
+ (aggregated_descriptors,
+ feature_visual_words) = self._ComputeRasmk(
+ features,
+ num_features_per_region,
+ self._codebook,
+ num_assignments=self._num_assignments)
+ else:
+ (aggregated_descriptors,
+ feature_visual_words) = self._ComputeAsmk(
+ features,
+ self._codebook,
+ num_assignments=self._num_assignments)
+
+ feature_visual_words_output = feature_visual_words.numpy()
+
+ # If using ASMK*/RASMK*, binarize the aggregated descriptors.
+ if self._aggregation_type == _ASMK_STAR:
+ reshaped_aggregated_descriptors = np.reshape(
+ aggregated_descriptors, [-1, self._feature_dimensionality])
+ packed_descriptors = np.packbits(
+ reshaped_aggregated_descriptors > 0, axis=1)
+ aggregated_descriptors_output = np.reshape(packed_descriptors, [-1])
+ else:
+ aggregated_descriptors_output = aggregated_descriptors.numpy()
+
+ return aggregated_descriptors_output, feature_visual_words_output
+
+ def _ComputeVlad(self,
+ features,
+ codebook,
+ use_l2_normalization=True,
+ num_assignments=1):
+ """Compute VLAD representation.
+
+ Args:
+ features: [N, D] float tensor.
+ codebook: [K, D] float tensor.
+ use_l2_normalization: If False, does not L2-normalize after aggregation.
+ num_assignments: Number of visual words to assign a feature to.
+
+ Returns:
+ vlad: [K*D] float tensor.
+ """
+
+ def _ComputeVladEmptyFeatures():
+ """Computes VLAD if `features` is empty.
+
+ Returns:
+ [K*D] all-zeros tensor.
+ """
+ return tf.zeros([self._codebook_size * self._feature_dimensionality],
+ dtype=tf.float32)
+
+ def _ComputeVladNonEmptyFeatures():
+ """Computes VLAD if `features` is not empty.
+
+ Returns:
+ [K*D] tensor with VLAD descriptor.
+ """
+ num_features = tf.shape(features)[0]
+
+ # Find nearest visual words for each feature. Possibly batch the local
+ # features to avoid OOM.
+ if self._feature_batch_size <= 0:
+ actual_batch_size = num_features
+ else:
+ actual_batch_size = self._feature_batch_size
+
+ def _BatchNearestVisualWords(ind, selected_visual_words):
+ """Compute nearest neighbor visual words for a batch of features.
+
+ Args:
+ ind: Integer index denoting feature.
+ selected_visual_words: Partial set of visual words.
+
+ Returns:
+ output_ind: Next index.
+ output_selected_visual_words: Updated set of visual words, including
+ the visual words for the new batch.
+ """
+ # Handle case of last batch, where there may be fewer than
+ # `actual_batch_size` features.
+ batch_size_to_use = tf.cond(
+ tf.greater(ind + actual_batch_size, num_features),
+ true_fn=lambda: num_features - ind,
+ false_fn=lambda: actual_batch_size)
+
+ # Denote B = batch_size_to_use.
+ # K*B x D.
+ tiled_features = tf.reshape(
+ tf.tile(
+ tf.slice(features, [ind, 0],
+ [batch_size_to_use, self._feature_dimensionality]),
+ [1, self._codebook_size]), [-1, self._feature_dimensionality])
+ # K*B x D.
+ tiled_codebook = tf.reshape(
+ tf.tile(tf.reshape(codebook, [1, -1]), [batch_size_to_use, 1]),
+ [-1, self._feature_dimensionality])
+ # B x K.
+ squared_distances = tf.reshape(
+ tf.reduce_sum(
+ tf.math.squared_difference(tiled_features, tiled_codebook),
+ axis=1), [batch_size_to_use, self._codebook_size])
+ # B x K.
+ nearest_visual_words = tf.argsort(squared_distances)
+ # B x num_assignments.
+ batch_selected_visual_words = tf.slice(
+ nearest_visual_words, [0, 0], [batch_size_to_use, num_assignments])
+ selected_visual_words = tf.concat(
+ [selected_visual_words, batch_selected_visual_words], axis=0)
+
+ return ind + batch_size_to_use, selected_visual_words
+
+ ind_batch = tf.constant(0, dtype=tf.int32)
+ keep_going = lambda j, selected_visual_words: tf.less(j, num_features)
+ selected_visual_words = tf.zeros([0, num_assignments], dtype=tf.int32)
+ _, selected_visual_words = tf.while_loop(
+ cond=keep_going,
+ body=_BatchNearestVisualWords,
+ loop_vars=[ind_batch, selected_visual_words],
+ shape_invariants=[
+ ind_batch.get_shape(),
+ tf.TensorShape([None, num_assignments])
+ ],
+ parallel_iterations=1,
+ back_prop=False)
+
+ # Helper function to collect residuals for relevant visual words.
+ def _ConstructVladFromAssignments(ind, vlad):
+ """Add contributions of a feature to a VLAD descriptor.
+
+ Args:
+ ind: Integer index denoting feature.
+ vlad: Partial VLAD descriptor.
+
+ Returns:
+ output_ind: Next index (ie, ind+1).
+ output_vlad: VLAD descriptor updated to take into account contribution
+ from ind-th feature.
+ """
+ diff = tf.tile(
+ tf.expand_dims(features[ind],
+ axis=0), [num_assignments, 1]) - tf.gather(
+ codebook, selected_visual_words[ind])
+ return ind + 1, tf.tensor_scatter_nd_add(
+ vlad, tf.expand_dims(selected_visual_words[ind], axis=1),
+ tf.cast(diff, dtype=tf.float32))
+
+ ind_vlad = tf.constant(0, dtype=tf.int32)
+ keep_going = lambda j, vlad: tf.less(j, num_features)
+ vlad = tf.zeros([self._codebook_size, self._feature_dimensionality],
+ dtype=tf.float32)
+ _, vlad = tf.while_loop(
+ cond=keep_going,
+ body=_ConstructVladFromAssignments,
+ loop_vars=[ind_vlad, vlad],
+ back_prop=False)
+
+ vlad = tf.reshape(vlad,
+ [self._codebook_size * self._feature_dimensionality])
+ if use_l2_normalization:
+ vlad = tf.math.l2_normalize(vlad, epsilon=_NORM_SQUARED_TOLERANCE)
+
+ return vlad
+
+ return tf.cond(
+ tf.greater(tf.size(features), 0),
+ true_fn=_ComputeVladNonEmptyFeatures,
+ false_fn=_ComputeVladEmptyFeatures)
+
+ def _ComputeRvlad(self,
+ features,
+ num_features_per_region,
+ codebook,
+ use_l2_normalization=False,
+ num_assignments=1):
+ """Compute R-VLAD representation.
+
+ Args:
+ features: [N, D] float tensor.
+ num_features_per_region: [R] int tensor. Contains number of features per
+ region, such that sum(num_features_per_region) = N. It indicates which
+ features correspond to each region.
+ codebook: [K, D] float tensor.
+ use_l2_normalization: If True, performs L2-normalization after regional
+ aggregation; if False (default), performs componentwise division by R
+ after regional aggregation.
+ num_assignments: Number of visual words to assign a feature to.
+
+ Returns:
+ rvlad: [K*D] float tensor.
+ """
+
+ def _ComputeRvladEmptyRegions():
+ """Computes R-VLAD if `num_features_per_region` is empty.
+
+ Returns:
+ [K*D] all-zeros tensor.
+ """
+ return tf.zeros([self._codebook_size * self._feature_dimensionality],
+ dtype=tf.float32)
+
+ def _ComputeRvladNonEmptyRegions():
+ """Computes R-VLAD if `num_features_per_region` is not empty.
+
+ Returns:
+ [K*D] tensor with R-VLAD descriptor.
+ """
+
+ # Helper function to compose initial R-VLAD from image regions.
+ def _ConstructRvladFromVlad(ind, rvlad):
+ """Add contributions from different regions into R-VLAD.
+
+ Args:
+ ind: Integer index denoting region.
+ rvlad: Partial R-VLAD descriptor.
+
+ Returns:
+ output_ind: Next index (ie, ind+1).
+ output_rvlad: R-VLAD descriptor updated to take into account
+ contribution from ind-th region.
+ """
+ return ind + 1, rvlad + self._ComputeVlad(
+ tf.slice(
+ features, [tf.reduce_sum(num_features_per_region[:ind]), 0],
+ [num_features_per_region[ind], self._feature_dimensionality]),
+ codebook,
+ num_assignments=num_assignments)
+
+ i = tf.constant(0, dtype=tf.int32)
+ num_regions = tf.shape(num_features_per_region)[0]
+ keep_going = lambda j, rvlad: tf.less(j, num_regions)
+ rvlad = tf.zeros([self._codebook_size * self._feature_dimensionality],
+ dtype=tf.float32)
+ _, rvlad = tf.while_loop(
+ cond=keep_going,
+ body=_ConstructRvladFromVlad,
+ loop_vars=[i, rvlad],
+ back_prop=False,
+ parallel_iterations=1)
+
+ if use_l2_normalization:
+ rvlad = tf.math.l2_normalize(rvlad, epsilon=_NORM_SQUARED_TOLERANCE)
+ else:
+ rvlad /= tf.cast(num_regions, dtype=tf.float32)
+
+ return rvlad
+
+ return tf.cond(
+ tf.greater(tf.size(num_features_per_region), 0),
+ true_fn=_ComputeRvladNonEmptyRegions,
+ false_fn=_ComputeRvladEmptyRegions)
+
+ def _PerCentroidNormalization(self, unnormalized_vector):
+ """Perform per-centroid normalization.
+
+ Args:
+ unnormalized_vector: [KxD] float tensor.
+
+ Returns:
+ per_centroid_normalized_vector: [KxD] float tensor, with normalized
+ aggregated residuals. Some residuals may be all-zero.
+ visual_words: Int tensor containing indices of visual words which are
+ present for the set of features.
+ """
+ unnormalized_vector = tf.reshape(
+ unnormalized_vector,
+ [self._codebook_size, self._feature_dimensionality])
+ per_centroid_norms = tf.norm(unnormalized_vector, axis=1)
+
+ visual_words = tf.reshape(
+ tf.where(
+ tf.greater(
+ per_centroid_norms,
+ tf.cast(tf.sqrt(_NORM_SQUARED_TOLERANCE), dtype=tf.float32))),
+ [-1])
+
+ per_centroid_normalized_vector = tf.math.l2_normalize(
+ unnormalized_vector, axis=1, epsilon=_NORM_SQUARED_TOLERANCE)
+
+ return per_centroid_normalized_vector, visual_words
+
+ def _ComputeAsmk(self, features, codebook, num_assignments=1):
+ """Compute ASMK representation.
+
+ Args:
+ features: [N, D] float tensor.
+ codebook: [K, D] float tensor.
+ num_assignments: Number of visual words to assign a feature to.
+
+ Returns:
+ normalized_residuals: 1-dimensional float tensor with concatenated
+ residuals which are non-zero. Note that the dimensionality is
+ input-dependent.
+ visual_words: 1-dimensional int tensor of sorted visual word ids.
+ Dimensionality is shape(normalized_residuals)[0] / D.
+ """
+ unnormalized_vlad = self._ComputeVlad(
+ features,
+ codebook,
+ use_l2_normalization=False,
+ num_assignments=num_assignments)
+
+ per_centroid_normalized_vlad, visual_words = self._PerCentroidNormalization(
+ unnormalized_vlad)
+
+ normalized_residuals = tf.reshape(
+ tf.gather(per_centroid_normalized_vlad, visual_words),
+ [tf.shape(visual_words)[0] * self._feature_dimensionality])
+
+ return normalized_residuals, visual_words
+
+ def _ComputeRasmk(self,
+ features,
+ num_features_per_region,
+ codebook,
+ num_assignments=1):
+ """Compute R-ASMK representation.
+
+ Args:
+ features: [N, D] float tensor.
+ num_features_per_region: [R] int tensor. Contains number of features per
+ region, such that sum(num_features_per_region) = N. It indicates which
+ features correspond to each region.
+ codebook: [K, D] float tensor.
+ num_assignments: Number of visual words to assign a feature to.
+
+ Returns:
+ normalized_residuals: 1-dimensional float tensor with concatenated
+ residuals which are non-zero. Note that the dimensionality is
+ input-dependent.
+ visual_words: 1-dimensional int tensor of sorted visual word ids.
+ Dimensionality is shape(normalized_residuals)[0] / D.
+ """
+ unnormalized_rvlad = self._ComputeRvlad(
+ features,
+ num_features_per_region,
+ codebook,
+ use_l2_normalization=False,
+ num_assignments=num_assignments)
+
+ (per_centroid_normalized_rvlad,
+ visual_words) = self._PerCentroidNormalization(unnormalized_rvlad)
+
+ normalized_residuals = tf.reshape(
+ tf.gather(per_centroid_normalized_rvlad, visual_words),
+ [tf.shape(visual_words)[0] * self._feature_dimensionality])
+
+ return normalized_residuals, visual_words
diff --git a/models/research/delf/delf/python/feature_aggregation_extractor_test.py b/models/research/delf/delf/python/feature_aggregation_extractor_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfba92a2b1b4847460fefbba5ef41fa8cce1ba42
--- /dev/null
+++ b/models/research/delf/delf/python/feature_aggregation_extractor_test.py
@@ -0,0 +1,494 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DELF feature aggregation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+import numpy as np
+import tensorflow as tf
+
+from delf import aggregation_config_pb2
+from delf import feature_aggregation_extractor
+
+FLAGS = flags.FLAGS
+
+
+class FeatureAggregationTest(tf.test.TestCase):
+
+ def _CreateCodebook(self, checkpoint_path):
+ """Creates codebook used in tests.
+
+ Args:
+ checkpoint_path: Directory where codebook is saved to.
+ """
+ codebook = tf.Variable(
+ [[0.5, 0.5], [0.0, 0.0], [1.0, 0.0], [-0.5, -0.5], [0.0, 1.0]],
+ name='clusters',
+ dtype=tf.float32)
+ ckpt = tf.train.Checkpoint(codebook=codebook)
+ ckpt.write(checkpoint_path)
+
+ def setUp(self):
+ self._codebook_path = os.path.join(FLAGS.test_tmpdir, 'test_codebook')
+ self._CreateCodebook(self._codebook_path)
+
+ def testComputeNormalizedVladWorks(self):
+ # Construct inputs.
+ # 3 2-D features.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0]], dtype=float)
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.use_l2_normalization = True
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ vlad, extra_output = extractor.Extract(features)
+
+ # Define expected results.
+ exp_vlad = [
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.316228, 0.316228, 0.632456, 0.632456
+ ]
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllClose(vlad, exp_vlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeNormalizedVladWithBatchingWorks(self):
+ # Construct inputs.
+ # 3 2-D features.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0]], dtype=float)
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.use_l2_normalization = True
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+ config.feature_batch_size = 2
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ vlad, extra_output = extractor.Extract(features)
+
+ # Define expected results.
+ exp_vlad = [
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.316228, 0.316228, 0.632456, 0.632456
+ ]
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllClose(vlad, exp_vlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeUnnormalizedVladWorks(self):
+ # Construct inputs.
+ # 3 2-D features.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0]], dtype=float)
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.use_l2_normalization = False
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ vlad, extra_output = extractor.Extract(features)
+
+ # Define expected results.
+ exp_vlad = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5, 0.5, 1.0, 1.0]
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllEqual(vlad, exp_vlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeUnnormalizedVladMultipleAssignmentWorks(self):
+ # Construct inputs.
+ # 3 2-D features.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0]], dtype=float)
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.use_l2_normalization = False
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 3
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ vlad, extra_output = extractor.Extract(features)
+
+ # Define expected results.
+ exp_vlad = [1.0, 1.0, 0.0, 0.0, 0.0, 2.0, -0.5, 0.5, 0.0, 0.0]
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllEqual(vlad, exp_vlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeVladEmptyFeaturesWorks(self):
+ # Construct inputs.
+ # Empty feature array.
+ features = np.array([[]])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.codebook_path = self._codebook_path
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ vlad, extra_output = extractor.Extract(features)
+
+ # Define expected results.
+ exp_vlad = np.zeros([10], dtype=float)
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllEqual(vlad, exp_vlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeUnnormalizedRvladWorks(self):
+ # Construct inputs.
+ # 4 2-D features: 3 in first region, 1 in second region.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0], [0.0, 2.0]],
+ dtype=float)
+ num_features_per_region = np.array([3, 1])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.use_l2_normalization = False
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ rvlad, extra_output = extractor.Extract(features, num_features_per_region)
+
+ # Define expected results.
+ exp_rvlad = [
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.158114, 0.158114, 0.316228, 0.816228
+ ]
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllClose(rvlad, exp_rvlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeNormalizedRvladWorks(self):
+ # Construct inputs.
+ # 4 2-D features: 3 in first region, 1 in second region.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0], [0.0, 2.0]],
+ dtype=float)
+ num_features_per_region = np.array([3, 1])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.use_l2_normalization = True
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ rvlad, extra_output = extractor.Extract(features, num_features_per_region)
+
+ # Define expected results.
+ exp_rvlad = [
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.175011, 0.175011, 0.350021, 0.903453
+ ]
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllClose(rvlad, exp_rvlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeRvladEmptyRegionsWorks(self):
+ # Construct inputs.
+ # Empty feature array.
+ features = np.array([[]])
+ num_features_per_region = np.array([])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.codebook_path = self._codebook_path
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ rvlad, extra_output = extractor.Extract(features, num_features_per_region)
+
+ # Define expected results.
+ exp_rvlad = np.zeros([10], dtype=float)
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllEqual(rvlad, exp_rvlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeUnnormalizedRvladSomeEmptyRegionsWorks(self):
+ # Construct inputs.
+ # 4 2-D features: 0 in first region, 3 in second region, 0 in third region,
+ # 1 in fourth region.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0], [0.0, 2.0]],
+ dtype=float)
+ num_features_per_region = np.array([0, 3, 0, 1])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.use_l2_normalization = False
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ rvlad, extra_output = extractor.Extract(features, num_features_per_region)
+
+ # Define expected results.
+ exp_rvlad = [
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.079057, 0.079057, 0.158114, 0.408114
+ ]
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllClose(rvlad, exp_rvlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeNormalizedRvladSomeEmptyRegionsWorks(self):
+ # Construct inputs.
+ # 4 2-D features: 0 in first region, 3 in second region, 0 in third region,
+ # 1 in fourth region.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0], [0.0, 2.0]],
+ dtype=float)
+ num_features_per_region = np.array([0, 3, 0, 1])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.use_l2_normalization = True
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ rvlad, extra_output = extractor.Extract(features, num_features_per_region)
+
+ # Define expected results.
+ exp_rvlad = [
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.175011, 0.175011, 0.350021, 0.903453
+ ]
+ exp_extra_output = -1
+
+ # Compare actual and expected results.
+ self.assertAllClose(rvlad, exp_rvlad)
+ self.assertAllEqual(extra_output, exp_extra_output)
+
+ def testComputeRvladMisconfiguredFeatures(self):
+ # Construct inputs.
+ # 4 2-D features: 3 in first region, 1 in second region.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0], [0.0, 2.0]],
+ dtype=float)
+ # Misconfigured number of features; there are only 4 features, but
+ # sum(num_features_per_region) = 5.
+ num_features_per_region = np.array([3, 2])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+ config.codebook_path = self._codebook_path
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ with self.assertRaisesRegex(
+ ValueError,
+ r'Incorrect arguments: sum\(num_features_per_region\) and '
+ r'features.shape\[0\] are different'):
+ extractor.Extract(features, num_features_per_region)
+
+ def testComputeAsmkWorks(self):
+ # Construct inputs.
+ # 3 2-D features.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0]], dtype=float)
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.ASMK
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ asmk, visual_words = extractor.Extract(features)
+
+ # Define expected results.
+ exp_asmk = [-0.707107, 0.707107, 0.707107, 0.707107]
+ exp_visual_words = [3, 4]
+
+ # Compare actual and expected results.
+ self.assertAllClose(asmk, exp_asmk)
+ self.assertAllEqual(visual_words, exp_visual_words)
+
+ def testComputeAsmkStarWorks(self):
+ # Construct inputs.
+ # 3 2-D features.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0]], dtype=float)
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.ASMK_STAR
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ asmk_star, visual_words = extractor.Extract(features)
+
+ # Define expected results.
+ exp_asmk_star = [64, 192]
+ exp_visual_words = [3, 4]
+
+ # Compare actual and expected results.
+ self.assertAllEqual(asmk_star, exp_asmk_star)
+ self.assertAllEqual(visual_words, exp_visual_words)
+
+ def testComputeAsmkMultipleAssignmentWorks(self):
+ # Construct inputs.
+ # 3 2-D features.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0]], dtype=float)
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.ASMK
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 3
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ asmk, visual_words = extractor.Extract(features)
+
+ # Define expected results.
+ exp_asmk = [0.707107, 0.707107, 0.0, 1.0, -0.707107, 0.707107]
+ exp_visual_words = [0, 2, 3]
+
+ # Compare actual and expected results.
+ self.assertAllClose(asmk, exp_asmk)
+ self.assertAllEqual(visual_words, exp_visual_words)
+
+ def testComputeRasmkWorks(self):
+ # Construct inputs.
+ # 4 2-D features: 3 in first region, 1 in second region.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0], [0.0, 2.0]],
+ dtype=float)
+ num_features_per_region = np.array([3, 1])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.ASMK
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ rasmk, visual_words = extractor.Extract(features, num_features_per_region)
+
+ # Define expected results.
+ exp_rasmk = [-0.707107, 0.707107, 0.361261, 0.932465]
+ exp_visual_words = [3, 4]
+
+ # Compare actual and expected results.
+ self.assertAllClose(rasmk, exp_rasmk)
+ self.assertAllEqual(visual_words, exp_visual_words)
+
+ def testComputeRasmkStarWorks(self):
+ # Construct inputs.
+ # 4 2-D features: 3 in first region, 1 in second region.
+ features = np.array([[1.0, 0.0], [-1.0, 0.0], [1.0, 2.0], [0.0, 2.0]],
+ dtype=float)
+ num_features_per_region = np.array([3, 1])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.ASMK_STAR
+ config.codebook_path = self._codebook_path
+ config.num_assignments = 1
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ extractor = feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+ rasmk_star, visual_words = extractor.Extract(features,
+ num_features_per_region)
+
+ # Define expected results.
+ exp_rasmk_star = [64, 192]
+ exp_visual_words = [3, 4]
+
+ # Compare actual and expected results.
+ self.assertAllEqual(rasmk_star, exp_rasmk_star)
+ self.assertAllEqual(visual_words, exp_visual_words)
+
+ def testComputeUnknownAggregation(self):
+ # Construct inputs.
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = 0
+ config.codebook_path = self._codebook_path
+ config.use_regional_aggregation = True
+
+ # Run tested function.
+ with self.assertRaisesRegex(ValueError, 'Invalid aggregation type'):
+ feature_aggregation_extractor.ExtractAggregatedRepresentation(
+ config)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/feature_aggregation_similarity.py b/models/research/delf/delf/python/feature_aggregation_similarity.py
new file mode 100644
index 0000000000000000000000000000000000000000..991c95c767c6bed5d0db38226a0cf361eee18c2f
--- /dev/null
+++ b/models/research/delf/delf/python/feature_aggregation_similarity.py
@@ -0,0 +1,265 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Local feature aggregation similarity computation.
+
+For more details, please refer to the paper:
+"Detect-to-Retrieve: Efficient Regional Aggregation for Image Search",
+Proc. CVPR'19 (https://arxiv.org/abs/1812.01584).
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from delf import aggregation_config_pb2
+
+# Aliases for aggregation types.
+_VLAD = aggregation_config_pb2.AggregationConfig.VLAD
+_ASMK = aggregation_config_pb2.AggregationConfig.ASMK
+_ASMK_STAR = aggregation_config_pb2.AggregationConfig.ASMK_STAR
+
+
+class SimilarityAggregatedRepresentation(object):
+ """Class for computing similarity of aggregated local feature representations.
+
+ Args:
+ aggregation_config: AggregationConfig object defining type of aggregation to
+ use.
+
+ Raises:
+ ValueError: If aggregation type is invalid.
+ """
+
+ def __init__(self, aggregation_config):
+ self._feature_dimensionality = aggregation_config.feature_dimensionality
+ self._aggregation_type = aggregation_config.aggregation_type
+
+ # Only relevant if using ASMK/ASMK*. Otherwise, ignored.
+ self._use_l2_normalization = aggregation_config.use_l2_normalization
+ self._alpha = aggregation_config.alpha
+ self._tau = aggregation_config.tau
+
+ # Only relevant if using ASMK*. Otherwise, ignored.
+ self._number_bits = np.array([bin(n).count('1') for n in range(256)])
+
+ def ComputeSimilarity(self,
+ aggregated_descriptors_1,
+ aggregated_descriptors_2,
+ feature_visual_words_1=None,
+ feature_visual_words_2=None):
+ """Computes similarity between aggregated descriptors.
+
+ Args:
+ aggregated_descriptors_1: 1-D NumPy array.
+ aggregated_descriptors_2: 1-D NumPy array.
+ feature_visual_words_1: Used only for ASMK/ASMK* aggregation type. 1-D
+ sorted NumPy integer array denoting visual words corresponding to
+ `aggregated_descriptors_1`.
+ feature_visual_words_2: Used only for ASMK/ASMK* aggregation type. 1-D
+ sorted NumPy integer array denoting visual words corresponding to
+ `aggregated_descriptors_2`.
+
+ Returns:
+ similarity: Float. The larger, the more similar.
+
+ Raises:
+ ValueError: If aggregation type is invalid.
+ """
+ if self._aggregation_type == _VLAD:
+ similarity = np.dot(aggregated_descriptors_1, aggregated_descriptors_2)
+ elif self._aggregation_type == _ASMK:
+ similarity = self._AsmkSimilarity(
+ aggregated_descriptors_1,
+ aggregated_descriptors_2,
+ feature_visual_words_1,
+ feature_visual_words_2,
+ binarized=False)
+ elif self._aggregation_type == _ASMK_STAR:
+ similarity = self._AsmkSimilarity(
+ aggregated_descriptors_1,
+ aggregated_descriptors_2,
+ feature_visual_words_1,
+ feature_visual_words_2,
+ binarized=True)
+ else:
+ raise ValueError('Invalid aggregation type: %d' % self._aggregation_type)
+
+ return similarity
+
+ def _CheckAsmkDimensionality(self, aggregated_descriptors, num_visual_words,
+ descriptor_name):
+ """Checks that ASMK dimensionality is as expected.
+
+ Args:
+ aggregated_descriptors: 1-D NumPy array.
+ num_visual_words: Integer.
+ descriptor_name: String.
+
+ Raises:
+ ValueError: If descriptor dimensionality is incorrect.
+ """
+ if len(aggregated_descriptors
+ ) / num_visual_words != self._feature_dimensionality:
+ raise ValueError(
+ 'Feature dimensionality for aggregated descriptor %s is invalid: %d;'
+ ' expected %d.' % (descriptor_name, len(aggregated_descriptors) /
+ num_visual_words, self._feature_dimensionality))
+
+ def _SigmaFn(self, x):
+ """Selectivity ASMK/ASMK* similarity function.
+
+ Args:
+ x: Scalar or 1-D NumPy array.
+
+ Returns:
+ result: Same type as input, with output of selectivity function.
+ """
+ if np.isscalar(x):
+ if x > self._tau:
+ result = np.sign(x) * np.power(np.absolute(x), self._alpha)
+ else:
+ result = 0.0
+ else:
+ result = np.zeros_like(x)
+ above_tau = np.nonzero(x > self._tau)
+ result[above_tau] = np.sign(x[above_tau]) * np.power(
+ np.absolute(x[above_tau]), self._alpha)
+
+ return result
+
+ def _BinaryNormalizedInnerProduct(self, descriptors_1, descriptors_2):
+ """Computes normalized binary inner product.
+
+ Args:
+ descriptors_1: 1-D NumPy integer array.
+ descriptors_2: 1-D NumPy integer array.
+
+ Returns:
+ inner_product: Float.
+
+ Raises:
+ ValueError: If the dimensionality of descriptors is different.
+ """
+ num_descriptors = len(descriptors_1)
+ if num_descriptors != len(descriptors_2):
+ raise ValueError(
+ 'Descriptors have incompatible dimensionality: %d vs %d' %
+ (len(descriptors_1), len(descriptors_2)))
+
+ h = 0
+ for i in range(num_descriptors):
+ h += self._number_bits[np.bitwise_xor(descriptors_1[i], descriptors_2[i])]
+
+ # If local feature dimensionality is lower than 8, then use that to compute
+ # proper binarized inner product.
+ bits_per_descriptor = min(self._feature_dimensionality, 8)
+
+ total_num_bits = bits_per_descriptor * num_descriptors
+
+ return 1.0 - 2.0 * h / total_num_bits
+
+ def _AsmkSimilarity(self,
+ aggregated_descriptors_1,
+ aggregated_descriptors_2,
+ visual_words_1,
+ visual_words_2,
+ binarized=False):
+ """Compute ASMK-based similarity.
+
+ If `aggregated_descriptors_1` or `aggregated_descriptors_2` is empty, we
+ return a similarity of -1.0.
+
+ If binarized is True, `aggregated_descriptors_1` and
+ `aggregated_descriptors_2` must be of type uint8.
+
+ Args:
+ aggregated_descriptors_1: 1-D NumPy array.
+ aggregated_descriptors_2: 1-D NumPy array.
+ visual_words_1: 1-D sorted NumPy integer array denoting visual words
+ corresponding to `aggregated_descriptors_1`.
+ visual_words_2: 1-D sorted NumPy integer array denoting visual words
+ corresponding to `aggregated_descriptors_2`.
+ binarized: If True, compute ASMK* similarity.
+
+ Returns:
+ similarity: Float. The larger, the more similar.
+
+ Raises:
+ ValueError: If input descriptor dimensionality is inconsistent, or if
+ descriptor type is unsupported.
+ """
+ num_visual_words_1 = len(visual_words_1)
+ num_visual_words_2 = len(visual_words_2)
+
+ if not num_visual_words_1 or not num_visual_words_2:
+ return -1.0
+
+ # Parse dimensionality used per visual word. They must be the same for both
+ # aggregated descriptors. If using ASMK, they also must be equal to
+ # self._feature_dimensionality.
+ if binarized:
+ if aggregated_descriptors_1.dtype != 'uint8':
+ raise ValueError('Incorrect input descriptor type: %s' %
+ aggregated_descriptors_1.dtype)
+ if aggregated_descriptors_2.dtype != 'uint8':
+ raise ValueError('Incorrect input descriptor type: %s' %
+ aggregated_descriptors_2.dtype)
+
+ per_visual_word_dimensionality = int(
+ len(aggregated_descriptors_1) / num_visual_words_1)
+ if len(aggregated_descriptors_2
+ ) / num_visual_words_2 != per_visual_word_dimensionality:
+ raise ValueError('ASMK* dimensionality is inconsistent.')
+ else:
+ per_visual_word_dimensionality = self._feature_dimensionality
+ self._CheckAsmkDimensionality(aggregated_descriptors_1,
+ num_visual_words_1, '1')
+ self._CheckAsmkDimensionality(aggregated_descriptors_2,
+ num_visual_words_2, '2')
+
+ aggregated_descriptors_1_reshape = np.reshape(
+ aggregated_descriptors_1,
+ [num_visual_words_1, per_visual_word_dimensionality])
+ aggregated_descriptors_2_reshape = np.reshape(
+ aggregated_descriptors_2,
+ [num_visual_words_2, per_visual_word_dimensionality])
+
+ # Loop over visual words, compute similarity.
+ unnormalized_similarity = 0.0
+ ind_1 = 0
+ ind_2 = 0
+ while ind_1 < num_visual_words_1 and ind_2 < num_visual_words_2:
+ if visual_words_1[ind_1] == visual_words_2[ind_2]:
+ if binarized:
+ inner_product = self._BinaryNormalizedInnerProduct(
+ aggregated_descriptors_1_reshape[ind_1],
+ aggregated_descriptors_2_reshape[ind_2])
+ else:
+ inner_product = np.dot(aggregated_descriptors_1_reshape[ind_1],
+ aggregated_descriptors_2_reshape[ind_2])
+ unnormalized_similarity += self._SigmaFn(inner_product)
+ ind_1 += 1
+ ind_2 += 1
+ elif visual_words_1[ind_1] > visual_words_2[ind_2]:
+ ind_2 += 1
+ else:
+ ind_1 += 1
+
+ final_similarity = unnormalized_similarity
+ if self._use_l2_normalization:
+ final_similarity /= np.sqrt(num_visual_words_1 * num_visual_words_2)
+
+ return final_similarity
diff --git a/models/research/delf/delf/python/feature_aggregation_similarity_test.py b/models/research/delf/delf/python/feature_aggregation_similarity_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2f01b1d2a7b36b87773714f0cc027a98a36324f
--- /dev/null
+++ b/models/research/delf/delf/python/feature_aggregation_similarity_test.py
@@ -0,0 +1,137 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DELF feature aggregation similarity."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from delf import aggregation_config_pb2
+from delf import feature_aggregation_similarity
+
+
+class FeatureAggregationSimilarityTest(tf.test.TestCase):
+
+ def testComputeVladSimilarityWorks(self):
+ # Construct inputs.
+ vlad_1 = np.array([0, 1, 2, 3, 4])
+ vlad_2 = np.array([5, 6, 7, 8, 9])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.VLAD
+
+ # Run tested function.
+ similarity_computer = (
+ feature_aggregation_similarity.SimilarityAggregatedRepresentation(
+ config))
+ similarity = similarity_computer.ComputeSimilarity(vlad_1, vlad_2)
+
+ # Define expected results.
+ exp_similarity = 80
+
+ # Compare actual and expected results.
+ self.assertAllEqual(similarity, exp_similarity)
+
+ def testComputeAsmkSimilarityWorks(self):
+ # Construct inputs.
+ aggregated_descriptors_1 = np.array([
+ 0.0, 0.0, -0.707107, -0.707107, 0.5, 0.866025, 0.816497, 0.577350, 1.0,
+ 0.0
+ ])
+ visual_words_1 = np.array([0, 1, 2, 3, 4])
+ aggregated_descriptors_2 = np.array(
+ [0.0, 1.0, 1.0, 0.0, 0.707107, 0.707107])
+ visual_words_2 = np.array([1, 2, 4])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.ASMK
+ config.use_l2_normalization = True
+
+ # Run tested function.
+ similarity_computer = (
+ feature_aggregation_similarity.SimilarityAggregatedRepresentation(
+ config))
+ similarity = similarity_computer.ComputeSimilarity(
+ aggregated_descriptors_1, aggregated_descriptors_2, visual_words_1,
+ visual_words_2)
+
+ # Define expected results.
+ exp_similarity = 0.123562
+
+ # Compare actual and expected results.
+ self.assertAllClose(similarity, exp_similarity)
+
+ def testComputeAsmkSimilarityNoNormalizationWorks(self):
+ # Construct inputs.
+ aggregated_descriptors_1 = np.array([
+ 0.0, 0.0, -0.707107, -0.707107, 0.5, 0.866025, 0.816497, 0.577350, 1.0,
+ 0.0
+ ])
+ visual_words_1 = np.array([0, 1, 2, 3, 4])
+ aggregated_descriptors_2 = np.array(
+ [0.0, 1.0, 1.0, 0.0, 0.707107, 0.707107])
+ visual_words_2 = np.array([1, 2, 4])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.ASMK
+ config.use_l2_normalization = False
+
+ # Run tested function.
+ similarity_computer = (
+ feature_aggregation_similarity.SimilarityAggregatedRepresentation(
+ config))
+ similarity = similarity_computer.ComputeSimilarity(
+ aggregated_descriptors_1, aggregated_descriptors_2, visual_words_1,
+ visual_words_2)
+
+ # Define expected results.
+ exp_similarity = 0.478554
+
+ # Compare actual and expected results.
+ self.assertAllClose(similarity, exp_similarity)
+
+ def testComputeAsmkStarSimilarityWorks(self):
+ # Construct inputs.
+ aggregated_descriptors_1 = np.array([0, 0, 3, 3, 3], dtype='uint8')
+ visual_words_1 = np.array([0, 1, 2, 3, 4])
+ aggregated_descriptors_2 = np.array([1, 2, 3], dtype='uint8')
+ visual_words_2 = np.array([1, 2, 4])
+ config = aggregation_config_pb2.AggregationConfig()
+ config.codebook_size = 5
+ config.feature_dimensionality = 2
+ config.aggregation_type = aggregation_config_pb2.AggregationConfig.ASMK_STAR
+ config.use_l2_normalization = True
+
+ # Run tested function.
+ similarity_computer = (
+ feature_aggregation_similarity.SimilarityAggregatedRepresentation(
+ config))
+ similarity = similarity_computer.ComputeSimilarity(
+ aggregated_descriptors_1, aggregated_descriptors_2, visual_words_1,
+ visual_words_2)
+
+ # Define expected results.
+ exp_similarity = 0.258199
+
+ # Compare actual and expected results.
+ self.assertAllClose(similarity, exp_similarity)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/feature_extractor.py b/models/research/delf/delf/python/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9545337f18724520e260af4e36ffa6ee35bce4c6
--- /dev/null
+++ b/models/research/delf/delf/python/feature_extractor.py
@@ -0,0 +1,175 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""DELF feature extractor."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def NormalizePixelValues(image,
+ pixel_value_offset=128.0,
+ pixel_value_scale=128.0):
+ """Normalize image pixel values.
+
+ Args:
+ image: a uint8 tensor.
+ pixel_value_offset: a Python float, offset for normalizing pixel values.
+ pixel_value_scale: a Python float, scale for normalizing pixel values.
+
+ Returns:
+ image: a float32 tensor of the same shape as the input image.
+ """
+ image = tf.cast(image, dtype=tf.float32)
+ image = tf.truediv(tf.subtract(image, pixel_value_offset), pixel_value_scale)
+ return image
+
+
+def CalculateReceptiveBoxes(height, width, rf, stride, padding):
+ """Calculate receptive boxes for each feature point.
+
+ Args:
+ height: The height of feature map.
+ width: The width of feature map.
+ rf: The receptive field size.
+ stride: The effective stride between two adjacent feature points.
+ padding: The effective padding size.
+
+ Returns:
+ rf_boxes: [N, 4] receptive boxes tensor. Here N equals to height x width.
+ Each box is represented by [ymin, xmin, ymax, xmax].
+ """
+ x, y = tf.meshgrid(tf.range(width), tf.range(height))
+ coordinates = tf.reshape(tf.stack([y, x], axis=2), [-1, 2])
+ # [y,x,y,x]
+ point_boxes = tf.cast(
+ tf.concat([coordinates, coordinates], 1), dtype=tf.float32)
+ bias = [-padding, -padding, -padding + rf - 1, -padding + rf - 1]
+ rf_boxes = stride * point_boxes + bias
+ return rf_boxes
+
+
+def CalculateKeypointCenters(boxes):
+ """Helper function to compute feature centers, from RF boxes.
+
+ Args:
+ boxes: [N, 4] float tensor.
+
+ Returns:
+ centers: [N, 2] float tensor.
+ """
+ return tf.divide(
+ tf.add(
+ tf.gather(boxes, [0, 1], axis=1), tf.gather(boxes, [2, 3], axis=1)),
+ 2.0)
+
+
+def ApplyPcaAndWhitening(data,
+ pca_matrix,
+ pca_mean,
+ output_dim,
+ use_whitening=False,
+ pca_variances=None):
+ """Applies PCA/whitening to data.
+
+ Args:
+ data: [N, dim] float tensor containing data which undergoes PCA/whitening.
+ pca_matrix: [dim, dim] float tensor PCA matrix, row-major.
+ pca_mean: [dim] float tensor, mean to subtract before projection.
+ output_dim: Number of dimensions to use in output data, of type int.
+ use_whitening: Whether whitening is to be used.
+ pca_variances: [dim] float tensor containing PCA variances. Only used if
+ use_whitening is True.
+
+ Returns:
+ output: [N, output_dim] float tensor with output of PCA/whitening operation.
+ """
+ output = tf.matmul(
+ tf.subtract(data, pca_mean),
+ tf.slice(pca_matrix, [0, 0], [output_dim, -1]),
+ transpose_b=True,
+ name='pca_matmul')
+
+ # Apply whitening if desired.
+ if use_whitening:
+ output = tf.divide(
+ output,
+ tf.sqrt(tf.slice(pca_variances, [0], [output_dim])),
+ name='whitening')
+
+ return output
+
+
+def PostProcessDescriptors(descriptors, use_pca, pca_parameters=None):
+ """Post-process descriptors.
+
+ Args:
+ descriptors: [N, input_dim] float tensor.
+ use_pca: Whether to use PCA.
+ pca_parameters: Only used if `use_pca` is True. Dict containing PCA
+ parameter tensors, with keys 'mean', 'matrix', 'dim', 'use_whitening',
+ 'variances'.
+
+ Returns:
+ final_descriptors: [N, output_dim] float tensor with descriptors after
+ normalization and (possibly) PCA/whitening.
+ """
+ # L2-normalize, and if desired apply PCA (followed by L2-normalization).
+ final_descriptors = tf.nn.l2_normalize(
+ descriptors, axis=1, name='l2_normalization')
+
+ if use_pca:
+ # Apply PCA, and whitening if desired.
+ final_descriptors = ApplyPcaAndWhitening(final_descriptors,
+ pca_parameters['matrix'],
+ pca_parameters['mean'],
+ pca_parameters['dim'],
+ pca_parameters['use_whitening'],
+ pca_parameters['variances'])
+
+ # Re-normalize.
+ final_descriptors = tf.nn.l2_normalize(
+ final_descriptors, axis=1, name='pca_l2_normalization')
+
+ return final_descriptors
+
+
+def DelfFeaturePostProcessing(boxes, descriptors, use_pca, pca_parameters=None):
+ """Extract DELF features from input image.
+
+ Args:
+ boxes: [N, 4] float tensor which denotes the selected receptive box. N is
+ the number of final feature points which pass through keypoint selection
+ and NMS steps.
+ descriptors: [N, input_dim] float tensor.
+ use_pca: Whether to use PCA.
+ pca_parameters: Only used if `use_pca` is True. Dict containing PCA
+ parameter tensors, with keys 'mean', 'matrix', 'dim', 'use_whitening',
+ 'variances'.
+
+ Returns:
+ locations: [N, 2] float tensor which denotes the selected keypoint
+ locations.
+ final_descriptors: [N, output_dim] float tensor with DELF descriptors after
+ normalization and (possibly) PCA/whitening.
+ """
+
+ # Get center of descriptor boxes, corresponding to feature locations.
+ locations = CalculateKeypointCenters(boxes)
+ final_descriptors = PostProcessDescriptors(descriptors, use_pca,
+ pca_parameters)
+
+ return locations, final_descriptors
diff --git a/models/research/delf/delf/python/feature_extractor_test.py b/models/research/delf/delf/python/feature_extractor_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0caa51c4321ae30866d2c1247626843a907c5a2d
--- /dev/null
+++ b/models/research/delf/delf/python/feature_extractor_test.py
@@ -0,0 +1,75 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DELF feature extractor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from delf import feature_extractor
+
+
+class FeatureExtractorTest(tf.test.TestCase):
+
+ def testNormalizePixelValues(self):
+ image = tf.constant(
+ [[[3, 255, 0], [34, 12, 5]], [[45, 5, 65], [56, 77, 89]]],
+ dtype=tf.uint8)
+ normalized_image = feature_extractor.NormalizePixelValues(
+ image, pixel_value_offset=5.0, pixel_value_scale=2.0)
+ exp_normalized_image = [[[-1.0, 125.0, -2.5], [14.5, 3.5, 0.0]],
+ [[20.0, 0.0, 30.0], [25.5, 36.0, 42.0]]]
+
+ self.assertAllEqual(normalized_image, exp_normalized_image)
+
+ def testCalculateReceptiveBoxes(self):
+ boxes = feature_extractor.CalculateReceptiveBoxes(
+ height=1, width=2, rf=291, stride=32, padding=145)
+ exp_boxes = [[-145., -145., 145., 145.], [-145., -113., 145., 177.]]
+
+ self.assertAllEqual(exp_boxes, boxes)
+
+ def testCalculateKeypointCenters(self):
+ boxes = [[-10.0, 0.0, 11.0, 21.0], [-2.5, 5.0, 18.5, 26.0],
+ [45.0, -2.5, 66.0, 18.5]]
+ centers = feature_extractor.CalculateKeypointCenters(boxes)
+
+ exp_centers = [[0.5, 10.5], [8.0, 15.5], [55.5, 8.0]]
+
+ self.assertAllEqual(exp_centers, centers)
+
+ def testPcaWhitening(self):
+ data = tf.constant([[1.0, 2.0, -2.0], [-5.0, 0.0, 3.0], [-1.0, 2.0, 0.0],
+ [0.0, 4.0, -1.0]])
+ pca_matrix = tf.constant([[2.0, 0.0, -1.0], [0.0, 1.0, 1.0],
+ [-1.0, 1.0, 3.0]])
+ pca_mean = tf.constant([1.0, 2.0, 3.0])
+ output_dim = 2
+ use_whitening = True
+ pca_variances = tf.constant([4.0, 1.0])
+
+ output = feature_extractor.ApplyPcaAndWhitening(data, pca_matrix, pca_mean,
+ output_dim, use_whitening,
+ pca_variances)
+
+ exp_output = [[2.5, -5.0], [-6.0, -2.0], [-0.5, -3.0], [1.0, -2.0]]
+
+ self.assertAllEqual(exp_output, output)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/feature_io.py b/models/research/delf/delf/python/feature_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b68586b8543b08bf16d345a65345be7cb6d8a67
--- /dev/null
+++ b/models/research/delf/delf/python/feature_io.py
@@ -0,0 +1,196 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python interface for DelfFeatures proto.
+
+Support read and write of DelfFeatures from/to numpy arrays and file.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from delf import feature_pb2
+from delf import datum_io
+
+
+def ArraysToDelfFeatures(locations,
+ scales,
+ descriptors,
+ attention,
+ orientations=None):
+ """Converts DELF features to DelfFeatures proto.
+
+ Args:
+ locations: [N, 2] float array which denotes the selected keypoint locations.
+ N is the number of features.
+ scales: [N] float array with feature scales.
+ descriptors: [N, depth] float array with DELF descriptors.
+ attention: [N] float array with attention scores.
+ orientations: [N] float array with orientations. If None, all orientations
+ are set to zero.
+
+ Returns:
+ delf_features: DelfFeatures object.
+ """
+ num_features = len(attention)
+ assert num_features == locations.shape[0]
+ assert num_features == len(scales)
+ assert num_features == descriptors.shape[0]
+
+ if orientations is None:
+ orientations = np.zeros([num_features], dtype=np.float32)
+ else:
+ assert num_features == len(orientations)
+
+ delf_features = feature_pb2.DelfFeatures()
+ for i in range(num_features):
+ delf_feature = delf_features.feature.add()
+ delf_feature.y = locations[i, 0]
+ delf_feature.x = locations[i, 1]
+ delf_feature.scale = scales[i]
+ delf_feature.orientation = orientations[i]
+ delf_feature.strength = attention[i]
+ delf_feature.descriptor.CopyFrom(datum_io.ArrayToDatum(descriptors[i,]))
+
+ return delf_features
+
+
+def DelfFeaturesToArrays(delf_features):
+ """Converts data saved in DelfFeatures to numpy arrays.
+
+ If there are no features, the function returns four empty arrays.
+
+ Args:
+ delf_features: DelfFeatures object.
+
+ Returns:
+ locations: [N, 2] float array which denotes the selected keypoint
+ locations. N is the number of features.
+ scales: [N] float array with feature scales.
+ descriptors: [N, depth] float array with DELF descriptors.
+ attention: [N] float array with attention scores.
+ orientations: [N] float array with orientations.
+ """
+ num_features = len(delf_features.feature)
+ if num_features == 0:
+ return np.array([]), np.array([]), np.array([]), np.array([]), np.array([])
+
+ # Figure out descriptor dimensionality by parsing first one.
+ descriptor_dim = len(
+ datum_io.DatumToArray(delf_features.feature[0].descriptor))
+ locations = np.zeros([num_features, 2])
+ scales = np.zeros([num_features])
+ descriptors = np.zeros([num_features, descriptor_dim])
+ attention = np.zeros([num_features])
+ orientations = np.zeros([num_features])
+
+ for i in range(num_features):
+ delf_feature = delf_features.feature[i]
+ locations[i, 0] = delf_feature.y
+ locations[i, 1] = delf_feature.x
+ scales[i] = delf_feature.scale
+ descriptors[i,] = datum_io.DatumToArray(delf_feature.descriptor)
+ attention[i] = delf_feature.strength
+ orientations[i] = delf_feature.orientation
+
+ return locations, scales, descriptors, attention, orientations
+
+
+def SerializeToString(locations,
+ scales,
+ descriptors,
+ attention,
+ orientations=None):
+ """Converts numpy arrays to serialized DelfFeatures.
+
+ Args:
+ locations: [N, 2] float array which denotes the selected keypoint locations.
+ N is the number of features.
+ scales: [N] float array with feature scales.
+ descriptors: [N, depth] float array with DELF descriptors.
+ attention: [N] float array with attention scores.
+ orientations: [N] float array with orientations. If None, all orientations
+ are set to zero.
+
+ Returns:
+ Serialized DelfFeatures string.
+ """
+ delf_features = ArraysToDelfFeatures(locations, scales, descriptors,
+ attention, orientations)
+ return delf_features.SerializeToString()
+
+
+def ParseFromString(string):
+ """Converts serialized DelfFeatures string to numpy arrays.
+
+ Args:
+ string: Serialized DelfFeatures string.
+
+ Returns:
+ locations: [N, 2] float array which denotes the selected keypoint
+ locations. N is the number of features.
+ scales: [N] float array with feature scales.
+ descriptors: [N, depth] float array with DELF descriptors.
+ attention: [N] float array with attention scores.
+ orientations: [N] float array with orientations.
+ """
+ delf_features = feature_pb2.DelfFeatures()
+ delf_features.ParseFromString(string)
+ return DelfFeaturesToArrays(delf_features)
+
+
+def ReadFromFile(file_path):
+ """Helper function to load data from a DelfFeatures format in a file.
+
+ Args:
+ file_path: Path to file containing data.
+
+ Returns:
+ locations: [N, 2] float array which denotes the selected keypoint
+ locations. N is the number of features.
+ scales: [N] float array with feature scales.
+ descriptors: [N, depth] float array with DELF descriptors.
+ attention: [N] float array with attention scores.
+ orientations: [N] float array with orientations.
+ """
+ with tf.io.gfile.GFile(file_path, 'rb') as f:
+ return ParseFromString(f.read())
+
+
+def WriteToFile(file_path,
+ locations,
+ scales,
+ descriptors,
+ attention,
+ orientations=None):
+ """Helper function to write data to a file in DelfFeatures format.
+
+ Args:
+ file_path: Path to file that will be written.
+ locations: [N, 2] float array which denotes the selected keypoint locations.
+ N is the number of features.
+ scales: [N] float array with feature scales.
+ descriptors: [N, depth] float array with DELF descriptors.
+ attention: [N] float array with attention scores.
+ orientations: [N] float array with orientations. If None, all orientations
+ are set to zero.
+ """
+ serialized_data = SerializeToString(locations, scales, descriptors, attention,
+ orientations)
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ f.write(serialized_data)
diff --git a/models/research/delf/delf/python/feature_io_test.py b/models/research/delf/delf/python/feature_io_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b68d3b241cf561de9362c84ec05e148d22ee0f2
--- /dev/null
+++ b/models/research/delf/delf/python/feature_io_test.py
@@ -0,0 +1,112 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for feature_io, the python interface of DelfFeatures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+import numpy as np
+import tensorflow as tf
+
+from delf import feature_io
+
+FLAGS = flags.FLAGS
+
+
+def create_data():
+ """Creates data to be used in tests.
+
+ Returns:
+ locations: [N, 2] float array which denotes the selected keypoint
+ locations. N is the number of features.
+ scales: [N] float array with feature scales.
+ descriptors: [N, depth] float array with DELF descriptors.
+ attention: [N] float array with attention scores.
+ orientations: [N] float array with orientations.
+ """
+ locations = np.arange(8, dtype=np.float32).reshape(4, 2)
+ scales = np.arange(4, dtype=np.float32)
+ attention = np.arange(4, dtype=np.float32)
+ orientations = np.arange(4, dtype=np.float32)
+ descriptors = np.zeros([4, 1024])
+ descriptors[0,] = np.arange(1024)
+ descriptors[1,] = np.zeros([1024])
+ descriptors[2,] = np.ones([1024])
+ descriptors[3,] = -np.ones([1024])
+
+ return locations, scales, descriptors, attention, orientations
+
+
+class DelfFeaturesIoTest(tf.test.TestCase):
+
+ def testConversionAndBack(self):
+ locations, scales, descriptors, attention, orientations = create_data()
+
+ serialized = feature_io.SerializeToString(locations, scales, descriptors,
+ attention, orientations)
+ parsed_data = feature_io.ParseFromString(serialized)
+
+ self.assertAllEqual(locations, parsed_data[0])
+ self.assertAllEqual(scales, parsed_data[1])
+ self.assertAllEqual(descriptors, parsed_data[2])
+ self.assertAllEqual(attention, parsed_data[3])
+ self.assertAllEqual(orientations, parsed_data[4])
+
+ def testConversionAndBackNoOrientations(self):
+ locations, scales, descriptors, attention, _ = create_data()
+
+ serialized = feature_io.SerializeToString(locations, scales, descriptors,
+ attention)
+ parsed_data = feature_io.ParseFromString(serialized)
+
+ self.assertAllEqual(locations, parsed_data[0])
+ self.assertAllEqual(scales, parsed_data[1])
+ self.assertAllEqual(descriptors, parsed_data[2])
+ self.assertAllEqual(attention, parsed_data[3])
+ self.assertAllEqual(np.zeros([4]), parsed_data[4])
+
+ def testWriteAndReadToFile(self):
+ locations, scales, descriptors, attention, orientations = create_data()
+
+ filename = os.path.join(FLAGS.test_tmpdir, 'test.delf')
+ feature_io.WriteToFile(filename, locations, scales, descriptors, attention,
+ orientations)
+ data_read = feature_io.ReadFromFile(filename)
+
+ self.assertAllEqual(locations, data_read[0])
+ self.assertAllEqual(scales, data_read[1])
+ self.assertAllEqual(descriptors, data_read[2])
+ self.assertAllEqual(attention, data_read[3])
+ self.assertAllEqual(orientations, data_read[4])
+
+ def testWriteAndReadToFileEmptyFile(self):
+ filename = os.path.join(FLAGS.test_tmpdir, 'test.delf')
+ feature_io.WriteToFile(filename, np.array([]), np.array([]), np.array([]),
+ np.array([]), np.array([]))
+ data_read = feature_io.ReadFromFile(filename)
+
+ self.assertAllEqual(np.array([]), data_read[0])
+ self.assertAllEqual(np.array([]), data_read[1])
+ self.assertAllEqual(np.array([]), data_read[2])
+ self.assertAllEqual(np.array([]), data_read[3])
+ self.assertAllEqual(np.array([]), data_read[4])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/google_landmarks_dataset/README.md b/models/research/delf/delf/python/google_landmarks_dataset/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..485c1a946b5b21ddb369cc1bc8645534abbfad1e
--- /dev/null
+++ b/models/research/delf/delf/python/google_landmarks_dataset/README.md
@@ -0,0 +1,123 @@
+## GLDv2 code/models
+
+[](https://arxiv.org/abs/2004.01804)
+
+These instructions can be used to reproduce results from the
+[GLDv2 paper](https://arxiv.org/abs/2004.01804). We present here results on the
+Revisited Oxford/Paris datasets since they are smaller and quicker to
+reproduce -- but note that a very similar procedure can be used to obtain
+results on the GLDv2 retrieval or recognition datasets.
+
+Note that this directory also contains code to compute GLDv2 metrics: see
+`compute_retrieval_metrics.py`, `compute_recognition_metrics.py` and associated
+file reading / metric computation modules.
+
+For more details on the dataset, please refer to its
+[website](https://github.com/cvdfoundation/google-landmark).
+
+### Install DELF library
+
+To be able to use this code, please follow
+[these instructions](../../../INSTALL_INSTRUCTIONS.md) to properly install the
+DELF library.
+
+### Download Revisited Oxford/Paris datasets
+
+```bash
+mkdir -p ~/revisitop/data && cd ~/revisitop/data
+
+# Oxford dataset.
+wget http://www.robots.ox.ac.uk/~vgg/data/oxbuildings/oxbuild_images.tgz
+mkdir oxford5k_images
+tar -xvzf oxbuild_images.tgz -C oxford5k_images/
+
+# Paris dataset. Download and move all images to same directory.
+wget http://www.robots.ox.ac.uk/~vgg/data/parisbuildings/paris_1.tgz
+wget http://www.robots.ox.ac.uk/~vgg/data/parisbuildings/paris_2.tgz
+mkdir paris6k_images_tmp
+tar -xvzf paris_1.tgz -C paris6k_images_tmp/
+tar -xvzf paris_2.tgz -C paris6k_images_tmp/
+mkdir paris6k_images
+mv paris6k_images_tmp/paris/*/*.jpg paris6k_images/
+
+# Revisited annotations.
+wget http://cmp.felk.cvut.cz/revisitop/data/datasets/roxford5k/gnd_roxford5k.mat
+wget http://cmp.felk.cvut.cz/revisitop/data/datasets/rparis6k/gnd_rparis6k.mat
+```
+
+### Download model
+
+```bash
+# From models/research/delf/delf/python/google_landmarks_dataset
+mkdir parameters && cd parameters
+
+# RN101-ArcFace model trained on GLDv2-clean.
+wget https://storage.googleapis.com/delf/rn101_af_gldv2clean_20200521.tar.gz
+tar -xvzf rn101_af_gldv2clean_20200521.tar.gz
+```
+
+### Feature extraction
+
+We present here commands for extraction on `roxford5k`. To extract on `rparis6k`
+instead, please edit the arguments accordingly (especially the
+`dataset_file_path` argument).
+
+#### Query feature extraction
+
+In the Revisited Oxford/Paris experimental protocol, query images must be the
+cropped before feature extraction (this is done in the `extract_features`
+script, when setting `image_set=query`). Note that this is specific to these
+datasets, and not required for the GLDv2 retrieval/recognition datasets.
+
+Run query feature extraction as follows:
+
+```bash
+# From models/research/delf/delf/python/google_landmarks_dataset
+python3 ../delg/extract_features.py \
+ --delf_config_path rn101_af_gldv2clean_config.pbtxt \
+ --dataset_file_path ~/revisitop/data/gnd_roxford5k.mat \
+ --images_dir ~/revisitop/data/oxford5k_images \
+ --image_set query \
+ --output_features_dir ~/revisitop/data/oxford5k_features/query
+```
+
+#### Index feature extraction
+
+Run index feature extraction as follows:
+
+```bash
+# From models/research/delf/delf/python/google_landmarks_dataset
+python3 ../delg/extract_features.py \
+ --delf_config_path rn101_af_gldv2clean_config.pbtxt \
+ --dataset_file_path ~/revisitop/data/gnd_roxford5k.mat \
+ --images_dir ~/revisitop/data/oxford5k_images \
+ --image_set index \
+ --output_features_dir ~/revisitop/data/oxford5k_features/index
+```
+
+### Perform retrieval
+
+To run retrieval on `roxford5k`, the following command can be used:
+
+```bash
+# From models/research/delf/delf/python/google_landmarks_dataset
+python3 ../delg/perform_retrieval.py \
+ --dataset_file_path ~/revisitop/data/gnd_roxford5k.mat \
+ --query_features_dir ~/revisitop/data/oxford5k_features/query \
+ --index_features_dir ~/revisitop/data/oxford5k_features/index \
+ --output_dir ~/revisitop/results/oxford5k
+```
+
+A file with named `metrics.txt` will be written to the path given in
+`output_dir`. The contents should look approximately like:
+
+```
+hard
+ mAP=55.54
+ mP@k[ 1 5 10] [88.57 80.86 70.14]
+ mR@k[ 1 5 10] [19.46 33.65 42.44]
+medium
+ mAP=76.23
+ mP@k[ 1 5 10] [95.71 92.86 90.43]
+ mR@k[ 1 5 10] [10.17 25.96 35.29]
+```
diff --git a/models/research/delf/delf/python/google_landmarks_dataset/compute_recognition_metrics.py b/models/research/delf/delf/python/google_landmarks_dataset/compute_recognition_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..f80cf47de7487d4cd584c969d994f7d3f1135cae
--- /dev/null
+++ b/models/research/delf/delf/python/google_landmarks_dataset/compute_recognition_metrics.py
@@ -0,0 +1,99 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Computes metrics for Google Landmarks Recognition dataset predictions.
+
+Metrics are written to stdout.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.python.platform import app
+from delf.python.google_landmarks_dataset import dataset_file_io
+from delf.python.google_landmarks_dataset import metrics
+
+cmd_args = None
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Read solution.
+ print('Reading solution...')
+ public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution(
+ cmd_args.solution_path, dataset_file_io.RECOGNITION_TASK_ID)
+ print('done!')
+
+ # Read predictions.
+ print('Reading predictions...')
+ public_predictions, private_predictions = dataset_file_io.ReadPredictions(
+ cmd_args.predictions_path, set(public_solution.keys()),
+ set(private_solution.keys()), set(ignored_ids),
+ dataset_file_io.RECOGNITION_TASK_ID)
+ print('done!')
+
+ # Global Average Precision.
+ print('**********************************************')
+ print('(Public) Global Average Precision: %f' %
+ metrics.GlobalAveragePrecision(public_predictions, public_solution))
+ print('(Private) Global Average Precision: %f' %
+ metrics.GlobalAveragePrecision(private_predictions, private_solution))
+
+ # Global Average Precision ignoring non-landmark queries.
+ print('**********************************************')
+ print(
+ '(Public) Global Average Precision ignoring non-landmark queries: %f' %
+ metrics.GlobalAveragePrecision(
+ public_predictions, public_solution, ignore_non_gt_test_images=True))
+ print(
+ '(Private) Global Average Precision ignoring non-landmark queries: %f' %
+ metrics.GlobalAveragePrecision(
+ private_predictions, private_solution,
+ ignore_non_gt_test_images=True))
+
+ # Top-1 accuracy.
+ print('**********************************************')
+ print('(Public) Top-1 accuracy: %.2f' %
+ (100.0 * metrics.Top1Accuracy(public_predictions, public_solution)))
+ print('(Private) Top-1 accuracy: %.2f' %
+ (100.0 * metrics.Top1Accuracy(private_predictions, private_solution)))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--predictions_path',
+ type=str,
+ default='/tmp/predictions.csv',
+ help="""
+ Path to CSV predictions file, formatted with columns 'id,landmarks' (the
+ file should include a header).
+ """)
+ parser.add_argument(
+ '--solution_path',
+ type=str,
+ default='/tmp/solution.csv',
+ help="""
+ Path to CSV solution file, formatted with columns 'id,landmarks,Usage'
+ (the file should include a header).
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/google_landmarks_dataset/compute_retrieval_metrics.py b/models/research/delf/delf/python/google_landmarks_dataset/compute_retrieval_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..adcee356e5d64d094236cda9656c86164c24faf8
--- /dev/null
+++ b/models/research/delf/delf/python/google_landmarks_dataset/compute_retrieval_metrics.py
@@ -0,0 +1,106 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Computes metrics for Google Landmarks Retrieval dataset predictions.
+
+Metrics are written to stdout.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.python.platform import app
+from delf.python.google_landmarks_dataset import dataset_file_io
+from delf.python.google_landmarks_dataset import metrics
+
+cmd_args = None
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise RuntimeError('Too many command-line arguments.')
+
+ # Read solution.
+ print('Reading solution...')
+ public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution(
+ cmd_args.solution_path, dataset_file_io.RETRIEVAL_TASK_ID)
+ print('done!')
+
+ # Read predictions.
+ print('Reading predictions...')
+ public_predictions, private_predictions = dataset_file_io.ReadPredictions(
+ cmd_args.predictions_path, set(public_solution.keys()),
+ set(private_solution.keys()), set(ignored_ids),
+ dataset_file_io.RETRIEVAL_TASK_ID)
+ print('done!')
+
+ # Mean average precision.
+ print('**********************************************')
+ print('(Public) Mean Average Precision: %f' %
+ metrics.MeanAveragePrecision(public_predictions, public_solution))
+ print('(Private) Mean Average Precision: %f' %
+ metrics.MeanAveragePrecision(private_predictions, private_solution))
+
+ # Mean precision@k.
+ print('**********************************************')
+ public_precisions = 100.0 * metrics.MeanPrecisions(public_predictions,
+ public_solution)
+ private_precisions = 100.0 * metrics.MeanPrecisions(private_predictions,
+ private_solution)
+ print('(Public) Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, '
+ 'P@50: %.2f, P@100: %.2f' %
+ (public_precisions[0], public_precisions[4], public_precisions[9],
+ public_precisions[49], public_precisions[99]))
+ print('(Private) Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, '
+ 'P@50: %.2f, P@100: %.2f' %
+ (private_precisions[0], private_precisions[4], private_precisions[9],
+ private_precisions[49], private_precisions[99]))
+
+ # Mean/median position of first correct.
+ print('**********************************************')
+ public_mean_position, public_median_position = metrics.MeanMedianPosition(
+ public_predictions, public_solution)
+ private_mean_position, private_median_position = metrics.MeanMedianPosition(
+ private_predictions, private_solution)
+ print('(Public) Mean position: %.2f, median position: %.2f' %
+ (public_mean_position, public_median_position))
+ print('(Private) Mean position: %.2f, median position: %.2f' %
+ (private_mean_position, private_median_position))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--predictions_path',
+ type=str,
+ default='/tmp/predictions.csv',
+ help="""
+ Path to CSV predictions file, formatted with columns 'id,images' (the
+ file should include a header).
+ """)
+ parser.add_argument(
+ '--solution_path',
+ type=str,
+ default='/tmp/solution.csv',
+ help="""
+ Path to CSV solution file, formatted with columns 'id,images,Usage'
+ (the file should include a header).
+ """)
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/models/research/delf/delf/python/google_landmarks_dataset/dataset_file_io.py b/models/research/delf/delf/python/google_landmarks_dataset/dataset_file_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..93f2785d78f03b5b112bbba635b4778f2e9b9a08
--- /dev/null
+++ b/models/research/delf/delf/python/google_landmarks_dataset/dataset_file_io.py
@@ -0,0 +1,159 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""IO module for files from Landmark recognition/retrieval challenges."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+
+import tensorflow as tf
+
+RECOGNITION_TASK_ID = 'recognition'
+RETRIEVAL_TASK_ID = 'retrieval'
+
+
+def ReadSolution(file_path, task):
+ """Reads solution from file, for a given task.
+
+ Args:
+ file_path: Path to CSV file with solution. File contains a header.
+ task: Type of challenge task. Supported values: 'recognition', 'retrieval'.
+
+ Returns:
+ public_solution: Dict mapping test image ID to list of ground-truth IDs, for
+ the Public subset of test images. If `task` == 'recognition', the IDs are
+ integers corresponding to landmark IDs. If `task` == 'retrieval', the IDs
+ are strings corresponding to index image IDs.
+ private_solution: Same as `public_solution`, but for the private subset of
+ test images.
+ ignored_ids: List of test images that are ignored in scoring.
+
+ Raises:
+ ValueError: If Usage field is not Public, Private or Ignored; or if `task`
+ is not supported.
+ """
+ public_solution = {}
+ private_solution = {}
+ ignored_ids = []
+ with tf.io.gfile.GFile(file_path, 'r') as csv_file:
+ reader = csv.reader(csv_file)
+ next(reader, None) # Skip header.
+ for row in reader:
+ test_id = row[0]
+ if row[2] == 'Ignored':
+ ignored_ids.append(test_id)
+ else:
+ ground_truth_ids = []
+ if task == RECOGNITION_TASK_ID:
+ if row[1]:
+ for landmark_id in row[1].split(' '):
+ ground_truth_ids.append(int(landmark_id))
+ elif task == RETRIEVAL_TASK_ID:
+ for image_id in row[1].split(' '):
+ ground_truth_ids.append(image_id)
+ else:
+ raise ValueError('Unrecognized task: %s' % task)
+
+ if row[2] == 'Public':
+ public_solution[test_id] = ground_truth_ids
+ elif row[2] == 'Private':
+ private_solution[test_id] = ground_truth_ids
+ else:
+ raise ValueError('Test image %s has unrecognized Usage tag %s' %
+ (row[0], row[2]))
+
+ return public_solution, private_solution, ignored_ids
+
+
+def ReadPredictions(file_path, public_ids, private_ids, ignored_ids, task):
+ """Reads predictions from file, for a given task.
+
+ Args:
+ file_path: Path to CSV file with predictions. File contains a header.
+ public_ids: Set (or list) of test image IDs in Public subset of test images.
+ private_ids: Same as `public_ids`, but for the private subset of test
+ images.
+ ignored_ids: Set (or list) of test image IDs that are ignored in scoring and
+ are associated to no ground-truth.
+ task: Type of challenge task. Supported values: 'recognition', 'retrieval'.
+
+ Returns:
+ public_predictions: Dict mapping test image ID to prediction, for the Public
+ subset of test images. If `task` == 'recognition', the prediction is a
+ dict with keys 'class' (integer) and 'score' (float). If `task` ==
+ 'retrieval', the prediction is a list of strings corresponding to index
+ image IDs.
+ private_predictions: Same as `public_predictions`, but for the private
+ subset of test images.
+
+ Raises:
+ ValueError:
+ - If test image ID is unrecognized/repeated;
+ - If `task` is not supported;
+ - If prediction is malformed.
+ """
+ public_predictions = {}
+ private_predictions = {}
+ with tf.io.gfile.GFile(file_path, 'r') as csv_file:
+ reader = csv.reader(csv_file)
+ next(reader, None) # Skip header.
+ for row in reader:
+ # Skip row if empty.
+ if not row:
+ continue
+
+ test_id = row[0]
+
+ # Makes sure this query has not yet been seen.
+ if test_id in public_predictions:
+ raise ValueError('Test image %s is repeated.' % test_id)
+ if test_id in private_predictions:
+ raise ValueError('Test image %s is repeated' % test_id)
+
+ # If ignored, skip it.
+ if test_id in ignored_ids:
+ continue
+
+ # Only parse result if there is a prediction.
+ if row[1]:
+ prediction_split = row[1].split(' ')
+ # Remove empty spaces at end (if any).
+ if not prediction_split[-1]:
+ prediction_split = prediction_split[:-1]
+
+ if task == RECOGNITION_TASK_ID:
+ if len(prediction_split) != 2:
+ raise ValueError('Prediction is malformed: there should only be 2 '
+ 'elements in second column, but found %d for test '
+ 'image %s' % (len(prediction_split), test_id))
+
+ landmark_id = int(prediction_split[0])
+ score = float(prediction_split[1])
+ prediction_entry = {'class': landmark_id, 'score': score}
+ elif task == RETRIEVAL_TASK_ID:
+ prediction_entry = prediction_split
+ else:
+ raise ValueError('Unrecognized task: %s' % task)
+
+ if test_id in public_ids:
+ public_predictions[test_id] = prediction_entry
+ elif test_id in private_ids:
+ private_predictions[test_id] = prediction_entry
+ else:
+ raise ValueError('test_id %s is unrecognized' % test_id)
+
+ return public_predictions, private_predictions
diff --git a/models/research/delf/delf/python/google_landmarks_dataset/dataset_file_io_test.py b/models/research/delf/delf/python/google_landmarks_dataset/dataset_file_io_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0101d989fba85e487842e65b3b1aec4e728c101c
--- /dev/null
+++ b/models/research/delf/delf/python/google_landmarks_dataset/dataset_file_io_test.py
@@ -0,0 +1,170 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dataset file IO module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+import tensorflow as tf
+
+from delf.python.google_landmarks_dataset import dataset_file_io
+
+FLAGS = flags.FLAGS
+
+
+class DatasetFileIoTest(tf.test.TestCase):
+
+ def testReadRecognitionSolutionWorks(self):
+ # Define inputs.
+ file_path = os.path.join(FLAGS.test_tmpdir,
+ 'recognition_solution.csv')
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ f.write('id,landmarks,Usage\n')
+ f.write('0123456789abcdef,0 12,Public\n')
+ f.write('0223456789abcdef,,Public\n')
+ f.write('0323456789abcdef,100,Ignored\n')
+ f.write('0423456789abcdef,1,Private\n')
+ f.write('0523456789abcdef,,Ignored\n')
+
+ # Run tested function.
+ (public_solution, private_solution,
+ ignored_ids) = dataset_file_io.ReadSolution(
+ file_path, dataset_file_io.RECOGNITION_TASK_ID)
+
+ # Define expected results.
+ expected_public_solution = {
+ '0123456789abcdef': [0, 12],
+ '0223456789abcdef': []
+ }
+ expected_private_solution = {
+ '0423456789abcdef': [1],
+ }
+ expected_ignored_ids = ['0323456789abcdef', '0523456789abcdef']
+
+ # Compare actual and expected results.
+ self.assertEqual(public_solution, expected_public_solution)
+ self.assertEqual(private_solution, expected_private_solution)
+ self.assertEqual(ignored_ids, expected_ignored_ids)
+
+ def testReadRetrievalSolutionWorks(self):
+ # Define inputs.
+ file_path = os.path.join(FLAGS.test_tmpdir,
+ 'retrieval_solution.csv')
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ f.write('id,images,Usage\n')
+ f.write('0123456789abcdef,None,Ignored\n')
+ f.write('0223456789abcdef,fedcba9876543210 fedcba9876543200,Public\n')
+ f.write('0323456789abcdef,fedcba9876543200,Private\n')
+ f.write('0423456789abcdef,fedcba9876543220,Private\n')
+ f.write('0523456789abcdef,None,Ignored\n')
+
+ # Run tested function.
+ (public_solution, private_solution,
+ ignored_ids) = dataset_file_io.ReadSolution(
+ file_path, dataset_file_io.RETRIEVAL_TASK_ID)
+
+ # Define expected results.
+ expected_public_solution = {
+ '0223456789abcdef': ['fedcba9876543210', 'fedcba9876543200'],
+ }
+ expected_private_solution = {
+ '0323456789abcdef': ['fedcba9876543200'],
+ '0423456789abcdef': ['fedcba9876543220'],
+ }
+ expected_ignored_ids = ['0123456789abcdef', '0523456789abcdef']
+
+ # Compare actual and expected results.
+ self.assertEqual(public_solution, expected_public_solution)
+ self.assertEqual(private_solution, expected_private_solution)
+ self.assertEqual(ignored_ids, expected_ignored_ids)
+
+ def testReadRecognitionPredictionsWorks(self):
+ # Define inputs.
+ file_path = os.path.join(FLAGS.test_tmpdir,
+ 'recognition_predictions.csv')
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ f.write('id,landmarks\n')
+ f.write('0123456789abcdef,12 0.1 \n')
+ f.write('0423456789abcdef,0 19.0\n')
+ f.write('0223456789abcdef,\n')
+ f.write('\n')
+ f.write('0523456789abcdef,14 0.01\n')
+ public_ids = ['0123456789abcdef', '0223456789abcdef']
+ private_ids = ['0423456789abcdef']
+ ignored_ids = ['0323456789abcdef', '0523456789abcdef']
+
+ # Run tested function.
+ public_predictions, private_predictions = dataset_file_io.ReadPredictions(
+ file_path, public_ids, private_ids, ignored_ids,
+ dataset_file_io.RECOGNITION_TASK_ID)
+
+ # Define expected results.
+ expected_public_predictions = {
+ '0123456789abcdef': {
+ 'class': 12,
+ 'score': 0.1
+ }
+ }
+ expected_private_predictions = {
+ '0423456789abcdef': {
+ 'class': 0,
+ 'score': 19.0
+ }
+ }
+
+ # Compare actual and expected results.
+ self.assertEqual(public_predictions, expected_public_predictions)
+ self.assertEqual(private_predictions, expected_private_predictions)
+
+ def testReadRetrievalPredictionsWorks(self):
+ # Define inputs.
+ file_path = os.path.join(FLAGS.test_tmpdir,
+ 'retrieval_predictions.csv')
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+ f.write('id,images\n')
+ f.write('0123456789abcdef,fedcba9876543250 \n')
+ f.write('0423456789abcdef,fedcba9876543260\n')
+ f.write('0223456789abcdef,fedcba9876543210 fedcba9876543200 '
+ 'fedcba9876543220\n')
+ f.write('\n')
+ f.write('0523456789abcdef,\n')
+ public_ids = ['0223456789abcdef']
+ private_ids = ['0323456789abcdef', '0423456789abcdef']
+ ignored_ids = ['0123456789abcdef', '0523456789abcdef']
+
+ # Run tested function.
+ public_predictions, private_predictions = dataset_file_io.ReadPredictions(
+ file_path, public_ids, private_ids, ignored_ids,
+ dataset_file_io.RETRIEVAL_TASK_ID)
+
+ # Define expected results.
+ expected_public_predictions = {
+ '0223456789abcdef': [
+ 'fedcba9876543210', 'fedcba9876543200', 'fedcba9876543220'
+ ]
+ }
+ expected_private_predictions = {'0423456789abcdef': ['fedcba9876543260']}
+
+ # Compare actual and expected results.
+ self.assertEqual(public_predictions, expected_public_predictions)
+ self.assertEqual(private_predictions, expected_private_predictions)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/google_landmarks_dataset/metrics.py b/models/research/delf/delf/python/google_landmarks_dataset/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..1516be9d8569cecfe470d9ec98ce273a72cce84f
--- /dev/null
+++ b/models/research/delf/delf/python/google_landmarks_dataset/metrics.py
@@ -0,0 +1,254 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python module to compute metrics for Google Landmarks dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+
+def _CountPositives(solution):
+ """Counts number of test images with non-empty ground-truth in `solution`.
+
+ Args:
+ solution: Dict mapping test image ID to list of ground-truth IDs.
+
+ Returns:
+ count: Number of test images with non-empty ground-truth.
+ """
+ count = 0
+ for v in solution.values():
+ if v:
+ count += 1
+
+ return count
+
+
+def GlobalAveragePrecision(predictions,
+ recognition_solution,
+ ignore_non_gt_test_images=False):
+ """Computes global average precision for recognition prediction.
+
+ Args:
+ predictions: Dict mapping test image ID to a dict with keys 'class'
+ (integer) and 'score' (float).
+ recognition_solution: Dict mapping test image ID to list of ground-truth
+ landmark IDs.
+ ignore_non_gt_test_images: If True, ignore test images which do not have
+ associated ground-truth landmark IDs. For the Google Landmark Recognition
+ challenge, this should be set to False.
+
+ Returns:
+ gap: Global average precision score (float).
+ """
+ # Compute number of expected results.
+ num_positives = _CountPositives(recognition_solution)
+
+ gap = 0.0
+ total_predictions = 0
+ correct_predictions = 0
+
+ # Sort predictions according to Kaggle's convention:
+ # - first by score (descending);
+ # - then by key (ascending);
+ # - then by class (ascending).
+ sorted_predictions_by_key_class = sorted(
+ predictions.items(), key=lambda item: (item[0], item[1]['class']))
+ sorted_predictions = sorted(
+ sorted_predictions_by_key_class,
+ key=lambda item: item[1]['score'],
+ reverse=True)
+
+ # Loop over sorted predictions (descending order) and compute GAPs.
+ for key, prediction in sorted_predictions:
+ if ignore_non_gt_test_images and not recognition_solution[key]:
+ continue
+
+ total_predictions += 1
+ if prediction['class'] in recognition_solution[key]:
+ correct_predictions += 1
+ gap += correct_predictions / total_predictions
+
+ gap /= num_positives
+
+ return gap
+
+
+def Top1Accuracy(predictions, recognition_solution):
+ """Computes top-1 accuracy for recognition prediction.
+
+ Note that test images without ground-truth are ignored.
+
+ Args:
+ predictions: Dict mapping test image ID to a dict with keys 'class'
+ (integer) and 'score' (float).
+ recognition_solution: Dict mapping test image ID to list of ground-truth
+ landmark IDs.
+
+ Returns:
+ accuracy: Top-1 accuracy (float).
+ """
+ # Loop over test images in solution. If it has at least one class label, we
+ # check if the predicion is correct.
+ num_correct_predictions = 0
+ num_test_images_with_ground_truth = 0
+ for key, ground_truth in recognition_solution.items():
+ if ground_truth:
+ num_test_images_with_ground_truth += 1
+ if key in predictions:
+ if predictions[key]['class'] in ground_truth:
+ num_correct_predictions += 1
+
+ return num_correct_predictions / num_test_images_with_ground_truth
+
+
+def MeanAveragePrecision(predictions, retrieval_solution, max_predictions=100):
+ """Computes mean average precision for retrieval prediction.
+
+ Args:
+ predictions: Dict mapping test image ID to a list of strings corresponding
+ to index image IDs.
+ retrieval_solution: Dict mapping test image ID to list of ground-truth image
+ IDs.
+ max_predictions: Maximum number of predictions per query to take into
+ account. For the Google Landmark Retrieval challenge, this should be set
+ to 100.
+
+ Returns:
+ mean_ap: Mean average precision score (float).
+
+ Raises:
+ ValueError: If a test image in `predictions` is not included in
+ `retrieval_solutions`.
+ """
+ # Compute number of test images.
+ num_test_images = len(retrieval_solution.keys())
+
+ # Loop over predictions for each query and compute mAP.
+ mean_ap = 0.0
+ for key, prediction in predictions.items():
+ if key not in retrieval_solution:
+ raise ValueError('Test image %s is not part of retrieval_solution' % key)
+
+ # Loop over predicted images, keeping track of those which were already
+ # used (duplicates are skipped).
+ ap = 0.0
+ already_predicted = set()
+ num_expected_retrieved = min(len(retrieval_solution[key]), max_predictions)
+ num_correct = 0
+ for i in range(min(len(prediction), max_predictions)):
+ if prediction[i] not in already_predicted:
+ if prediction[i] in retrieval_solution[key]:
+ num_correct += 1
+ ap += num_correct / (i + 1)
+ already_predicted.add(prediction[i])
+
+ ap /= num_expected_retrieved
+ mean_ap += ap
+
+ mean_ap /= num_test_images
+
+ return mean_ap
+
+
+def MeanPrecisions(predictions, retrieval_solution, max_predictions=100):
+ """Computes mean precisions for retrieval prediction.
+
+ Args:
+ predictions: Dict mapping test image ID to a list of strings corresponding
+ to index image IDs.
+ retrieval_solution: Dict mapping test image ID to list of ground-truth image
+ IDs.
+ max_predictions: Maximum number of predictions per query to take into
+ account.
+
+ Returns:
+ mean_precisions: NumPy array with mean precisions at ranks 1 through
+ `max_predictions`.
+
+ Raises:
+ ValueError: If a test image in `predictions` is not included in
+ `retrieval_solutions`.
+ """
+ # Compute number of test images.
+ num_test_images = len(retrieval_solution.keys())
+
+ # Loop over predictions for each query and compute precisions@k.
+ precisions = np.zeros((num_test_images, max_predictions))
+ count_test_images = 0
+ for key, prediction in predictions.items():
+ if key not in retrieval_solution:
+ raise ValueError('Test image %s is not part of retrieval_solution' % key)
+
+ # Loop over predicted images, keeping track of those which were already
+ # used (duplicates are skipped).
+ already_predicted = set()
+ num_correct = 0
+ for i in range(max_predictions):
+ if i < len(prediction):
+ if prediction[i] not in already_predicted:
+ if prediction[i] in retrieval_solution[key]:
+ num_correct += 1
+ already_predicted.add(prediction[i])
+ precisions[count_test_images, i] = num_correct / (i + 1)
+ count_test_images += 1
+
+ mean_precisions = np.mean(precisions, axis=0)
+
+ return mean_precisions
+
+
+def MeanMedianPosition(predictions, retrieval_solution, max_predictions=100):
+ """Computes mean and median positions of first correct image.
+
+ Args:
+ predictions: Dict mapping test image ID to a list of strings corresponding
+ to index image IDs.
+ retrieval_solution: Dict mapping test image ID to list of ground-truth image
+ IDs.
+ max_predictions: Maximum number of predictions per query to take into
+ account.
+
+ Returns:
+ mean_position: Float.
+ median_position: Float.
+
+ Raises:
+ ValueError: If a test image in `predictions` is not included in
+ `retrieval_solutions`.
+ """
+ # Compute number of test images.
+ num_test_images = len(retrieval_solution.keys())
+
+ # Loop over predictions for each query to find first correct ranked image.
+ positions = (max_predictions + 1) * np.ones((num_test_images))
+ count_test_images = 0
+ for key, prediction in predictions.items():
+ if key not in retrieval_solution:
+ raise ValueError('Test image %s is not part of retrieval_solution' % key)
+
+ for i in range(min(len(prediction), max_predictions)):
+ if prediction[i] in retrieval_solution[key]:
+ positions[count_test_images] = i + 1
+ break
+
+ count_test_images += 1
+
+ mean_position = np.mean(positions)
+ median_position = np.median(positions)
+
+ return mean_position, median_position
diff --git a/models/research/delf/delf/python/google_landmarks_dataset/metrics_test.py b/models/research/delf/delf/python/google_landmarks_dataset/metrics_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..50838cae2b5bfaa8f6f0c5cbfab2a07aa20b7c52
--- /dev/null
+++ b/models/research/delf/delf/python/google_landmarks_dataset/metrics_test.py
@@ -0,0 +1,219 @@
+# Copyright 2019 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Google Landmarks dataset metric computation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from delf.python.google_landmarks_dataset import metrics
+
+
+def _CreateRecognitionSolution():
+ """Creates recognition solution to be used in tests.
+
+ Returns:
+ solution: Dict mapping test image ID to list of ground-truth landmark IDs.
+ """
+ return {
+ '0123456789abcdef': [0, 12],
+ '0223456789abcdef': [100, 200, 300],
+ '0323456789abcdef': [1],
+ '0423456789abcdef': [],
+ '0523456789abcdef': [],
+ }
+
+
+def _CreateRecognitionPredictions():
+ """Creates recognition predictions to be used in tests.
+
+ Returns:
+ predictions: Dict mapping test image ID to a dict with keys 'class'
+ (integer) and 'score' (float).
+ """
+ return {
+ '0223456789abcdef': {
+ 'class': 0,
+ 'score': 0.01
+ },
+ '0323456789abcdef': {
+ 'class': 1,
+ 'score': 10.0
+ },
+ '0423456789abcdef': {
+ 'class': 150,
+ 'score': 15.0
+ },
+ }
+
+
+def _CreateRetrievalSolution():
+ """Creates retrieval solution to be used in tests.
+
+ Returns:
+ solution: Dict mapping test image ID to list of ground-truth image IDs.
+ """
+ return {
+ '0123456789abcdef': ['fedcba9876543210', 'fedcba9876543220'],
+ '0223456789abcdef': ['fedcba9876543210'],
+ '0323456789abcdef': [
+ 'fedcba9876543230', 'fedcba9876543240', 'fedcba9876543250'
+ ],
+ '0423456789abcdef': ['fedcba9876543230'],
+ }
+
+
+def _CreateRetrievalPredictions():
+ """Creates retrieval predictions to be used in tests.
+
+ Returns:
+ predictions: Dict mapping test image ID to a list with predicted index image
+ ids.
+ """
+ return {
+ '0223456789abcdef': ['fedcba9876543200', 'fedcba9876543210'],
+ '0323456789abcdef': ['fedcba9876543240'],
+ '0423456789abcdef': ['fedcba9876543230', 'fedcba9876543240'],
+ }
+
+
+class MetricsTest(tf.test.TestCase):
+
+ def testGlobalAveragePrecisionWorks(self):
+ # Define input.
+ predictions = _CreateRecognitionPredictions()
+ solution = _CreateRecognitionSolution()
+
+ # Run tested function.
+ gap = metrics.GlobalAveragePrecision(predictions, solution)
+
+ # Define expected results.
+ expected_gap = 0.166667
+
+ # Compare actual and expected results.
+ self.assertAllClose(gap, expected_gap)
+
+ def testGlobalAveragePrecisionIgnoreNonGroundTruthWorks(self):
+ # Define input.
+ predictions = _CreateRecognitionPredictions()
+ solution = _CreateRecognitionSolution()
+
+ # Run tested function.
+ gap = metrics.GlobalAveragePrecision(
+ predictions, solution, ignore_non_gt_test_images=True)
+
+ # Define expected results.
+ expected_gap = 0.333333
+
+ # Compare actual and expected results.
+ self.assertAllClose(gap, expected_gap)
+
+ def testTop1AccuracyWorks(self):
+ # Define input.
+ predictions = _CreateRecognitionPredictions()
+ solution = _CreateRecognitionSolution()
+
+ # Run tested function.
+ accuracy = metrics.Top1Accuracy(predictions, solution)
+
+ # Define expected results.
+ expected_accuracy = 0.333333
+
+ # Compare actual and expected results.
+ self.assertAllClose(accuracy, expected_accuracy)
+
+ def testMeanAveragePrecisionWorks(self):
+ # Define input.
+ predictions = _CreateRetrievalPredictions()
+ solution = _CreateRetrievalSolution()
+
+ # Run tested function.
+ mean_ap = metrics.MeanAveragePrecision(predictions, solution)
+
+ # Define expected results.
+ expected_mean_ap = 0.458333
+
+ # Compare actual and expected results.
+ self.assertAllClose(mean_ap, expected_mean_ap)
+
+ def testMeanAveragePrecisionMaxPredictionsWorks(self):
+ # Define input.
+ predictions = _CreateRetrievalPredictions()
+ solution = _CreateRetrievalSolution()
+
+ # Run tested function.
+ mean_ap = metrics.MeanAveragePrecision(
+ predictions, solution, max_predictions=1)
+
+ # Define expected results.
+ expected_mean_ap = 0.5
+
+ # Compare actual and expected results.
+ self.assertAllClose(mean_ap, expected_mean_ap)
+
+ def testMeanPrecisionsWorks(self):
+ # Define input.
+ predictions = _CreateRetrievalPredictions()
+ solution = _CreateRetrievalSolution()
+
+ # Run tested function.
+ mean_precisions = metrics.MeanPrecisions(
+ predictions, solution, max_predictions=2)
+
+ # Define expected results.
+ expected_mean_precisions = [0.5, 0.375]
+
+ # Compare actual and expected results.
+ self.assertAllClose(mean_precisions, expected_mean_precisions)
+
+ def testMeanMedianPositionWorks(self):
+ # Define input.
+ predictions = _CreateRetrievalPredictions()
+ solution = _CreateRetrievalSolution()
+
+ # Run tested function.
+ mean_position, median_position = metrics.MeanMedianPosition(
+ predictions, solution)
+
+ # Define expected results.
+ expected_mean_position = 26.25
+ expected_median_position = 1.5
+
+ # Compare actual and expected results.
+ self.assertAllClose(mean_position, expected_mean_position)
+ self.assertAllClose(median_position, expected_median_position)
+
+ def testMeanMedianPositionMaxPredictionsWorks(self):
+ # Define input.
+ predictions = _CreateRetrievalPredictions()
+ solution = _CreateRetrievalSolution()
+
+ # Run tested function.
+ mean_position, median_position = metrics.MeanMedianPosition(
+ predictions, solution, max_predictions=1)
+
+ # Define expected results.
+ expected_mean_position = 1.5
+ expected_median_position = 1.5
+
+ # Compare actual and expected results.
+ self.assertAllClose(mean_position, expected_mean_position)
+ self.assertAllClose(median_position, expected_median_position)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/google_landmarks_dataset/rn101_af_gldv2clean_config.pbtxt b/models/research/delf/delf/python/google_landmarks_dataset/rn101_af_gldv2clean_config.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..992cb0fd142ba8c6d89c763d5e323a0d33e7a3a5
--- /dev/null
+++ b/models/research/delf/delf/python/google_landmarks_dataset/rn101_af_gldv2clean_config.pbtxt
@@ -0,0 +1,10 @@
+use_local_features: false
+use_global_features: true
+model_path: "parameters/rn101_af_gldv2clean_20200521"
+image_scales: 0.70710677
+image_scales: 1.0
+image_scales: 1.4142135
+delf_global_config {
+ use_pca: false
+}
+max_image_size: 1024
diff --git a/models/research/delf/delf/python/training/README.md b/models/research/delf/delf/python/training/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a836370fb7830392715c45298987c40e24859032
--- /dev/null
+++ b/models/research/delf/delf/python/training/README.md
@@ -0,0 +1,128 @@
+# DELF Training Instructions
+
+This README documents the end-to-end process for training a landmark detection and retrieval
+model using the DELF library on the [Google Landmarks Dataset v2](https://github.com/cvdfoundation/google-landmark) (GLDv2). This can be achieved following these steps:
+1. Install the DELF Python library.
+2. Download the raw images of the GLDv2 dataset.
+3. Prepare the training data.
+4. Run the training.
+
+The next sections will cove each of these steps in greater detail.
+
+## Prerequisites
+
+Clone the [TensorFlow Model Garden](https://github.com/tensorflow/models) repository and move
+into the `models/research/delf/delf/python/training`folder.
+```
+git clone https://github.com/tensorflow/models.git
+cd models/research/delf/delf/python/training
+```
+
+## Install the DELF Library
+
+The DELF Python library can be installed by running the [`install_delf.sh`](./install_delf.sh)
+script using the command:
+```
+bash install_delf.sh
+```
+The script installs both the DELF library and its dependencies in the following sequence:
+* Install TensorFlow 2.2 and TensorFlow 2.2 for GPU.
+* Install the [TF-Slim](https://github.com/google-research/tf-slim) library from source.
+* Download [protoc](https://github.com/protocolbuffers/protobuf) and compile the DELF Protocol
+Buffers.
+* Install the matplotlib, numpy, scikit-image, scipy and python3-tk Python libraries.
+* Install the [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) from the cloned TensorFlow Model Garden repository.
+* Install the DELF package.
+
+*Please note that the current installation only works on 64 bits Linux architectures due to the
+`protoc` binary downloaded by the installation script. If you wish to install the DELF library on
+other architectures please update the [`install_delf.sh`](./install_delf.sh) script by referencing
+the desired `protoc` [binary release](https://github.com/protocolbuffers/protobuf/releases).*
+
+## Download the GLDv2 Training Data
+
+The [GLDv2](https://github.com/cvdfoundation/google-landmark) images are grouped in 3 datasets: TRAIN, INDEX, TEST. Images in each dataset are grouped into `*.tar` files and individually
+referenced in `*.csv`files containing training metadata and licensing information. The number of
+`*.tar` files per dataset is as follows:
+* TRAIN: 500 files.
+* INDEX: 100 files.
+* TEST: 20 files.
+
+To download the GLDv2 images, run the [`download_dataset.sh`](./download_dataset.sh) script like in
+the following example:
+```
+bash download_dataset.sh 500 100 20
+```
+The script takes the following parameters, in order:
+* The number of image files from the TRAIN dataset to download (maximum 500).
+* The number of image files from the INDEX dataset to download (maximum 100).
+* The number of image files from the TEST dataset to download (maximum 20).
+
+The script downloads the GLDv2 images under the following directory structure:
+* gldv2_dataset/
+ * train/ - Contains raw images from the TRAIN dataset.
+ * index/ - Contains raw images from the INDEX dataset.
+ * test/ - Contains raw images from the TEST dataset.
+
+Each of the three folders `gldv2_dataset/train/`, `gldv2_dataset/index/` and `gldv2_dataset/test/`
+contains the following:
+* The downloaded `*.tar` files.
+* The corresponding MD5 checksum files, `*.txt`.
+* The unpacked content of the downloaded files. (*Images are organized in folders and subfolders
+based on the first, second and third character in their file name.*)
+* The CSV files containing training and licensing metadata of the downloaded images.
+
+*Please note that due to the large size of the GLDv2 dataset, the download can take up to 12
+hours and up to 1 TB of disk space. In order to save bandwidth and disk space, you may want to start by downloading only the TRAIN dataset, the only one required for the training, thus saving
+approximately ~95 GB, the equivalent of the INDEX and TEST datasets. To further save disk space,
+the `*.tar` files can be deleted after downloading and upacking them.*
+
+## Prepare the Data for Training
+
+Preparing the data for training consists of creating [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)
+files from the raw GLDv2 images grouped into TRAIN and VALIDATION splits. The training set
+produced contains only the *clean* subset of the GLDv2 dataset. The [CVPR'20 paper](https://arxiv.org/abs/2004.01804)
+introducing the GLDv2 dataset contains a detailed description of the *clean* subset.
+
+Generating the TFRecord files containing the TRAIN and VALIDATION splits of the *clean* GLDv2
+subset can be achieved by running the [`build_image_dataset.py`](./build_image_dataset.py)
+script. Assuming that the GLDv2 images have been downloaded to the `gldv2_dataset` folder, the
+script can be run as follows:
+```
+python3 build_image_dataset.py \
+ --train_csv_path=gldv2_dataset/train/train.csv \
+ --train_clean_csv_path=gldv2_dataset/train/train_clean.csv \
+ --train_directory=gldv2_dataset/train/*/*/*/ \
+ --output_directory=gldv2_dataset/tfrecord/ \
+ --num_shards=128 \
+ --generate_train_validation_splits \
+ --validation_split_size=0.2
+```
+*Please refer to the source code of the [`build_image_dataset.py`](./build_image_dataset.py) script for a detailed description of its parameters.*
+
+The TFRecord files written in the `OUTPUT_DIRECTORY` will be prefixed as follows:
+* TRAIN split: `train-*`
+* VALIDATION split: `validation-*`
+
+The same script can be used to generate TFRecord files for the TEST split for post-training
+evaluation purposes. This can be achieved by adding the parameters:
+```
+ --test_csv_path=gldv2_dataset/train/test.csv \
+ --test_directory=gldv2_dataset/test/*/*/*/ \
+```
+In this scenario, the TFRecord files of the TEST split written in the `OUTPUT_DIRECTORY` will be
+named according to the pattern `test-*`.
+
+*Please note that due to the large size of the GLDv2 dataset, the generation of the TFRecord
+files can take up to 12 hours and up to 500 GB of space disk.*
+
+## Running the Training
+
+Assuming the TFRecord files were generated in the `gldv2_dataset/tfrecord/` directory, running
+the following command should start training a model:
+
+```
+python3 train.py \
+ --train_file_pattern=gldv2_dataset/tfrecord/train* \
+ --validation_file_pattern=gldv2_dataset/tfrecord/validation*
+```
diff --git a/models/research/delf/delf/python/training/__init__.py b/models/research/delf/delf/python/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c87f3d895c72593403f71e2768b31307f3db5ea6
--- /dev/null
+++ b/models/research/delf/delf/python/training/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2020 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module for DELF training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from delf.python.training import build_image_dataset
+# pylint: enable=unused-import
diff --git a/models/research/delf/delf/python/training/build_image_dataset.py b/models/research/delf/delf/python/training/build_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5df58df0b80d506330e3560b46b8835c283a2bf8
--- /dev/null
+++ b/models/research/delf/delf/python/training/build_image_dataset.py
@@ -0,0 +1,473 @@
+#!/usr/bin/python
+# Copyright 2020 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts landmark image data to TFRecords file format with Example protos.
+
+The image data set is expected to reside in JPEG files ends up with '.jpg'.
+
+This script converts the training and testing data into
+a sharded data set consisting of TFRecord files
+ train_directory/train-00000-of-00128
+ train_directory/train-00001-of-00128
+ ...
+ train_directory/train-00127-of-00128
+and
+ test_directory/test-00000-of-00128
+ test_directory/test-00001-of-00128
+ ...
+ test_directory/test-00127-of-00128
+where we have selected 128 shards for both data sets. Each record
+within the TFRecord file is a serialized Example proto. The Example proto
+contains the following fields:
+ image/encoded: string containing JPEG encoded image in RGB colorspace
+ image/height: integer, image height in pixels
+ image/width: integer, image width in pixels
+ image/colorspace: string, specifying the colorspace, always 'RGB'
+ image/channels: integer, specifying the number of channels, always 3
+ image/format: string, specifying the format, always 'JPEG'
+ image/filename: string, the unique id of the image file
+ e.g. '97c0a12e07ae8dd5' or '650c989dd3493748'
+Furthermore, if the data set type is training, it would contain one more field:
+ image/class/label: integer, the landmark_id from the input training csv file.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+import os
+
+from absl import app
+from absl import flags
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('train_directory', '/tmp/', 'Training data directory.')
+flags.DEFINE_string('test_directory', None,
+ '(Optional) Testing data directory. Required only if '
+ 'test_csv_path is not None.')
+flags.DEFINE_string('output_directory', '/tmp/', 'Output data directory.')
+flags.DEFINE_string('train_csv_path', '/tmp/train.csv',
+ 'Training data csv file path.')
+flags.DEFINE_string('train_clean_csv_path', None,
+ ('(Optional) Clean training data csv file path. '
+ 'If provided, filters images keeping the ones listed in '
+ 'this file. In this case, also outputs a CSV file '
+ 'relabeling.csv mapping new labels to old ones.'))
+flags.DEFINE_string('test_csv_path', None,
+ '(Optional) Testing data csv file path. If None or absent,'
+ 'TFRecords for the images in the test dataset are not'
+ 'generated')
+flags.DEFINE_integer('num_shards', 128, 'Number of shards in output data.')
+flags.DEFINE_boolean('generate_train_validation_splits', False,
+ '(Optional) Whether to split the train dataset into'
+ 'TRAIN and VALIDATION splits.')
+flags.DEFINE_float('validation_split_size', 0.2,
+ '(Optional) The size of the VALIDATION split as a fraction'
+ 'of the train dataset.')
+flags.DEFINE_integer('seed', 0,
+ '(Optional) The seed to be used while shuffling the train'
+ 'dataset when generating the TRAIN and VALIDATION splits.'
+ 'Recommended for splits reproducibility purposes.')
+
+_FILE_IDS_KEY = 'file_ids'
+_IMAGE_PATHS_KEY = 'image_paths'
+_LABELS_KEY = 'labels'
+_TEST_SPLIT = 'test'
+_TRAIN_SPLIT = 'train'
+_VALIDATION_SPLIT = 'validation'
+
+
+def _get_all_image_files_and_labels(name, csv_path, image_dir):
+ """Process input and get the image file paths, image ids and the labels.
+
+ Args:
+ name: 'train' or 'test'.
+ csv_path: path to the Google-landmark Dataset csv Data Sources files.
+ image_dir: directory that stores downloaded images.
+ Returns:
+ image_paths: the paths to all images in the image_dir.
+ file_ids: the unique ids of images.
+ labels: the landmark id of all images. When name='test', the returned labels
+ will be an empty list.
+ Raises:
+ ValueError: if input name is not supported.
+ """
+ image_paths = tf.io.gfile.glob(os.path.join(image_dir, '*.jpg'))
+ file_ids = [os.path.basename(os.path.normpath(f))[:-4] for f in image_paths]
+ if name == _TRAIN_SPLIT:
+ with tf.io.gfile.GFile(csv_path, 'rb') as csv_file:
+ df = pd.read_csv(csv_file)
+ df = df.set_index('id')
+ labels = [int(df.loc[fid]['landmark_id']) for fid in file_ids]
+ elif name == _TEST_SPLIT:
+ labels = []
+ else:
+ raise ValueError('Unsupported dataset split name: %s' % name)
+ return image_paths, file_ids, labels
+
+
+def _get_clean_train_image_files_and_labels(csv_path, image_dir):
+ """Get image file paths, image ids and labels for the clean training split.
+
+ Args:
+ csv_path: path to the Google-landmark Dataset v2 CSV Data Sources files
+ of the clean train dataset. Assumes CSV header landmark_id;images.
+ image_dir: directory that stores downloaded images.
+
+ Returns:
+ image_paths: the paths to all images in the image_dir.
+ file_ids: the unique ids of images.
+ labels: the landmark id of all images.
+ relabeling: relabeling rules created to replace actual labels with
+ a continuous set of labels.
+ """
+ # Load the content of the CSV file (landmark_id/label -> images).
+ with tf.io.gfile.GFile(csv_path, 'rb') as csv_file:
+ df = pd.read_csv(csv_file)
+
+ # Create the dictionary (key = image_id, value = {label, file_id}).
+ images = {}
+ for _, row in df.iterrows():
+ label = row['landmark_id']
+ for file_id in row['images'].split(' '):
+ images[file_id] = {}
+ images[file_id]['label'] = label
+ images[file_id]['file_id'] = file_id
+
+ # Add the full image path to the dictionary of images.
+ image_paths = tf.io.gfile.glob(os.path.join(image_dir, '*.jpg'))
+ for image_path in image_paths:
+ file_id = os.path.basename(os.path.normpath(image_path))[:-4]
+ if file_id in images:
+ images[file_id]['image_path'] = image_path
+
+ # Explode the dictionary into lists (1 per image attribute).
+ image_paths = []
+ file_ids = []
+ labels = []
+ for _, value in images.items():
+ image_paths.append(value['image_path'])
+ file_ids.append(value['file_id'])
+ labels.append(value['label'])
+
+ # Relabel image labels to contiguous values.
+ unique_labels = sorted(set(labels))
+ relabeling = {label: index for index, label in enumerate(unique_labels)}
+ new_labels = [relabeling[label] for label in labels]
+ return image_paths, file_ids, new_labels, relabeling
+
+
+def _process_image(filename):
+ """Process a single image file.
+
+ Args:
+ filename: string, path to an image file e.g., '/path/to/example.jpg'.
+
+ Returns:
+ image_buffer: string, JPEG encoding of RGB image.
+ height: integer, image height in pixels.
+ width: integer, image width in pixels.
+ Raises:
+ ValueError: if parsed image has wrong number of dimensions or channels.
+ """
+ # Read the image file.
+ with tf.io.gfile.GFile(filename, 'rb') as f:
+ image_data = f.read()
+
+ # Decode the RGB JPEG.
+ image = tf.io.decode_jpeg(image_data, channels=3)
+
+ # Check that image converted to RGB
+ if len(image.shape) != 3:
+ raise ValueError('The parsed image number of dimensions is not 3 but %d' %
+ (image.shape))
+ height = image.shape[0]
+ width = image.shape[1]
+ if image.shape[2] != 3:
+ raise ValueError('The parsed image channels is not 3 but %d' %
+ (image.shape[2]))
+
+ return image_data, height, width
+
+
+def _int64_feature(value):
+ """Returns an int64_list from a bool / enum / int / uint."""
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+ """Returns a bytes_list from a string / byte."""
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _convert_to_example(file_id, image_buffer, height, width, label=None):
+ """Build an Example proto for the given inputs.
+
+ Args:
+ file_id: string, unique id of an image file, e.g., '97c0a12e07ae8dd5'.
+ image_buffer: string, JPEG encoding of RGB image.
+ height: integer, image height in pixels.
+ width: integer, image width in pixels.
+ label: integer, the landmark id and prediction label.
+
+ Returns:
+ Example proto.
+ """
+ colorspace = 'RGB'
+ channels = 3
+ image_format = 'JPEG'
+ features = {
+ 'image/height': _int64_feature(height),
+ 'image/width': _int64_feature(width),
+ 'image/colorspace': _bytes_feature(colorspace.encode('utf-8')),
+ 'image/channels': _int64_feature(channels),
+ 'image/format': _bytes_feature(image_format.encode('utf-8')),
+ 'image/id': _bytes_feature(file_id.encode('utf-8')),
+ 'image/encoded': _bytes_feature(image_buffer)
+ }
+ if label is not None:
+ features['image/class/label'] = _int64_feature(label)
+ example = tf.train.Example(features=tf.train.Features(feature=features))
+
+ return example
+
+
+def _write_tfrecord(output_prefix, image_paths, file_ids, labels):
+ """Read image files and write image and label data into TFRecord files.
+
+ Args:
+ output_prefix: string, the prefix of output files, e.g. 'train'.
+ image_paths: list of strings, the paths to images to be converted.
+ file_ids: list of strings, the image unique ids.
+ labels: list of integers, the landmark ids of images. It is an empty list
+ when output_prefix='test'.
+
+ Raises:
+ ValueError: if the length of input images, ids and labels don't match
+ """
+ if output_prefix == _TEST_SPLIT:
+ labels = [None] * len(image_paths)
+ if not len(image_paths) == len(file_ids) == len(labels):
+ raise ValueError('length of image_paths, file_ids, labels shoud be the' +
+ ' same. But they are %d, %d, %d, respectively' %
+ (len(image_paths), len(file_ids), len(labels)))
+
+ spacing = np.linspace(0, len(image_paths), FLAGS.num_shards + 1, dtype=np.int)
+
+ for shard in range(FLAGS.num_shards):
+ output_file = os.path.join(
+ FLAGS.output_directory,
+ '%s-%.5d-of-%.5d' % (output_prefix, shard, FLAGS.num_shards))
+ writer = tf.io.TFRecordWriter(output_file)
+ print('Processing shard ', shard, ' and writing file ', output_file)
+ for i in range(spacing[shard], spacing[shard + 1]):
+ image_buffer, height, width = _process_image(image_paths[i])
+ example = _convert_to_example(file_ids[i], image_buffer, height, width,
+ labels[i])
+ writer.write(example.SerializeToString())
+ writer.close()
+
+
+def _write_relabeling_rules(relabeling_rules):
+ """Write to a file the relabeling rules when the clean train dataset is used.
+
+ Args:
+ relabeling_rules: dictionary of relabeling rules applied when the clean
+ train dataset is used (key = old_label, value = new_label).
+ """
+ relabeling_file_name = os.path.join(FLAGS.output_directory,
+ 'relabeling.csv')
+ with tf.io.gfile.GFile(relabeling_file_name, 'w') as relabeling_file:
+ csv_writer = csv.writer(relabeling_file, delimiter=',')
+ csv_writer.writerow(['new_label', 'old_label'])
+ for old_label, new_label in relabeling_rules.items():
+ csv_writer.writerow([new_label, old_label])
+
+
+def _build_train_and_validation_splits(image_paths, file_ids, labels,
+ validation_split_size, seed):
+ """Create TRAIN and VALIDATION splits containg all labels in equal proportion.
+
+ Args:
+ image_paths: list of paths to the image files in the train dataset.
+ file_ids: list of image file ids in the train dataset.
+ labels: list of image labels in the train dataset.
+ validation_split_size: size of the VALIDATION split as a ratio of the train
+ dataset.
+ seed: seed to use for shuffling the dataset for reproducibility purposes.
+
+ Returns:
+ splits : tuple containing the TRAIN and VALIDATION splits.
+ Raises:
+ ValueError: if the image attributes arrays don't all have the same length,
+ which makes the shuffling impossible.
+ """
+ # Ensure all image attribute arrays have the same length.
+ total_images = len(file_ids)
+ if not (len(image_paths) == total_images and len(labels) == total_images):
+ raise ValueError('Inconsistencies between number of file_ids (%d), number '
+ 'of image_paths (%d) and number of labels (%d). Cannot'
+ 'shuffle the train dataset.'% (total_images,
+ len(image_paths),
+ len(labels)))
+
+ # Stack all image attributes arrays in a single 2D array of dimensions
+ # (3, number of images) and group by label the indices of datapoins in the
+ # image attributes arrays. Explicitly convert label types from 'int' to 'str'
+ # to avoid implicit conversion during stacking with image_paths and file_ids
+ # which are 'str'.
+ labels_str = [str(label) for label in labels]
+ image_attrs = np.stack((image_paths, file_ids, labels_str))
+ image_attrs_idx_by_label = {}
+ for index, label in enumerate(labels):
+ if label not in image_attrs_idx_by_label:
+ image_attrs_idx_by_label[label] = []
+ image_attrs_idx_by_label[label].append(index)
+
+ # Create subsets of image attributes by label, shuffle them separately and
+ # split each subset into TRAIN and VALIDATION splits based on the size of the
+ # validation split.
+ splits = {
+ _VALIDATION_SPLIT: [],
+ _TRAIN_SPLIT: []
+ }
+ rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(seed)))
+ for label, indexes in image_attrs_idx_by_label.items():
+ # Create the subset for the current label.
+ image_attrs_label = image_attrs[:, indexes]
+ images_per_label = image_attrs_label.shape[1]
+ # Shuffle the current label subset.
+ columns_indices = np.arange(images_per_label)
+ rs.shuffle(columns_indices)
+ image_attrs_label = image_attrs_label[:, columns_indices]
+ # Split the current label subset into TRAIN and VALIDATION splits and add
+ # each split to the list of all splits.
+ cutoff_idx = max(1, int(validation_split_size * images_per_label))
+ splits[_VALIDATION_SPLIT].append(image_attrs_label[:, 0 : cutoff_idx])
+ splits[_TRAIN_SPLIT].append(image_attrs_label[:, cutoff_idx : ])
+
+ validation_split = np.concatenate(splits[_VALIDATION_SPLIT], axis=1)
+ train_split = np.concatenate(splits[_TRAIN_SPLIT], axis=1)
+
+ # Unstack the image attribute arrays in the TRAIN and VALIDATION splits and
+ # convert them back to lists. Convert labels back to 'int' from 'str'
+ # following the explicit type change from 'str' to 'int' for stacking.
+ return (
+ {
+ _IMAGE_PATHS_KEY: validation_split[0, :].tolist(),
+ _FILE_IDS_KEY: validation_split[1, :].tolist(),
+ _LABELS_KEY: [int(label) for label in validation_split[2, :].tolist()]
+ }, {
+ _IMAGE_PATHS_KEY: train_split[0, :].tolist(),
+ _FILE_IDS_KEY: train_split[1, :].tolist(),
+ _LABELS_KEY: [int(label) for label in train_split[2, :].tolist()]
+ })
+
+
+def _build_train_tfrecord_dataset(csv_path,
+ clean_csv_path,
+ image_dir,
+ generate_train_validation_splits,
+ validation_split_size,
+ seed):
+ """Build a TFRecord dataset for the train split.
+
+ Args:
+ csv_path: path to the train Google-landmark Dataset csv Data Sources files.
+ clean_csv_path: path to the Google-landmark Dataset v2 CSV Data Sources
+ files of the clean train dataset.
+ image_dir: directory that stores downloaded images.
+ generate_train_validation_splits: whether to split the test dataset into
+ TRAIN and VALIDATION splits.
+ validation_split_size: size of the VALIDATION split as a ratio of the train
+ dataset. Only used if 'generate_train_validation_splits' is True.
+ seed: seed to use for shuffling the dataset for reproducibility purposes.
+ Only used if 'generate_train_validation_splits' is True.
+
+ Returns:
+ Nothing. After the function call, sharded TFRecord files are materialized.
+ Raises:
+ ValueError: if the size of the VALIDATION split is outside (0,1) when TRAIN
+ and VALIDATION splits need to be generated.
+ """
+ # Make sure the size of the VALIDATION split is inside (0, 1) if we need to
+ # generate the TRAIN and VALIDATION splits.
+ if generate_train_validation_splits:
+ if validation_split_size <= 0 or validation_split_size >= 1:
+ raise ValueError('Invalid VALIDATION split size. Expected inside (0,1)'
+ 'but received %f.' % validation_split_size)
+
+ if clean_csv_path:
+ # Load clean train images and labels and write the relabeling rules.
+ (image_paths, file_ids, labels,
+ relabeling_rules) = _get_clean_train_image_files_and_labels(clean_csv_path,
+ image_dir)
+ _write_relabeling_rules(relabeling_rules)
+ else:
+ # Load all train images.
+ image_paths, file_ids, labels = _get_all_image_files_and_labels(
+ _TRAIN_SPLIT, csv_path, image_dir)
+
+ if generate_train_validation_splits:
+ # Generate the TRAIN and VALIDATION splits and write them to TFRecord.
+ validation_split, train_split = _build_train_and_validation_splits(
+ image_paths, file_ids, labels, validation_split_size, seed)
+ _write_tfrecord(_VALIDATION_SPLIT,
+ validation_split[_IMAGE_PATHS_KEY],
+ validation_split[_FILE_IDS_KEY],
+ validation_split[_LABELS_KEY])
+ _write_tfrecord(_TRAIN_SPLIT,
+ train_split[_IMAGE_PATHS_KEY],
+ train_split[_FILE_IDS_KEY],
+ train_split[_LABELS_KEY])
+ else:
+ # Write to TFRecord a single split, TRAIN.
+ _write_tfrecord(_TRAIN_SPLIT, image_paths, file_ids, labels)
+
+
+def _build_test_tfrecord_dataset(csv_path, image_dir):
+ """Build a TFRecord dataset for the 'test' split.
+
+ Args:
+ csv_path: path to the 'test' Google-landmark Dataset csv Data Sources files.
+ image_dir: directory that stores downloaded images.
+
+ Returns:
+ Nothing. After the function call, sharded TFRecord files are materialized.
+ """
+ image_paths, file_ids, labels = _get_all_image_files_and_labels(
+ _TEST_SPLIT, csv_path, image_dir)
+ _write_tfrecord(_TEST_SPLIT, image_paths, file_ids, labels)
+
+
+def main(unused_argv):
+ _build_train_tfrecord_dataset(FLAGS.train_csv_path,
+ FLAGS.train_clean_csv_path,
+ FLAGS.train_directory,
+ FLAGS.generate_train_validation_splits,
+ FLAGS.validation_split_size,
+ FLAGS.seed)
+ if FLAGS.test_csv_path is not None:
+ _build_test_tfrecord_dataset(FLAGS.test_csv_path, FLAGS.test_directory)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/models/research/delf/delf/python/training/datasets/__init__.py b/models/research/delf/delf/python/training/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e0a672716945394cce4b2c69ee3d086192da87c
--- /dev/null
+++ b/models/research/delf/delf/python/training/datasets/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2020 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module exposing datasets for training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from delf.python.training.datasets import googlelandmarks
+# pylint: enable=unused-import
diff --git a/models/research/delf/delf/python/training/datasets/googlelandmarks.py b/models/research/delf/delf/python/training/datasets/googlelandmarks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f289cc166460f3a2fd9f157bc672ea0a464a2995
--- /dev/null
+++ b/models/research/delf/delf/python/training/datasets/googlelandmarks.py
@@ -0,0 +1,187 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Google Landmarks Dataset(GLD).
+
+Placeholder for Google Landmarks dataset.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf
+
+
+class _GoogleLandmarksInfo(object):
+ """Metadata about the Google Landmarks dataset."""
+ num_classes = {
+ 'gld_v1': 14951,
+ 'gld_v2': 203094,
+ 'gld_v2_clean': 81313
+ }
+
+
+class _DataAugmentationParams(object):
+ """Default parameters for augmentation."""
+ # The following are used for training.
+ min_object_covered = 0.1
+ aspect_ratio_range_min = 3. / 4
+ aspect_ratio_range_max = 4. / 3
+ area_range_min = 0.08
+ area_range_max = 1.0
+ max_attempts = 100
+ update_labels = False
+ # 'central_fraction' is used for central crop in inference.
+ central_fraction = 0.875
+
+ random_reflection = False
+ input_rows = 321
+ input_cols = 321
+
+
+def NormalizeImages(images, pixel_value_scale=0.5, pixel_value_offset=0.5):
+ """Normalize pixel values in image.
+
+ Output is computed as
+ normalized_images = (images - pixel_value_offset) / pixel_value_scale.
+
+ Args:
+ images: `Tensor`, images to normalize.
+ pixel_value_scale: float, scale.
+ pixel_value_offset: float, offset.
+
+ Returns:
+ normalized_images: `Tensor`, normalized images.
+ """
+ images = tf.cast(images, tf.float32)
+ normalized_images = tf.math.divide(
+ tf.subtract(images, pixel_value_offset), pixel_value_scale)
+ return normalized_images
+
+
+def _ImageNetCrop(image):
+ """Imagenet-style crop with random bbox and aspect ratio.
+
+ Args:
+ image: a `Tensor`, image to crop.
+
+ Returns:
+ cropped_image: `Tensor`, cropped image.
+ """
+
+ params = _DataAugmentationParams()
+ bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
+ (bbox_begin, bbox_size, _) = tf.image.sample_distorted_bounding_box(
+ tf.shape(image),
+ bounding_boxes=bbox,
+ min_object_covered=params.min_object_covered,
+ aspect_ratio_range=(params.aspect_ratio_range_min,
+ params.aspect_ratio_range_max),
+ area_range=(params.area_range_min, params.area_range_max),
+ max_attempts=params.max_attempts,
+ use_image_if_no_bounding_boxes=True)
+ cropped_image = tf.slice(image, bbox_begin, bbox_size)
+ cropped_image.set_shape([None, None, 3])
+
+ cropped_image = tf.image.resize(
+ cropped_image, [params.input_rows, params.input_cols], method='area')
+ if params.random_reflection:
+ cropped_image = tf.image.random_flip_left_right(cropped_image)
+
+ return cropped_image
+
+
+def _ParseFunction(example, name_to_features, image_size, augmentation):
+ """Parse a single TFExample to get the image and label and process the image.
+
+ Args:
+ example: a `TFExample`.
+ name_to_features: a `dict`. The mapping from feature names to its type.
+ image_size: an `int`. The image size for the decoded image, on each side.
+ augmentation: a `boolean`. True if the image will be augmented.
+
+ Returns:
+ image: a `Tensor`. The processed image.
+ label: a `Tensor`. The ground-truth label.
+ """
+ parsed_example = tf.io.parse_single_example(example, name_to_features)
+ # Parse to get image.
+ image = parsed_example['image/encoded']
+ image = tf.io.decode_jpeg(image)
+ if augmentation:
+ image = _ImageNetCrop(image)
+ else:
+ image = tf.image.resize(image, [image_size, image_size])
+ image.set_shape([image_size, image_size, 3])
+ # Parse to get label.
+ label = parsed_example['image/class/label']
+ return image, label
+
+
+def CreateDataset(file_pattern,
+ image_size=321,
+ batch_size=32,
+ augmentation=False,
+ seed=0):
+ """Creates a dataset.
+
+ Args:
+ file_pattern: str, file pattern of the dataset files.
+ image_size: int, image size.
+ batch_size: int, batch size.
+ augmentation: bool, whether to apply augmentation.
+ seed: int, seed for shuffling the dataset.
+
+ Returns:
+ tf.data.TFRecordDataset.
+ """
+
+ filenames = tf.io.gfile.glob(file_pattern)
+
+ dataset = tf.data.TFRecordDataset(filenames)
+ dataset = dataset.repeat().shuffle(buffer_size=100, seed=seed)
+
+ # Create a description of the features.
+ feature_description = {
+ 'image/height': tf.io.FixedLenFeature([], tf.int64, default_value=0),
+ 'image/width': tf.io.FixedLenFeature([], tf.int64, default_value=0),
+ 'image/channels': tf.io.FixedLenFeature([], tf.int64, default_value=0),
+ 'image/format': tf.io.FixedLenFeature([], tf.string, default_value=''),
+ 'image/filename': tf.io.FixedLenFeature([], tf.string, default_value=''),
+ 'image/encoded': tf.io.FixedLenFeature([], tf.string, default_value=''),
+ 'image/class/label': tf.io.FixedLenFeature([], tf.int64, default_value=0),
+ }
+
+ customized_parse_func = functools.partial(
+ _ParseFunction,
+ name_to_features=feature_description,
+ image_size=image_size,
+ augmentation=augmentation)
+ dataset = dataset.map(customized_parse_func)
+ dataset = dataset.batch(batch_size)
+
+ return dataset
+
+
+def GoogleLandmarksInfo():
+ """Returns metadata information on the Google Landmarks dataset.
+
+ Returns:
+ object _GoogleLandmarksInfo containing metadata about the GLD dataset.
+ """
+ return _GoogleLandmarksInfo()
diff --git a/models/research/delf/delf/python/training/download_dataset.sh b/models/research/delf/delf/python/training/download_dataset.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ecbd905eccde6b4056f4b3cc0a011695debb3390
--- /dev/null
+++ b/models/research/delf/delf/python/training/download_dataset.sh
@@ -0,0 +1,161 @@
+#!/bin/bash
+
+# Copyright 2020 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# This script downloads the Google Landmarks v2 dataset. To download the dataset
+# run the script like in the following example:
+# bash download_dataset.sh 500 100 20
+#
+# The script takes the following parameters, in order:
+# - number of image files from the TRAIN split to download (maximum 500)
+# - number of image files from the INDEX split to download (maximum 100)
+# - number of image files from the TEST split to download (maximum 20)
+
+image_files_train=$1 # Number of image files to download from the TRAIN split
+image_files_index=$2 # Number of image files to download from the INDEX split
+image_files_test=$3 # Number of image files to download from the TEST split
+
+splits=("train" "test" "index")
+dataset_root_folder=gldv2_dataset
+
+metadata_url="https://s3.amazonaws.com/google-landmark/metadata"
+ground_truth_url="https://s3.amazonaws.com/google-landmark/ground_truth"
+csv_train=(${metadata_url}/train.csv ${metadata_url}/train_clean.csv ${metadata_url}/train_attribution.csv ${metadata_url}/train_label_to_category.csv)
+csv_index=(${metadata_url}/index.csv ${metadata_url}/index_image_to_landmark.csv ${metadata_url}/index_label_to_category.csv)
+csv_test=(${metadata_url}/test.csv ${ground_truth_url}/recognition_solution_v2.1.csv ${ground_truth_url}/retrieval_solution_v2.1.csv)
+
+images_tar_file_base_url="https://s3.amazonaws.com/google-landmark"
+images_md5_file_base_url="https://s3.amazonaws.com/google-landmark/md5sum"
+num_processes=6
+
+make_folder() {
+ # Creates a folder and checks if it exists. Exits if folder creation fails.
+ local folder=$1
+ if [ -d "${folder}" ]; then
+ echo "Folder ${folder} already exists. Skipping folder creation."
+ else
+ echo "Creating folder ${folder}."
+ if mkdir ${folder}; then
+ echo "Successfully created folder ${folder}."
+ else
+ echo "Failed to create folder ${folder}. Exiting."
+ exit 1
+ fi
+ fi
+}
+
+download_file() {
+ # Downloads a file from an URL into a specified folder.
+ local file_url=$1
+ local folder=$2
+ local file_path="${folder}/`basename ${file_url}`"
+ echo "Downloading file ${file_url} to folder ${folder}."
+ pushd . > /dev/null
+ cd ${folder}
+ curl -Os ${file_url}
+ popd > /dev/null
+}
+
+validate_md5_checksum() {
+ # Validate the MD5 checksum of a downloaded file.
+ local content_file=$1
+ local md5_file=$2
+ echo "Checking MD5 checksum of file ${content_file} against ${md5_file}"
+ if [[ "${OSTYPE}" == "linux-gnu" ]]; then
+ content_md5=`md5sum ${content_file}`
+ elif [[ "${OSTYPE}" == "darwin"* ]]; then
+ content_md5=`md5 -r "${content_file}"`
+ fi
+ content_md5=`cut -d' ' -f1<<<"${content_md5}"`
+ expected_md5=`cut -d' ' -f1<<${max_idx}?${max_idx}:${curr_max_idx}))
+ for j in $(seq ${i} 1 ${last_idx}); do download_image_file "${split}" "${j}" "${split_folder}" & done
+ wait
+ done
+}
+
+download_csv_files() {
+ # Downloads all medatada CSV files of a split.
+ local split=$1
+ local split_folder=$2
+ local csv_list="csv_${split}[*]"
+ for csv_file in ${!csv_list}; do
+ download_file "${csv_file}" "${split_folder}"
+ done
+}
+
+download_split() {
+ # Downloads all artifacts, metadata CSV files and image files of a single split.
+ local split=$1
+ local split_folder=${dataset_root_folder}/${split}
+ make_folder "${split_folder}"
+ download_csv_files "${split}" "${split_folder}"
+ download_image_files "${split}" "${split_folder}"
+}
+
+download_all_splits() {
+ # Downloads all artifacts, metadata CSV files and image files of all splits.
+ make_folder "${dataset_root_folder}"
+ for split in "${splits[@]}"; do
+ download_split "$split"
+ done
+}
+
+download_all_splits
+
+exit 0
diff --git a/models/research/delf/delf/python/training/install_delf.sh b/models/research/delf/delf/python/training/install_delf.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4feb464aa7def067028e65281d906e006f4533a2
--- /dev/null
+++ b/models/research/delf/delf/python/training/install_delf.sh
@@ -0,0 +1,153 @@
+#!/bin/bash
+
+# Copyright 2020 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# This script installs the DELF package along with its dependencies. To install
+# the DELF package run the script like in the following example:
+# bash install_delf.sh
+
+protoc_folder="protoc"
+protoc_url="https://github.com/google/protobuf/releases/download/v3.3.0/protoc-3.3.0-linux-x86_64.zip"
+tf_slim_git_repo="https://github.com/google-research/tf-slim.git"
+
+handle_exit_code() {
+ # Fail gracefully in case of an exit code different than 0.
+ exit_code=$1
+ error_message=$2
+ if [ ${exit_code} -ne 0 ]; then
+ echo "${error_message} Exiting."
+ exit 1
+ fi
+}
+
+install_tensorflow() {
+ # Install TensorFlow 2.2.
+ echo "Installing TensorFlow 2.2"
+ pip3 install --upgrade tensorflow==2.2.0
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to install Tensorflow 2.2."
+ echo "Installing TensorFlow 2.2 for GPU"
+ pip3 install --upgrade tensorflow-gpu==2.2.0
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to install Tensorflow for GPU 2.2.0."
+}
+
+install_tf_slim() {
+ # Install TF-Slim from source.
+ echo "Installing TF-Slim from source: ${git_repo}"
+ git clone ${tf_slim_git_repo}
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to clone TF-Slim repository ${tf_slim_git_repo}."
+ pushd . > /dev/null
+ cd tf-slim
+ pip3 install .
+ popd > /dev/null
+ rm -rf tf-slim
+}
+
+download_protoc() {
+ # Installs the Protobuf compiler protoc.
+ echo "Downloading Protobuf compiler from ${protoc_url}"
+ curl -L -Os ${protoc_url}
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to download Protobuf compiler from ${tf_slim_git_repo}."
+
+ mkdir ${protoc_folder}
+ local protoc_archive=`basename ${protoc_url}`
+ unzip ${protoc_archive} -d ${protoc_folder}
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to unzip Protobuf compiler from ${protoc_archive}."
+
+ rm ${protoc_archive}
+}
+
+compile_delf_protos() {
+ # Compiles DELF protobufs from tensorflow/models/research/delf using the potoc compiler.
+ echo "Compiling DELF Protobufs"
+ PATH_TO_PROTOC="`pwd`/${protoc_folder}"
+ pushd . > /dev/null
+ cd ../../..
+ ${PATH_TO_PROTOC}/bin/protoc delf/protos/*.proto --python_out=.
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to compile DELF Protobufs."
+ popd > /dev/null
+}
+
+cleanup_protoc() {
+ # Removes the downloaded Protobuf compiler protoc after the installation of the DELF package.
+ echo "Cleaning up Protobuf compiler download"
+ rm -rf ${protoc_folder}
+}
+
+install_python_libraries() {
+ # Installs Python libraries upon which the DELF package has dependencies.
+ echo "Installing matplotlib, numpy, scikit-image, scipy and python3-tk"
+ pip3 install matplotlib numpy scikit-image scipy
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to install at least one of: matplotlib numpy scikit-image scipy."
+ sudo apt-get -y install python3-tk
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to install python3-tk."
+}
+
+install_object_detection() {
+ # Installs the object detection package from tensorflow/models/research.
+ echo "Installing object detection"
+ pushd . > /dev/null
+ cd ../../../..
+ export PYTHONPATH=$PYTHONPATH:`pwd`
+ pip3 install .
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to install the object_detection package."
+ popd > /dev/null
+}
+
+install_delf_package() {
+ # Installs the DELF package from tensorflow/models/research/delf/delf.
+ echo "Installing DELF package"
+ pushd . > /dev/null
+ cd ../../..
+ pip3 install -e .
+ local exit_code=$?
+ handle_exit_code ${exit_code} "Unable to install the DELF package."
+ popd > /dev/null
+}
+
+post_install_check() {
+ # Checks the DELF package has been successfully installed.
+ echo "Checking DELF package installation"
+ python3 -c 'import delf'
+ local exit_code=$?
+ handle_exit_code ${exit_code} "DELF package installation check failed."
+ echo "Installation successful."
+}
+
+install_delf() {
+ # Orchestrates DELF package installation.
+ install_tensorflow
+ install_tf_slim
+ download_protoc
+ compile_delf_protos
+ cleanup_protoc
+ install_python_libraries
+ install_object_detection
+ install_delf_package
+ post_install_check
+}
+
+install_delf
+
+exit 0
diff --git a/models/research/delf/delf/python/training/model/__init__.py b/models/research/delf/delf/python/training/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcc888bd8a65e9ba48f15e4082064e7285ac2591
--- /dev/null
+++ b/models/research/delf/delf/python/training/model/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2020 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""DELF model module, used for training and exporting."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from delf.python.training.model import delf_model
+from delf.python.training.model import export_model_utils
+from delf.python.training.model import resnet50
+# pylint: enable=unused-import
diff --git a/models/research/delf/delf/python/training/model/delf_model.py b/models/research/delf/delf/python/training/model/delf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..27409de99c52dcb0f0eb00ca9ae0602a2be0d30b
--- /dev/null
+++ b/models/research/delf/delf/python/training/model/delf_model.py
@@ -0,0 +1,141 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""DELF model implementation based on the following paper.
+
+ Large-Scale Image Retrieval with Attentive Deep Local Features
+ https://arxiv.org/abs/1612.06321
+"""
+
+import tensorflow as tf
+
+from delf.python.training.model import resnet50 as resnet
+
+layers = tf.keras.layers
+reg = tf.keras.regularizers
+
+_DECAY = 0.0001
+
+
+class AttentionModel(tf.keras.Model):
+ """Instantiates attention model.
+
+ Uses two [kernel_size x kernel_size] convolutions and softplus as activation
+ to compute an attention map with the same resolution as the featuremap.
+ Features l2-normalized and aggregated using attention probabilites as weights.
+ """
+
+ def __init__(self, kernel_size=1, decay=_DECAY, name='attention'):
+ """Initialization of attention model.
+
+ Args:
+ kernel_size: int, kernel size of convolutions.
+ decay: float, decay for l2 regularization of kernel weights.
+ name: str, name to identify model.
+ """
+ super(AttentionModel, self).__init__(name=name)
+
+ # First convolutional layer (called with relu activation).
+ self.conv1 = layers.Conv2D(
+ 512,
+ kernel_size,
+ kernel_regularizer=reg.l2(decay),
+ padding='same',
+ name='attn_conv1')
+ self.bn_conv1 = layers.BatchNormalization(axis=3, name='bn_conv1')
+
+ # Second convolutional layer, with softplus activation.
+ self.conv2 = layers.Conv2D(
+ 1,
+ kernel_size,
+ kernel_regularizer=reg.l2(decay),
+ padding='same',
+ name='attn_conv2')
+ self.activation_layer = layers.Activation('softplus')
+
+ def call(self, inputs, training=True):
+ x = self.conv1(inputs)
+ x = self.bn_conv1(x, training=training)
+ x = tf.nn.relu(x)
+
+ score = self.conv2(x)
+ prob = self.activation_layer(score)
+
+ # L2-normalize the featuremap before pooling.
+ inputs = tf.nn.l2_normalize(inputs, axis=-1)
+ feat = tf.reduce_mean(tf.multiply(inputs, prob), [1, 2], keepdims=False)
+
+ return feat, prob, score
+
+
+class Delf(tf.keras.Model):
+ """Instantiates Keras DELF model using ResNet50 as backbone.
+
+ This class implements the [DELF](https://arxiv.org/abs/1612.06321) model for
+ extracting local features from images. The backbone is a ResNet50 network
+ that extracts featuremaps from both conv_4 and conv_5 layers. Activations
+ from conv_4 are used to compute an attention map of the same resolution.
+ """
+
+ def __init__(self, block3_strides=True, name='DELF'):
+ """Initialization of DELF model.
+
+ Args:
+ block3_strides: bool, whether to add strides to the output of block3.
+ name: str, name to identify model.
+ """
+ super(Delf, self).__init__(name=name)
+
+ # Backbone using Keras ResNet50.
+ self.backbone = resnet.ResNet50(
+ 'channels_last',
+ name='backbone',
+ include_top=False,
+ pooling='avg',
+ block3_strides=block3_strides,
+ average_pooling=False)
+
+ # Attention model.
+ self.attention = AttentionModel(name='attention')
+
+ # Define classifiers for training backbone and attention models.
+ def init_classifiers(self, num_classes):
+ self.num_classes = num_classes
+ self.desc_classification = layers.Dense(
+ num_classes, activation=None, kernel_regularizer=None, name='desc_fc')
+
+ self.attn_classification = layers.Dense(
+ num_classes, activation=None, kernel_regularizer=None, name='att_fc')
+
+ # Weights to optimize for descriptor fine tuning.
+ @property
+ def desc_trainable_weights(self):
+ return (self.backbone.trainable_weights +
+ self.desc_classification.trainable_weights)
+
+ # Weights to optimize for attention model training.
+ @property
+ def attn_trainable_weights(self):
+ return (self.attention.trainable_weights +
+ self.attn_classification.trainable_weights)
+
+ def call(self, input_image, training=True):
+ blocks = {'block3': None}
+ self.backbone(input_image, intermediates_dict=blocks, training=training)
+
+ features = blocks['block3']
+ _, probs, _ = self.attention(features, training=training)
+
+ return probs, features
diff --git a/models/research/delf/delf/python/training/model/delf_model_test.py b/models/research/delf/delf/python/training/model/delf_model_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4cbcef555db3cd6e6395aee69f1479863916bd4
--- /dev/null
+++ b/models/research/delf/delf/python/training/model/delf_model_test.py
@@ -0,0 +1,115 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the DELF model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import tensorflow as tf
+
+from delf.python.training.model import delf_model
+
+
+class DelfTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('block3_stridesTrue', True),
+ ('block3_stridesFalse', False),
+ )
+ def test_build_model(self, block3_strides):
+ image_size = 321
+ num_classes = 1000
+ batch_size = 2
+ input_shape = (batch_size, image_size, image_size, 3)
+
+ model = delf_model.Delf(block3_strides=block3_strides, name='DELF')
+ model.init_classifiers(num_classes)
+
+ images = tf.random.uniform(input_shape, minval=-1.0, maxval=1.0, seed=0)
+ blocks = {}
+
+ # Get global feature by pooling block4 features.
+ desc_prelogits = model.backbone(
+ images, intermediates_dict=blocks, training=False)
+ desc_logits = model.desc_classification(desc_prelogits)
+ self.assertAllEqual(desc_prelogits.shape, (batch_size, 2048))
+ self.assertAllEqual(desc_logits.shape, (batch_size, num_classes))
+
+ features = blocks['block3']
+ attn_prelogits, _, _ = model.attention(features)
+ attn_logits = model.attn_classification(attn_prelogits)
+ self.assertAllEqual(attn_prelogits.shape, (batch_size, 1024))
+ self.assertAllEqual(attn_logits.shape, (batch_size, num_classes))
+
+ @parameterized.named_parameters(
+ ('block3_stridesTrue', True),
+ ('block3_stridesFalse', False),
+ )
+ def test_train_step(self, block3_strides):
+
+ image_size = 321
+ num_classes = 1000
+ batch_size = 2
+ clip_val = 10.0
+ input_shape = (batch_size, image_size, image_size, 3)
+
+ model = delf_model.Delf(block3_strides=block3_strides, name='DELF')
+ model.init_classifiers(num_classes)
+
+ optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)
+
+ images = tf.random.uniform(input_shape, minval=0.0, maxval=1.0, seed=0)
+ labels = tf.random.uniform((batch_size,),
+ minval=0,
+ maxval=model.num_classes - 1,
+ dtype=tf.int64)
+
+ loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
+
+ def compute_loss(labels, predictions):
+ per_example_loss = loss_object(labels, predictions)
+ return tf.nn.compute_average_loss(
+ per_example_loss, global_batch_size=batch_size)
+
+ with tf.GradientTape() as desc_tape:
+ blocks = {}
+ desc_prelogits = model.backbone(
+ images, intermediates_dict=blocks, training=False)
+ desc_logits = model.desc_classification(desc_prelogits)
+ desc_logits = model.desc_classification(desc_prelogits)
+ desc_loss = compute_loss(labels, desc_logits)
+
+ gradients = desc_tape.gradient(desc_loss, model.desc_trainable_weights)
+ clipped, _ = tf.clip_by_global_norm(gradients, clip_norm=clip_val)
+ optimizer.apply_gradients(zip(clipped, model.desc_trainable_weights))
+
+ with tf.GradientTape() as attn_tape:
+ block3 = blocks['block3']
+ block3 = tf.stop_gradient(block3)
+ attn_prelogits, _, _ = model.attention(block3, training=True)
+ attn_logits = model.attn_classification(attn_prelogits)
+ attn_loss = compute_loss(labels, attn_logits)
+
+ gradients = attn_tape.gradient(attn_loss, model.attn_trainable_weights)
+ clipped, _ = tf.clip_by_global_norm(gradients, clip_norm=clip_val)
+ optimizer.apply_gradients(zip(clipped, model.attn_trainable_weights))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/delf/delf/python/training/model/export_model.py b/models/research/delf/delf/python/training/model/export_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4af69a231641ef2cc69a08fb9a5ba5c31655c26c
--- /dev/null
+++ b/models/research/delf/delf/python/training/model/export_model.py
@@ -0,0 +1,137 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Export DELF tensorflow inference model.
+
+This model includes feature extraction, receptive field calculation and
+key-point selection and outputs the selected feature descriptors.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+import tensorflow as tf
+
+from delf.python.training.model import delf_model
+from delf.python.training.model import export_model_utils
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('ckpt_path', '/tmp/delf-logdir/delf-weights',
+ 'Path to saved checkpoint.')
+flags.DEFINE_string('export_path', None, 'Path where model will be exported.')
+flags.DEFINE_boolean('block3_strides', False,
+ 'Whether to apply strides after block3.')
+flags.DEFINE_float('iou', 1.0, 'IOU for non-max suppression.')
+
+
+def _build_tensor_info(tensor_dict):
+ """Replace the dict's value by the tensor info.
+
+ Args:
+ tensor_dict: A dictionary contains .
+
+ Returns:
+ dict: New dictionary contains .
+ """
+ return {
+ k: tf.compat.v1.saved_model.utils.build_tensor_info(t)
+ for k, t in tensor_dict.items()
+ }
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+
+ export_path = FLAGS.export_path
+ if os.path.exists(export_path):
+ raise ValueError('Export_path already exists.')
+
+ with tf.Graph().as_default() as g, tf.compat.v1.Session(graph=g) as sess:
+
+ # Setup the DELF model for extraction.
+ model = delf_model.Delf(block3_strides=FLAGS.block3_strides, name='DELF')
+
+ # Initial forward pass to build model.
+ images = tf.zeros((1, 321, 321, 3), dtype=tf.float32)
+ model(images)
+
+ stride_factor = 2.0 if FLAGS.block3_strides else 1.0
+
+ # Setup the multiscale keypoint extraction.
+ input_image = tf.compat.v1.placeholder(
+ tf.uint8, shape=(None, None, 3), name='input_image')
+ input_abs_thres = tf.compat.v1.placeholder(
+ tf.float32, shape=(), name='input_abs_thres')
+ input_scales = tf.compat.v1.placeholder(
+ tf.float32, shape=[None], name='input_scales')
+ input_max_feature_num = tf.compat.v1.placeholder(
+ tf.int32, shape=(), name='input_max_feature_num')
+
+ extracted_features = export_model_utils.ExtractLocalFeatures(
+ input_image, input_scales, input_max_feature_num, input_abs_thres,
+ FLAGS.iou, lambda x: model(x, training=False), stride_factor)
+
+ # Load the weights.
+ checkpoint_path = FLAGS.ckpt_path
+ model.load_weights(checkpoint_path)
+ print('Checkpoint loaded from ', checkpoint_path)
+
+ named_input_tensors = {
+ 'input_image': input_image,
+ 'input_scales': input_scales,
+ 'input_abs_thres': input_abs_thres,
+ 'input_max_feature_num': input_max_feature_num,
+ }
+
+ # Outputs to the exported model.
+ named_output_tensors = {}
+ named_output_tensors['boxes'] = tf.identity(
+ extracted_features[0], name='boxes')
+ named_output_tensors['features'] = tf.identity(
+ extracted_features[1], name='features')
+ named_output_tensors['scales'] = tf.identity(
+ extracted_features[2], name='scales')
+ named_output_tensors['scores'] = tf.identity(
+ extracted_features[3], name='scores')
+
+ # Export the model.
+ signature_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
+ inputs=_build_tensor_info(named_input_tensors),
+ outputs=_build_tensor_info(named_output_tensors))
+
+ print('Exporting trained model to:', export_path)
+ builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)
+
+ init_op = None
+ builder.add_meta_graph_and_variables(
+ sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
+ signature_def_map={
+ tf.compat.v1.saved_model.signature_constants
+ .DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ signature_def
+ },
+ main_op=init_op)
+ builder.save()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/models/research/delf/delf/python/training/model/export_model_utils.py b/models/research/delf/delf/python/training/model/export_model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4302aca139802e99d80bfd4e1fc27e353abdfbb
--- /dev/null
+++ b/models/research/delf/delf/python/training/model/export_model_utils.py
@@ -0,0 +1,171 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper functions for DELF model exporting."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from delf import feature_extractor
+from delf.python.training.datasets import googlelandmarks as gld
+from object_detection.core import box_list
+from object_detection.core import box_list_ops
+
+
+def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
+ attention_model_fn, stride_factor):
+ """Extract local features for input image.
+
+ Args:
+ image: image tensor of type tf.uint8 with shape [h, w, channels].
+ image_scales: 1D float tensor which contains float scales used for image
+ pyramid construction.
+ max_feature_num: int tensor denotes the maximum selected feature points.
+ abs_thres: float tensor denotes the score threshold for feature selection.
+ iou: float scalar denotes the iou threshold for NMS.
+ attention_model_fn: model function. Follows the signature:
+ * Args:
+ * `images`: Image tensor which is re-scaled.
+ * Returns:
+ * `attention_prob`: attention map after the non-linearity.
+ * `feature_map`: feature map after ResNet convolution.
+ stride_factor: integer accounting for striding after block3.
+
+ Returns:
+ boxes: [N, 4] float tensor which denotes the selected receptive box. N is
+ the number of final feature points which pass through keypoint selection
+ and NMS steps.
+ features: [N, depth] float tensor.
+ feature_scales: [N] float tensor. It is the inverse of the input image
+ scales such that larger image scales correspond to larger image regions,
+ which is compatible with keypoints detected with other techniques, for
+ example Congas.
+ scores: [N, 1] float tensor denotes the attention score.
+
+ """
+ original_image_shape_float = tf.gather(
+ tf.dtypes.cast(tf.shape(image), tf.float32), [0, 1])
+
+ image_tensor = gld.NormalizeImages(
+ image, pixel_value_offset=128.0, pixel_value_scale=128.0)
+ image_tensor = tf.expand_dims(image_tensor, 0, name='image/expand_dims')
+
+ # Hard code the feature depth and receptive field parameters for now.
+ rf, stride, padding = [291.0, 16.0 * stride_factor, 145.0]
+ feature_depth = 1024
+
+ def _ProcessSingleScale(scale_index, boxes, features, scales, scores):
+ """Resizes the image and run feature extraction and keypoint selection.
+
+ This function will be passed into tf.while_loop() and be called
+ repeatedly. The input boxes are collected from the previous iteration
+ [0: scale_index -1]. We get the current scale by
+ image_scales[scale_index], and run resize image, feature extraction and
+ keypoint selection. Then we will get a new set of selected_boxes for
+ current scale. In the end, we concat the previous boxes with current
+ selected_boxes as the output.
+ Args:
+ scale_index: A valid index in the image_scales.
+ boxes: Box tensor with the shape of [N, 4].
+ features: Feature tensor with the shape of [N, depth].
+ scales: Scale tensor with the shape of [N].
+ scores: Attention score tensor with the shape of [N].
+
+ Returns:
+ scale_index: The next scale index for processing.
+ boxes: Concatenated box tensor with the shape of [K, 4]. K >= N.
+ features: Concatenated feature tensor with the shape of [K, depth].
+ scales: Concatenated scale tensor with the shape of [K].
+ scores: Concatenated score tensor with the shape of [K].
+ """
+ scale = tf.gather(image_scales, scale_index)
+ new_image_size = tf.dtypes.cast(
+ tf.round(original_image_shape_float * scale), tf.int32)
+ resized_image = tf.image.resize(image_tensor, new_image_size)
+
+ attention_prob, feature_map = attention_model_fn(resized_image)
+ attention_prob = tf.squeeze(attention_prob, axis=[0])
+ feature_map = tf.squeeze(feature_map, axis=[0])
+
+ rf_boxes = feature_extractor.CalculateReceptiveBoxes(
+ tf.shape(feature_map)[0],
+ tf.shape(feature_map)[1], rf, stride, padding)
+
+ # Re-project back to the original image space.
+ rf_boxes = tf.divide(rf_boxes, scale)
+ attention_prob = tf.reshape(attention_prob, [-1])
+ feature_map = tf.reshape(feature_map, [-1, feature_depth])
+
+ # Use attention score to select feature vectors.
+ indices = tf.reshape(tf.where(attention_prob >= abs_thres), [-1])
+ selected_boxes = tf.gather(rf_boxes, indices)
+ selected_features = tf.gather(feature_map, indices)
+ selected_scores = tf.gather(attention_prob, indices)
+ selected_scales = tf.ones_like(selected_scores, tf.float32) / scale
+
+ # Concat with the previous result from different scales.
+ boxes = tf.concat([boxes, selected_boxes], 0)
+ features = tf.concat([features, selected_features], 0)
+ scales = tf.concat([scales, selected_scales], 0)
+ scores = tf.concat([scores, selected_scores], 0)
+
+ return scale_index + 1, boxes, features, scales, scores
+
+ output_boxes = tf.zeros([0, 4], dtype=tf.float32)
+ output_features = tf.zeros([0, feature_depth], dtype=tf.float32)
+ output_scales = tf.zeros([0], dtype=tf.float32)
+ output_scores = tf.zeros([0], dtype=tf.float32)
+
+ # Process the first scale separately, the following scales will reuse the
+ # graph variables.
+ (_, output_boxes, output_features, output_scales,
+ output_scores) = _ProcessSingleScale(0, output_boxes, output_features,
+ output_scales, output_scores)
+
+ i = tf.constant(1, dtype=tf.int32)
+ num_scales = tf.shape(image_scales)[0]
+ keep_going = lambda j, b, f, scales, scores: tf.less(j, num_scales)
+
+ (_, output_boxes, output_features, output_scales,
+ output_scores) = tf.while_loop(
+ cond=keep_going,
+ body=_ProcessSingleScale,
+ loop_vars=[
+ i, output_boxes, output_features, output_scales, output_scores
+ ],
+ shape_invariants=[
+ i.get_shape(),
+ tf.TensorShape([None, 4]),
+ tf.TensorShape([None, feature_depth]),
+ tf.TensorShape([None]),
+ tf.TensorShape([None])
+ ],
+ back_prop=False)
+
+ feature_boxes = box_list.BoxList(output_boxes)
+ feature_boxes.add_field('features', output_features)
+ feature_boxes.add_field('scales', output_scales)
+ feature_boxes.add_field('scores', output_scores)
+
+ nms_max_boxes = tf.minimum(max_feature_num, feature_boxes.num_boxes())
+ final_boxes = box_list_ops.non_max_suppression(feature_boxes, iou,
+ nms_max_boxes)
+
+ return final_boxes.get(), final_boxes.get_field(
+ 'features'), final_boxes.get_field('scales'), tf.expand_dims(
+ final_boxes.get_field('scores'), 1)
diff --git a/models/research/delf/delf/python/training/model/resnet50.py b/models/research/delf/delf/python/training/model/resnet50.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c4d7c2f68dea12d74fcd32a8b52fd1285e92b59
--- /dev/null
+++ b/models/research/delf/delf/python/training/model/resnet50.py
@@ -0,0 +1,358 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ResNet50 backbone used in DELF model.
+
+Copied over from tensorflow/python/eager/benchmarks/resnet50/resnet50.py,
+because that code does not support dependencies.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf
+
+layers = tf.keras.layers
+
+
+class _IdentityBlock(tf.keras.Model):
+ """_IdentityBlock is the block that has no conv layer at shortcut.
+
+ Args:
+ kernel_size: the kernel size of middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ block: 'a','b'..., current block label, used for generating layer names
+ data_format: data_format for the input ('channels_first' or
+ 'channels_last').
+ """
+
+ def __init__(self, kernel_size, filters, stage, block, data_format):
+ super(_IdentityBlock, self).__init__(name='')
+ filters1, filters2, filters3 = filters
+
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+ bn_axis = 1 if data_format == 'channels_first' else 3
+
+ self.conv2a = layers.Conv2D(
+ filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format)
+ self.bn2a = layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2a')
+
+ self.conv2b = layers.Conv2D(
+ filters2,
+ kernel_size,
+ padding='same',
+ data_format=data_format,
+ name=conv_name_base + '2b')
+ self.bn2b = layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2b')
+
+ self.conv2c = layers.Conv2D(
+ filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
+ self.bn2c = layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2c')
+
+ def call(self, input_tensor, training=False):
+ x = self.conv2a(input_tensor)
+ x = self.bn2a(x, training=training)
+ x = tf.nn.relu(x)
+
+ x = self.conv2b(x)
+ x = self.bn2b(x, training=training)
+ x = tf.nn.relu(x)
+
+ x = self.conv2c(x)
+ x = self.bn2c(x, training=training)
+
+ x += input_tensor
+ return tf.nn.relu(x)
+
+
+class _ConvBlock(tf.keras.Model):
+ """_ConvBlock is the block that has a conv layer at shortcut.
+
+ Args:
+ kernel_size: the kernel size of middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ block: 'a','b'..., current block label, used for generating layer names
+ data_format: data_format for the input ('channels_first' or
+ 'channels_last').
+ strides: strides for the convolution. Note that from stage 3, the first
+ conv layer at main path is with strides=(2,2), and the shortcut should
+ have strides=(2,2) as well.
+ """
+
+ def __init__(self,
+ kernel_size,
+ filters,
+ stage,
+ block,
+ data_format,
+ strides=(2, 2)):
+ super(_ConvBlock, self).__init__(name='')
+ filters1, filters2, filters3 = filters
+
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+ bn_axis = 1 if data_format == 'channels_first' else 3
+
+ self.conv2a = layers.Conv2D(
+ filters1, (1, 1),
+ strides=strides,
+ name=conv_name_base + '2a',
+ data_format=data_format)
+ self.bn2a = layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2a')
+
+ self.conv2b = layers.Conv2D(
+ filters2,
+ kernel_size,
+ padding='same',
+ name=conv_name_base + '2b',
+ data_format=data_format)
+ self.bn2b = layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2b')
+
+ self.conv2c = layers.Conv2D(
+ filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
+ self.bn2c = layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2c')
+
+ self.conv_shortcut = layers.Conv2D(
+ filters3, (1, 1),
+ strides=strides,
+ name=conv_name_base + '1',
+ data_format=data_format)
+ self.bn_shortcut = layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '1')
+
+ def call(self, input_tensor, training=False):
+ x = self.conv2a(input_tensor)
+ x = self.bn2a(x, training=training)
+ x = tf.nn.relu(x)
+
+ x = self.conv2b(x)
+ x = self.bn2b(x, training=training)
+ x = tf.nn.relu(x)
+
+ x = self.conv2c(x)
+ x = self.bn2c(x, training=training)
+
+ shortcut = self.conv_shortcut(input_tensor)
+ shortcut = self.bn_shortcut(shortcut, training=training)
+
+ x += shortcut
+ return tf.nn.relu(x)
+
+
+# pylint: disable=not-callable
+class ResNet50(tf.keras.Model):
+ """Instantiates the ResNet50 architecture.
+
+ Args:
+ data_format: format for the image. Either 'channels_first' or
+ 'channels_last'. 'channels_first' is typically faster on GPUs while
+ 'channels_last' is typically faster on CPUs. See
+ https://www.tensorflow.org/performance/performance_guide#data_formats
+ name: Prefix applied to names of variables created in the model.
+ include_top: whether to include the fully-connected layer at the top of the
+ network.
+ pooling: Optional pooling mode for feature extraction when `include_top` is
+ False. 'None' means that the output of the model will be the 4D tensor
+ output of the last convolutional layer. 'avg' means that global average
+ pooling will be applied to the output of the last convolutional layer, and
+ thus the output of the model will be a 2D tensor. 'max' means that global
+ max pooling will be applied.
+ block3_strides: whether to add a stride of 2 to block3 to make it compatible
+ with tf.slim ResNet implementation.
+ average_pooling: whether to do average pooling of block4 features before
+ global pooling.
+ classes: optional number of classes to classify images into, only to be
+ specified if `include_top` is True.
+
+ Raises:
+ ValueError: in case of invalid argument for data_format.
+ """
+
+ def __init__(self,
+ data_format,
+ name='',
+ include_top=True,
+ pooling=None,
+ block3_strides=False,
+ average_pooling=True,
+ classes=1000):
+ super(ResNet50, self).__init__(name=name)
+
+ valid_channel_values = ('channels_first', 'channels_last')
+ if data_format not in valid_channel_values:
+ raise ValueError('Unknown data_format: %s. Valid values: %s' %
+ (data_format, valid_channel_values))
+ self.include_top = include_top
+ self.block3_strides = block3_strides
+ self.average_pooling = average_pooling
+ self.pooling = pooling
+
+ def conv_block(filters, stage, block, strides=(2, 2)):
+ return _ConvBlock(
+ 3,
+ filters,
+ stage=stage,
+ block=block,
+ data_format=data_format,
+ strides=strides)
+
+ def id_block(filters, stage, block):
+ return _IdentityBlock(
+ 3, filters, stage=stage, block=block, data_format=data_format)
+
+ self.conv1 = layers.Conv2D(
+ 64, (7, 7),
+ strides=(2, 2),
+ data_format=data_format,
+ padding='same',
+ name='conv1')
+ bn_axis = 1 if data_format == 'channels_first' else 3
+ self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
+ self.max_pool = layers.MaxPooling2D((3, 3),
+ strides=(2, 2),
+ data_format=data_format)
+
+ self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
+ self.l2b = id_block([64, 64, 256], stage=2, block='b')
+ self.l2c = id_block([64, 64, 256], stage=2, block='c')
+
+ self.l3a = conv_block([128, 128, 512], stage=3, block='a')
+ self.l3b = id_block([128, 128, 512], stage=3, block='b')
+ self.l3c = id_block([128, 128, 512], stage=3, block='c')
+ self.l3d = id_block([128, 128, 512], stage=3, block='d')
+
+ self.l4a = conv_block([256, 256, 1024], stage=4, block='a')
+ self.l4b = id_block([256, 256, 1024], stage=4, block='b')
+ self.l4c = id_block([256, 256, 1024], stage=4, block='c')
+ self.l4d = id_block([256, 256, 1024], stage=4, block='d')
+ self.l4e = id_block([256, 256, 1024], stage=4, block='e')
+ self.l4f = id_block([256, 256, 1024], stage=4, block='f')
+
+ # Striding layer that can be used on top of block3 to produce feature maps
+ # with the same resolution as the TF-Slim implementation.
+ if self.block3_strides:
+ self.subsampling_layer = layers.MaxPooling2D((1, 1),
+ strides=(2, 2),
+ data_format=data_format)
+ self.l5a = conv_block([512, 512, 2048],
+ stage=5,
+ block='a',
+ strides=(1, 1))
+ else:
+ self.l5a = conv_block([512, 512, 2048], stage=5, block='a')
+ self.l5b = id_block([512, 512, 2048], stage=5, block='b')
+ self.l5c = id_block([512, 512, 2048], stage=5, block='c')
+
+ self.avg_pool = layers.AveragePooling2D((7, 7),
+ strides=(7, 7),
+ data_format=data_format)
+
+ if self.include_top:
+ self.flatten = layers.Flatten()
+ self.fc1000 = layers.Dense(classes, name='fc1000')
+ else:
+ reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
+ reduction_indices = tf.constant(reduction_indices)
+ if pooling == 'avg':
+ self.global_pooling = functools.partial(
+ tf.reduce_mean, axis=reduction_indices, keepdims=False)
+ elif pooling == 'max':
+ self.global_pooling = functools.partial(
+ tf.reduce_max, axis=reduction_indices, keepdims=False)
+ else:
+ self.global_pooling = None
+
+ def call(self, inputs, training=True, intermediates_dict=None):
+ """Call the ResNet50 model.
+
+ Args:
+ inputs: Images to compute features for.
+ training: Whether model is in training phase.
+ intermediates_dict: `None` or dictionary. If not None, accumulate feature
+ maps from intermediate blocks into the dictionary. ""
+
+ Returns:
+ Tensor with featuremap.
+ """
+
+ x = self.conv1(inputs)
+ x = self.bn_conv1(x, training=training)
+ x = tf.nn.relu(x)
+ if intermediates_dict is not None:
+ intermediates_dict['block0'] = x
+
+ x = self.max_pool(x)
+ if intermediates_dict is not None:
+ intermediates_dict['block0mp'] = x
+
+ # Block 1 (equivalent to "conv2" in Resnet paper).
+ x = self.l2a(x, training=training)
+ x = self.l2b(x, training=training)
+ x = self.l2c(x, training=training)
+ if intermediates_dict is not None:
+ intermediates_dict['block1'] = x
+
+ # Block 2 (equivalent to "conv3" in Resnet paper).
+ x = self.l3a(x, training=training)
+ x = self.l3b(x, training=training)
+ x = self.l3c(x, training=training)
+ x = self.l3d(x, training=training)
+ if intermediates_dict is not None:
+ intermediates_dict['block2'] = x
+
+ # Block 3 (equivalent to "conv4" in Resnet paper).
+ x = self.l4a(x, training=training)
+ x = self.l4b(x, training=training)
+ x = self.l4c(x, training=training)
+ x = self.l4d(x, training=training)
+ x = self.l4e(x, training=training)
+ x = self.l4f(x, training=training)
+
+ if self.block3_strides:
+ x = self.subsampling_layer(x)
+ if intermediates_dict is not None:
+ intermediates_dict['block3'] = x
+ else:
+ if intermediates_dict is not None:
+ intermediates_dict['block3'] = x
+
+ x = self.l5a(x, training=training)
+ x = self.l5b(x, training=training)
+ x = self.l5c(x, training=training)
+
+ if self.average_pooling:
+ x = self.avg_pool(x)
+ if intermediates_dict is not None:
+ intermediates_dict['block4'] = x
+ else:
+ if intermediates_dict is not None:
+ intermediates_dict['block4'] = x
+
+ if self.include_top:
+ return self.fc1000(self.flatten(x))
+ elif self.global_pooling:
+ return self.global_pooling(x)
+ else:
+ return x
diff --git a/models/research/delf/delf/python/training/train.py b/models/research/delf/delf/python/training/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b0d0a6cdaea696398ae50fcdadbead91899539f
--- /dev/null
+++ b/models/research/delf/delf/python/training/train.py
@@ -0,0 +1,442 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Training script for DELF on Google Landmarks Dataset.
+
+Script to train DELF using classification loss on Google Landmarks Dataset
+using MirroredStrategy to so it can run on multiple GPUs.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf
+import tensorflow_probability as tfp
+
+# Placeholder for internal import. Do not remove this line.
+from delf.python.training.datasets import googlelandmarks as gld
+from delf.python.training.model import delf_model
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_boolean('debug', False, 'Debug mode.')
+flags.DEFINE_string('logdir', '/tmp/delf', 'WithTensorBoard logdir.')
+flags.DEFINE_string('train_file_pattern', '/tmp/data/train*',
+ 'File pattern of training dataset files.')
+flags.DEFINE_string('validation_file_pattern', '/tmp/data/validation*',
+ 'File pattern of validation dataset files.')
+flags.DEFINE_enum('dataset_version', 'gld_v1',
+ ['gld_v1', 'gld_v2', 'gld_v2_clean'],
+ 'Google Landmarks dataset version, used to determine the'
+ 'number of classes.')
+flags.DEFINE_integer('seed', 0, 'Seed to training dataset.')
+flags.DEFINE_float('initial_lr', 0.001, 'Initial learning rate.')
+flags.DEFINE_integer('batch_size', 32, 'Global batch size.')
+flags.DEFINE_integer('max_iters', 500000, 'Maximum iterations.')
+flags.DEFINE_boolean('block3_strides', False, 'Whether to use block3_strides.')
+flags.DEFINE_boolean('use_augmentation', True,
+ 'Whether to use ImageNet style augmentation.')
+
+
+def _record_accuracy(metric, logits, labels):
+ """Record accuracy given predicted logits and ground-truth labels."""
+ softmax_probabilities = tf.keras.layers.Softmax()(logits)
+ metric.update_state(labels, softmax_probabilities)
+
+
+def _attention_summaries(scores, global_step):
+ """Record statistics of the attention score."""
+ tf.summary.scalar('attention/max', tf.reduce_max(scores), step=global_step)
+ tf.summary.scalar('attention/min', tf.reduce_min(scores), step=global_step)
+ tf.summary.scalar('attention/mean', tf.reduce_mean(scores), step=global_step)
+ tf.summary.scalar(
+ 'attention/percent_25',
+ tfp.stats.percentile(scores, 25.0),
+ step=global_step)
+ tf.summary.scalar(
+ 'attention/percent_50',
+ tfp.stats.percentile(scores, 50.0),
+ step=global_step)
+ tf.summary.scalar(
+ 'attention/percent_75',
+ tfp.stats.percentile(scores, 75.0),
+ step=global_step)
+
+
+def create_model(num_classes):
+ """Define DELF model, and initialize classifiers."""
+ model = delf_model.Delf(block3_strides=FLAGS.block3_strides, name='DELF')
+ model.init_classifiers(num_classes)
+ return model
+
+
+def _learning_rate_schedule(global_step_value, max_iters, initial_lr):
+ """Calculates learning_rate with linear decay.
+
+ Args:
+ global_step_value: int, global step.
+ max_iters: int, maximum iterations.
+ initial_lr: float, initial learning rate.
+
+ Returns:
+ lr: float, learning rate.
+ """
+ lr = initial_lr * (1.0 - global_step_value / max_iters)
+ return lr
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+
+ #-------------------------------------------------------------
+ # Log flags used.
+ logging.info('Running training script with\n')
+ logging.info('logdir= %s', FLAGS.logdir)
+ logging.info('initial_lr= %f', FLAGS.initial_lr)
+ logging.info('block3_strides= %s', str(FLAGS.block3_strides))
+
+ # ------------------------------------------------------------
+ # Create the strategy.
+ strategy = tf.distribute.MirroredStrategy()
+ logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
+ if FLAGS.debug:
+ print('Number of devices:', strategy.num_replicas_in_sync)
+
+ max_iters = FLAGS.max_iters
+ global_batch_size = FLAGS.batch_size
+ image_size = 321
+ num_eval = 1000
+ report_interval = 100
+ eval_interval = 1000
+ save_interval = 20000
+
+ initial_lr = FLAGS.initial_lr
+
+ clip_val = tf.constant(10.0)
+
+ if FLAGS.debug:
+ global_batch_size = 4
+ max_iters = 4
+ num_eval = 1
+ save_interval = 1
+ report_interval = 1
+
+ # Determine the number of classes based on the version of the dataset.
+ gld_info = gld.GoogleLandmarksInfo()
+ num_classes = gld_info.num_classes[FLAGS.dataset_version]
+
+ # ------------------------------------------------------------
+ # Create the distributed train/validation sets.
+ train_dataset = gld.CreateDataset(
+ file_pattern=FLAGS.train_file_pattern,
+ batch_size=global_batch_size,
+ image_size=image_size,
+ augmentation=FLAGS.use_augmentation,
+ seed=FLAGS.seed)
+ validation_dataset = gld.CreateDataset(
+ file_pattern=FLAGS.validation_file_pattern,
+ batch_size=global_batch_size,
+ image_size=image_size,
+ augmentation=False,
+ seed=FLAGS.seed)
+
+ train_iterator = strategy.make_dataset_iterator(train_dataset)
+ validation_iterator = strategy.make_dataset_iterator(validation_dataset)
+
+ train_iterator.initialize()
+ validation_iterator.initialize()
+
+ # Create a checkpoint directory to store the checkpoints.
+ checkpoint_prefix = os.path.join(FLAGS.logdir, 'delf_tf2-ckpt')
+
+ # ------------------------------------------------------------
+ # Finally, we do everything in distributed scope.
+ with strategy.scope():
+ # Compute loss.
+ # Set reduction to `none` so we can do the reduction afterwards and divide
+ # by global batch size.
+ loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
+
+ def compute_loss(labels, predictions):
+ per_example_loss = loss_object(labels, predictions)
+ return tf.nn.compute_average_loss(
+ per_example_loss, global_batch_size=global_batch_size)
+
+ # Set up metrics.
+ desc_validation_loss = tf.keras.metrics.Mean(name='desc_validation_loss')
+ attn_validation_loss = tf.keras.metrics.Mean(name='attn_validation_loss')
+ desc_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
+ name='desc_train_accuracy')
+ attn_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
+ name='attn_train_accuracy')
+ desc_validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
+ name='desc_validation_accuracy')
+ attn_validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
+ name='attn_validation_accuracy')
+
+ # ------------------------------------------------------------
+ # Setup DELF model and optimizer.
+ model = create_model(num_classes)
+ logging.info('Model, datasets loaded.\nnum_classes= %d', num_classes)
+
+ optimizer = tf.keras.optimizers.SGD(learning_rate=initial_lr, momentum=0.9)
+
+ # Setup summary writer.
+ summary_writer = tf.summary.create_file_writer(
+ os.path.join(FLAGS.logdir, 'train_logs'), flush_millis=10000)
+
+ # Setup checkpoint directory.
+ checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
+ manager = tf.train.CheckpointManager(
+ checkpoint, checkpoint_prefix, max_to_keep=3)
+
+ # ------------------------------------------------------------
+ # Train step to run on one GPU.
+ def train_step(inputs):
+ """Train one batch."""
+ images, labels = inputs
+ # Temporary workaround to avoid some corrupted labels.
+ labels = tf.clip_by_value(labels, 0, model.num_classes)
+
+ global_step = optimizer.iterations
+ tf.summary.scalar(
+ 'image_range/max', tf.reduce_max(images), step=global_step)
+ tf.summary.scalar(
+ 'image_range/min', tf.reduce_min(images), step=global_step)
+
+ def _backprop_loss(tape, loss, weights):
+ """Backpropogate losses using clipped gradients.
+
+ Args:
+ tape: gradient tape.
+ loss: scalar Tensor, loss value.
+ weights: keras model weights.
+ """
+ gradients = tape.gradient(loss, weights)
+ clipped, _ = tf.clip_by_global_norm(gradients, clip_norm=clip_val)
+ optimizer.apply_gradients(zip(clipped, weights))
+
+ # Record gradients and loss through backbone.
+ with tf.GradientTape() as desc_tape:
+
+ blocks = {}
+ prelogits = model.backbone(
+ images, intermediates_dict=blocks, training=True)
+
+ # Report sparsity.
+ activations_zero_fractions = {
+ 'sparsity/%s' % k: tf.nn.zero_fraction(v)
+ for k, v in blocks.items()
+ }
+ for k, v in activations_zero_fractions.items():
+ tf.summary.scalar(k, v, step=global_step)
+
+ # Apply descriptor classifier.
+ logits = model.desc_classification(prelogits)
+
+ desc_loss = compute_loss(labels, logits)
+
+ # Backprop only through backbone weights.
+ _backprop_loss(desc_tape, desc_loss, model.desc_trainable_weights)
+
+ # Record descriptor train accuracy.
+ _record_accuracy(desc_train_accuracy, logits, labels)
+
+ # Record gradients and loss through attention block.
+ with tf.GradientTape() as attn_tape:
+ block3 = blocks['block3'] # pytype: disable=key-error
+
+ # Stopping gradients according to DELG paper:
+ # (https://arxiv.org/abs/2001.05027).
+ block3 = tf.stop_gradient(block3)
+
+ prelogits, scores, _ = model.attention(block3, training=True)
+ _attention_summaries(scores, global_step)
+
+ # Apply attention block classifier.
+ logits = model.attn_classification(prelogits)
+
+ attn_loss = compute_loss(labels, logits)
+
+ # Backprop only through attention weights.
+ _backprop_loss(attn_tape, attn_loss, model.attn_trainable_weights)
+
+ # Record attention train accuracy.
+ _record_accuracy(attn_train_accuracy, logits, labels)
+
+ return desc_loss, attn_loss
+
+ # ------------------------------------------------------------
+ def validation_step(inputs):
+ """Validate one batch."""
+ images, labels = inputs
+ labels = tf.clip_by_value(labels, 0, model.num_classes)
+
+ # Get descriptor predictions.
+ blocks = {}
+ prelogits = model.backbone(
+ images, intermediates_dict=blocks, training=False)
+ logits = model.desc_classification(prelogits, training=False)
+ softmax_probabilities = tf.keras.layers.Softmax()(logits)
+
+ validation_loss = loss_object(labels, logits)
+ desc_validation_loss.update_state(validation_loss)
+ desc_validation_accuracy.update_state(labels, softmax_probabilities)
+
+ # Get attention predictions.
+ block3 = blocks['block3'] # pytype: disable=key-error
+ prelogits, _, _ = model.attention(block3, training=False)
+
+ logits = model.attn_classification(prelogits, training=False)
+ softmax_probabilities = tf.keras.layers.Softmax()(logits)
+
+ validation_loss = loss_object(labels, logits)
+ attn_validation_loss.update_state(validation_loss)
+ attn_validation_accuracy.update_state(labels, softmax_probabilities)
+
+ return desc_validation_accuracy.result(), attn_validation_accuracy.result(
+ )
+
+ # `run` replicates the provided computation and runs it
+ # with the distributed input.
+ @tf.function
+ def distributed_train_step(dataset_inputs):
+ """Get the actual losses."""
+ # Each (desc, attn) is a list of 3 losses - crossentropy, reg, total.
+ desc_per_replica_loss, attn_per_replica_loss = (
+ strategy.run(train_step, args=(dataset_inputs,)))
+
+ # Reduce over the replicas.
+ desc_global_loss = strategy.reduce(
+ tf.distribute.ReduceOp.SUM, desc_per_replica_loss, axis=None)
+ attn_global_loss = strategy.reduce(
+ tf.distribute.ReduceOp.SUM, attn_per_replica_loss, axis=None)
+
+ return desc_global_loss, attn_global_loss
+
+ @tf.function
+ def distributed_validation_step(dataset_inputs):
+ return strategy.run(validation_step, args=(dataset_inputs,))
+
+ # ------------------------------------------------------------
+ # *** TRAIN LOOP ***
+ with summary_writer.as_default():
+ with tf.summary.record_if(
+ tf.math.equal(0, optimizer.iterations % report_interval)):
+
+ global_step_value = optimizer.iterations.numpy()
+ while global_step_value < max_iters:
+
+ # input_batch : images(b, h, w, c), labels(b,).
+ try:
+ input_batch = train_iterator.get_next()
+ except tf.errors.OutOfRangeError:
+ # Break if we run out of data in the dataset.
+ logging.info('Stopping training at global step %d, no more data',
+ global_step_value)
+ break
+
+ # Set learning rate for optimizer to use.
+ global_step = optimizer.iterations
+ global_step_value = global_step.numpy()
+
+ learning_rate = _learning_rate_schedule(global_step_value, max_iters,
+ initial_lr)
+ optimizer.learning_rate = learning_rate
+ tf.summary.scalar(
+ 'learning_rate', optimizer.learning_rate, step=global_step)
+
+ # Run the training step over num_gpu gpus.
+ desc_dist_loss, attn_dist_loss = distributed_train_step(input_batch)
+
+ # Log losses and accuracies to tensorboard.
+ tf.summary.scalar(
+ 'loss/desc/crossentropy', desc_dist_loss, step=global_step)
+ tf.summary.scalar(
+ 'loss/attn/crossentropy', attn_dist_loss, step=global_step)
+ tf.summary.scalar(
+ 'train_accuracy/desc',
+ desc_train_accuracy.result(),
+ step=global_step)
+ tf.summary.scalar(
+ 'train_accuracy/attn',
+ attn_train_accuracy.result(),
+ step=global_step)
+
+ # Print to console if running locally.
+ if FLAGS.debug:
+ if global_step_value % report_interval == 0:
+ print(global_step.numpy())
+ print('desc:', desc_dist_loss.numpy())
+ print('attn:', attn_dist_loss.numpy())
+
+ # Validate once in {eval_interval*n, n \in N} steps.
+ if global_step_value % eval_interval == 0:
+ for i in range(num_eval):
+ try:
+ validation_batch = validation_iterator.get_next()
+ desc_validation_result, attn_validation_result = (
+ distributed_validation_step(validation_batch))
+ except tf.errors.OutOfRangeError:
+ logging.info('Stopping eval at batch %d, no more data', i)
+ break
+
+ # Log validation results to tensorboard.
+ tf.summary.scalar(
+ 'validation/desc', desc_validation_result, step=global_step)
+ tf.summary.scalar(
+ 'validation/attn', attn_validation_result, step=global_step)
+
+ logging.info('\nValidation(%f)\n', global_step_value)
+ logging.info(': desc: %f\n', desc_validation_result.numpy())
+ logging.info(': attn: %f\n', attn_validation_result.numpy())
+ # Print to console.
+ if FLAGS.debug:
+ print('Validation: desc:', desc_validation_result.numpy())
+ print(' : attn:', attn_validation_result.numpy())
+
+ # Save checkpoint once (each save_interval*n, n \in N) steps.
+ if global_step_value % save_interval == 0:
+ save_path = manager.save()
+ logging.info('Saved({global_step_value}) at %s', save_path)
+
+ file_path = '%s/delf_weights' % FLAGS.logdir
+ model.save_weights(file_path, save_format='tf')
+ logging.info('Saved weights({global_step_value}) at %s', file_path)
+
+ # Reset metrics for next step.
+ desc_train_accuracy.reset_states()
+ attn_train_accuracy.reset_states()
+ desc_validation_loss.reset_states()
+ attn_validation_loss.reset_states()
+ desc_validation_accuracy.reset_states()
+ attn_validation_accuracy.reset_states()
+
+ if global_step.numpy() > max_iters:
+ break
+
+ logging.info('Finished training for %d steps.', max_iters)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/models/research/delf/delf/python/utils.py b/models/research/delf/delf/python/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbab2d8c7f1f423991c98851ad509e4684b738b7
--- /dev/null
+++ b/models/research/delf/delf/python/utils.py
@@ -0,0 +1,41 @@
+# Copyright 2020 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper functions for DELF."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from PIL import Image
+from PIL import ImageFile
+import tensorflow as tf
+
+# To avoid PIL crashing for truncated (corrupted) images.
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+
+def RgbLoader(path):
+ """Helper function to read image with PIL.
+
+ Args:
+ path: Path to image to be loaded.
+
+ Returns:
+ PIL image in RGB format.
+ """
+ with tf.io.gfile.GFile(path, 'rb') as f:
+ img = Image.open(f)
+ return img.convert('RGB')
+
diff --git a/models/research/delf/setup.py b/models/research/delf/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aec6f0065a476dbc83b28145916f2981df4bd82
--- /dev/null
+++ b/models/research/delf/setup.py
@@ -0,0 +1,37 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Setup script for delf."""
+
+from setuptools import setup, find_packages
+
+install_requires = [
+ 'absl-py >= 0.7.1',
+ 'protobuf >= 3.8.0',
+ 'pandas >= 0.24.2',
+ 'numpy >= 1.16.1',
+ 'scipy >= 1.2.2',
+ 'tensorflow >= 2.0.0b1',
+ 'tf_slim >= 1.1',
+ 'tensorflow_probability >= 0.9.0',
+]
+
+setup(
+ name='delf',
+ version='2.0',
+ include_package_data=True,
+ packages=find_packages(),
+ install_requires=install_requires,
+ description='DELF (DEep Local Features)',
+)
diff --git a/models/research/domain_adaptation/README.md b/models/research/domain_adaptation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e8a2b83794f11ed3711e6bc26254a90cb5469440
--- /dev/null
+++ b/models/research/domain_adaptation/README.md
@@ -0,0 +1,124 @@
+
+
+
+
+## Introduction
+This is the code used for two domain adaptation papers.
+
+The `domain_separation` directory contains code for the "Domain Separation
+Networks" paper by Bousmalis K., Trigeorgis G., et al. which was presented at
+NIPS 2016. The paper can be found here: https://arxiv.org/abs/1608.06019.
+
+The `pixel_domain_adaptation` directory contains the code used for the
+"Unsupervised Pixel-Level Domain Adaptation with Generative Adversarial
+Networks" paper by Bousmalis K., et al. (presented at CVPR 2017). The paper can
+be found here: https://arxiv.org/abs/1612.05424. PixelDA aims to perform domain
+adaptation by transfering the visual style of the target domain (which has few
+or no labels) to a source domain (which has many labels). This is accomplished
+using a Generative Adversarial Network (GAN).
+
+### Other implementations
+* [Simplified-DSN](https://github.com/AmirHussein96/Simplified-DSN):
+ An unofficial implementation of the [Domain Separation Networks paper](https://arxiv.org/abs/1608.06019).
+
+## Contact
+The domain separation code was open-sourced
+by [Konstantinos Bousmalis](https://github.com/bousmalis)
+(konstantinos@google.com), while the pixel level domain adaptation code was
+open-sourced by [David Dohan](https://github.com/dmrd) (ddohan@google.com).
+
+## Installation
+You will need to have the following installed on your machine before trying out the DSN code.
+
+* TensorFlow 1.x: https://www.tensorflow.org/install/
+* Bazel: https://bazel.build/
+
+## Initial setup
+In order to run the MNIST to MNIST-M experiments, you will need to set the
+data directory:
+
+```
+$ export DSN_DATA_DIR=/your/dir
+```
+
+Add models and models/slim to your `$PYTHONPATH` (assumes $PWD is /models):
+
+```
+$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
+```
+
+## Getting the datasets
+
+You can fetch the MNIST data by running
+
+```
+ $ bazel run slim:download_and_convert_data -- --dataset_dir $DSN_DATA_DIR --dataset_name=mnist
+```
+
+The MNIST-M dataset is available online [here](http://bit.ly/2nrlUAJ). Once it is downloaded and extracted into your data directory, create TFRecord files by running:
+```
+$ bazel run domain_adaptation/datasets:download_and_convert_mnist_m -- --dataset_dir $DSN_DATA_DIR
+```
+
+# Running PixelDA from MNIST to MNIST-M
+You can run PixelDA as follows (using Tensorboard to examine the results):
+
+```
+$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_train -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m
+```
+
+And evaluation as:
+```
+$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_eval -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m --target_split_name test
+```
+
+The MNIST-M results in the paper were run with the following hparams flag:
+```
+--hparams arch=resnet,domain_loss_weight=0.135603587834,num_training_examples=16000000,style_transfer_loss_weight=0.0113173311334,task_loss_in_g_weight=0.0100959947002,task_tower=mnist,task_tower_in_g_step=true
+```
+
+### A note on terminology/language of the code:
+
+The components of the network can be grouped into two parts
+which correspond to elements which are jointly optimized: The generator
+component and the discriminator component.
+
+The generator component takes either an image or noise vector and produces an
+output image.
+
+The discriminator component takes the generated images and the target images
+and attempts to discriminate between them.
+
+## Running DSN code for adapting MNIST to MNIST-M
+
+Then you need to build the binaries with Bazel:
+
+```
+$ bazel build -c opt domain_adaptation/domain_separation/...
+```
+
+You can then train with the following command:
+
+```
+$ ./bazel-bin/domain_adaptation/domain_separation/dsn_train \
+ --similarity_loss=dann_loss \
+ --basic_tower=dann_mnist \
+ --source_dataset=mnist \
+ --target_dataset=mnist_m \
+ --learning_rate=0.0117249 \
+ --gamma_weight=0.251175 \
+ --weight_decay=1e-6 \
+ --layers_to_regularize=fc3 \
+ --nouse_separation \
+ --master="" \
+ --dataset_dir=${DSN_DATA_DIR} \
+ -v --use_logging
+```
+
+Evaluation can be invoked with the following command:
+
+```
+$ ./bazel-bin/domain_adaptation/domain_separation/dsn_eval \
+ -v --dataset mnist_m --split test --num_examples=9001 \
+ --dataset_dir=${DSN_DATA_DIR}
+```
diff --git a/models/research/domain_adaptation/WORKSPACE b/models/research/domain_adaptation/WORKSPACE
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/domain_adaptation/__init__.py b/models/research/domain_adaptation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/domain_adaptation/datasets/BUILD b/models/research/domain_adaptation/datasets/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..067a79374fbcedaa6fcd90293e5365aaad4c18c6
--- /dev/null
+++ b/models/research/domain_adaptation/datasets/BUILD
@@ -0,0 +1,45 @@
+# Domain Adaptation Scenarios Datasets
+
+package(
+ default_visibility = [
+ ":internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+ name = "internal",
+ packages = [
+ "//domain_adaptation/...",
+ ],
+)
+
+py_library(
+ name = "dataset_factory",
+ srcs = ["dataset_factory.py"],
+ deps = [
+ ":mnist_m",
+ "//slim:mnist",
+ ],
+)
+
+py_binary(
+ name = "download_and_convert_mnist_m",
+ srcs = ["download_and_convert_mnist_m.py"],
+ deps = [
+
+ "//slim:dataset_utils",
+ ],
+)
+
+py_binary(
+ name = "mnist_m",
+ srcs = ["mnist_m.py"],
+ deps = [
+
+ "//slim:dataset_utils",
+ ],
+)
diff --git a/models/research/domain_adaptation/datasets/__init__.py b/models/research/domain_adaptation/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/domain_adaptation/datasets/dataset_factory.py b/models/research/domain_adaptation/datasets/dataset_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ca1b41c412a78d25053fc786c8f81072fe90adb
--- /dev/null
+++ b/models/research/domain_adaptation/datasets/dataset_factory.py
@@ -0,0 +1,107 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A factory-pattern class which returns image/label pairs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+import tensorflow as tf
+
+from slim.datasets import mnist
+from domain_adaptation.datasets import mnist_m
+
+slim = tf.contrib.slim
+
+
+def get_dataset(dataset_name,
+ split_name,
+ dataset_dir,
+ file_pattern=None,
+ reader=None):
+ """Given a dataset name and a split_name returns a Dataset.
+
+ Args:
+ dataset_name: String, the name of the dataset.
+ split_name: A train/test split name.
+ dataset_dir: The directory where the dataset files are stored.
+ file_pattern: The file pattern to use for matching the dataset source files.
+ reader: The subclass of tf.ReaderBase. If left as `None`, then the default
+ reader defined by each dataset is used.
+
+ Returns:
+ A tf-slim `Dataset` class.
+
+ Raises:
+ ValueError: if `dataset_name` isn't recognized.
+ """
+ dataset_name_to_module = {'mnist': mnist, 'mnist_m': mnist_m}
+ if dataset_name not in dataset_name_to_module:
+ raise ValueError('Name of dataset unknown %s.' % dataset_name)
+
+ return dataset_name_to_module[dataset_name].get_split(split_name, dataset_dir,
+ file_pattern, reader)
+
+
+def provide_batch(dataset_name, split_name, dataset_dir, num_readers,
+ batch_size, num_preprocessing_threads):
+ """Provides a batch of images and corresponding labels.
+
+ Args:
+ dataset_name: String, the name of the dataset.
+ split_name: A train/test split name.
+ dataset_dir: The directory where the dataset files are stored.
+ num_readers: The number of readers used by DatasetDataProvider.
+ batch_size: The size of the batch requested.
+ num_preprocessing_threads: The number of preprocessing threads for
+ tf.train.batch.
+ file_pattern: The file pattern to use for matching the dataset source files.
+ reader: The subclass of tf.ReaderBase. If left as `None`, then the default
+ reader defined by each dataset is used.
+
+ Returns:
+ A batch of
+ images: tensor of [batch_size, height, width, channels].
+ labels: dictionary of labels.
+ """
+ dataset = get_dataset(dataset_name, split_name, dataset_dir)
+ provider = slim.dataset_data_provider.DatasetDataProvider(
+ dataset,
+ num_readers=num_readers,
+ common_queue_capacity=20 * batch_size,
+ common_queue_min=10 * batch_size)
+ [image, label] = provider.get(['image', 'label'])
+
+ # Convert images to float32
+ image = tf.image.convert_image_dtype(image, tf.float32)
+ image -= 0.5
+ image *= 2
+
+ # Load the data.
+ labels = {}
+ images, labels['classes'] = tf.train.batch(
+ [image, label],
+ batch_size=batch_size,
+ num_threads=num_preprocessing_threads,
+ capacity=5 * batch_size)
+ labels['classes'] = slim.one_hot_encoding(labels['classes'],
+ dataset.num_classes)
+
+ # Convert mnist to RGB and 32x32 so that it can match mnist_m.
+ if dataset_name == 'mnist':
+ images = tf.image.grayscale_to_rgb(images)
+ images = tf.image.resize_images(images, [32, 32])
+ return images, labels
diff --git a/models/research/domain_adaptation/datasets/download_and_convert_mnist_m.py b/models/research/domain_adaptation/datasets/download_and_convert_mnist_m.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b5004d3d8aaf54656389e517c50f38299714bc7
--- /dev/null
+++ b/models/research/domain_adaptation/datasets/download_and_convert_mnist_m.py
@@ -0,0 +1,237 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Downloads and converts MNIST-M data to TFRecords of TF-Example protos.
+
+This module downloads the MNIST-M data, uncompresses it, reads the files
+that make up the MNIST-M data and creates two TFRecord datasets: one for train
+and one for test. Each TFRecord dataset is comprised of a set of TF-Example
+protocol buffers, each of which contain a single image and label.
+
+The script should take about a minute to run.
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+import sys
+
+# Dependency imports
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+tf.app.flags.DEFINE_string(
+ 'dataset_dir', None,
+ 'The directory where the output TFRecords and temporary files are saved.')
+
+FLAGS = tf.app.flags.FLAGS
+
+_IMAGE_SIZE = 32
+_NUM_CHANNELS = 3
+
+# The number of images in the training set.
+_NUM_TRAIN_SAMPLES = 59001
+
+# The number of images to be kept from the training set for the validation set.
+_NUM_VALIDATION = 1000
+
+# The number of images in the test set.
+_NUM_TEST_SAMPLES = 9001
+
+# Seed for repeatability.
+_RANDOM_SEED = 0
+
+# The names of the classes.
+_CLASS_NAMES = [
+ 'zero',
+ 'one',
+ 'two',
+ 'three',
+ 'four',
+ 'five',
+ 'size',
+ 'seven',
+ 'eight',
+ 'nine',
+]
+
+
+class ImageReader(object):
+ """Helper class that provides TensorFlow image coding utilities."""
+
+ def __init__(self):
+ # Initializes function that decodes RGB PNG data.
+ self._decode_png_data = tf.placeholder(dtype=tf.string)
+ self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)
+
+ def read_image_dims(self, sess, image_data):
+ image = self.decode_png(sess, image_data)
+ return image.shape[0], image.shape[1]
+
+ def decode_png(self, sess, image_data):
+ image = sess.run(
+ self._decode_png, feed_dict={self._decode_png_data: image_data})
+ assert len(image.shape) == 3
+ assert image.shape[2] == 3
+ return image
+
+
+def _convert_dataset(split_name, filenames, filename_to_class_id, dataset_dir):
+ """Converts the given filenames to a TFRecord dataset.
+
+ Args:
+ split_name: The name of the dataset, either 'train' or 'valid'.
+ filenames: A list of absolute paths to png images.
+ filename_to_class_id: A dictionary from filenames (strings) to class ids
+ (integers).
+ dataset_dir: The directory where the converted datasets are stored.
+ """
+ print('Converting the {} split.'.format(split_name))
+ # Train and validation splits are both in the train directory.
+ if split_name in ['train', 'valid']:
+ png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train')
+ elif split_name == 'test':
+ png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test')
+
+ with tf.Graph().as_default():
+ image_reader = ImageReader()
+
+ with tf.Session('') as sess:
+ output_filename = _get_output_filename(dataset_dir, split_name)
+
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+ for filename in filenames:
+ # Read the filename:
+ image_data = tf.gfile.FastGFile(
+ os.path.join(png_directory, filename), 'r').read()
+ height, width = image_reader.read_image_dims(sess, image_data)
+
+ class_id = filename_to_class_id[filename]
+ example = dataset_utils.image_to_tfexample(image_data, 'png', height,
+ width, class_id)
+ tfrecord_writer.write(example.SerializeToString())
+
+ sys.stdout.write('\n')
+ sys.stdout.flush()
+
+
+def _extract_labels(label_filename):
+ """Extract the labels into a dict of filenames to int labels.
+
+ Args:
+ labels_filename: The filename of the MNIST-M labels.
+
+ Returns:
+ A dictionary of filenames to int labels.
+ """
+ print('Extracting labels from: ', label_filename)
+ label_file = tf.gfile.FastGFile(label_filename, 'r').readlines()
+ label_lines = [line.rstrip('\n').split() for line in label_file]
+ labels = {}
+ for line in label_lines:
+ assert len(line) == 2
+ labels[line[0]] = int(line[1])
+ return labels
+
+
+def _get_output_filename(dataset_dir, split_name):
+ """Creates the output filename.
+
+ Args:
+ dataset_dir: The directory where the temporary files are stored.
+ split_name: The name of the train/test split.
+
+ Returns:
+ An absolute file path.
+ """
+ return '%s/mnist_m_%s.tfrecord' % (dataset_dir, split_name)
+
+
+def _get_filenames(dataset_dir):
+ """Returns a list of filenames and inferred class names.
+
+ Args:
+ dataset_dir: A directory containing a set PNG encoded MNIST-M images.
+
+ Returns:
+ A list of image file paths, relative to `dataset_dir`.
+ """
+ photo_filenames = []
+ for filename in os.listdir(dataset_dir):
+ photo_filenames.append(filename)
+ return photo_filenames
+
+
+def run(dataset_dir):
+ """Runs the download and conversion operation.
+
+ Args:
+ dataset_dir: The dataset directory where the dataset is stored.
+ """
+ if not tf.gfile.Exists(dataset_dir):
+ tf.gfile.MakeDirs(dataset_dir)
+
+ train_filename = _get_output_filename(dataset_dir, 'train')
+ testing_filename = _get_output_filename(dataset_dir, 'test')
+
+ if tf.gfile.Exists(train_filename) and tf.gfile.Exists(testing_filename):
+ print('Dataset files already exist. Exiting without re-creating them.')
+ return
+
+ # TODO(konstantinos): Add download and cleanup functionality
+
+ train_validation_filenames = _get_filenames(
+ os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train'))
+ test_filenames = _get_filenames(
+ os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test'))
+
+ # Divide into train and validation:
+ random.seed(_RANDOM_SEED)
+ random.shuffle(train_validation_filenames)
+ train_filenames = train_validation_filenames[_NUM_VALIDATION:]
+ validation_filenames = train_validation_filenames[:_NUM_VALIDATION]
+
+ train_validation_filenames_to_class_ids = _extract_labels(
+ os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train_labels.txt'))
+ test_filenames_to_class_ids = _extract_labels(
+ os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test_labels.txt'))
+
+ # Convert the train, validation, and test sets.
+ _convert_dataset('train', train_filenames,
+ train_validation_filenames_to_class_ids, dataset_dir)
+ _convert_dataset('valid', validation_filenames,
+ train_validation_filenames_to_class_ids, dataset_dir)
+ _convert_dataset('test', test_filenames, test_filenames_to_class_ids,
+ dataset_dir)
+
+ # Finally, write the labels file:
+ labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
+ dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
+
+ print('\nFinished converting the MNIST-M dataset!')
+
+
+def main(_):
+ assert FLAGS.dataset_dir
+ run(FLAGS.dataset_dir)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/domain_adaptation/datasets/mnist_m.py b/models/research/domain_adaptation/datasets/mnist_m.py
new file mode 100644
index 0000000000000000000000000000000000000000..fab6c443cf3d2e9783d19bf52c81b7aa62d56a38
--- /dev/null
+++ b/models/research/domain_adaptation/datasets/mnist_m.py
@@ -0,0 +1,98 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Provides data for the MNIST-M dataset.
+
+The dataset scripts used to create the dataset can be found at:
+tensorflow_models/domain_adaptation_/datasets/download_and_convert_mnist_m_dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+# Dependency imports
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+slim = tf.contrib.slim
+
+_FILE_PATTERN = 'mnist_m_%s.tfrecord'
+
+_SPLITS_TO_SIZES = {'train': 58001, 'valid': 1000, 'test': 9001}
+
+_NUM_CLASSES = 10
+
+_ITEMS_TO_DESCRIPTIONS = {
+ 'image': 'A [32 x 32 x 1] RGB image.',
+ 'label': 'A single integer between 0 and 9',
+}
+
+
+def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
+ """Gets a dataset tuple with instructions for reading MNIST.
+
+ Args:
+ split_name: A train/test split name.
+ dataset_dir: The base directory of the dataset sources.
+
+ Returns:
+ A `Dataset` namedtuple.
+
+ Raises:
+ ValueError: if `split_name` is not a valid train/test split.
+ """
+ if split_name not in _SPLITS_TO_SIZES:
+ raise ValueError('split name %s was not recognized.' % split_name)
+
+ if not file_pattern:
+ file_pattern = _FILE_PATTERN
+ file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
+
+ # Allowing None in the signature so that dataset_factory can use the default.
+ if reader is None:
+ reader = tf.TFRecordReader
+
+ keys_to_features = {
+ 'image/encoded':
+ tf.FixedLenFeature((), tf.string, default_value=''),
+ 'image/format':
+ tf.FixedLenFeature((), tf.string, default_value='png'),
+ 'image/class/label':
+ tf.FixedLenFeature(
+ [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),
+ }
+
+ items_to_handlers = {
+ 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3], channels=3),
+ 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
+ }
+
+ decoder = slim.tfexample_decoder.TFExampleDecoder(
+ keys_to_features, items_to_handlers)
+
+ labels_to_names = None
+ if dataset_utils.has_labels(dataset_dir):
+ labels_to_names = dataset_utils.read_label_file(dataset_dir)
+
+ return slim.dataset.Dataset(
+ data_sources=file_pattern,
+ reader=reader,
+ decoder=decoder,
+ num_samples=_SPLITS_TO_SIZES[split_name],
+ num_classes=_NUM_CLASSES,
+ items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
+ labels_to_names=labels_to_names)
diff --git a/models/research/domain_adaptation/domain_separation/BUILD b/models/research/domain_adaptation/domain_separation/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..14dceda27e49d74eaaaeae21676183b78c72b9c2
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/BUILD
@@ -0,0 +1,157 @@
+# Domain Separation Networks
+
+package(
+ default_visibility = [
+ ":internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+ name = "internal",
+ packages = [
+ "//domain_adaptation/...",
+ ],
+)
+
+py_library(
+ name = "models",
+ srcs = [
+ "models.py",
+ ],
+ deps = [
+ ":utils",
+ ],
+)
+
+py_library(
+ name = "losses",
+ srcs = [
+ "losses.py",
+ ],
+ deps = [
+ ":grl_op_grads_py",
+ ":grl_op_shapes_py",
+ ":grl_ops",
+ ":utils",
+ ],
+)
+
+py_test(
+ name = "losses_test",
+ srcs = [
+ "losses_test.py",
+ ],
+ deps = [
+ ":losses",
+ ":utils",
+ ],
+)
+
+py_library(
+ name = "dsn",
+ srcs = [
+ "dsn.py",
+ ],
+ deps = [
+ ":grl_op_grads_py",
+ ":grl_op_shapes_py",
+ ":grl_ops",
+ ":losses",
+ ":models",
+ ":utils",
+ ],
+)
+
+py_test(
+ name = "dsn_test",
+ srcs = [
+ "dsn_test.py",
+ ],
+ deps = [
+ ":dsn",
+ ],
+)
+
+py_binary(
+ name = "dsn_train",
+ srcs = [
+ "dsn_train.py",
+ ],
+ deps = [
+ ":dsn",
+ ":models",
+ "//domain_adaptation/datasets:dataset_factory",
+ ],
+)
+
+py_binary(
+ name = "dsn_eval",
+ srcs = [
+ "dsn_eval.py",
+ ],
+ deps = [
+ ":dsn",
+ ":models",
+ "//domain_adaptation/datasets:dataset_factory",
+ ],
+)
+
+py_test(
+ name = "models_test",
+ srcs = [
+ "models_test.py",
+ ],
+ deps = [
+ ":models",
+ "//domain_adaptation/datasets:dataset_factory",
+ ],
+)
+
+py_library(
+ name = "utils",
+ srcs = [
+ "utils.py",
+ ],
+ deps = [
+ ],
+)
+
+py_library(
+ name = "grl_op_grads_py",
+ srcs = [
+ "grl_op_grads.py",
+ ],
+ deps = [
+ ":grl_ops",
+ ],
+)
+
+py_library(
+ name = "grl_op_shapes_py",
+ srcs = [
+ "grl_op_shapes.py",
+ ],
+ deps = [
+ ],
+)
+
+py_library(
+ name = "grl_ops",
+ srcs = ["grl_ops.py"],
+ data = ["_grl_ops.so"],
+)
+
+py_test(
+ name = "grl_ops_test",
+ size = "small",
+ srcs = ["grl_ops_test.py"],
+ deps = [
+ ":grl_op_grads_py",
+ ":grl_op_shapes_py",
+ ":grl_ops",
+ ],
+)
diff --git a/models/research/domain_adaptation/domain_separation/__init__.py b/models/research/domain_adaptation/domain_separation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/domain_adaptation/domain_separation/_grl_ops.so b/models/research/domain_adaptation/domain_separation/_grl_ops.so
new file mode 100644
index 0000000000000000000000000000000000000000..4c35473760a76dcb743d58f45eddccecb5f5161e
Binary files /dev/null and b/models/research/domain_adaptation/domain_separation/_grl_ops.so differ
diff --git a/models/research/domain_adaptation/domain_separation/dsn.py b/models/research/domain_adaptation/domain_separation/dsn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3018e8a791840ae465bad493913235cc04c31cff
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/dsn.py
@@ -0,0 +1,355 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions to create a DSN model and add the different losses to it.
+
+Specifically, in this file we define the:
+ - Shared Encoding Similarity Loss Module, with:
+ - The MMD Similarity method
+ - The Correlation Similarity method
+ - The Gradient Reversal (Domain-Adversarial) method
+ - Difference Loss Module
+ - Reconstruction Loss Module
+ - Task Loss Module
+"""
+from functools import partial
+
+import tensorflow as tf
+
+import losses
+import models
+import utils
+
+slim = tf.contrib.slim
+
+
+################################################################################
+# HELPER FUNCTIONS
+################################################################################
+def dsn_loss_coefficient(params):
+ """The global_step-dependent weight that specifies when to kick in DSN losses.
+
+ Args:
+ params: A dictionary of parameters. Expecting 'domain_separation_startpoint'
+
+ Returns:
+ A weight to that effectively enables or disables the DSN-related losses,
+ i.e. similarity, difference, and reconstruction losses.
+ """
+ return tf.where(
+ tf.less(slim.get_or_create_global_step(),
+ params['domain_separation_startpoint']), 1e-10, 1.0)
+
+
+################################################################################
+# MODEL CREATION
+################################################################################
+def create_model(source_images, source_labels, domain_selection_mask,
+ target_images, target_labels, similarity_loss, params,
+ basic_tower_name):
+ """Creates a DSN model.
+
+ Args:
+ source_images: images from the source domain, a tensor of size
+ [batch_size, height, width, channels]
+ source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
+ hot for the number of classes.
+ domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
+ the labeled images that belong to the source domain.
+ target_images: images from the target domain, a tensor of size
+ [batch_size, height width, channels].
+ target_labels: a dictionary with the name, tensor pairs.
+ similarity_loss: The type of method to use for encouraging
+ the codes from the shared encoder to be similar.
+ params: A dictionary of parameters. Expecting 'weight_decay',
+ 'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
+ 'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
+ 'decoder_name', 'encoder_name'
+ basic_tower_name: the name of the tower to use for the shared encoder.
+
+ Raises:
+ ValueError: if the arch is not one of the available architectures.
+ """
+ network = getattr(models, basic_tower_name)
+ num_classes = source_labels['classes'].get_shape().as_list()[1]
+
+ # Make sure we are using the appropriate number of classes.
+ network = partial(network, num_classes=num_classes)
+
+ # Add the classification/pose estimation loss to the source domain.
+ source_endpoints = add_task_loss(source_images, source_labels, network,
+ params)
+
+ if similarity_loss == 'none':
+ # No domain adaptation, we can stop here.
+ return
+
+ with tf.variable_scope('towers', reuse=True):
+ target_logits, target_endpoints = network(
+ target_images, weight_decay=params['weight_decay'], prefix='target')
+
+ # Plot target accuracy of the train set.
+ target_accuracy = utils.accuracy(
+ tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))
+
+ if 'quaternions' in target_labels:
+ target_quaternion_loss = losses.log_quaternion_loss(
+ target_labels['quaternions'], target_endpoints['quaternion_pred'],
+ params)
+ tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)
+
+ tf.summary.scalar('eval/Target accuracy', target_accuracy)
+
+ source_shared = source_endpoints[params['layers_to_regularize']]
+ target_shared = target_endpoints[params['layers_to_regularize']]
+
+ # When using the semisupervised model we include labeled target data in the
+ # source classifier. We do not want to include these target domain when
+ # we use the similarity loss.
+ indices = tf.range(0, source_shared.get_shape().as_list()[0])
+ indices = tf.boolean_mask(indices, domain_selection_mask)
+ add_similarity_loss(similarity_loss,
+ tf.gather(source_shared, indices),
+ tf.gather(target_shared, indices), params)
+
+ if params['use_separation']:
+ add_autoencoders(
+ source_images,
+ source_shared,
+ target_images,
+ target_shared,
+ params=params,)
+
+
+def add_similarity_loss(method_name,
+ source_samples,
+ target_samples,
+ params,
+ scope=None):
+ """Adds a loss encouraging the shared encoding from each domain to be similar.
+
+ Args:
+ method_name: the name of the encoding similarity method to use. Valid
+ options include `dann_loss', `mmd_loss' or `correlation_loss'.
+ source_samples: a tensor of shape [num_samples, num_features].
+ target_samples: a tensor of shape [num_samples, num_features].
+ params: a dictionary of parameters. Expecting 'gamma_weight'.
+ scope: optional name scope for summary tags.
+ Raises:
+ ValueError: if `method_name` is not recognized.
+ """
+ weight = dsn_loss_coefficient(params) * params['gamma_weight']
+ method = getattr(losses, method_name)
+ method(source_samples, target_samples, weight, scope)
+
+
+def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain):
+ """Adds a reconstruction loss.
+
+ Args:
+ recon_loss_name: The name of the reconstruction loss.
+ images: A `Tensor` of size [batch_size, height, width, 3].
+ recons: A `Tensor` whose size matches `images`.
+ weight: A scalar coefficient for the loss.
+ domain: The name of the domain being reconstructed.
+
+ Raises:
+ ValueError: If `recon_loss_name` is not recognized.
+ """
+ if recon_loss_name == 'sum_of_pairwise_squares':
+ loss_fn = tf.contrib.losses.mean_pairwise_squared_error
+ elif recon_loss_name == 'sum_of_squares':
+ loss_fn = tf.contrib.losses.mean_squared_error
+ else:
+ raise ValueError('recon_loss_name value [%s] not recognized.' %
+ recon_loss_name)
+
+ loss = loss_fn(recons, images, weight)
+ assert_op = tf.Assert(tf.is_finite(loss), [loss])
+ with tf.control_dependencies([assert_op]):
+ tf.summary.scalar('losses/%s Recon Loss' % domain, loss)
+
+
+def add_autoencoders(source_data, source_shared, target_data, target_shared,
+ params):
+ """Adds the encoders/decoders for our domain separation model w/ incoherence.
+
+ Args:
+ source_data: images from the source domain, a tensor of size
+ [batch_size, height, width, channels]
+ source_shared: a tensor with first dimension batch_size
+ target_data: images from the target domain, a tensor of size
+ [batch_size, height, width, channels]
+ target_shared: a tensor with first dimension batch_size
+ params: A dictionary of parameters. Expecting 'layers_to_regularize',
+ 'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name',
+ 'encoder_name', 'weight_decay'
+ """
+
+ def normalize_images(images):
+ images -= tf.reduce_min(images)
+ return images / tf.reduce_max(images)
+
+ def concat_operation(shared_repr, private_repr):
+ return shared_repr + private_repr
+
+ mu = dsn_loss_coefficient(params)
+
+ # The layer to concatenate the networks at.
+ concat_layer = params['layers_to_regularize']
+
+ # The coefficient for modulating the private/shared difference loss.
+ difference_loss_weight = params['beta_weight'] * mu
+
+ # The reconstruction weight.
+ recon_loss_weight = params['alpha_weight'] * mu
+
+ # The reconstruction loss to use.
+ recon_loss_name = params['recon_loss_name']
+
+ # The decoder/encoder to use.
+ decoder_name = params['decoder_name']
+ encoder_name = params['encoder_name']
+
+ _, height, width, _ = source_data.get_shape().as_list()
+ code_size = source_shared.get_shape().as_list()[-1]
+ weight_decay = params['weight_decay']
+
+ encoder_fn = getattr(models, encoder_name)
+ # Target Auto-encoding.
+ with tf.variable_scope('source_encoder'):
+ source_endpoints = encoder_fn(
+ source_data, code_size, weight_decay=weight_decay)
+
+ with tf.variable_scope('target_encoder'):
+ target_endpoints = encoder_fn(
+ target_data, code_size, weight_decay=weight_decay)
+
+ decoder_fn = getattr(models, decoder_name)
+
+ decoder = partial(
+ decoder_fn,
+ height=height,
+ width=width,
+ channels=source_data.get_shape().as_list()[-1],
+ weight_decay=weight_decay)
+
+ # Source Auto-encoding.
+ source_private = source_endpoints[concat_layer]
+ target_private = target_endpoints[concat_layer]
+ with tf.variable_scope('decoder'):
+ source_recons = decoder(concat_operation(source_shared, source_private))
+
+ with tf.variable_scope('decoder', reuse=True):
+ source_private_recons = decoder(
+ concat_operation(tf.zeros_like(source_private), source_private))
+ source_shared_recons = decoder(
+ concat_operation(source_shared, tf.zeros_like(source_shared)))
+
+ with tf.variable_scope('decoder', reuse=True):
+ target_recons = decoder(concat_operation(target_shared, target_private))
+ target_shared_recons = decoder(
+ concat_operation(target_shared, tf.zeros_like(target_shared)))
+ target_private_recons = decoder(
+ concat_operation(tf.zeros_like(target_private), target_private))
+
+ losses.difference_loss(
+ source_private,
+ source_shared,
+ weight=difference_loss_weight,
+ name='Source')
+ losses.difference_loss(
+ target_private,
+ target_shared,
+ weight=difference_loss_weight,
+ name='Target')
+
+ add_reconstruction_loss(recon_loss_name, source_data, source_recons,
+ recon_loss_weight, 'source')
+ add_reconstruction_loss(recon_loss_name, target_data, target_recons,
+ recon_loss_weight, 'target')
+
+ # Add summaries
+ source_reconstructions = tf.concat(
+ axis=2,
+ values=map(normalize_images, [
+ source_data, source_recons, source_shared_recons,
+ source_private_recons
+ ]))
+ target_reconstructions = tf.concat(
+ axis=2,
+ values=map(normalize_images, [
+ target_data, target_recons, target_shared_recons,
+ target_private_recons
+ ]))
+ tf.summary.image(
+ 'Source Images:Recons:RGB',
+ source_reconstructions[:, :, :, :3],
+ max_outputs=10)
+ tf.summary.image(
+ 'Target Images:Recons:RGB',
+ target_reconstructions[:, :, :, :3],
+ max_outputs=10)
+
+ if source_reconstructions.get_shape().as_list()[3] == 4:
+ tf.summary.image(
+ 'Source Images:Recons:Depth',
+ source_reconstructions[:, :, :, 3:4],
+ max_outputs=10)
+ tf.summary.image(
+ 'Target Images:Recons:Depth',
+ target_reconstructions[:, :, :, 3:4],
+ max_outputs=10)
+
+
+def add_task_loss(source_images, source_labels, basic_tower, params):
+ """Adds a classification and/or pose estimation loss to the model.
+
+ Args:
+ source_images: images from the source domain, a tensor of size
+ [batch_size, height, width, channels]
+ source_labels: labels from the source domain, a tensor of size [batch_size].
+ or a tuple of (quaternions, class_labels)
+ basic_tower: a function that creates the single tower of the model.
+ params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
+ Returns:
+ The source endpoints.
+
+ Raises:
+ RuntimeError: if basic tower does not support pose estimation.
+ """
+ with tf.variable_scope('towers'):
+ source_logits, source_endpoints = basic_tower(
+ source_images, weight_decay=params['weight_decay'], prefix='Source')
+
+ if 'quaternions' in source_labels: # We have pose estimation as well
+ if 'quaternion_pred' not in source_endpoints:
+ raise RuntimeError('Please use a model for estimation e.g. pose_mini')
+
+ loss = losses.log_quaternion_loss(source_labels['quaternions'],
+ source_endpoints['quaternion_pred'],
+ params)
+
+ assert_op = tf.Assert(tf.is_finite(loss), [loss])
+ with tf.control_dependencies([assert_op]):
+ quaternion_loss = loss
+ tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
+ slim.losses.add_loss(quaternion_loss * params['pose_weight'])
+ tf.summary.scalar('losses/quaternion_loss', quaternion_loss)
+
+ classification_loss = tf.losses.softmax_cross_entropy(
+ source_labels['classes'], source_logits)
+
+ tf.summary.scalar('losses/classification_loss', classification_loss)
+ return source_endpoints
diff --git a/models/research/domain_adaptation/domain_separation/dsn_eval.py b/models/research/domain_adaptation/domain_separation/dsn_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6cccdfcc17e8f18e8381530b5c8f41501bda29b
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/dsn_eval.py
@@ -0,0 +1,161 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+"""Evaluation for Domain Separation Networks (DSNs)."""
+# pylint: enable=line-too-long
+import math
+
+import numpy as np
+from six.moves import xrange
+import tensorflow as tf
+
+from domain_adaptation.datasets import dataset_factory
+from domain_adaptation.domain_separation import losses
+from domain_adaptation.domain_separation import models
+
+slim = tf.contrib.slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+ 'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('master', '',
+ 'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
+ 'Directory where the model was written to.')
+
+tf.app.flags.DEFINE_string(
+ 'eval_dir', '/tmp/da/',
+ 'Directory where we should write the tf summaries to.')
+
+tf.app.flags.DEFINE_string('dataset_dir', None,
+ 'The directory where the dataset files are stored.')
+
+tf.app.flags.DEFINE_string('dataset', 'mnist_m',
+ 'Which dataset to test on: "mnist", "mnist_m".')
+
+tf.app.flags.DEFINE_string('split', 'valid',
+ 'Which portion to test on: "valid", "test".')
+
+tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
+
+tf.app.flags.DEFINE_string('basic_tower', 'dann_mnist',
+ 'The basic tower building block.')
+
+tf.app.flags.DEFINE_bool('enable_precision_recall', False,
+ 'If True, precision and recall for each class will '
+ 'be added to the metrics.')
+
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+
+def quaternion_metric(predictions, labels):
+ params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
+ logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
+ return slim.metrics.streaming_mean(logcost)
+
+
+def angle_diff(true_q, pred_q):
+ angles = 2 * (
+ 180.0 /
+ np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
+ return angles
+
+
+def provide_batch_fn():
+ """ The provide_batch function to use. """
+ return dataset_factory.provide_batch
+
+
+def main(_):
+ g = tf.Graph()
+ with g.as_default():
+ # Load the data.
+ images, labels = provide_batch_fn()(
+ FLAGS.dataset, FLAGS.split, FLAGS.dataset_dir, 4, FLAGS.batch_size, 4)
+
+ num_classes = labels['classes'].get_shape().as_list()[1]
+
+ tf.summary.image('eval_images', images, max_outputs=3)
+
+ # Define the model:
+ with tf.variable_scope('towers'):
+ basic_tower = getattr(models, FLAGS.basic_tower)
+ predictions, endpoints = basic_tower(
+ images,
+ num_classes=num_classes,
+ is_training=False,
+ batch_norm_params=None)
+ metric_names_to_values = {}
+
+ # Define the metrics:
+ if 'quaternions' in labels: # Also have to evaluate pose estimation!
+ quaternion_loss = quaternion_metric(labels['quaternions'],
+ endpoints['quaternion_pred'])
+
+ angle_errors, = tf.py_func(
+ angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
+ [tf.float32])
+
+ metric_names_to_values[
+ 'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
+ metric_names_to_values['Quaternion Loss'] = quaternion_loss
+
+ accuracy = tf.contrib.metrics.streaming_accuracy(
+ tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+ predictions = tf.argmax(predictions, 1)
+ labels = tf.argmax(labels['classes'], 1)
+ metric_names_to_values['Accuracy'] = accuracy
+
+ if FLAGS.enable_precision_recall:
+ for i in xrange(num_classes):
+ index_map = tf.one_hot(i, depth=num_classes)
+ name = 'PR/Precision_{}'.format(i)
+ metric_names_to_values[name] = slim.metrics.streaming_precision(
+ tf.gather(index_map, predictions), tf.gather(index_map, labels))
+ name = 'PR/Recall_{}'.format(i)
+ metric_names_to_values[name] = slim.metrics.streaming_recall(
+ tf.gather(index_map, predictions), tf.gather(index_map, labels))
+
+ names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
+ metric_names_to_values)
+
+ # Create the summary ops such that they also print out to std output:
+ summary_ops = []
+ for metric_name, metric_value in names_to_values.iteritems():
+ op = tf.summary.scalar(metric_name, metric_value)
+ op = tf.Print(op, [metric_value], metric_name)
+ summary_ops.append(op)
+
+ # This ensures that we make a single pass over all of the data.
+ num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
+
+ # Setup the global step.
+ slim.get_or_create_global_step()
+ slim.evaluation.evaluation_loop(
+ FLAGS.master,
+ checkpoint_dir=FLAGS.checkpoint_dir,
+ logdir=FLAGS.eval_dir,
+ num_evals=num_batches,
+ eval_op=names_to_updates.values(),
+ summary_op=tf.summary.merge(summary_ops))
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/domain_adaptation/domain_separation/dsn_test.py b/models/research/domain_adaptation/domain_separation/dsn_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d687398a9b9356455f739417bc96ddb2ca5ad40
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/dsn_test.py
@@ -0,0 +1,157 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DSN model assembly functions."""
+
+import numpy as np
+import tensorflow as tf
+
+import dsn
+
+
+class HelperFunctionsTest(tf.test.TestCase):
+
+ def testBasicDomainSeparationStartPoint(self):
+ with self.test_session() as sess:
+ # Test for when global_step < domain_separation_startpoint
+ step = tf.contrib.slim.get_or_create_global_step()
+ sess.run(tf.global_variables_initializer()) # global_step = 0
+ params = {'domain_separation_startpoint': 2}
+ weight = dsn.dsn_loss_coefficient(params)
+ weight_np = sess.run(weight)
+ self.assertAlmostEqual(weight_np, 1e-10)
+
+ step_op = tf.assign_add(step, 1)
+ step_np = sess.run(step_op) # global_step = 1
+ weight = dsn.dsn_loss_coefficient(params)
+ weight_np = sess.run(weight)
+ self.assertAlmostEqual(weight_np, 1e-10)
+
+ # Test for when global_step >= domain_separation_startpoint
+ step_np = sess.run(step_op) # global_step = 2
+ tf.logging.info(step_np)
+ weight = dsn.dsn_loss_coefficient(params)
+ weight_np = sess.run(weight)
+ self.assertAlmostEqual(weight_np, 1.0)
+
+
+class DsnModelAssemblyTest(tf.test.TestCase):
+
+ def _testBuildDefaultModel(self):
+ images = tf.to_float(np.random.rand(32, 28, 28, 1))
+ labels = {}
+ labels['classes'] = tf.one_hot(
+ tf.to_int32(np.random.randint(0, 9, (32))), 10)
+
+ params = {
+ 'use_separation': True,
+ 'layers_to_regularize': 'fc3',
+ 'weight_decay': 0.0,
+ 'ps_tasks': 1,
+ 'domain_separation_startpoint': 1,
+ 'alpha_weight': 1,
+ 'beta_weight': 1,
+ 'gamma_weight': 1,
+ 'recon_loss_name': 'sum_of_squares',
+ 'decoder_name': 'small_decoder',
+ 'encoder_name': 'default_encoder',
+ }
+ return images, labels, params
+
+ def testBuildModelDann(self):
+ images, labels, params = self._testBuildDefaultModel()
+
+ with self.test_session():
+ dsn.create_model(images, labels,
+ tf.cast(tf.ones([32,]), tf.bool), images, labels,
+ 'dann_loss', params, 'dann_mnist')
+ loss_tensors = tf.contrib.losses.get_losses()
+ self.assertEqual(len(loss_tensors), 6)
+
+ def testBuildModelDannSumOfPairwiseSquares(self):
+ images, labels, params = self._testBuildDefaultModel()
+
+ with self.test_session():
+ dsn.create_model(images, labels,
+ tf.cast(tf.ones([32,]), tf.bool), images, labels,
+ 'dann_loss', params, 'dann_mnist')
+ loss_tensors = tf.contrib.losses.get_losses()
+ self.assertEqual(len(loss_tensors), 6)
+
+ def testBuildModelDannMultiPSTasks(self):
+ images, labels, params = self._testBuildDefaultModel()
+ params['ps_tasks'] = 10
+ with self.test_session():
+ dsn.create_model(images, labels,
+ tf.cast(tf.ones([32,]), tf.bool), images, labels,
+ 'dann_loss', params, 'dann_mnist')
+ loss_tensors = tf.contrib.losses.get_losses()
+ self.assertEqual(len(loss_tensors), 6)
+
+ def testBuildModelMmd(self):
+ images, labels, params = self._testBuildDefaultModel()
+
+ with self.test_session():
+ dsn.create_model(images, labels,
+ tf.cast(tf.ones([32,]), tf.bool), images, labels,
+ 'mmd_loss', params, 'dann_mnist')
+ loss_tensors = tf.contrib.losses.get_losses()
+ self.assertEqual(len(loss_tensors), 6)
+
+ def testBuildModelCorr(self):
+ images, labels, params = self._testBuildDefaultModel()
+
+ with self.test_session():
+ dsn.create_model(images, labels,
+ tf.cast(tf.ones([32,]), tf.bool), images, labels,
+ 'correlation_loss', params, 'dann_mnist')
+ loss_tensors = tf.contrib.losses.get_losses()
+ self.assertEqual(len(loss_tensors), 6)
+
+ def testBuildModelNoDomainAdaptation(self):
+ images, labels, params = self._testBuildDefaultModel()
+ params['use_separation'] = False
+ with self.test_session():
+ dsn.create_model(images, labels,
+ tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
+ params, 'dann_mnist')
+ loss_tensors = tf.contrib.losses.get_losses()
+ self.assertEqual(len(loss_tensors), 1)
+ self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 0)
+
+ def testBuildModelNoAdaptationWeightDecay(self):
+ images, labels, params = self._testBuildDefaultModel()
+ params['use_separation'] = False
+ params['weight_decay'] = 1e-5
+ with self.test_session():
+ dsn.create_model(images, labels,
+ tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
+ params, 'dann_mnist')
+ loss_tensors = tf.contrib.losses.get_losses()
+ self.assertEqual(len(loss_tensors), 1)
+ self.assertTrue(len(tf.contrib.losses.get_regularization_losses()) >= 1)
+
+ def testBuildModelNoSeparation(self):
+ images, labels, params = self._testBuildDefaultModel()
+ params['use_separation'] = False
+ with self.test_session():
+ dsn.create_model(images, labels,
+ tf.cast(tf.ones([32,]), tf.bool), images, labels,
+ 'dann_loss', params, 'dann_mnist')
+ loss_tensors = tf.contrib.losses.get_losses()
+ self.assertEqual(len(loss_tensors), 2)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/domain_adaptation/domain_separation/dsn_train.py b/models/research/domain_adaptation/domain_separation/dsn_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e364ad3037b041125a3523370b3b040478f0d8e
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/dsn_train.py
@@ -0,0 +1,278 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Training for Domain Separation Networks (DSNs)."""
+from __future__ import division
+
+import tensorflow as tf
+
+from domain_adaptation.datasets import dataset_factory
+import dsn
+
+slim = tf.contrib.slim
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+ 'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('source_dataset', 'pose_synthetic',
+ 'Source dataset to train on.')
+
+tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
+ 'Target dataset to train on.')
+
+tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
+ 'Target dataset to train on.')
+
+tf.app.flags.DEFINE_string('dataset_dir', None,
+ 'The directory where the dataset files are stored.')
+
+tf.app.flags.DEFINE_string('master', '',
+ 'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('train_log_dir', '/tmp/da/',
+ 'Directory where to write event logs.')
+
+tf.app.flags.DEFINE_string(
+ 'layers_to_regularize', 'fc3',
+ 'Comma-separated list of layer names to use MMD regularization on.')
+
+tf.app.flags.DEFINE_float('learning_rate', .01, 'The learning rate')
+
+tf.app.flags.DEFINE_float('alpha_weight', 1e-6,
+ 'The coefficient for scaling the reconstruction '
+ 'loss.')
+
+tf.app.flags.DEFINE_float(
+ 'beta_weight', 1e-6,
+ 'The coefficient for scaling the private/shared difference loss.')
+
+tf.app.flags.DEFINE_float(
+ 'gamma_weight', 1e-6,
+ 'The coefficient for scaling the shared encoding similarity loss.')
+
+tf.app.flags.DEFINE_float('pose_weight', 0.125,
+ 'The coefficient for scaling the pose loss.')
+
+tf.app.flags.DEFINE_float(
+ 'weight_decay', 1e-6,
+ 'The coefficient for the L2 regularization applied for all weights.')
+
+tf.app.flags.DEFINE_integer(
+ 'save_summaries_secs', 60,
+ 'The frequency with which summaries are saved, in seconds.')
+
+tf.app.flags.DEFINE_integer(
+ 'save_interval_secs', 60,
+ 'The frequency with which the model is saved, in seconds.')
+
+tf.app.flags.DEFINE_integer(
+ 'max_number_of_steps', None,
+ 'The maximum number of gradient steps. Use None to train indefinitely.')
+
+tf.app.flags.DEFINE_integer(
+ 'domain_separation_startpoint', 1,
+ 'The global step to add the domain separation losses.')
+
+tf.app.flags.DEFINE_integer(
+ 'bipartite_assignment_top_k', 3,
+ 'The number of top-k matches to use in bipartite matching adaptation.')
+
+tf.app.flags.DEFINE_float('decay_rate', 0.95, 'Learning rate decay factor.')
+
+tf.app.flags.DEFINE_integer('decay_steps', 20000, 'Learning rate decay steps.')
+
+tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
+
+tf.app.flags.DEFINE_bool('use_separation', False,
+ 'Use our domain separation model.')
+
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+tf.app.flags.DEFINE_integer(
+ 'ps_tasks', 0,
+ 'The number of parameter servers. If the value is 0, then the parameters '
+ 'are handled locally by the worker.')
+
+tf.app.flags.DEFINE_integer(
+ 'num_readers', 4,
+ 'The number of parallel readers that read data from the dataset.')
+
+tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4,
+ 'The number of threads used to create the batches.')
+
+tf.app.flags.DEFINE_integer(
+ 'task', 0,
+ 'The Task ID. This value is used when training with multiple workers to '
+ 'identify each worker.')
+
+tf.app.flags.DEFINE_string('decoder_name', 'small_decoder',
+ 'The decoder to use.')
+tf.app.flags.DEFINE_string('encoder_name', 'default_encoder',
+ 'The encoder to use.')
+
+################################################################################
+# Flags that control the architecture and losses
+################################################################################
+tf.app.flags.DEFINE_string(
+ 'similarity_loss', 'grl',
+ 'The method to use for encouraging the common encoder codes to be '
+ 'similar, one of "grl", "mmd", "corr".')
+
+tf.app.flags.DEFINE_string('recon_loss_name', 'sum_of_pairwise_squares',
+ 'The name of the reconstruction loss.')
+
+tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
+ 'The basic tower building block.')
+
+def provide_batch_fn():
+ """ The provide_batch function to use. """
+ return dataset_factory.provide_batch
+
+def main(_):
+ model_params = {
+ 'use_separation': FLAGS.use_separation,
+ 'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
+ 'layers_to_regularize': FLAGS.layers_to_regularize,
+ 'alpha_weight': FLAGS.alpha_weight,
+ 'beta_weight': FLAGS.beta_weight,
+ 'gamma_weight': FLAGS.gamma_weight,
+ 'pose_weight': FLAGS.pose_weight,
+ 'recon_loss_name': FLAGS.recon_loss_name,
+ 'decoder_name': FLAGS.decoder_name,
+ 'encoder_name': FLAGS.encoder_name,
+ 'weight_decay': FLAGS.weight_decay,
+ 'batch_size': FLAGS.batch_size,
+ 'use_logging': FLAGS.use_logging,
+ 'ps_tasks': FLAGS.ps_tasks,
+ 'task': FLAGS.task,
+ }
+ g = tf.Graph()
+ with g.as_default():
+ with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
+ # Load the data.
+ source_images, source_labels = provide_batch_fn()(
+ FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
+ FLAGS.batch_size, FLAGS.num_preprocessing_threads)
+ target_images, target_labels = provide_batch_fn()(
+ FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
+ FLAGS.batch_size, FLAGS.num_preprocessing_threads)
+
+ # In the unsupervised case all the samples in the labeled
+ # domain are from the source domain.
+ domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],),
+ True)
+
+ # When using the semisupervised model we include labeled target data in
+ # the source labelled data.
+ if FLAGS.target_labeled_dataset != 'none':
+ # 1000 is the maximum number of labelled target samples that exists in
+ # the datasets.
+ target_semi_images, target_semi_labels = provide_batch_fn()(
+ FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)
+
+ # Calculate the proportion of source domain samples in the semi-
+ # supervised setting, so that the proportion is set accordingly in the
+ # batches.
+ proportion = float(source_labels['num_train_samples']) / (
+ source_labels['num_train_samples'] +
+ target_semi_labels['num_train_samples'])
+
+ rnd_tensor = tf.random_uniform(
+ (target_semi_images.get_shape().as_list()[0],))
+
+ domain_selection_mask = rnd_tensor < proportion
+ source_images = tf.where(domain_selection_mask, source_images,
+ target_semi_images)
+ source_class_labels = tf.where(domain_selection_mask,
+ source_labels['classes'],
+ target_semi_labels['classes'])
+
+ if 'quaternions' in source_labels:
+ source_pose_labels = tf.where(domain_selection_mask,
+ source_labels['quaternions'],
+ target_semi_labels['quaternions'])
+ (source_images, source_class_labels, source_pose_labels,
+ domain_selection_mask) = tf.train.shuffle_batch(
+ [
+ source_images, source_class_labels, source_pose_labels,
+ domain_selection_mask
+ ],
+ FLAGS.batch_size,
+ 50000,
+ 5000,
+ num_threads=1,
+ enqueue_many=True)
+
+ else:
+ (source_images, source_class_labels,
+ domain_selection_mask) = tf.train.shuffle_batch(
+ [source_images, source_class_labels, domain_selection_mask],
+ FLAGS.batch_size,
+ 50000,
+ 5000,
+ num_threads=1,
+ enqueue_many=True)
+ source_labels = {}
+ source_labels['classes'] = source_class_labels
+ if 'quaternions' in source_labels:
+ source_labels['quaternions'] = source_pose_labels
+
+ slim.get_or_create_global_step()
+ tf.summary.image('source_images', source_images, max_outputs=3)
+ tf.summary.image('target_images', target_images, max_outputs=3)
+
+ dsn.create_model(
+ source_images,
+ source_labels,
+ domain_selection_mask,
+ target_images,
+ target_labels,
+ FLAGS.similarity_loss,
+ model_params,
+ basic_tower_name=FLAGS.basic_tower)
+
+ # Configure the optimization scheme:
+ learning_rate = tf.train.exponential_decay(
+ FLAGS.learning_rate,
+ slim.get_or_create_global_step(),
+ FLAGS.decay_steps,
+ FLAGS.decay_rate,
+ staircase=True,
+ name='learning_rate')
+
+ tf.summary.scalar('learning_rate', learning_rate)
+ tf.summary.scalar('total_loss', tf.losses.get_total_loss())
+
+ opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
+ tf.logging.set_verbosity(tf.logging.INFO)
+ # Run training.
+ loss_tensor = slim.learning.create_train_op(
+ slim.losses.get_total_loss(),
+ opt,
+ summarize_gradients=True,
+ colocate_gradients_with_ops=True)
+ slim.learning.train(
+ train_op=loss_tensor,
+ logdir=FLAGS.train_log_dir,
+ master=FLAGS.master,
+ is_chief=FLAGS.task == 0,
+ number_of_steps=FLAGS.max_number_of_steps,
+ save_summaries_secs=FLAGS.save_summaries_secs,
+ save_interval_secs=FLAGS.save_interval_secs)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/domain_adaptation/domain_separation/grl_op_grads.py b/models/research/domain_adaptation/domain_separation/grl_op_grads.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcd85ba2b5e7912bffe646a73558af8184812ea6
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/grl_op_grads.py
@@ -0,0 +1,34 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Gradients for operators defined in grl_ops.py."""
+import tensorflow as tf
+
+
+@tf.RegisterGradient("GradientReversal")
+def _GradientReversalGrad(_, grad):
+ """The gradients for `gradient_reversal`.
+
+ Args:
+ _: The `gradient_reversal` `Operation` that we are differentiating,
+ which we can use to find the inputs and outputs of the original op.
+ grad: Gradient with respect to the output of the `gradient_reversal` op.
+
+ Returns:
+ Gradient with respect to the input of `gradient_reversal`, which is simply
+ the negative of the input gradient.
+
+ """
+ return tf.negative(grad)
diff --git a/models/research/domain_adaptation/domain_separation/grl_op_kernels.cc b/models/research/domain_adaptation/domain_separation/grl_op_kernels.cc
new file mode 100644
index 0000000000000000000000000000000000000000..ba30128f11e9e88c702d3a80593d930519f346fe
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/grl_op_kernels.cc
@@ -0,0 +1,47 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file contains the implementations of the ops registered in
+// grl_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.pb.h"
+
+namespace tensorflow {
+
+// The gradient reversal op is used in domain adversarial training. It behaves
+// as the identity op during forward propagation, and multiplies its input by -1
+// during backward propagation.
+class GradientReversalOp : public OpKernel {
+ public:
+ explicit GradientReversalOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ // Gradient reversal op behaves as the identity op during forward
+ // propagation. Compute() function copied from the IdentityOp::Compute()
+ // function here: third_party/tensorflow/core/kernels/identity_op.h.
+ void Compute(OpKernelContext* context) override {
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ context->set_output(0, context->input(0));
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("GradientReversal").Device(DEVICE_CPU),
+ GradientReversalOp);
+
+} // namespace tensorflow
diff --git a/models/research/domain_adaptation/domain_separation/grl_op_shapes.py b/models/research/domain_adaptation/domain_separation/grl_op_shapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..52773c680af265beca9125e48bf68152b8a34e56
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/grl_op_shapes.py
@@ -0,0 +1,16 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Shape inference for operators defined in grl_ops.cc."""
diff --git a/models/research/domain_adaptation/domain_separation/grl_ops.cc b/models/research/domain_adaptation/domain_separation/grl_ops.cc
new file mode 100644
index 0000000000000000000000000000000000000000..d441c2b484215605db65a043be6cfa0ab90da2c3
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/grl_ops.cc
@@ -0,0 +1,36 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Contains custom ops.
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// This custom op is used by adversarial training.
+REGISTER_OP("GradientReversal")
+ .Input("input: float")
+ .Output("output: float")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+This op copies the input to the output during forward propagation, and
+negates the input during backward propagation.
+
+input: Tensor.
+output: Tensor, copied from input.
+)doc");
+
+} // namespace tensorflow
diff --git a/models/research/domain_adaptation/domain_separation/grl_ops.py b/models/research/domain_adaptation/domain_separation/grl_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..50447247b10caf3e41f3c0fb1c6f943dd3d9de6e
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/grl_ops.py
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GradientReversal op Python library."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+tf.logging.info(tf.resource_loader.get_data_files_path())
+_grl_ops_module = tf.load_op_library(
+ os.path.join(tf.resource_loader.get_data_files_path(),
+ '_grl_ops.so'))
+gradient_reversal = _grl_ops_module.gradient_reversal
diff --git a/models/research/domain_adaptation/domain_separation/grl_ops_test.py b/models/research/domain_adaptation/domain_separation/grl_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b431a6c02b60ade92a653d2ee8108c0586c70fbb
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/grl_ops_test.py
@@ -0,0 +1,73 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for grl_ops."""
+
+#from models.domain_adaptation.domain_separation import grl_op_grads # pylint: disable=unused-import
+#from models.domain_adaptation.domain_separation import grl_op_shapes # pylint: disable=unused-import
+import tensorflow as tf
+
+import grl_op_grads
+import grl_ops
+
+FLAGS = tf.app.flags.FLAGS
+
+
+class GRLOpsTest(tf.test.TestCase):
+
+ def testGradientReversalOp(self):
+ with tf.Graph().as_default():
+ with self.test_session():
+ # Test that in forward prop, gradient reversal op acts as the
+ # identity operation.
+ examples = tf.constant([5.0, 4.0, 3.0, 2.0, 1.0])
+ output = grl_ops.gradient_reversal(examples)
+ expected_output = examples
+ self.assertAllEqual(output.eval(), expected_output.eval())
+
+ # Test that shape inference works as expected.
+ self.assertAllEqual(output.get_shape(), expected_output.get_shape())
+
+ # Test that in backward prop, gradient reversal op multiplies
+ # gradients by -1.
+ examples = tf.constant([[1.0]])
+ w = tf.get_variable(name='w', shape=[1, 1])
+ b = tf.get_variable(name='b', shape=[1])
+ init_op = tf.global_variables_initializer()
+ init_op.run()
+ features = tf.nn.xw_plus_b(examples, w, b)
+ # Construct two outputs: features layer passes directly to output1, but
+ # features layer passes through a gradient reversal layer before
+ # reaching output2.
+ output1 = features
+ output2 = grl_ops.gradient_reversal(features)
+ gold = tf.constant([1.0])
+ loss1 = gold - output1
+ loss2 = gold - output2
+ opt = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+ grads_and_vars_1 = opt.compute_gradients(loss1,
+ tf.trainable_variables())
+ grads_and_vars_2 = opt.compute_gradients(loss2,
+ tf.trainable_variables())
+ self.assertAllEqual(len(grads_and_vars_1), len(grads_and_vars_2))
+ for i in range(len(grads_and_vars_1)):
+ g1 = grads_and_vars_1[i][0]
+ g2 = grads_and_vars_2[i][0]
+ # Verify that gradients of loss1 are the negative of gradients of
+ # loss2.
+ self.assertAllEqual(tf.negative(g1).eval(), g2.eval())
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/domain_adaptation/domain_separation/losses.py b/models/research/domain_adaptation/domain_separation/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d882340de10e4dd64d44f9357e8bfc5b1dd4712
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/losses.py
@@ -0,0 +1,290 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Domain Adaptation Loss Functions.
+
+The following domain adaptation loss functions are defined:
+
+- Maximum Mean Discrepancy (MMD).
+ Relevant paper:
+ Gretton, Arthur, et al.,
+ "A kernel two-sample test."
+ The Journal of Machine Learning Research, 2012
+
+- Correlation Loss on a batch.
+"""
+from functools import partial
+import tensorflow as tf
+
+import grl_op_grads # pylint: disable=unused-import
+import grl_op_shapes # pylint: disable=unused-import
+import grl_ops
+import utils
+slim = tf.contrib.slim
+
+
+################################################################################
+# SIMILARITY LOSS
+################################################################################
+def maximum_mean_discrepancy(x, y, kernel=utils.gaussian_kernel_matrix):
+ r"""Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y.
+
+ Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of
+ the distributions of x and y. Here we use the kernel two sample estimate
+ using the empirical mean of the two distributions.
+
+ MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2
+ = \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) },
+
+ where K = <\phi(x), \phi(y)>,
+ is the desired kernel function, in this case a radial basis kernel.
+
+ Args:
+ x: a tensor of shape [num_samples, num_features]
+ y: a tensor of shape [num_samples, num_features]
+ kernel: a function which computes the kernel in MMD. Defaults to the
+ GaussianKernelMatrix.
+
+ Returns:
+ a scalar denoting the squared maximum mean discrepancy loss.
+ """
+ with tf.name_scope('MaximumMeanDiscrepancy'):
+ # \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) }
+ cost = tf.reduce_mean(kernel(x, x))
+ cost += tf.reduce_mean(kernel(y, y))
+ cost -= 2 * tf.reduce_mean(kernel(x, y))
+
+ # We do not allow the loss to become negative.
+ cost = tf.where(cost > 0, cost, 0, name='value')
+ return cost
+
+
+def mmd_loss(source_samples, target_samples, weight, scope=None):
+ """Adds a similarity loss term, the MMD between two representations.
+
+ This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
+ different Gaussian kernels.
+
+ Args:
+ source_samples: a tensor of shape [num_samples, num_features].
+ target_samples: a tensor of shape [num_samples, num_features].
+ weight: the weight of the MMD loss.
+ scope: optional name scope for summary tags.
+
+ Returns:
+ a scalar tensor representing the MMD loss value.
+ """
+ sigmas = [
+ 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
+ 1e3, 1e4, 1e5, 1e6
+ ]
+ gaussian_kernel = partial(
+ utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))
+
+ loss_value = maximum_mean_discrepancy(
+ source_samples, target_samples, kernel=gaussian_kernel)
+ loss_value = tf.maximum(1e-4, loss_value) * weight
+ assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
+ with tf.control_dependencies([assert_op]):
+ tag = 'MMD Loss'
+ if scope:
+ tag = scope + tag
+ tf.summary.scalar(tag, loss_value)
+ tf.losses.add_loss(loss_value)
+
+ return loss_value
+
+
+def correlation_loss(source_samples, target_samples, weight, scope=None):
+ """Adds a similarity loss term, the correlation between two representations.
+
+ Args:
+ source_samples: a tensor of shape [num_samples, num_features]
+ target_samples: a tensor of shape [num_samples, num_features]
+ weight: a scalar weight for the loss.
+ scope: optional name scope for summary tags.
+
+ Returns:
+ a scalar tensor representing the correlation loss value.
+ """
+ with tf.name_scope('corr_loss'):
+ source_samples -= tf.reduce_mean(source_samples, 0)
+ target_samples -= tf.reduce_mean(target_samples, 0)
+
+ source_samples = tf.nn.l2_normalize(source_samples, 1)
+ target_samples = tf.nn.l2_normalize(target_samples, 1)
+
+ source_cov = tf.matmul(tf.transpose(source_samples), source_samples)
+ target_cov = tf.matmul(tf.transpose(target_samples), target_samples)
+
+ corr_loss = tf.reduce_mean(tf.square(source_cov - target_cov)) * weight
+
+ assert_op = tf.Assert(tf.is_finite(corr_loss), [corr_loss])
+ with tf.control_dependencies([assert_op]):
+ tag = 'Correlation Loss'
+ if scope:
+ tag = scope + tag
+ tf.summary.scalar(tag, corr_loss)
+ tf.losses.add_loss(corr_loss)
+
+ return corr_loss
+
+
+def dann_loss(source_samples, target_samples, weight, scope=None):
+ """Adds the domain adversarial (DANN) loss.
+
+ Args:
+ source_samples: a tensor of shape [num_samples, num_features].
+ target_samples: a tensor of shape [num_samples, num_features].
+ weight: the weight of the loss.
+ scope: optional name scope for summary tags.
+
+ Returns:
+ a scalar tensor representing the correlation loss value.
+ """
+ with tf.variable_scope('dann'):
+ batch_size = tf.shape(source_samples)[0]
+ samples = tf.concat(axis=0, values=[source_samples, target_samples])
+ samples = slim.flatten(samples)
+
+ domain_selection_mask = tf.concat(
+ axis=0, values=[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))])
+
+ # Perform the gradient reversal and be careful with the shape.
+ grl = grl_ops.gradient_reversal(samples)
+ grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1]))
+
+ grl = slim.fully_connected(grl, 100, scope='fc1')
+ logits = slim.fully_connected(grl, 1, activation_fn=None, scope='fc2')
+
+ domain_predictions = tf.sigmoid(logits)
+
+ domain_loss = tf.losses.log_loss(
+ domain_selection_mask, domain_predictions, weights=weight)
+
+ domain_accuracy = utils.accuracy(
+ tf.round(domain_predictions), domain_selection_mask)
+
+ assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
+ with tf.control_dependencies([assert_op]):
+ tag_loss = 'losses/domain_loss'
+ tag_accuracy = 'losses/domain_accuracy'
+ if scope:
+ tag_loss = scope + tag_loss
+ tag_accuracy = scope + tag_accuracy
+
+ tf.summary.scalar(tag_loss, domain_loss)
+ tf.summary.scalar(tag_accuracy, domain_accuracy)
+
+ return domain_loss
+
+
+################################################################################
+# DIFFERENCE LOSS
+################################################################################
+def difference_loss(private_samples, shared_samples, weight=1.0, name=''):
+ """Adds the difference loss between the private and shared representations.
+
+ Args:
+ private_samples: a tensor of shape [num_samples, num_features].
+ shared_samples: a tensor of shape [num_samples, num_features].
+ weight: the weight of the incoherence loss.
+ name: the name of the tf summary.
+ """
+ private_samples -= tf.reduce_mean(private_samples, 0)
+ shared_samples -= tf.reduce_mean(shared_samples, 0)
+
+ private_samples = tf.nn.l2_normalize(private_samples, 1)
+ shared_samples = tf.nn.l2_normalize(shared_samples, 1)
+
+ correlation_matrix = tf.matmul(
+ private_samples, shared_samples, transpose_a=True)
+
+ cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight
+ cost = tf.where(cost > 0, cost, 0, name='value')
+
+ tf.summary.scalar('losses/Difference Loss {}'.format(name),
+ cost)
+ assert_op = tf.Assert(tf.is_finite(cost), [cost])
+ with tf.control_dependencies([assert_op]):
+ tf.losses.add_loss(cost)
+
+
+################################################################################
+# TASK LOSS
+################################################################################
+def log_quaternion_loss_batch(predictions, labels, params):
+ """A helper function to compute the error between quaternions.
+
+ Args:
+ predictions: A Tensor of size [batch_size, 4].
+ labels: A Tensor of size [batch_size, 4].
+ params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
+
+ Returns:
+ A Tensor of size [batch_size], denoting the error between the quaternions.
+ """
+ use_logging = params['use_logging']
+ assertions = []
+ if use_logging:
+ assertions.append(
+ tf.Assert(
+ tf.reduce_all(
+ tf.less(
+ tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
+ 1e-4)),
+ ['The l2 norm of each prediction quaternion vector should be 1.']))
+ assertions.append(
+ tf.Assert(
+ tf.reduce_all(
+ tf.less(
+ tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
+ ['The l2 norm of each label quaternion vector should be 1.']))
+
+ with tf.control_dependencies(assertions):
+ product = tf.multiply(predictions, labels)
+ internal_dot_products = tf.reduce_sum(product, [1])
+
+ if use_logging:
+ internal_dot_products = tf.Print(
+ internal_dot_products,
+ [internal_dot_products, tf.shape(internal_dot_products)],
+ 'internal_dot_products:')
+
+ logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+ return logcost
+
+
+def log_quaternion_loss(predictions, labels, params):
+ """A helper function to compute the mean error between batches of quaternions.
+
+ The caller is expected to add the loss to the graph.
+
+ Args:
+ predictions: A Tensor of size [batch_size, 4].
+ labels: A Tensor of size [batch_size, 4].
+ params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
+
+ Returns:
+ A Tensor of size 1, denoting the mean error between batches of quaternions.
+ """
+ use_logging = params['use_logging']
+ logcost = log_quaternion_loss_batch(predictions, labels, params)
+ logcost = tf.reduce_sum(logcost, [0])
+ batch_size = params['batch_size']
+ logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
+ if use_logging:
+ logcost = tf.Print(
+ logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
+ return logcost
diff --git a/models/research/domain_adaptation/domain_separation/losses_test.py b/models/research/domain_adaptation/domain_separation/losses_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..46e50301be56f5977adcb3fb00587f076934b785
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/losses_test.py
@@ -0,0 +1,110 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DSN losses."""
+from functools import partial
+
+import numpy as np
+import tensorflow as tf
+
+import losses
+import utils
+
+
+def MaximumMeanDiscrepancySlow(x, y, sigmas):
+ num_samples = x.get_shape().as_list()[0]
+
+ def AverageGaussianKernel(x, y, sigmas):
+ result = 0
+ for sigma in sigmas:
+ dist = tf.reduce_sum(tf.square(x - y))
+ result += tf.exp((-1.0 / (2.0 * sigma)) * dist)
+ return result / num_samples**2
+
+ total = 0
+
+ for i in range(num_samples):
+ for j in range(num_samples):
+ total += AverageGaussianKernel(x[i, :], x[j, :], sigmas)
+ total += AverageGaussianKernel(y[i, :], y[j, :], sigmas)
+ total += -2 * AverageGaussianKernel(x[i, :], y[j, :], sigmas)
+
+ return total
+
+
+class LogQuaternionLossTest(tf.test.TestCase):
+
+ def test_log_quaternion_loss_batch(self):
+ with self.test_session():
+ predictions = tf.random_uniform((10, 4), seed=1)
+ predictions = tf.nn.l2_normalize(predictions, 1)
+ labels = tf.random_uniform((10, 4), seed=1)
+ labels = tf.nn.l2_normalize(labels, 1)
+ params = {'batch_size': 10, 'use_logging': False}
+ x = losses.log_quaternion_loss_batch(predictions, labels, params)
+ self.assertTrue(((10,) == tf.shape(x).eval()).all())
+
+
+class MaximumMeanDiscrepancyTest(tf.test.TestCase):
+
+ def test_mmd_name(self):
+ with self.test_session():
+ x = tf.random_uniform((2, 3), seed=1)
+ kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
+ loss = losses.maximum_mean_discrepancy(x, x, kernel)
+
+ self.assertEquals(loss.op.name, 'MaximumMeanDiscrepancy/value')
+
+ def test_mmd_is_zero_when_inputs_are_same(self):
+ with self.test_session():
+ x = tf.random_uniform((2, 3), seed=1)
+ kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
+ self.assertEquals(0, losses.maximum_mean_discrepancy(x, x, kernel).eval())
+
+ def test_fast_mmd_is_similar_to_slow_mmd(self):
+ with self.test_session():
+ x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
+ y = tf.constant(np.random.rand(2, 3), tf.float32)
+
+ cost_old = MaximumMeanDiscrepancySlow(x, y, [1.]).eval()
+ kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
+ cost_new = losses.maximum_mean_discrepancy(x, y, kernel).eval()
+
+ self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
+
+ def test_multiple_sigmas(self):
+ with self.test_session():
+ x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
+ y = tf.constant(np.random.rand(2, 3), tf.float32)
+
+ sigmas = tf.constant([2., 5., 10, 20, 30])
+ kernel = partial(utils.gaussian_kernel_matrix, sigmas=sigmas)
+ cost_old = MaximumMeanDiscrepancySlow(x, y, [2., 5., 10, 20, 30]).eval()
+ cost_new = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
+
+ self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
+
+ def test_mmd_is_zero_when_distributions_are_same(self):
+
+ with self.test_session():
+ x = tf.random_uniform((1000, 10), seed=1)
+ y = tf.random_uniform((1000, 10), seed=3)
+
+ kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([100.]))
+ loss = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
+
+ self.assertAlmostEqual(0, loss, delta=1e-4)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/domain_adaptation/domain_separation/models.py b/models/research/domain_adaptation/domain_separation/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..04ccaf82eb9b31a6ea78871204c7df70eca3fbfd
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/models.py
@@ -0,0 +1,443 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains different architectures for the different DSN parts.
+
+We define here the modules that can be used in the different parts of the DSN
+model.
+- shared encoder (dsn_cropped_linemod, dann_xxxx)
+- private encoder (default_encoder)
+- decoder (large_decoder, gtsrb_decoder, small_decoder)
+"""
+import tensorflow as tf
+
+#from models.domain_adaptation.domain_separation
+import utils
+
+slim = tf.contrib.slim
+
+
+def default_batch_norm_params(is_training=False):
+ """Returns default batch normalization parameters for DSNs.
+
+ Args:
+ is_training: whether or not the model is training.
+
+ Returns:
+ a dictionary that maps batch norm parameter names (strings) to values.
+ """
+ return {
+ # Decay for the moving averages.
+ 'decay': 0.5,
+ # epsilon to prevent 0s in variance.
+ 'epsilon': 0.001,
+ 'is_training': is_training
+ }
+
+
+################################################################################
+# PRIVATE ENCODERS
+################################################################################
+def default_encoder(images, code_size, batch_norm_params=None,
+ weight_decay=0.0):
+ """Encodes the given images to codes of the given size.
+
+ Args:
+ images: a tensor of size [batch_size, height, width, 1].
+ code_size: the number of hidden units in the code layer of the classifier.
+ batch_norm_params: a dictionary that maps batch norm parameter names to
+ values.
+ weight_decay: the value for the weight decay coefficient.
+
+ Returns:
+ end_points: the code of the input.
+ """
+ end_points = {}
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,
+ normalizer_fn=slim.batch_norm,
+ normalizer_params=batch_norm_params):
+ with slim.arg_scope([slim.conv2d], kernel_size=[5, 5], padding='SAME'):
+ net = slim.conv2d(images, 32, scope='conv1')
+ net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
+ net = slim.conv2d(net, 64, scope='conv2')
+ net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
+
+ net = slim.flatten(net)
+ end_points['flatten'] = net
+ net = slim.fully_connected(net, code_size, scope='fc1')
+ end_points['fc3'] = net
+ return end_points
+
+
+################################################################################
+# DECODERS
+################################################################################
+def large_decoder(codes,
+ height,
+ width,
+ channels,
+ batch_norm_params=None,
+ weight_decay=0.0):
+ """Decodes the codes to a fixed output size.
+
+ Args:
+ codes: a tensor of size [batch_size, code_size].
+ height: the height of the output images.
+ width: the width of the output images.
+ channels: the number of the output channels.
+ batch_norm_params: a dictionary that maps batch norm parameter names to
+ values.
+ weight_decay: the value for the weight decay coefficient.
+
+ Returns:
+ recons: the reconstruction tensor of shape [batch_size, height, width, 3].
+ """
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,
+ normalizer_fn=slim.batch_norm,
+ normalizer_params=batch_norm_params):
+ net = slim.fully_connected(codes, 600, scope='fc1')
+ batch_size = net.get_shape().as_list()[0]
+ net = tf.reshape(net, [batch_size, 10, 10, 6])
+
+ net = slim.conv2d(net, 32, [5, 5], scope='conv1_1')
+
+ net = tf.image.resize_nearest_neighbor(net, (16, 16))
+
+ net = slim.conv2d(net, 32, [5, 5], scope='conv2_1')
+
+ net = tf.image.resize_nearest_neighbor(net, (32, 32))
+
+ net = slim.conv2d(net, 32, [5, 5], scope='conv3_2')
+
+ output_size = [height, width]
+ net = tf.image.resize_nearest_neighbor(net, output_size)
+
+ with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
+ net = slim.conv2d(net, channels, activation_fn=None, scope='conv4_1')
+
+ return net
+
+
+def gtsrb_decoder(codes,
+ height,
+ width,
+ channels,
+ batch_norm_params=None,
+ weight_decay=0.0):
+ """Decodes the codes to a fixed output size. This decoder is specific to GTSRB
+
+ Args:
+ codes: a tensor of size [batch_size, 100].
+ height: the height of the output images.
+ width: the width of the output images.
+ channels: the number of the output channels.
+ batch_norm_params: a dictionary that maps batch norm parameter names to
+ values.
+ weight_decay: the value for the weight decay coefficient.
+
+ Returns:
+ recons: the reconstruction tensor of shape [batch_size, height, width, 3].
+
+ Raises:
+ ValueError: When the input code size is not 100.
+ """
+ batch_size, code_size = codes.get_shape().as_list()
+ if code_size != 100:
+ raise ValueError('The code size used as an input to the GTSRB decoder is '
+ 'expected to be 100.')
+
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,
+ normalizer_fn=slim.batch_norm,
+ normalizer_params=batch_norm_params):
+ net = codes
+ net = tf.reshape(net, [batch_size, 10, 10, 1])
+ net = slim.conv2d(net, 32, [3, 3], scope='conv1_1')
+
+ # First upsampling 20x20
+ net = tf.image.resize_nearest_neighbor(net, [20, 20])
+
+ net = slim.conv2d(net, 32, [3, 3], scope='conv2_1')
+
+ output_size = [height, width]
+ # Final upsampling 40 x 40
+ net = tf.image.resize_nearest_neighbor(net, output_size)
+
+ with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
+ net = slim.conv2d(net, 16, scope='conv3_1')
+ net = slim.conv2d(net, channels, activation_fn=None, scope='conv3_2')
+
+ return net
+
+
+def small_decoder(codes,
+ height,
+ width,
+ channels,
+ batch_norm_params=None,
+ weight_decay=0.0):
+ """Decodes the codes to a fixed output size.
+
+ Args:
+ codes: a tensor of size [batch_size, code_size].
+ height: the height of the output images.
+ width: the width of the output images.
+ channels: the number of the output channels.
+ batch_norm_params: a dictionary that maps batch norm parameter names to
+ values.
+ weight_decay: the value for the weight decay coefficient.
+
+ Returns:
+ recons: the reconstruction tensor of shape [batch_size, height, width, 3].
+ """
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,
+ normalizer_fn=slim.batch_norm,
+ normalizer_params=batch_norm_params):
+ net = slim.fully_connected(codes, 300, scope='fc1')
+ batch_size = net.get_shape().as_list()[0]
+ net = tf.reshape(net, [batch_size, 10, 10, 3])
+
+ net = slim.conv2d(net, 16, [3, 3], scope='conv1_1')
+ net = slim.conv2d(net, 16, [3, 3], scope='conv1_2')
+
+ output_size = [height, width]
+ net = tf.image.resize_nearest_neighbor(net, output_size)
+
+ with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
+ net = slim.conv2d(net, 16, scope='conv2_1')
+ net = slim.conv2d(net, channels, activation_fn=None, scope='conv2_2')
+
+ return net
+
+
+################################################################################
+# SHARED ENCODERS
+################################################################################
+def dann_mnist(images,
+ weight_decay=0.0,
+ prefix='model',
+ num_classes=10,
+ **kwargs):
+ """Creates a convolution MNIST model.
+
+ Note that this model implements the architecture for MNIST proposed in:
+ Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
+ JMLR 2015
+
+ Args:
+ images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
+ weight_decay: the value for the weight decay coefficient.
+ prefix: name of the model to use when prefixing tags.
+ num_classes: the number of output classes to use.
+ **kwargs: Placeholder for keyword arguments used by other shared encoders.
+
+ Returns:
+ the output logits, a tensor of size [batch_size, num_classes].
+ a dictionary with key/values the layer names and tensors.
+ """
+ end_points = {}
+
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,):
+ with slim.arg_scope([slim.conv2d], padding='SAME'):
+ end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
+ end_points['pool1'] = slim.max_pool2d(
+ end_points['conv1'], [2, 2], 2, scope='pool1')
+ end_points['conv2'] = slim.conv2d(
+ end_points['pool1'], 48, [5, 5], scope='conv2')
+ end_points['pool2'] = slim.max_pool2d(
+ end_points['conv2'], [2, 2], 2, scope='pool2')
+ end_points['fc3'] = slim.fully_connected(
+ slim.flatten(end_points['pool2']), 100, scope='fc3')
+ end_points['fc4'] = slim.fully_connected(
+ slim.flatten(end_points['fc3']), 100, scope='fc4')
+
+ logits = slim.fully_connected(
+ end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
+
+ return logits, end_points
+
+
+def dann_svhn(images,
+ weight_decay=0.0,
+ prefix='model',
+ num_classes=10,
+ **kwargs):
+ """Creates the convolutional SVHN model.
+
+ Note that this model implements the architecture for MNIST proposed in:
+ Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
+ JMLR 2015
+
+ Args:
+ images: the SVHN digits, a tensor of size [batch_size, 32, 32, 3].
+ weight_decay: the value for the weight decay coefficient.
+ prefix: name of the model to use when prefixing tags.
+ num_classes: the number of output classes to use.
+ **kwargs: Placeholder for keyword arguments used by other shared encoders.
+
+ Returns:
+ the output logits, a tensor of size [batch_size, num_classes].
+ a dictionary with key/values the layer names and tensors.
+ """
+
+ end_points = {}
+
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,):
+ with slim.arg_scope([slim.conv2d], padding='SAME'):
+
+ end_points['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
+ end_points['pool1'] = slim.max_pool2d(
+ end_points['conv1'], [3, 3], 2, scope='pool1')
+ end_points['conv2'] = slim.conv2d(
+ end_points['pool1'], 64, [5, 5], scope='conv2')
+ end_points['pool2'] = slim.max_pool2d(
+ end_points['conv2'], [3, 3], 2, scope='pool2')
+ end_points['conv3'] = slim.conv2d(
+ end_points['pool2'], 128, [5, 5], scope='conv3')
+
+ end_points['fc3'] = slim.fully_connected(
+ slim.flatten(end_points['conv3']), 3072, scope='fc3')
+ end_points['fc4'] = slim.fully_connected(
+ slim.flatten(end_points['fc3']), 2048, scope='fc4')
+
+ logits = slim.fully_connected(
+ end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
+
+ return logits, end_points
+
+
+def dann_gtsrb(images,
+ weight_decay=0.0,
+ prefix='model',
+ num_classes=43,
+ **kwargs):
+ """Creates the convolutional GTSRB model.
+
+ Note that this model implements the architecture for MNIST proposed in:
+ Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
+ JMLR 2015
+
+ Args:
+ images: the GTSRB images, a tensor of size [batch_size, 40, 40, 3].
+ weight_decay: the value for the weight decay coefficient.
+ prefix: name of the model to use when prefixing tags.
+ num_classes: the number of output classes to use.
+ **kwargs: Placeholder for keyword arguments used by other shared encoders.
+
+ Returns:
+ the output logits, a tensor of size [batch_size, num_classes].
+ a dictionary with key/values the layer names and tensors.
+ """
+
+ end_points = {}
+
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,):
+ with slim.arg_scope([slim.conv2d], padding='SAME'):
+
+ end_points['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
+ end_points['pool1'] = slim.max_pool2d(
+ end_points['conv1'], [2, 2], 2, scope='pool1')
+ end_points['conv2'] = slim.conv2d(
+ end_points['pool1'], 144, [3, 3], scope='conv2')
+ end_points['pool2'] = slim.max_pool2d(
+ end_points['conv2'], [2, 2], 2, scope='pool2')
+ end_points['conv3'] = slim.conv2d(
+ end_points['pool2'], 256, [5, 5], scope='conv3')
+ end_points['pool3'] = slim.max_pool2d(
+ end_points['conv3'], [2, 2], 2, scope='pool3')
+
+ end_points['fc3'] = slim.fully_connected(
+ slim.flatten(end_points['pool3']), 512, scope='fc3')
+
+ logits = slim.fully_connected(
+ end_points['fc3'], num_classes, activation_fn=None, scope='fc4')
+
+ return logits, end_points
+
+
+def dsn_cropped_linemod(images,
+ weight_decay=0.0,
+ prefix='model',
+ num_classes=11,
+ batch_norm_params=None,
+ is_training=False):
+ """Creates the convolutional pose estimation model for Cropped Linemod.
+
+ Args:
+ images: the Cropped Linemod samples, a tensor of size
+ [batch_size, 64, 64, 4].
+ weight_decay: the value for the weight decay coefficient.
+ prefix: name of the model to use when prefixing tags.
+ num_classes: the number of output classes to use.
+ batch_norm_params: a dictionary that maps batch norm parameter names to
+ values.
+ is_training: specifies whether or not we're currently training the model.
+ This variable will determine the behaviour of the dropout layer.
+
+ Returns:
+ the output logits, a tensor of size [batch_size, num_classes].
+ a dictionary with key/values the layer names and tensors.
+ """
+
+ end_points = {}
+
+ tf.summary.image('{}/input_images'.format(prefix), images)
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ activation_fn=tf.nn.relu,
+ normalizer_fn=slim.batch_norm if batch_norm_params else None,
+ normalizer_params=batch_norm_params):
+ with slim.arg_scope([slim.conv2d], padding='SAME'):
+ end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
+ end_points['pool1'] = slim.max_pool2d(
+ end_points['conv1'], [2, 2], 2, scope='pool1')
+ end_points['conv2'] = slim.conv2d(
+ end_points['pool1'], 64, [5, 5], scope='conv2')
+ end_points['pool2'] = slim.max_pool2d(
+ end_points['conv2'], [2, 2], 2, scope='pool2')
+ net = slim.flatten(end_points['pool2'])
+ end_points['fc3'] = slim.fully_connected(net, 128, scope='fc3')
+ net = slim.dropout(
+ end_points['fc3'], 0.5, is_training=is_training, scope='dropout')
+
+ with tf.variable_scope('quaternion_prediction'):
+ predicted_quaternion = slim.fully_connected(
+ net, 4, activation_fn=tf.nn.tanh)
+ predicted_quaternion = tf.nn.l2_normalize(predicted_quaternion, 1)
+ logits = slim.fully_connected(
+ net, num_classes, activation_fn=None, scope='fc4')
+ end_points['quaternion_pred'] = predicted_quaternion
+
+ return logits, end_points
diff --git a/models/research/domain_adaptation/domain_separation/models_test.py b/models/research/domain_adaptation/domain_separation/models_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..69d1a27259022569cc5865e49dd6bba5675d834f
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/models_test.py
@@ -0,0 +1,167 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DSN components."""
+
+import numpy as np
+import tensorflow as tf
+
+#from models.domain_adaptation.domain_separation
+import models
+
+
+class SharedEncodersTest(tf.test.TestCase):
+
+ def _testSharedEncoder(self,
+ input_shape=[5, 28, 28, 1],
+ model=models.dann_mnist,
+ is_training=True):
+ images = tf.to_float(np.random.rand(*input_shape))
+
+ with self.test_session() as sess:
+ logits, _ = model(images)
+ sess.run(tf.global_variables_initializer())
+ logits_np = sess.run(logits)
+ return logits_np
+
+ def testBuildGRLMnistModel(self):
+ logits = self._testSharedEncoder(model=getattr(models,
+ 'dann_mnist'))
+ self.assertEqual(logits.shape, (5, 10))
+ self.assertTrue(np.any(logits))
+
+ def testBuildGRLSvhnModel(self):
+ logits = self._testSharedEncoder(model=getattr(models,
+ 'dann_svhn'))
+ self.assertEqual(logits.shape, (5, 10))
+ self.assertTrue(np.any(logits))
+
+ def testBuildGRLGtsrbModel(self):
+ logits = self._testSharedEncoder([5, 40, 40, 3],
+ getattr(models, 'dann_gtsrb'))
+ self.assertEqual(logits.shape, (5, 43))
+ self.assertTrue(np.any(logits))
+
+ def testBuildPoseModel(self):
+ logits = self._testSharedEncoder([5, 64, 64, 4],
+ getattr(models, 'dsn_cropped_linemod'))
+ self.assertEqual(logits.shape, (5, 11))
+ self.assertTrue(np.any(logits))
+
+ def testBuildPoseModelWithBatchNorm(self):
+ images = tf.to_float(np.random.rand(10, 64, 64, 4))
+
+ with self.test_session() as sess:
+ logits, _ = getattr(models, 'dsn_cropped_linemod')(
+ images, batch_norm_params=models.default_batch_norm_params(True))
+ sess.run(tf.global_variables_initializer())
+ logits_np = sess.run(logits)
+ self.assertEqual(logits_np.shape, (10, 11))
+ self.assertTrue(np.any(logits_np))
+
+
+class EncoderTest(tf.test.TestCase):
+
+ def _testEncoder(self, batch_norm_params=None, channels=1):
+ images = tf.to_float(np.random.rand(10, 28, 28, channels))
+
+ with self.test_session() as sess:
+ end_points = models.default_encoder(
+ images, 128, batch_norm_params=batch_norm_params)
+ sess.run(tf.global_variables_initializer())
+ private_code = sess.run(end_points['fc3'])
+ self.assertEqual(private_code.shape, (10, 128))
+ self.assertTrue(np.any(private_code))
+ self.assertTrue(np.all(np.isfinite(private_code)))
+
+ def testEncoder(self):
+ self._testEncoder()
+
+ def testEncoderMultiChannel(self):
+ self._testEncoder(None, 4)
+
+ def testEncoderIsTrainingBatchNorm(self):
+ self._testEncoder(models.default_batch_norm_params(True))
+
+ def testEncoderBatchNorm(self):
+ self._testEncoder(models.default_batch_norm_params(False))
+
+
+class DecoderTest(tf.test.TestCase):
+
+ def _testDecoder(self,
+ height=64,
+ width=64,
+ channels=4,
+ batch_norm_params=None,
+ decoder=models.small_decoder):
+ codes = tf.to_float(np.random.rand(32, 100))
+
+ with self.test_session() as sess:
+ output = decoder(
+ codes,
+ height=height,
+ width=width,
+ channels=channels,
+ batch_norm_params=batch_norm_params)
+ sess.run(tf.global_variables_initializer())
+ output_np = sess.run(output)
+ self.assertEqual(output_np.shape, (32, height, width, channels))
+ self.assertTrue(np.any(output_np))
+ self.assertTrue(np.all(np.isfinite(output_np)))
+
+ def testSmallDecoder(self):
+ self._testDecoder(28, 28, 4, None, getattr(models, 'small_decoder'))
+
+ def testSmallDecoderThreeChannels(self):
+ self._testDecoder(28, 28, 3)
+
+ def testSmallDecoderBatchNorm(self):
+ self._testDecoder(28, 28, 4, models.default_batch_norm_params(False))
+
+ def testSmallDecoderIsTrainingBatchNorm(self):
+ self._testDecoder(28, 28, 4, models.default_batch_norm_params(True))
+
+ def testLargeDecoder(self):
+ self._testDecoder(32, 32, 4, None, getattr(models, 'large_decoder'))
+
+ def testLargeDecoderThreeChannels(self):
+ self._testDecoder(32, 32, 3, None, getattr(models, 'large_decoder'))
+
+ def testLargeDecoderBatchNorm(self):
+ self._testDecoder(32, 32, 4,
+ models.default_batch_norm_params(False),
+ getattr(models, 'large_decoder'))
+
+ def testLargeDecoderIsTrainingBatchNorm(self):
+ self._testDecoder(32, 32, 4,
+ models.default_batch_norm_params(True),
+ getattr(models, 'large_decoder'))
+
+ def testGtsrbDecoder(self):
+ self._testDecoder(40, 40, 3, None, getattr(models, 'large_decoder'))
+
+ def testGtsrbDecoderBatchNorm(self):
+ self._testDecoder(40, 40, 4,
+ models.default_batch_norm_params(False),
+ getattr(models, 'gtsrb_decoder'))
+
+ def testGtsrbDecoderIsTrainingBatchNorm(self):
+ self._testDecoder(40, 40, 4,
+ models.default_batch_norm_params(True),
+ getattr(models, 'gtsrb_decoder'))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/domain_adaptation/domain_separation/utils.py b/models/research/domain_adaptation/domain_separation/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e144ee86120bd58eb06b710fb35f3f58b5a05343
--- /dev/null
+++ b/models/research/domain_adaptation/domain_separation/utils.py
@@ -0,0 +1,183 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Auxiliary functions for domain adaptation related losses.
+"""
+import math
+import tensorflow as tf
+
+
+def create_summaries(end_points, prefix='', max_images=3, use_op_name=False):
+ """Creates a tf summary per endpoint.
+
+ If the endpoint is a 4 dimensional tensor it displays it as an image
+ otherwise if it is a two dimensional one it creates a histogram summary.
+
+ Args:
+ end_points: a dictionary of name, tf tensor pairs.
+ prefix: an optional string to prefix the summary with.
+ max_images: the maximum number of images to display per summary.
+ use_op_name: Use the op name as opposed to the shorter end_points key.
+ """
+ for layer_name in end_points:
+ if use_op_name:
+ name = end_points[layer_name].op.name
+ else:
+ name = layer_name
+ if len(end_points[layer_name].get_shape().as_list()) == 4:
+ # if it's an actual image do not attempt to reshape it
+ if end_points[layer_name].get_shape().as_list()[-1] == 1 or end_points[
+ layer_name].get_shape().as_list()[-1] == 3:
+ visualization_image = end_points[layer_name]
+ else:
+ visualization_image = reshape_feature_maps(end_points[layer_name])
+ tf.summary.image(
+ '{}/{}'.format(prefix, name),
+ visualization_image,
+ max_outputs=max_images)
+ elif len(end_points[layer_name].get_shape().as_list()) == 3:
+ images = tf.expand_dims(end_points[layer_name], 3)
+ tf.summary.image(
+ '{}/{}'.format(prefix, name),
+ images,
+ max_outputs=max_images)
+ elif len(end_points[layer_name].get_shape().as_list()) == 2:
+ tf.summary.histogram('{}/{}'.format(prefix, name), end_points[layer_name])
+
+
+def reshape_feature_maps(features_tensor):
+ """Reshape activations for tf.summary.image visualization.
+
+ Arguments:
+ features_tensor: a tensor of activations with a square number of feature
+ maps, eg 4, 9, 16, etc.
+ Returns:
+ A composite image with all the feature maps that can be passed as an
+ argument to tf.summary.image.
+ """
+ assert len(features_tensor.get_shape().as_list()) == 4
+ num_filters = features_tensor.get_shape().as_list()[-1]
+ assert num_filters > 0
+ num_filters_sqrt = math.sqrt(num_filters)
+ assert num_filters_sqrt.is_integer(
+ ), 'Number of filters should be a square number but got {}'.format(
+ num_filters)
+ num_filters_sqrt = int(num_filters_sqrt)
+ conv_summary = tf.unstack(features_tensor, axis=3)
+ conv_one_row = tf.concat(axis=2, values=conv_summary[0:num_filters_sqrt])
+ ind = 1
+ conv_final = conv_one_row
+ for ind in range(1, num_filters_sqrt):
+ conv_one_row = tf.concat(axis=2,
+ values=conv_summary[
+ ind * num_filters_sqrt + 0:ind * num_filters_sqrt + num_filters_sqrt])
+ conv_final = tf.concat(
+ axis=1, values=[tf.squeeze(conv_final), tf.squeeze(conv_one_row)])
+ conv_final = tf.expand_dims(conv_final, -1)
+ return conv_final
+
+
+def accuracy(predictions, labels):
+ """Calculates the classificaton accuracy.
+
+ Args:
+ predictions: the predicted values, a tensor whose size matches 'labels'.
+ labels: the ground truth values, a tensor of any size.
+
+ Returns:
+ a tensor whose value on evaluation returns the total accuracy.
+ """
+ return tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))
+
+
+def compute_upsample_values(input_tensor, upsample_height, upsample_width):
+ """Compute values for an upsampling op (ops.BatchCropAndResize).
+
+ Args:
+ input_tensor: image tensor with shape [batch, height, width, in_channels]
+ upsample_height: integer
+ upsample_width: integer
+
+ Returns:
+ grid_centers: tensor with shape [batch, 1]
+ crop_sizes: tensor with shape [batch, 1]
+ output_height: integer
+ output_width: integer
+ """
+ batch, input_height, input_width, _ = input_tensor.shape
+
+ height_half = input_height / 2.
+ width_half = input_width / 2.
+ grid_centers = tf.constant(batch * [[height_half, width_half]])
+ crop_sizes = tf.constant(batch * [[input_height, input_width]])
+ output_height = input_height * upsample_height
+ output_width = input_width * upsample_width
+
+ return grid_centers, tf.to_float(crop_sizes), output_height, output_width
+
+
+def compute_pairwise_distances(x, y):
+ """Computes the squared pairwise Euclidean distances between x and y.
+
+ Args:
+ x: a tensor of shape [num_x_samples, num_features]
+ y: a tensor of shape [num_y_samples, num_features]
+
+ Returns:
+ a distance matrix of dimensions [num_x_samples, num_y_samples].
+
+ Raises:
+ ValueError: if the inputs do no matched the specified dimensions.
+ """
+
+ if not len(x.get_shape()) == len(y.get_shape()) == 2:
+ raise ValueError('Both inputs should be matrices.')
+
+ if x.get_shape().as_list()[1] != y.get_shape().as_list()[1]:
+ raise ValueError('The number of features should be the same.')
+
+ norm = lambda x: tf.reduce_sum(tf.square(x), 1)
+
+ # By making the `inner' dimensions of the two matrices equal to 1 using
+ # broadcasting then we are essentially substracting every pair of rows
+ # of x and y.
+ # x will be num_samples x num_features x 1,
+ # and y will be 1 x num_features x num_samples (after broadcasting).
+ # After the substraction we will get a
+ # num_x_samples x num_features x num_y_samples matrix.
+ # The resulting dist will be of shape num_y_samples x num_x_samples.
+ # and thus we need to transpose it again.
+ return tf.transpose(norm(tf.expand_dims(x, 2) - tf.transpose(y)))
+
+
+def gaussian_kernel_matrix(x, y, sigmas):
+ r"""Computes a Guassian Radial Basis Kernel between the samples of x and y.
+
+ We create a sum of multiple gaussian kernels each having a width sigma_i.
+
+ Args:
+ x: a tensor of shape [num_samples, num_features]
+ y: a tensor of shape [num_samples, num_features]
+ sigmas: a tensor of floats which denote the widths of each of the
+ gaussians in the kernel.
+ Returns:
+ A tensor of shape [num_samples{x}, num_samples{y}] with the RBF kernel.
+ """
+ beta = 1. / (2. * (tf.expand_dims(sigmas, 1)))
+
+ dist = compute_pairwise_distances(x, y)
+
+ s = tf.matmul(beta, tf.reshape(dist, (1, -1)))
+
+ return tf.reshape(tf.reduce_sum(tf.exp(-s), 0), tf.shape(dist))
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/BUILD b/models/research/domain_adaptation/pixel_domain_adaptation/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..2bc8d4a49a828f97b8f45166aa2bbc552d4a3b92
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/BUILD
@@ -0,0 +1,90 @@
+# Description:
+# Contains code for domain-adaptation style transfer.
+
+package(
+ default_visibility = [
+ ":internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+ name = "internal",
+ packages = [
+ "//domain_adaptation/...",
+ ],
+)
+
+py_library(
+ name = "pixelda_preprocess",
+ srcs = ["pixelda_preprocess.py"],
+ deps = [
+
+ ],
+)
+
+py_test(
+ name = "pixelda_preprocess_test",
+ srcs = ["pixelda_preprocess_test.py"],
+ deps = [
+ ":pixelda_preprocess",
+
+ ],
+)
+
+py_library(
+ name = "pixelda_model",
+ srcs = [
+ "pixelda_model.py",
+ "pixelda_task_towers.py",
+ "hparams.py",
+ ],
+ deps = [
+
+ ],
+)
+
+py_library(
+ name = "pixelda_utils",
+ srcs = ["pixelda_utils.py"],
+ deps = [
+
+ ],
+)
+
+py_library(
+ name = "pixelda_losses",
+ srcs = ["pixelda_losses.py"],
+ deps = [
+
+ ],
+)
+
+py_binary(
+ name = "pixelda_train",
+ srcs = ["pixelda_train.py"],
+ deps = [
+ ":pixelda_losses",
+ ":pixelda_model",
+ ":pixelda_preprocess",
+ ":pixelda_utils",
+
+ "//domain_adaptation/datasets:dataset_factory",
+ ],
+)
+
+py_binary(
+ name = "pixelda_eval",
+ srcs = ["pixelda_eval.py"],
+ deps = [
+ ":pixelda_losses",
+ ":pixelda_model",
+ ":pixelda_preprocess",
+ ":pixelda_utils",
+
+ "//domain_adaptation/datasets:dataset_factory",
+ ],
+)
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/README.md b/models/research/domain_adaptation/pixel_domain_adaptation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/baselines/BUILD b/models/research/domain_adaptation/pixel_domain_adaptation/baselines/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..c41a4ffeee80114145c4c3fc32a2191879b1b08a
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/baselines/BUILD
@@ -0,0 +1,23 @@
+licenses(["notice"]) # Apache 2.0
+
+py_binary(
+ name = "baseline_train",
+ srcs = ["baseline_train.py"],
+ deps = [
+
+ "//domain_adaptation/datasets:dataset_factory",
+ "//domain_adaptation/pixel_domain_adaptation:pixelda_model",
+ "//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
+ ],
+)
+
+py_binary(
+ name = "baseline_eval",
+ srcs = ["baseline_eval.py"],
+ deps = [
+
+ "//domain_adaptation/datasets:dataset_factory",
+ "//domain_adaptation/pixel_domain_adaptation:pixelda_model",
+ "//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
+ ],
+)
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/baselines/README.md b/models/research/domain_adaptation/pixel_domain_adaptation/baselines/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d61195ad2de6867801143aeda906cb5efe30a5e3
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/baselines/README.md
@@ -0,0 +1,60 @@
+The best baselines are obtainable via the following configuration:
+
+
+## MNIST => MNIST_M
+
+Accuracy:
+MNIST-Train: 99.9
+MNIST_M-Train: 63.9
+MNIST_M-Valid: 63.9
+MNIST_M-Test: 63.6
+
+Learning Rate = 0.0001
+Weight Decay = 0.0
+Number of Steps: 105,000
+
+## MNIST => USPS
+
+Accuracy:
+MNIST-Train: 100.0
+USPS-Train: 82.8
+USPS-Valid: 82.8
+USPS-Test: 78.9
+
+Learning Rate = 0.0001
+Weight Decay = 0.0
+Number of Steps: 22,000
+
+## MNIST_M => MNIST
+
+Accuracy:
+MNIST_M-Train: 100
+MNIST-Train: 98.5
+MNIST-Valid: 98.5
+MNIST-Test: 98.1
+
+Learning Rate = 0.001
+Weight Decay = 0.0
+Number of Steps: 604,400
+
+## MNIST_M => MNIST_M
+
+Accuracy:
+MNIST_M-Train: 100.0
+MNIST_M-Valid: 96.6
+MNIST_M-Test: 96.4
+
+Learning Rate = 0.001
+Weight Decay = 0.0
+Number of Steps: 139,400
+
+## USPS => USPS
+
+Accuracy:
+USPS-Train: 100.0
+USPS-Valid: 100.0
+USPS-Test: 96.5
+
+Learning Rate = 0.001
+Weight Decay = 0.0
+Number of Steps: 67,000
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/baselines/baseline_eval.py b/models/research/domain_adaptation/pixel_domain_adaptation/baselines/baseline_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b7ef6452b4897b00dc8c977bf40526ad5052ede
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/baselines/baseline_eval.py
@@ -0,0 +1,141 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Evals the classification/pose baselines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from functools import partial
+
+import math
+
+# Dependency imports
+
+import tensorflow as tf
+
+from domain_adaptation.datasets import dataset_factory
+from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
+from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+slim = tf.contrib.slim
+
+flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
+
+flags.DEFINE_string(
+ 'checkpoint_dir', None, 'The location of the checkpoint files.')
+
+flags.DEFINE_string(
+ 'eval_dir', None, 'The directory where evaluation logs are written.')
+
+flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
+
+flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
+
+flags.DEFINE_string('dataset_dir', None,
+ 'The directory where the data is stored.')
+
+flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
+
+flags.DEFINE_integer('eval_interval_secs', 60 * 5,
+ 'How often (in seconds) to run evaluation.')
+
+flags.DEFINE_integer(
+ 'num_readers', 4,
+ 'The number of parallel readers that read data from the dataset.')
+
+def main(unused_argv):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ hparams = tf.contrib.training.HParams()
+ hparams.weight_decay_task_classifier = 0.0
+
+ if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
+ hparams.task_tower = 'mnist'
+ else:
+ raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
+
+ if not tf.gfile.Exists(FLAGS.eval_dir):
+ tf.gfile.MakeDirs(FLAGS.eval_dir)
+
+ with tf.Graph().as_default():
+ dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.split_name,
+ FLAGS.dataset_dir)
+ num_classes = dataset.num_classes
+ num_samples = dataset.num_samples
+
+ preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
+ is_training=False)
+
+ images, labels = dataset_factory.provide_batch(
+ FLAGS.dataset_name,
+ FLAGS.split_name,
+ dataset_dir=FLAGS.dataset_dir,
+ num_readers=FLAGS.num_readers,
+ batch_size=FLAGS.batch_size,
+ num_preprocessing_threads=FLAGS.num_readers)
+
+ # Define the model
+ logits, _ = pixelda_task_towers.add_task_specific_model(
+ images, hparams, num_classes=num_classes, is_training=True)
+
+ #####################
+ # Define the losses #
+ #####################
+ if 'classes' in labels:
+ one_hot_labels = labels['classes']
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=one_hot_labels, logits=logits)
+ tf.summary.scalar('losses/Classification_Loss', loss)
+ else:
+ raise ValueError('Only support classification for now.')
+
+ total_loss = tf.losses.get_total_loss()
+
+ predictions = tf.reshape(tf.argmax(logits, 1), shape=[-1])
+ class_labels = tf.argmax(labels['classes'], 1)
+
+ metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
+ 'Mean_Loss':
+ tf.contrib.metrics.streaming_mean(total_loss),
+ 'Accuracy':
+ tf.contrib.metrics.streaming_accuracy(predictions,
+ tf.reshape(
+ class_labels,
+ shape=[-1])),
+ 'Recall_at_5':
+ tf.contrib.metrics.streaming_recall_at_k(logits, class_labels, 5),
+ })
+
+ tf.summary.histogram('outputs/Predictions', predictions)
+ tf.summary.histogram('outputs/Ground_Truth', class_labels)
+
+ for name, value in metrics_to_values.iteritems():
+ tf.summary.scalar(name, value)
+
+ num_batches = int(math.ceil(num_samples / float(FLAGS.batch_size)))
+
+ slim.evaluation.evaluation_loop(
+ master=FLAGS.master,
+ checkpoint_dir=FLAGS.checkpoint_dir,
+ logdir=FLAGS.eval_dir,
+ num_evals=num_batches,
+ eval_op=metrics_to_updates.values(),
+ eval_interval_secs=FLAGS.eval_interval_secs)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/baselines/baseline_train.py b/models/research/domain_adaptation/pixel_domain_adaptation/baselines/baseline_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c92bd81a7b68879000dd793ba2fd013f395f408
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/baselines/baseline_train.py
@@ -0,0 +1,161 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Trains the classification/pose baselines."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from functools import partial
+
+# Dependency imports
+
+import tensorflow as tf
+
+from domain_adaptation.datasets import dataset_factory
+from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
+from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+slim = tf.contrib.slim
+
+flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
+
+flags.DEFINE_integer('task', 0, 'The task ID.')
+
+flags.DEFINE_integer('num_ps_tasks', 0,
+ 'The number of parameter servers. If the value is 0, then '
+ 'the parameters are handled locally by the worker.')
+
+flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
+
+flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
+
+flags.DEFINE_string('dataset_dir', None,
+ 'The directory where the data is stored.')
+
+flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
+
+flags.DEFINE_float('learning_rate', 0.001, 'The initial learning rate.')
+
+flags.DEFINE_integer(
+ 'learning_rate_decay_steps', 20000,
+ 'The frequency, in steps, at which the learning rate is decayed.')
+
+flags.DEFINE_float('learning_rate_decay_factor',
+ 0.95,
+ 'The factor with which the learning rate is decayed.')
+
+flags.DEFINE_float('adam_beta1', 0.5, 'The beta1 value for the AdamOptimizer')
+
+flags.DEFINE_float('weight_decay', 1e-5,
+ 'The L2 coefficient on the model weights.')
+
+flags.DEFINE_string(
+ 'logdir', None, 'The location of the logs and checkpoints.')
+
+flags.DEFINE_integer('save_interval_secs', 600,
+ 'How often, in seconds, we save the model to disk.')
+
+flags.DEFINE_integer('save_summaries_secs', 600,
+ 'How often, in seconds, we compute the summaries.')
+
+flags.DEFINE_integer(
+ 'num_readers', 4,
+ 'The number of parallel readers that read data from the dataset.')
+
+flags.DEFINE_float(
+ 'moving_average_decay', 0.9999,
+ 'The amount of decay to use for moving averages.')
+
+
+def main(unused_argv):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ hparams = tf.contrib.training.HParams()
+ hparams.weight_decay_task_classifier = FLAGS.weight_decay
+
+ if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
+ hparams.task_tower = 'mnist'
+ else:
+ raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
+
+ with tf.Graph().as_default():
+ with tf.device(
+ tf.train.replica_device_setter(FLAGS.num_ps_tasks, merge_devices=True)):
+ dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
+ FLAGS.split_name, FLAGS.dataset_dir)
+ num_classes = dataset.num_classes
+
+ preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
+ is_training=True)
+
+ images, labels = dataset_factory.provide_batch(
+ FLAGS.dataset_name,
+ FLAGS.split_name,
+ dataset_dir=FLAGS.dataset_dir,
+ num_readers=FLAGS.num_readers,
+ batch_size=FLAGS.batch_size,
+ num_preprocessing_threads=FLAGS.num_readers)
+ # preprocess_fn=preprocess_fn)
+
+ # Define the model
+ logits, _ = pixelda_task_towers.add_task_specific_model(
+ images, hparams, num_classes=num_classes, is_training=True)
+
+ # Define the losses
+ if 'classes' in labels:
+ one_hot_labels = labels['classes']
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=one_hot_labels, logits=logits)
+ tf.summary.scalar('losses/Classification_Loss', loss)
+ else:
+ raise ValueError('Only support classification for now.')
+
+ total_loss = tf.losses.get_total_loss()
+ tf.summary.scalar('losses/Total_Loss', total_loss)
+
+ # Setup the moving averages
+ moving_average_variables = slim.get_model_variables()
+ variable_averages = tf.train.ExponentialMovingAverage(
+ FLAGS.moving_average_decay, slim.get_or_create_global_step())
+ tf.add_to_collection(
+ tf.GraphKeys.UPDATE_OPS,
+ variable_averages.apply(moving_average_variables))
+
+ # Specify the optimization scheme:
+ learning_rate = tf.train.exponential_decay(
+ FLAGS.learning_rate,
+ slim.get_or_create_global_step(),
+ FLAGS.learning_rate_decay_steps,
+ FLAGS.learning_rate_decay_factor,
+ staircase=True)
+
+ optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.adam_beta1)
+
+ train_op = slim.learning.create_train_op(total_loss, optimizer)
+
+ slim.learning.train(
+ train_op,
+ FLAGS.logdir,
+ master=FLAGS.master,
+ is_chief=(FLAGS.task == 0),
+ save_summaries_secs=FLAGS.save_summaries_secs,
+ save_interval_secs=FLAGS.save_interval_secs)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/hparams.py b/models/research/domain_adaptation/pixel_domain_adaptation/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba9539f7d435c86f9fc92ed3406835bdaf2b50f3
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/hparams.py
@@ -0,0 +1,201 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Define model HParams."""
+import tensorflow as tf
+
+
+def create_hparams(hparam_string=None):
+ """Create model hyperparameters. Parse nondefault from given string."""
+ hparams = tf.contrib.training.HParams(
+ # The name of the architecture to use.
+ arch='resnet',
+ lrelu_leakiness=0.2,
+ batch_norm_decay=0.9,
+ weight_decay=1e-5,
+ normal_init_std=0.02,
+ generator_kernel_size=3,
+ discriminator_kernel_size=3,
+
+ # Stop training after this many examples are processed
+ # If none, train indefinitely
+ num_training_examples=0,
+
+ # Apply data augmentation to datasets
+ # Applies only in training job
+ augment_source_images=False,
+ augment_target_images=False,
+
+ # Discriminator
+ # Number of filters in first layer of discriminator
+ num_discriminator_filters=64,
+ discriminator_conv_block_size=1, # How many convs to have at each size
+ discriminator_filter_factor=2.0, # Multiply # filters by this each layer
+ # Add gaussian noise with this stddev to every hidden layer of D
+ discriminator_noise_stddev=0.2, # lmetz: Start seeing results at >= 0.1
+ # If true, add this gaussian noise to input images to D as well
+ discriminator_image_noise=False,
+ discriminator_first_stride=1, # Stride in first conv of discriminator
+ discriminator_do_pooling=False, # If true, replace stride 2 with avg pool
+ discriminator_dropout_keep_prob=0.9, # keep probability for dropout
+
+ # DCGAN Generator
+ # Number of filters in generator decoder last layer (repeatedly halved
+ # from 1st layer)
+ num_decoder_filters=64,
+ # Number of filters in generator encoder 1st layer (repeatedly doubled
+ # after 1st layer)
+ num_encoder_filters=64,
+
+ # This is the shape to which the noise vector is projected (if we're
+ # transferring from noise).
+ # Write this way instead of [4, 4, 64] for hparam search flexibility
+ projection_shape_size=4,
+ projection_shape_channels=64,
+
+ # Indicates the method by which we enlarge the spatial representation
+ # of an image. Possible values include:
+ # - resize_conv: Performs a nearest neighbor resize followed by a conv.
+ # - conv2d_transpose: Performs a conv2d_transpose.
+ upsample_method='resize_conv',
+
+ # Visualization
+ summary_steps=500, # Output image summary every N steps
+
+ ###################################
+ # Task Classifier Hyperparameters #
+ ###################################
+
+ # Which task-specific prediction tower to use. Possible choices are:
+ # none: No task tower.
+ # doubling_pose_estimator: classifier + quaternion regressor.
+ # [conv + pool]* + FC
+ # Classifiers used in DSN paper:
+ # gtsrb: Classifier used for GTSRB
+ # svhn: Classifier used for SVHN
+ # mnist: Classifier used for MNIST
+ # pose_mini: Classifier + regressor used for pose_mini
+ task_tower='doubling_pose_estimator',
+ weight_decay_task_classifier=1e-5,
+ source_task_loss_weight=1.0,
+ transferred_task_loss_weight=1.0,
+
+ # Number of private layers in doubling_pose_estimator task tower
+ num_private_layers=2,
+
+ # The weight for the log quaternion loss we use for source and transferred
+ # samples of the cropped_linemod dataset.
+ # In the DSN work, 1/8 of the classifier weight worked well for our log
+ # quaternion loss
+ source_pose_weight=0.125 * 2.0,
+ transferred_pose_weight=0.125 * 1.0,
+
+ # If set to True, the style transfer network also attempts to change its
+ # weights to maximize the performance of the task tower. If set to False,
+ # then the style transfer network only attempts to change its weights to
+ # make the transferred images more likely according to the domain
+ # classifier.
+ task_tower_in_g_step=True,
+ task_loss_in_g_weight=1.0, # Weight of task loss in G
+
+ #########################################
+ # 'simple` generator arch model hparams #
+ #########################################
+ simple_num_conv_layers=1,
+ simple_conv_filters=8,
+
+ #########################
+ # Resnet Hyperparameters#
+ #########################
+ resnet_blocks=6, # Number of resnet blocks
+ resnet_filters=64, # Number of filters per conv in resnet blocks
+ # If true, add original input back to result of convolutions inside the
+ # resnet arch. If false, it turns into a simple stack of conv/relu/BN
+ # layers.
+ resnet_residuals=True,
+
+ #######################################
+ # The residual / interpretable model. #
+ #######################################
+ res_int_blocks=2, # The number of residual blocks.
+ res_int_convs=2, # The number of conv calls inside each block.
+ res_int_filters=64, # The number of filters used by each convolution.
+
+ ####################
+ # Latent variables #
+ ####################
+ # if true, then generate random noise and project to input for generator
+ noise_channel=True,
+ # The number of dimensions in the input noise vector.
+ noise_dims=10,
+
+ # If true, then one hot encode source image class and project as an
+ # additional channel for the input to generator. This gives the generator
+ # access to the class, which may help generation performance.
+ condition_on_source_class=False,
+
+ ########################
+ # Loss Hyperparameters #
+ ########################
+ domain_loss_weight=1.0,
+ style_transfer_loss_weight=1.0,
+
+ ########################################################################
+ # Encourages the transferred images to be similar to the source images #
+ # using a configurable metric. #
+ ########################################################################
+
+ # The weight of the loss function encouraging the source and transferred
+ # images to be similar. If set to 0, then the loss function is not used.
+ transferred_similarity_loss_weight=0.0,
+
+ # The type of loss used to encourage transferred and source image
+ # similarity. Valid values include:
+ # mpse: Mean Pairwise Squared Error
+ # mse: Mean Squared Error
+ # hinged_mse: Computes the mean squared error using squared differences
+ # greater than hparams.transferred_similarity_max_diff
+ # hinged_mae: Computes the mean absolute error using absolute
+ # differences greater than hparams.transferred_similarity_max_diff.
+ transferred_similarity_loss='mpse',
+
+ # The maximum allowable difference between the source and target images.
+ # This value is used, in effect, to produce a hinge loss. Note that the
+ # range of values should be between 0 and 1.
+ transferred_similarity_max_diff=0.4,
+
+ ################################
+ # Optimization Hyperparameters #
+ ################################
+ learning_rate=0.001,
+ batch_size=32,
+ lr_decay_steps=20000,
+ lr_decay_rate=0.95,
+
+ # Recomendation from the DCGAN paper:
+ adam_beta1=0.5,
+ clip_gradient_norm=5.0,
+
+ # The number of times we run the discriminator train_op in a row.
+ discriminator_steps=1,
+
+ # The number of times we run the generator train_op in a row.
+ generator_steps=1)
+
+ if hparam_string:
+ tf.logging.info('Parsing command line hparams: %s', hparam_string)
+ hparams.parse(hparam_string)
+
+ tf.logging.info('Final parsed hparams: %s', hparams.values())
+ return hparams
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_eval.py b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..23824249a9e95586ed85e40cd89c5f6814977969
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_eval.py
@@ -0,0 +1,298 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Evaluates the PIXELDA model.
+
+-- Compiles the model for CPU.
+$ bazel build -c opt third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
+
+-- Compile the model for GPU.
+$ bazel build -c opt --copt=-mavx --config=cuda \
+ third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
+
+-- Runs the training.
+$ ./bazel-bin/third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation/pixelda_eval \
+ --source_dataset=mnist \
+ --target_dataset=mnist_m \
+ --dataset_dir=/tmp/datasets/ \
+ --alsologtostderr
+
+-- Visualize the results.
+$ bash learning/brain/tensorboard/tensorboard.sh \
+ --port 2222 --logdir=/tmp/pixelda/
+"""
+from functools import partial
+import math
+
+# Dependency imports
+
+import tensorflow as tf
+
+from domain_adaptation.datasets import dataset_factory
+from domain_adaptation.pixel_domain_adaptation import pixelda_model
+from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
+from domain_adaptation.pixel_domain_adaptation import pixelda_utils
+from domain_adaptation.pixel_domain_adaptation import pixelda_losses
+from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
+
+slim = tf.contrib.slim
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
+
+flags.DEFINE_string('checkpoint_dir', '/tmp/pixelda/',
+ 'Directory where the model was written to.')
+
+flags.DEFINE_string('eval_dir', '/tmp/pixelda/',
+ 'Directory where the results are saved to.')
+
+flags.DEFINE_integer('eval_interval_secs', 60,
+ 'The frequency, in seconds, with which evaluation is run.')
+
+flags.DEFINE_string('target_split_name', 'test',
+ 'The name of the train/test split.')
+flags.DEFINE_string('source_split_name', 'train', 'Split for source dataset.'
+ ' Defaults to train.')
+
+flags.DEFINE_string('source_dataset', 'mnist',
+ 'The name of the source dataset.')
+
+flags.DEFINE_string('target_dataset', 'mnist_m',
+ 'The name of the target dataset.')
+
+flags.DEFINE_string(
+ 'dataset_dir',
+ '', # None,
+ 'The directory where the datasets can be found.')
+
+flags.DEFINE_integer(
+ 'num_readers', 4,
+ 'The number of parallel readers that read data from the dataset.')
+
+flags.DEFINE_integer('num_preprocessing_threads', 4,
+ 'The number of threads used to create the batches.')
+
+# HParams
+
+flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
+
+
+def run_eval(run_dir, checkpoint_dir, hparams):
+ """Runs the eval loop.
+
+ Args:
+ run_dir: The directory where eval specific logs are placed
+ checkpoint_dir: The directory where the checkpoints are stored
+ hparams: The hyperparameters struct.
+
+ Raises:
+ ValueError: if hparams.arch is not recognized.
+ """
+ for checkpoint_path in slim.evaluation.checkpoints_iterator(
+ checkpoint_dir, FLAGS.eval_interval_secs):
+ with tf.Graph().as_default():
+ #########################
+ # Preprocess the inputs #
+ #########################
+ target_dataset = dataset_factory.get_dataset(
+ FLAGS.target_dataset,
+ split_name=FLAGS.target_split_name,
+ dataset_dir=FLAGS.dataset_dir)
+ target_images, target_labels = dataset_factory.provide_batch(
+ FLAGS.target_dataset, FLAGS.target_split_name, FLAGS.dataset_dir,
+ FLAGS.num_readers, hparams.batch_size,
+ FLAGS.num_preprocessing_threads)
+ num_target_classes = target_dataset.num_classes
+ target_labels['class'] = tf.argmax(target_labels['classes'], 1)
+ del target_labels['classes']
+
+ if hparams.arch not in ['dcgan']:
+ source_dataset = dataset_factory.get_dataset(
+ FLAGS.source_dataset,
+ split_name=FLAGS.source_split_name,
+ dataset_dir=FLAGS.dataset_dir)
+ num_source_classes = source_dataset.num_classes
+ source_images, source_labels = dataset_factory.provide_batch(
+ FLAGS.source_dataset, FLAGS.source_split_name, FLAGS.dataset_dir,
+ FLAGS.num_readers, hparams.batch_size,
+ FLAGS.num_preprocessing_threads)
+ source_labels['class'] = tf.argmax(source_labels['classes'], 1)
+ del source_labels['classes']
+ if num_source_classes != num_target_classes:
+ raise ValueError(
+ 'Input and output datasets must have same number of classes')
+ else:
+ source_images = None
+ source_labels = None
+
+ ####################
+ # Define the model #
+ ####################
+ end_points = pixelda_model.create_model(
+ hparams,
+ target_images,
+ source_images=source_images,
+ source_labels=source_labels,
+ is_training=False,
+ num_classes=num_target_classes)
+
+ #######################
+ # Metrics & Summaries #
+ #######################
+ names_to_values, names_to_updates = create_metrics(end_points,
+ source_labels,
+ target_labels, hparams)
+ pixelda_utils.summarize_model(end_points)
+ pixelda_utils.summarize_transferred_grid(
+ end_points['transferred_images'], source_images, name='Transferred')
+ if 'source_images_recon' in end_points:
+ pixelda_utils.summarize_transferred_grid(
+ end_points['source_images_recon'],
+ source_images,
+ name='Source Reconstruction')
+ pixelda_utils.summarize_images(target_images, 'Target')
+
+ for name, value in names_to_values.iteritems():
+ tf.summary.scalar(name, value)
+
+ # Use the entire split by default
+ num_examples = target_dataset.num_samples
+
+ num_batches = math.ceil(num_examples / float(hparams.batch_size))
+ global_step = slim.get_or_create_global_step()
+
+ result = slim.evaluation.evaluate_once(
+ master=FLAGS.master,
+ checkpoint_path=checkpoint_path,
+ logdir=run_dir,
+ num_evals=num_batches,
+ eval_op=names_to_updates.values(),
+ final_op=names_to_values)
+
+
+def to_degrees(log_quaternion_loss):
+ """Converts a log quaternion distance to an angle.
+
+ Args:
+ log_quaternion_loss: The log quaternion distance between two
+ unit quaternions (or a batch of pairs of quaternions).
+
+ Returns:
+ The angle in degrees of the implied angle-axis representation.
+ """
+ return tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
+
+
+def create_metrics(end_points, source_labels, target_labels, hparams):
+ """Create metrics for the model.
+
+ Args:
+ end_points: A dictionary of end point name to tensor
+ source_labels: Labels for source images. batch_size x 1
+ target_labels: Labels for target images. batch_size x 1
+ hparams: The hyperparameters struct.
+
+ Returns:
+ Tuple of (names_to_values, names_to_updates), dictionaries that map a metric
+ name to its value and update op, respectively
+
+ """
+ ###########################################
+ # Evaluate the Domain Prediction Accuracy #
+ ###########################################
+ batch_size = hparams.batch_size
+ names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
+ ('eval/Domain_Accuracy-Transferred'):
+ tf.contrib.metrics.streaming_accuracy(
+ tf.to_int32(
+ tf.round(tf.sigmoid(end_points[
+ 'transferred_domain_logits']))),
+ tf.zeros(batch_size, dtype=tf.int32)),
+ ('eval/Domain_Accuracy-Target'):
+ tf.contrib.metrics.streaming_accuracy(
+ tf.to_int32(
+ tf.round(tf.sigmoid(end_points['target_domain_logits']))),
+ tf.ones(batch_size, dtype=tf.int32))
+ })
+
+ ################################
+ # Evaluate the task classifier #
+ ################################
+ if 'source_task_logits' in end_points:
+ metric_name = 'eval/Task_Accuracy-Source'
+ names_to_values[metric_name], names_to_updates[
+ metric_name] = tf.contrib.metrics.streaming_accuracy(
+ tf.argmax(end_points['source_task_logits'], 1),
+ source_labels['class'])
+
+ if 'transferred_task_logits' in end_points:
+ metric_name = 'eval/Task_Accuracy-Transferred'
+ names_to_values[metric_name], names_to_updates[
+ metric_name] = tf.contrib.metrics.streaming_accuracy(
+ tf.argmax(end_points['transferred_task_logits'], 1),
+ source_labels['class'])
+
+ if 'target_task_logits' in end_points:
+ metric_name = 'eval/Task_Accuracy-Target'
+ names_to_values[metric_name], names_to_updates[
+ metric_name] = tf.contrib.metrics.streaming_accuracy(
+ tf.argmax(end_points['target_task_logits'], 1),
+ target_labels['class'])
+
+ ##########################################################################
+ # Pose data-specific losses.
+ ##########################################################################
+ if 'quaternion' in source_labels.keys():
+ params = {}
+ params['use_logging'] = False
+ params['batch_size'] = batch_size
+
+ angle_loss_source = to_degrees(
+ pixelda_losses.log_quaternion_loss_batch(end_points[
+ 'source_quaternion'], source_labels['quaternion'], params))
+ angle_loss_transferred = to_degrees(
+ pixelda_losses.log_quaternion_loss_batch(end_points[
+ 'transferred_quaternion'], source_labels['quaternion'], params))
+ angle_loss_target = to_degrees(
+ pixelda_losses.log_quaternion_loss_batch(end_points[
+ 'target_quaternion'], target_labels['quaternion'], params))
+
+ metric_name = 'eval/Angle_Loss-Source'
+ names_to_values[metric_name], names_to_updates[
+ metric_name] = slim.metrics.mean(angle_loss_source)
+
+ metric_name = 'eval/Angle_Loss-Transferred'
+ names_to_values[metric_name], names_to_updates[
+ metric_name] = slim.metrics.mean(angle_loss_transferred)
+
+ metric_name = 'eval/Angle_Loss-Target'
+ names_to_values[metric_name], names_to_updates[
+ metric_name] = slim.metrics.mean(angle_loss_target)
+
+ return names_to_values, names_to_updates
+
+
+def main(_):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ hparams = create_hparams(FLAGS.hparams)
+ run_eval(
+ run_dir=FLAGS.eval_dir,
+ checkpoint_dir=FLAGS.checkpoint_dir,
+ hparams=hparams)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_losses.py b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf39765d4d28c5a04cb8868cdc465cdd0129b0df
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_losses.py
@@ -0,0 +1,385 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines the various loss functions in use by the PIXELDA model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def add_domain_classifier_losses(end_points, hparams):
+ """Adds losses related to the domain-classifier.
+
+ Args:
+ end_points: A map of network end point names to `Tensors`.
+ hparams: The hyperparameters struct.
+
+ Returns:
+ loss: A `Tensor` representing the total task-classifier loss.
+ """
+ if hparams.domain_loss_weight == 0:
+ tf.logging.info(
+ 'Domain classifier loss weight is 0, so not creating losses.')
+ return 0
+
+ # The domain prediction loss is minimized with respect to the domain
+ # classifier features only. Its aim is to predict the domain of the images.
+ # Note: 1 = 'real image' label, 0 = 'fake image' label
+ transferred_domain_loss = tf.losses.sigmoid_cross_entropy(
+ multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']),
+ logits=end_points['transferred_domain_logits'])
+ tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss)
+
+ target_domain_loss = tf.losses.sigmoid_cross_entropy(
+ multi_class_labels=tf.ones_like(end_points['target_domain_logits']),
+ logits=end_points['target_domain_logits'])
+ tf.summary.scalar('Domain_loss_target', target_domain_loss)
+
+ # Compute the total domain loss:
+ total_domain_loss = transferred_domain_loss + target_domain_loss
+ total_domain_loss *= hparams.domain_loss_weight
+ tf.summary.scalar('Domain_loss_total', total_domain_loss)
+
+ return total_domain_loss
+
+def log_quaternion_loss_batch(predictions, labels, params):
+ """A helper function to compute the error between quaternions.
+
+ Args:
+ predictions: A Tensor of size [batch_size, 4].
+ labels: A Tensor of size [batch_size, 4].
+ params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
+
+ Returns:
+ A Tensor of size [batch_size], denoting the error between the quaternions.
+ """
+ use_logging = params['use_logging']
+ assertions = []
+ if use_logging:
+ assertions.append(
+ tf.Assert(
+ tf.reduce_all(
+ tf.less(
+ tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
+ 1e-4)),
+ ['The l2 norm of each prediction quaternion vector should be 1.']))
+ assertions.append(
+ tf.Assert(
+ tf.reduce_all(
+ tf.less(
+ tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
+ ['The l2 norm of each label quaternion vector should be 1.']))
+
+ with tf.control_dependencies(assertions):
+ product = tf.multiply(predictions, labels)
+ internal_dot_products = tf.reduce_sum(product, [1])
+
+ if use_logging:
+ internal_dot_products = tf.Print(internal_dot_products, [
+ internal_dot_products,
+ tf.shape(internal_dot_products)
+ ], 'internal_dot_products:')
+
+ logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+ return logcost
+
+
+def log_quaternion_loss(predictions, labels, params):
+ """A helper function to compute the mean error between batches of quaternions.
+
+ The caller is expected to add the loss to the graph.
+
+ Args:
+ predictions: A Tensor of size [batch_size, 4].
+ labels: A Tensor of size [batch_size, 4].
+ params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
+
+ Returns:
+ A Tensor of size 1, denoting the mean error between batches of quaternions.
+ """
+ use_logging = params['use_logging']
+ logcost = log_quaternion_loss_batch(predictions, labels, params)
+ logcost = tf.reduce_sum(logcost, [0])
+ batch_size = params['batch_size']
+ logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
+ if use_logging:
+ logcost = tf.Print(
+ logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
+ return logcost
+
+def _quaternion_loss(labels, predictions, weight, batch_size, domain,
+ add_summaries):
+ """Creates a Quaternion Loss.
+
+ Args:
+ labels: The true quaternions.
+ predictions: The predicted quaternions.
+ weight: A scalar weight.
+ batch_size: The size of the batches.
+ domain: The name of the domain from which the labels were taken.
+ add_summaries: Whether or not to add summaries for the losses.
+
+ Returns:
+ A `Tensor` representing the loss.
+ """
+ assert domain in ['Source', 'Transferred']
+
+ params = {'use_logging': False, 'batch_size': batch_size}
+ loss = weight * log_quaternion_loss(labels, predictions, params)
+
+ if add_summaries:
+ assert_op = tf.Assert(tf.is_finite(loss), [loss])
+ with tf.control_dependencies([assert_op]):
+ tf.summary.histogram(
+ 'Log_Quaternion_Loss_%s' % domain, loss, collections='losses')
+ tf.summary.scalar(
+ 'Task_Quaternion_Loss_%s' % domain, loss, collections='losses')
+
+ return loss
+
+
+def _add_task_specific_losses(end_points, source_labels, num_classes, hparams,
+ add_summaries=False):
+ """Adds losses related to the task-classifier.
+
+ Args:
+ end_points: A map of network end point names to `Tensors`.
+ source_labels: A dictionary of output labels to `Tensors`.
+ num_classes: The number of classes used by the classifier.
+ hparams: The hyperparameters struct.
+ add_summaries: Whether or not to add the summaries.
+
+ Returns:
+ loss: A `Tensor` representing the total task-classifier loss.
+ """
+ # TODO(ddohan): Make sure the l2 regularization is added to the loss
+
+ one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes)
+ total_loss = 0
+
+ if 'source_task_logits' in end_points:
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=one_hot_labels,
+ logits=end_points['source_task_logits'],
+ weights=hparams.source_task_loss_weight)
+ if add_summaries:
+ tf.summary.scalar('Task_Classifier_Loss_Source', loss)
+ total_loss += loss
+
+ if 'transferred_task_logits' in end_points:
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=one_hot_labels,
+ logits=end_points['transferred_task_logits'],
+ weights=hparams.transferred_task_loss_weight)
+ if add_summaries:
+ tf.summary.scalar('Task_Classifier_Loss_Transferred', loss)
+ total_loss += loss
+
+ #########################
+ # Pose specific losses. #
+ #########################
+ if 'quaternion' in source_labels:
+ total_loss += _quaternion_loss(
+ source_labels['quaternion'],
+ end_points['source_quaternion'],
+ hparams.source_pose_weight,
+ hparams.batch_size,
+ 'Source',
+ add_summaries)
+
+ total_loss += _quaternion_loss(
+ source_labels['quaternion'],
+ end_points['transferred_quaternion'],
+ hparams.transferred_pose_weight,
+ hparams.batch_size,
+ 'Transferred',
+ add_summaries)
+
+ if add_summaries:
+ tf.summary.scalar('Task_Loss_Total', total_loss)
+
+ return total_loss
+
+
+def _transferred_similarity_loss(reconstructions,
+ source_images,
+ weight=1.0,
+ method='mse',
+ max_diff=0.4,
+ name='similarity'):
+ """Computes a loss encouraging similarity between source and transferred.
+
+ Args:
+ reconstructions: A `Tensor` of shape [batch_size, height, width, channels]
+ source_images: A `Tensor` of shape [batch_size, height, width, channels].
+ weight: Multiple similarity loss by this weight before returning
+ method: One of:
+ mpse = Mean Pairwise Squared Error
+ mse = Mean Squared Error
+ hinged_mse = Computes the mean squared error using squared differences
+ greater than hparams.transferred_similarity_max_diff
+ hinged_mae = Computes the mean absolute error using absolute
+ differences greater than hparams.transferred_similarity_max_diff.
+ max_diff: Maximum unpenalized difference for hinged losses
+ name: Identifying name to use for creating summaries
+
+
+ Returns:
+ A `Tensor` representing the transferred similarity loss.
+
+ Raises:
+ ValueError: if `method` is not recognized.
+ """
+ if weight == 0:
+ return 0
+
+ source_channels = source_images.shape.as_list()[-1]
+ reconstruction_channels = reconstructions.shape.as_list()[-1]
+
+ # Convert grayscale source to RGB if target is RGB
+ if source_channels == 1 and reconstruction_channels != 1:
+ source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels])
+ if reconstruction_channels == 1 and source_channels != 1:
+ reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels])
+
+ if method == 'mpse':
+ reconstruction_similarity_loss_fn = (
+ tf.contrib.losses.mean_pairwise_squared_error)
+ elif method == 'masked_mpse':
+
+ def masked_mpse(predictions, labels, weight):
+ """Masked mpse assuming we have a depth to create a mask from."""
+ assert labels.shape.as_list()[-1] == 4
+ mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99))
+ mask = tf.tile(mask, [1, 1, 1, 4])
+ predictions *= mask
+ labels *= mask
+ tf.image_summary('masked_pred', predictions)
+ tf.image_summary('masked_label', labels)
+ return tf.contrib.losses.mean_pairwise_squared_error(
+ predictions, labels, weight)
+
+ reconstruction_similarity_loss_fn = masked_mpse
+ elif method == 'mse':
+ reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error
+ elif method == 'hinged_mse':
+
+ def hinged_mse(predictions, labels, weight):
+ diffs = tf.square(predictions - labels)
+ diffs = tf.maximum(0.0, diffs - max_diff)
+ return tf.reduce_mean(diffs) * weight
+
+ reconstruction_similarity_loss_fn = hinged_mse
+ elif method == 'hinged_mae':
+
+ def hinged_mae(predictions, labels, weight):
+ diffs = tf.abs(predictions - labels)
+ diffs = tf.maximum(0.0, diffs - max_diff)
+ return tf.reduce_mean(diffs) * weight
+
+ reconstruction_similarity_loss_fn = hinged_mae
+ else:
+ raise ValueError('Unknown reconstruction loss %s' % method)
+
+ reconstruction_similarity_loss = reconstruction_similarity_loss_fn(
+ reconstructions, source_images, weight)
+
+ name = '%s_Similarity_(%s)' % (name, method)
+ tf.summary.scalar(name, reconstruction_similarity_loss)
+ return reconstruction_similarity_loss
+
+
+def g_step_loss(source_images, source_labels, end_points, hparams, num_classes):
+ """Configures the loss function which runs during the g-step.
+
+ Args:
+ source_images: A `Tensor` of shape [batch_size, height, width, channels].
+ source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys
+ are 'class' and 'quaternion'.
+ end_points: A map of the network end points.
+ hparams: The hyperparameters struct.
+ num_classes: Number of classes for classifier loss
+
+ Returns:
+ A `Tensor` representing a loss function.
+
+ Raises:
+ ValueError: if hparams.transferred_similarity_loss_weight is non-zero but
+ hparams.transferred_similarity_loss is invalid.
+ """
+ generator_loss = 0
+
+ ################################################################
+ # Adds a loss which encourages the discriminator probabilities #
+ # to be high (near one).
+ ################################################################
+
+ # As per the GAN paper, maximize the log probs, instead of minimizing
+ # log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is
+ # the same thing.
+ style_transfer_loss = tf.losses.sigmoid_cross_entropy(
+ logits=end_points['transferred_domain_logits'],
+ multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']),
+ weights=hparams.style_transfer_loss_weight)
+ tf.summary.scalar('Style_transfer_loss', style_transfer_loss)
+ generator_loss += style_transfer_loss
+
+ # Optimizes the style transfer network to produce transferred images similar
+ # to the source images.
+ generator_loss += _transferred_similarity_loss(
+ end_points['transferred_images'],
+ source_images,
+ weight=hparams.transferred_similarity_loss_weight,
+ method=hparams.transferred_similarity_loss,
+ name='transferred_similarity')
+
+ # Optimizes the style transfer network to maximize classification accuracy.
+ if source_labels is not None and hparams.task_tower_in_g_step:
+ generator_loss += _add_task_specific_losses(
+ end_points, source_labels, num_classes,
+ hparams) * hparams.task_loss_in_g_weight
+
+ return generator_loss
+
+
+def d_step_loss(end_points, source_labels, num_classes, hparams):
+ """Configures the losses during the D-Step.
+
+ Note that during the D-step, the model optimizes both the domain (binary)
+ classifier and the task classifier.
+
+ Args:
+ end_points: A map of the network end points.
+ source_labels: A dictionary of output labels to `Tensors`.
+ num_classes: The number of classes used by the classifier.
+ hparams: The hyperparameters struct.
+
+ Returns:
+ A `Tensor` representing the value of the D-step loss.
+ """
+ domain_classifier_loss = add_domain_classifier_losses(end_points, hparams)
+
+ task_classifier_loss = 0
+ if source_labels is not None:
+ task_classifier_loss = _add_task_specific_losses(
+ end_points, source_labels, num_classes, hparams, add_summaries=True)
+
+ return domain_classifier_loss + task_classifier_loss
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_model.py b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..16b550a62d88ec2724c91f9dab9e3b34c736ec4f
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_model.py
@@ -0,0 +1,713 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains the Domain Adaptation via Style Transfer (PixelDA) model components.
+
+A number of details in the implementation make reference to one of the following
+works:
+
+- "Unsupervised Representation Learning with Deep Convolutional
+ Generative Adversarial Networks""
+ https://arxiv.org/abs/1511.06434
+
+This paper makes several architecture recommendations:
+1. Use strided convs in discriminator, fractional-strided convs in generator
+2. batchnorm everywhere
+3. remove fully connected layers for deep models
+4. ReLu for all layers in generator, except tanh on output
+5. LeakyReLu for everything in discriminator
+"""
+import functools
+import math
+
+# Dependency imports
+import numpy as np
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
+
+
+def create_model(hparams,
+ target_images,
+ source_images=None,
+ source_labels=None,
+ is_training=False,
+ noise=None,
+ num_classes=None):
+ """Create a GAN model.
+
+ Arguments:
+ hparams: HParam object specifying model params
+ target_images: A `Tensor` of size [batch_size, height, width, channels]. It
+ is assumed that the images are [-1, 1] normalized.
+ source_images: A `Tensor` of size [batch_size, height, width, channels]. It
+ is assumed that the images are [-1, 1] normalized.
+ source_labels: A `Tensor` of size [batch_size] of categorical labels between
+ [0, num_classes]
+ is_training: whether model is currently training
+ noise: If None, model generates its own noise. Otherwise use provided.
+ num_classes: Number of classes for classification
+
+ Returns:
+ end_points dict with model outputs
+
+ Raises:
+ ValueError: unknown hparams.arch setting
+ """
+ if num_classes is None and hparams.arch in ['resnet', 'simple']:
+ raise ValueError('Num classes must be provided to create task classifier')
+
+ if target_images.dtype != tf.float32:
+ raise ValueError('target_images must be tf.float32 and [-1, 1] normalized.')
+ if source_images is not None and source_images.dtype != tf.float32:
+ raise ValueError('source_images must be tf.float32 and [-1, 1] normalized.')
+
+ ###########################
+ # Create latent variables #
+ ###########################
+ latent_vars = dict()
+
+ if hparams.noise_channel:
+ noise_shape = [hparams.batch_size, hparams.noise_dims]
+ if noise is not None:
+ assert noise.shape.as_list() == noise_shape
+ tf.logging.info('Using provided noise')
+ else:
+ tf.logging.info('Using random noise')
+ noise = tf.random_uniform(
+ shape=noise_shape,
+ minval=-1,
+ maxval=1,
+ dtype=tf.float32,
+ name='random_noise')
+ latent_vars['noise'] = noise
+
+ ####################
+ # Create generator #
+ ####################
+
+ with slim.arg_scope(
+ [slim.conv2d, slim.conv2d_transpose, slim.fully_connected],
+ normalizer_params=batch_norm_params(is_training,
+ hparams.batch_norm_decay),
+ weights_initializer=tf.random_normal_initializer(
+ stddev=hparams.normal_init_std),
+ weights_regularizer=tf.contrib.layers.l2_regularizer(
+ hparams.weight_decay)):
+ with slim.arg_scope([slim.conv2d], padding='SAME'):
+ if hparams.arch == 'dcgan':
+ end_points = dcgan(
+ target_images, latent_vars, hparams, scope='generator')
+ elif hparams.arch == 'resnet':
+ end_points = resnet_generator(
+ source_images,
+ target_images.shape.as_list()[1:4],
+ hparams=hparams,
+ latent_vars=latent_vars)
+ elif hparams.arch == 'residual_interpretation':
+ end_points = residual_interpretation_generator(
+ source_images, is_training=is_training, hparams=hparams)
+ elif hparams.arch == 'simple':
+ end_points = simple_generator(
+ source_images,
+ target_images,
+ is_training=is_training,
+ hparams=hparams,
+ latent_vars=latent_vars)
+ elif hparams.arch == 'identity':
+ # Pass through unmodified, besides changing # channels
+ # Used to calculate baseline numbers
+ # Also set `generator_steps=0` for baseline
+ if hparams.generator_steps:
+ raise ValueError('Must set generator_steps=0 for identity arch. Is %s'
+ % hparams.generator_steps)
+ transferred_images = source_images
+ source_channels = source_images.shape.as_list()[-1]
+ target_channels = target_images.shape.as_list()[-1]
+ if source_channels == 1 and target_channels == 3:
+ transferred_images = tf.tile(source_images, [1, 1, 1, 3])
+ if source_channels == 3 and target_channels == 1:
+ transferred_images = tf.image.rgb_to_grayscale(source_images)
+ end_points = {'transferred_images': transferred_images}
+ else:
+ raise ValueError('Unknown architecture: %s' % hparams.arch)
+
+ #####################
+ # Domain Classifier #
+ #####################
+ if hparams.arch in [
+ 'dcgan', 'resnet', 'residual_interpretation', 'simple', 'identity',
+ ]:
+
+ # Add a discriminator for these architectures
+ end_points['transferred_domain_logits'] = predict_domain(
+ end_points['transferred_images'],
+ hparams,
+ is_training=is_training,
+ reuse=False)
+ end_points['target_domain_logits'] = predict_domain(
+ target_images,
+ hparams,
+ is_training=is_training,
+ reuse=True)
+
+ ###################
+ # Task Classifier #
+ ###################
+ if hparams.task_tower != 'none' and hparams.arch in [
+ 'resnet', 'residual_interpretation', 'simple', 'identity',
+ ]:
+ with tf.variable_scope('discriminator'):
+ with tf.variable_scope('task_tower'):
+ end_points['source_task_logits'], end_points[
+ 'source_quaternion'] = pixelda_task_towers.add_task_specific_model(
+ source_images,
+ hparams,
+ num_classes=num_classes,
+ is_training=is_training,
+ reuse_private=False,
+ private_scope='source_task_classifier',
+ reuse_shared=False)
+ end_points['transferred_task_logits'], end_points[
+ 'transferred_quaternion'] = (
+ pixelda_task_towers.add_task_specific_model(
+ end_points['transferred_images'],
+ hparams,
+ num_classes=num_classes,
+ is_training=is_training,
+ reuse_private=False,
+ private_scope='transferred_task_classifier',
+ reuse_shared=True))
+ end_points['target_task_logits'], end_points[
+ 'target_quaternion'] = pixelda_task_towers.add_task_specific_model(
+ target_images,
+ hparams,
+ num_classes=num_classes,
+ is_training=is_training,
+ reuse_private=True,
+ private_scope='transferred_task_classifier',
+ reuse_shared=True)
+ # Remove any endpoints with None values
+ return dict((k, v) for k, v in end_points.iteritems() if v is not None)
+
+
+def batch_norm_params(is_training, batch_norm_decay):
+ return {
+ 'is_training': is_training,
+ # Decay for the moving averages.
+ 'decay': batch_norm_decay,
+ # epsilon to prevent 0s in variance.
+ 'epsilon': 0.001,
+ }
+
+
+def lrelu(x, leakiness=0.2):
+ """Relu, with optional leaky support."""
+ return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
+
+
+def upsample(net, num_filters, scale=2, method='resize_conv', scope=None):
+ """Performs spatial upsampling of the given features.
+
+ Args:
+ net: A `Tensor` of shape [batch_size, height, width, filters].
+ num_filters: The number of output filters.
+ scale: The scale of the upsampling. Must be a positive integer greater or
+ equal to two.
+ method: The method by which the features are upsampled. Valid options
+ include 'resize_conv' and 'conv2d_transpose'.
+ scope: An optional variable scope.
+
+ Returns:
+ A new set of features of shape
+ [batch_size, height*scale, width*scale, num_filters].
+
+ Raises:
+ ValueError: if `method` is not valid or
+ """
+ if scale < 2:
+ raise ValueError('scale must be greater or equal to two.')
+
+ with tf.variable_scope(scope, 'upsample', [net]):
+ if method == 'resize_conv':
+ net = tf.image.resize_nearest_neighbor(
+ net, [net.shape.as_list()[1] * scale,
+ net.shape.as_list()[2] * scale],
+ align_corners=True,
+ name='resize')
+ return slim.conv2d(net, num_filters, stride=1, scope='conv')
+ elif method == 'conv2d_transpose':
+ return slim.conv2d_transpose(net, num_filters, scope='deconv')
+ else:
+ raise ValueError('Upsample method [%s] was not recognized.' % method)
+
+
+def project_latent_vars(hparams, proj_shape, latent_vars, combine_method='sum'):
+ """Generate noise and project to input volume size.
+
+ Args:
+ hparams: The hyperparameter HParams struct.
+ proj_shape: Shape to project noise (not including batch size).
+ latent_vars: dictionary of `'key': Tensor of shape [batch_size, N]`
+ combine_method: How to combine the projected values.
+ sum = project to volume then sum
+ concat = concatenate along last dimension (i.e. channel)
+
+ Returns:
+ If combine_method=sum, a `Tensor` of size `hparams.projection_shape`
+ If combine_method=concat and there are N latent vars, a `Tensor` of size
+ `hparams.projection_shape`, with the last channel multiplied by N
+
+
+ Raises:
+ ValueError: combine_method is not one of sum/concat
+ """
+ values = []
+ for var in latent_vars:
+ with tf.variable_scope(var):
+ # Project & reshape noise to a HxWxC input
+ projected = slim.fully_connected(
+ latent_vars[var],
+ np.prod(proj_shape),
+ activation_fn=tf.nn.relu,
+ normalizer_fn=slim.batch_norm)
+ values.append(tf.reshape(projected, [hparams.batch_size] + proj_shape))
+
+ if combine_method == 'sum':
+ result = values[0]
+ for value in values[1:]:
+ result += value
+ elif combine_method == 'concat':
+ # Concatenate along last axis
+ result = tf.concat(values, len(proj_shape))
+ else:
+ raise ValueError('Unknown combine_method %s' % combine_method)
+
+ tf.logging.info('Latent variables projected to size %s volume', result.shape)
+
+ return result
+
+
+def resnet_block(net, hparams):
+ """Create a resnet block."""
+ net_in = net
+ net = slim.conv2d(
+ net,
+ hparams.resnet_filters,
+ stride=1,
+ normalizer_fn=slim.batch_norm,
+ activation_fn=tf.nn.relu)
+ net = slim.conv2d(
+ net,
+ hparams.resnet_filters,
+ stride=1,
+ normalizer_fn=slim.batch_norm,
+ activation_fn=None)
+ if hparams.resnet_residuals:
+ net += net_in
+ return net
+
+
+def resnet_stack(images, output_shape, hparams, scope=None):
+ """Create a resnet style transfer block.
+
+ Args:
+ images: [batch-size, height, width, channels] image tensor to feed as input
+ output_shape: output image shape in form [height, width, channels]
+ hparams: hparams objects
+ scope: Variable scope
+
+ Returns:
+ Images after processing with resnet blocks.
+ """
+ end_points = {}
+ if hparams.noise_channel:
+ # separate the noise for visualization
+ end_points['noise'] = images[:, :, :, -1]
+ assert images.shape.as_list()[1:3] == output_shape[0:2]
+
+ with tf.variable_scope(scope, 'resnet_style_transfer', [images]):
+ with slim.arg_scope(
+ [slim.conv2d],
+ normalizer_fn=slim.batch_norm,
+ kernel_size=[hparams.generator_kernel_size] * 2,
+ stride=1):
+ net = slim.conv2d(
+ images,
+ hparams.resnet_filters,
+ normalizer_fn=None,
+ activation_fn=tf.nn.relu)
+ for block in range(hparams.resnet_blocks):
+ net = resnet_block(net, hparams)
+ end_points['resnet_block_{}'.format(block)] = net
+
+ net = slim.conv2d(
+ net,
+ output_shape[-1],
+ kernel_size=[1, 1],
+ normalizer_fn=None,
+ activation_fn=tf.nn.tanh,
+ scope='conv_out')
+ end_points['transferred_images'] = net
+ return net, end_points
+
+
+def predict_domain(images,
+ hparams,
+ is_training=False,
+ reuse=False,
+ scope='discriminator'):
+ """Creates a discriminator for a GAN.
+
+ Args:
+ images: A `Tensor` of size [batch_size, height, width, channels]. It is
+ assumed that the images are centered between -1 and 1.
+ hparams: hparam object with params for discriminator
+ is_training: Specifies whether or not we're training or testing.
+ reuse: Whether to reuse variable scope
+ scope: An optional variable_scope.
+
+ Returns:
+ [batch size, 1] - logit output of discriminator.
+ """
+ with tf.variable_scope(scope, 'discriminator', [images], reuse=reuse):
+ lrelu_partial = functools.partial(lrelu, leakiness=hparams.lrelu_leakiness)
+ with slim.arg_scope(
+ [slim.conv2d],
+ kernel_size=[hparams.discriminator_kernel_size] * 2,
+ activation_fn=lrelu_partial,
+ stride=2,
+ normalizer_fn=slim.batch_norm):
+
+ def add_noise(hidden, scope_num=None):
+ if scope_num:
+ hidden = slim.dropout(
+ hidden,
+ hparams.discriminator_dropout_keep_prob,
+ is_training=is_training,
+ scope='dropout_%s' % scope_num)
+ if hparams.discriminator_noise_stddev == 0:
+ return hidden
+ return hidden + tf.random_normal(
+ hidden.shape.as_list(),
+ mean=0.0,
+ stddev=hparams.discriminator_noise_stddev)
+
+ # As per the recommendation of the DCGAN paper, we don't use batch norm
+ # on the discriminator input (https://arxiv.org/pdf/1511.06434v2.pdf).
+ if hparams.discriminator_image_noise:
+ images = add_noise(images)
+ net = slim.conv2d(
+ images,
+ hparams.num_discriminator_filters,
+ normalizer_fn=None,
+ stride=hparams.discriminator_first_stride,
+ scope='conv1_stride%s' % hparams.discriminator_first_stride)
+ net = add_noise(net, 1)
+
+ block_id = 2
+ # Repeatedly stack
+ # discriminator_conv_block_size-1 conv layers with stride 1
+ # followed by a stride 2 layer
+ # Add (optional) noise at every point
+ while net.shape.as_list()[1] > hparams.projection_shape_size:
+ num_filters = int(hparams.num_discriminator_filters *
+ (hparams.discriminator_filter_factor**(block_id - 1)))
+ for conv_id in range(1, hparams.discriminator_conv_block_size):
+ net = slim.conv2d(
+ net,
+ num_filters,
+ stride=1,
+ scope='conv_%s_%s' % (block_id, conv_id))
+ if hparams.discriminator_do_pooling:
+ net = slim.conv2d(
+ net, num_filters, scope='conv_%s_prepool' % block_id)
+ net = slim.avg_pool2d(
+ net, kernel_size=[2, 2], stride=2, scope='pool_%s' % block_id)
+ else:
+ net = slim.conv2d(
+ net, num_filters, scope='conv_%s_stride2' % block_id)
+ net = add_noise(net, block_id)
+ block_id += 1
+ net = slim.flatten(net)
+ net = slim.fully_connected(
+ net,
+ 1,
+ # Models with BN here generally produce noise
+ normalizer_fn=None,
+ activation_fn=None,
+ scope='fc_logit_out') # Returns logits!
+ return net
+
+
+def dcgan_generator(images, output_shape, hparams, scope=None):
+ """Transforms the visual style of the input images.
+
+ Args:
+ images: A `Tensor` of shape [batch_size, height, width, channels].
+ output_shape: A list or tuple of 3 elements: the output height, width and
+ number of channels.
+ hparams: hparams object with generator parameters
+ scope: Scope to place generator inside
+
+ Returns:
+ A `Tensor` of shape [batch_size, height, width, output_channels] which
+ represents the result of style transfer.
+
+ Raises:
+ ValueError: If `output_shape` is not a list or tuple or if it doesn't have
+ three elements or if `output_shape` or `images` arent square.
+ """
+ if not isinstance(output_shape, (tuple, list)):
+ raise ValueError('output_shape must be a tuple or list.')
+ elif len(output_shape) != 3:
+ raise ValueError('output_shape must have three elements.')
+
+ if output_shape[0] != output_shape[1]:
+ raise ValueError('output_shape must be square')
+ if images.shape.as_list()[1] != images.shape.as_list()[2]:
+ raise ValueError('images height and width must match.')
+
+ outdim = output_shape[0]
+ indim = images.shape.as_list()[1]
+ num_iterations = int(math.ceil(math.log(float(outdim) / float(indim), 2.0)))
+
+ with slim.arg_scope(
+ [slim.conv2d, slim.conv2d_transpose],
+ kernel_size=[hparams.generator_kernel_size] * 2,
+ stride=2):
+ with tf.variable_scope(scope or 'generator'):
+
+ net = images
+
+ # Repeatedly halve # filters until = hparams.decode_filters in last layer
+ for i in range(num_iterations):
+ num_filters = hparams.num_decoder_filters * 2**(num_iterations - i - 1)
+ net = slim.conv2d_transpose(net, num_filters, scope='deconv_%s' % i)
+
+ # Crop down to desired size (e.g. 32x32 -> 28x28)
+ dif = net.shape.as_list()[1] - outdim
+ low = dif / 2
+ high = net.shape.as_list()[1] - low
+ net = net[:, low:high, low:high, :]
+
+ # No batch norm on generator output
+ net = slim.conv2d(
+ net,
+ output_shape[2],
+ kernel_size=[1, 1],
+ stride=1,
+ normalizer_fn=None,
+ activation_fn=tf.tanh,
+ scope='conv_out')
+ return net
+
+
+def dcgan(target_images, latent_vars, hparams, scope='dcgan'):
+ """Creates the PixelDA model.
+
+ Args:
+ target_images: A `Tensor` of shape [batch_size, height, width, 3]
+ sampled from the image domain to which we want to transfer.
+ latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
+ hparams: The hyperparameter map.
+ scope: Surround generator component with this scope
+
+ Returns:
+ A dictionary of model outputs.
+ """
+ proj_shape = [
+ hparams.projection_shape_size, hparams.projection_shape_size,
+ hparams.projection_shape_channels
+ ]
+ source_volume = project_latent_vars(
+ hparams, proj_shape, latent_vars, combine_method='concat')
+
+ ###################################################
+ # Transfer the source images to the target style. #
+ ###################################################
+ with tf.variable_scope(scope, 'generator', [target_images]):
+ transferred_images = dcgan_generator(
+ source_volume,
+ output_shape=target_images.shape.as_list()[1:4],
+ hparams=hparams)
+ assert transferred_images.shape.as_list() == target_images.shape.as_list()
+
+ return {'transferred_images': transferred_images}
+
+
+def resnet_generator(images, output_shape, hparams, latent_vars=None):
+ """Creates a ResNet-based generator.
+
+ Args:
+ images: A `Tensor` of shape [batch_size, height, width, num_channels]
+ sampled from the image domain from which we want to transfer
+ output_shape: A length-3 array indicating the height, width and channels of
+ the output.
+ hparams: The hyperparameter map.
+ latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
+
+ Returns:
+ A dictionary of model outputs.
+ """
+ with tf.variable_scope('generator'):
+ if latent_vars:
+ noise_channel = project_latent_vars(
+ hparams,
+ proj_shape=images.shape.as_list()[1:3] + [1],
+ latent_vars=latent_vars,
+ combine_method='concat')
+ images = tf.concat([images, noise_channel], 3)
+
+ transferred_images, end_points = resnet_stack(
+ images,
+ output_shape=output_shape,
+ hparams=hparams,
+ scope='resnet_stack')
+ end_points['transferred_images'] = transferred_images
+
+ return end_points
+
+
+def residual_interpretation_block(images, hparams, scope):
+ """Learns a residual image which is added to the incoming image.
+
+ Args:
+ images: A `Tensor` of size [batch_size, height, width, 3]
+ hparams: The hyperparameters struct.
+ scope: The name of the variable op scope.
+
+ Returns:
+ The updated images.
+ """
+ with tf.variable_scope(scope):
+ with slim.arg_scope(
+ [slim.conv2d],
+ normalizer_fn=None,
+ kernel_size=[hparams.generator_kernel_size] * 2):
+
+ net = images
+ for _ in range(hparams.res_int_convs):
+ net = slim.conv2d(
+ net, hparams.res_int_filters, activation_fn=tf.nn.relu)
+ net = slim.conv2d(net, 3, activation_fn=tf.nn.tanh)
+
+ # Add the residual
+ images += net
+
+ # Clip the output
+ images = tf.maximum(images, -1.0)
+ images = tf.minimum(images, 1.0)
+ return images
+
+
+def residual_interpretation_generator(images,
+ is_training,
+ hparams,
+ latent_vars=None):
+ """Creates a generator producing purely residual transformations.
+
+ A residual generator differs from the resnet generator in that each 'block' of
+ the residual generator produces a residual image. Consequently, the 'progress'
+ of the model generation process can be directly observed at inference time,
+ making it easier to diagnose and understand.
+
+ Args:
+ images: A `Tensor` of shape [batch_size, height, width, num_channels]
+ sampled from the image domain from which we want to transfer. It is
+ assumed that the images are centered between -1 and 1.
+ is_training: whether or not the model is training.
+ hparams: The hyperparameter map.
+ latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
+
+ Returns:
+ A dictionary of model outputs.
+ """
+ end_points = {}
+
+ with tf.variable_scope('generator'):
+ if latent_vars:
+ projected_latent = project_latent_vars(
+ hparams,
+ proj_shape=images.shape.as_list()[1:3] + [images.shape.as_list()[-1]],
+ latent_vars=latent_vars,
+ combine_method='sum')
+ images += projected_latent
+ with tf.variable_scope(None, 'residual_style_transfer', [images]):
+ for i in range(hparams.res_int_blocks):
+ images = residual_interpretation_block(images, hparams,
+ 'residual_%d' % i)
+ end_points['transferred_images_%d' % i] = images
+
+ end_points['transferred_images'] = images
+
+ return end_points
+
+
+def simple_generator(source_images, target_images, is_training, hparams,
+ latent_vars):
+ """Simple generator architecture (stack of convs) for trying small models."""
+ end_points = {}
+ with tf.variable_scope('generator'):
+ feed_source_images = source_images
+
+ if latent_vars:
+ projected_latent = project_latent_vars(
+ hparams,
+ proj_shape=source_images.shape.as_list()[1:3] + [1],
+ latent_vars=latent_vars,
+ combine_method='concat')
+ feed_source_images = tf.concat([source_images, projected_latent], 3)
+
+ end_points = {}
+
+ ###################################################
+ # Transfer the source images to the target style. #
+ ###################################################
+ with slim.arg_scope(
+ [slim.conv2d],
+ normalizer_fn=slim.batch_norm,
+ stride=1,
+ kernel_size=[hparams.generator_kernel_size] * 2):
+ net = feed_source_images
+
+ # N convolutions
+ for i in range(1, hparams.simple_num_conv_layers):
+ normalizer_fn = None
+ if i != 0:
+ normalizer_fn = slim.batch_norm
+ net = slim.conv2d(
+ net,
+ hparams.simple_conv_filters,
+ normalizer_fn=normalizer_fn,
+ activation_fn=tf.nn.relu)
+
+ # Project back to right # image channels
+ net = slim.conv2d(
+ net,
+ target_images.shape.as_list()[-1],
+ kernel_size=[1, 1],
+ stride=1,
+ normalizer_fn=None,
+ activation_fn=tf.tanh,
+ scope='conv_out')
+
+ transferred_images = net
+ assert transferred_images.shape.as_list() == target_images.shape.as_list()
+ end_points['transferred_images'] = transferred_images
+
+ return end_points
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_preprocess.py b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..747c17b18bf007d85e606015da6687a343bf74d2
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_preprocess.py
@@ -0,0 +1,129 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains functions for preprocessing the inputs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+import tensorflow as tf
+
+
+def preprocess_classification(image, labels, is_training=False):
+ """Preprocesses the image and labels for classification purposes.
+
+ Preprocessing includes shifting the images to be 0-centered between -1 and 1.
+ This is not only a popular method of preprocessing (inception) but is also
+ the mechanism used by DSNs.
+
+ Args:
+ image: A `Tensor` of size [height, width, 3].
+ labels: A dictionary of labels.
+ is_training: Whether or not we're training the model.
+
+ Returns:
+ The preprocessed image and labels.
+ """
+ # If the image is uint8, this will scale it to 0-1.
+ image = tf.image.convert_image_dtype(image, tf.float32)
+ image -= 0.5
+ image *= 2
+
+ return image, labels
+
+
+def preprocess_style_transfer(image,
+ labels,
+ augment=False,
+ size=None,
+ is_training=False):
+ """Preprocesses the image and labels for style transfer purposes.
+
+ Args:
+ image: A `Tensor` of size [height, width, 3].
+ labels: A dictionary of labels.
+ augment: Whether to apply data augmentation to inputs
+ size: The height and width to which images should be resized. If left as
+ `None`, then no resizing is performed
+ is_training: Whether or not we're training the model
+
+ Returns:
+ The preprocessed image and labels. Scaled to [-1, 1]
+ """
+ # If the image is uint8, this will scale it to 0-1.
+ image = tf.image.convert_image_dtype(image, tf.float32)
+ if augment and is_training:
+ image = image_augmentation(image)
+
+ if size:
+ image = resize_image(image, size)
+
+ image -= 0.5
+ image *= 2
+
+ return image, labels
+
+
+def image_augmentation(image):
+ """Performs data augmentation by randomly permuting the inputs.
+
+ Args:
+ image: A float `Tensor` of size [height, width, channels] with values
+ in range[0,1].
+
+ Returns:
+ The mutated batch of images
+ """
+ # Apply photometric data augmentation (contrast etc.)
+ num_channels = image.shape_as_list()[-1]
+ if num_channels == 4:
+ # Only augment image part
+ image, depth = image[:, :, 0:3], image[:, :, 3:4]
+ elif num_channels == 1:
+ image = tf.image.grayscale_to_rgb(image)
+ image = tf.image.random_brightness(image, max_delta=0.1)
+ image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+ image = tf.image.random_hue(image, max_delta=0.032)
+ image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+ image = tf.clip_by_value(image, 0, 1.0)
+ if num_channels == 4:
+ image = tf.concat(2, [image, depth])
+ elif num_channels == 1:
+ image = tf.image.rgb_to_grayscale(image)
+ return image
+
+
+def resize_image(image, size=None):
+ """Resize image to target size.
+
+ Args:
+ image: A `Tensor` of size [height, width, 3].
+ size: (height, width) to resize image to.
+
+ Returns:
+ resized image
+ """
+ if size is None:
+ raise ValueError('Must specify size')
+
+ if image.shape_as_list()[:2] == size:
+ # Don't resize if not necessary
+ return image
+ image = tf.expand_dims(image, 0)
+ image = tf.image.resize_images(image, size)
+ image = tf.squeeze(image, 0)
+ return image
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_preprocess_test.py b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_preprocess_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..73f8c7ff05fc7d2614c419759a02f78ffbcdfec0
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_preprocess_test.py
@@ -0,0 +1,69 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for domain_adaptation.pixel_domain_adaptation.pixelda_preprocess."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+import tensorflow as tf
+
+from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
+
+
+class PixelDAPreprocessTest(tf.test.TestCase):
+
+ def assert_preprocess_classification_is_centered(self, dtype, is_training):
+ tf.set_random_seed(0)
+
+ if dtype == tf.uint8:
+ image = tf.random_uniform((100, 200, 3), maxval=255, dtype=tf.int64)
+ image = tf.cast(image, tf.uint8)
+ else:
+ image = tf.random_uniform((100, 200, 3), maxval=1.0, dtype=dtype)
+
+ labels = {}
+ image, labels = pixelda_preprocess.preprocess_classification(
+ image, labels, is_training=is_training)
+
+ with self.test_session() as sess:
+ np_image = sess.run(image)
+
+ self.assertTrue(np_image.min() <= -0.95)
+ self.assertTrue(np_image.min() >= -1.0)
+ self.assertTrue(np_image.max() >= 0.95)
+ self.assertTrue(np_image.max() <= 1.0)
+
+ def testPreprocessClassificationZeroCentersUint8DuringTrain(self):
+ self.assert_preprocess_classification_is_centered(
+ tf.uint8, is_training=True)
+
+ def testPreprocessClassificationZeroCentersUint8DuringTest(self):
+ self.assert_preprocess_classification_is_centered(
+ tf.uint8, is_training=False)
+
+ def testPreprocessClassificationZeroCentersFloatDuringTrain(self):
+ self.assert_preprocess_classification_is_centered(
+ tf.float32, is_training=True)
+
+ def testPreprocessClassificationZeroCentersFloatDuringTest(self):
+ self.assert_preprocess_classification_is_centered(
+ tf.float32, is_training=False)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_task_towers.py b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_task_towers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb42e2d890a7759318cf0981640c0dd1645461e
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_task_towers.py
@@ -0,0 +1,317 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Task towers for PixelDA model."""
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def add_task_specific_model(images,
+ hparams,
+ num_classes=10,
+ is_training=False,
+ reuse_private=False,
+ private_scope=None,
+ reuse_shared=False,
+ shared_scope=None):
+ """Create a classifier for the given images.
+
+ The classifier is composed of a few 'private' layers followed by a few
+ 'shared' layers. This lets us account for different image 'style', while
+ sharing the last few layers as 'content' layers.
+
+ Args:
+ images: A `Tensor` of size [batch_size, height, width, 3].
+ hparams: model hparams
+ num_classes: The number of output classes.
+ is_training: whether model is training
+ reuse_private: Whether or not to reuse the private weights, which are the
+ first few layers in the classifier
+ private_scope: The name of the variable_scope for the private (unshared)
+ components of the classifier.
+ reuse_shared: Whether or not to reuse the shared weights, which are the last
+ few layers in the classifier
+ shared_scope: The name of the variable_scope for the shared components of
+ the classifier.
+
+ Returns:
+ The logits, a `Tensor` of shape [batch_size, num_classes].
+
+ Raises:
+ ValueError: If hparams.task_classifier is an unknown value
+ """
+
+ model = hparams.task_tower
+ # Make sure the classifier name shows up in graph
+ shared_scope = shared_scope or (model + '_shared')
+ kwargs = {
+ 'num_classes': num_classes,
+ 'is_training': is_training,
+ 'reuse_private': reuse_private,
+ 'reuse_shared': reuse_shared,
+ }
+
+ if private_scope:
+ kwargs['private_scope'] = private_scope
+ if shared_scope:
+ kwargs['shared_scope'] = shared_scope
+
+ quaternion_pred = None
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ activation_fn=tf.nn.relu,
+ weights_regularizer=tf.contrib.layers.l2_regularizer(
+ hparams.weight_decay_task_classifier)):
+ with slim.arg_scope([slim.conv2d], padding='SAME'):
+ if model == 'doubling_pose_estimator':
+ logits, quaternion_pred = doubling_cnn_class_and_quaternion(
+ images, num_private_layers=hparams.num_private_layers, **kwargs)
+ elif model == 'mnist':
+ logits, _ = mnist_classifier(images, **kwargs)
+ elif model == 'svhn':
+ logits, _ = svhn_classifier(images, **kwargs)
+ elif model == 'gtsrb':
+ logits, _ = gtsrb_classifier(images, **kwargs)
+ elif model == 'pose_mini':
+ logits, quaternion_pred = pose_mini_tower(images, **kwargs)
+ else:
+ raise ValueError('Unknown task classifier %s' % model)
+
+ return logits, quaternion_pred
+
+
+#####################################
+# Classifiers used in the DSN paper #
+#####################################
+
+
+def mnist_classifier(images,
+ is_training=False,
+ num_classes=10,
+ reuse_private=False,
+ private_scope='mnist',
+ reuse_shared=False,
+ shared_scope='task_model'):
+ """Creates the convolutional MNIST model from the gradient reversal paper.
+
+ Note that since the output is a set of 'logits', the values fall in the
+ interval of (-infinity, infinity). Consequently, to convert the outputs to a
+ probability distribution over the characters, one will need to convert them
+ using the softmax function:
+ logits, endpoints = conv_mnist(images, is_training=False)
+ predictions = tf.nn.softmax(logits)
+
+ Args:
+ images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
+ is_training: specifies whether or not we're currently training the model.
+ This variable will determine the behaviour of the dropout layer.
+ num_classes: the number of output classes to use.
+
+ Returns:
+ the output logits, a tensor of size [batch_size, num_classes].
+ a dictionary with key/values the layer names and tensors.
+ """
+
+ net = {}
+
+ with tf.variable_scope(private_scope, reuse=reuse_private):
+ net['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
+ net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
+
+ with tf.variable_scope(shared_scope, reuse=reuse_shared):
+ net['conv2'] = slim.conv2d(net['pool1'], 48, [5, 5], scope='conv2')
+ net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
+ net['fc3'] = slim.fully_connected(
+ slim.flatten(net['pool2']), 100, scope='fc3')
+ net['fc4'] = slim.fully_connected(
+ slim.flatten(net['fc3']), 100, scope='fc4')
+ logits = slim.fully_connected(
+ net['fc4'], num_classes, activation_fn=None, scope='fc5')
+ return logits, net
+
+
+def svhn_classifier(images,
+ is_training=False,
+ num_classes=10,
+ reuse_private=False,
+ private_scope=None,
+ reuse_shared=False,
+ shared_scope='task_model'):
+ """Creates the convolutional SVHN model from the gradient reversal paper.
+
+ Note that since the output is a set of 'logits', the values fall in the
+ interval of (-infinity, infinity). Consequently, to convert the outputs to a
+ probability distribution over the characters, one will need to convert them
+ using the softmax function:
+ logits = mnist.Mnist(images, is_training=False)
+ predictions = tf.nn.softmax(logits)
+
+ Args:
+ images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
+ is_training: specifies whether or not we're currently training the model.
+ This variable will determine the behaviour of the dropout layer.
+ num_classes: the number of output classes to use.
+
+ Returns:
+ the output logits, a tensor of size [batch_size, num_classes].
+ a dictionary with key/values the layer names and tensors.
+ """
+
+ net = {}
+
+ with tf.variable_scope(private_scope, reuse=reuse_private):
+ net['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
+ net['pool1'] = slim.max_pool2d(net['conv1'], [3, 3], 2, scope='pool1')
+
+ with tf.variable_scope(shared_scope, reuse=reuse_shared):
+ net['conv2'] = slim.conv2d(net['pool1'], 64, [5, 5], scope='conv2')
+ net['pool2'] = slim.max_pool2d(net['conv2'], [3, 3], 2, scope='pool2')
+ net['conv3'] = slim.conv2d(net['pool2'], 128, [5, 5], scope='conv3')
+
+ net['fc3'] = slim.fully_connected(
+ slim.flatten(net['conv3']), 3072, scope='fc3')
+ net['fc4'] = slim.fully_connected(
+ slim.flatten(net['fc3']), 2048, scope='fc4')
+
+ logits = slim.fully_connected(
+ net['fc4'], num_classes, activation_fn=None, scope='fc5')
+
+ return logits, net
+
+
+def gtsrb_classifier(images,
+ is_training=False,
+ num_classes=43,
+ reuse_private=False,
+ private_scope='gtsrb',
+ reuse_shared=False,
+ shared_scope='task_model'):
+ """Creates the convolutional GTSRB model from the gradient reversal paper.
+
+ Note that since the output is a set of 'logits', the values fall in the
+ interval of (-infinity, infinity). Consequently, to convert the outputs to a
+ probability distribution over the characters, one will need to convert them
+ using the softmax function:
+ logits = mnist.Mnist(images, is_training=False)
+ predictions = tf.nn.softmax(logits)
+
+ Args:
+ images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
+ is_training: specifies whether or not we're currently training the model.
+ This variable will determine the behaviour of the dropout layer.
+ num_classes: the number of output classes to use.
+ reuse_private: Whether or not to reuse the private components of the model.
+ private_scope: The name of the private scope.
+ reuse_shared: Whether or not to reuse the shared components of the model.
+ shared_scope: The name of the shared scope.
+
+ Returns:
+ the output logits, a tensor of size [batch_size, num_classes].
+ a dictionary with key/values the layer names and tensors.
+ """
+
+ net = {}
+
+ with tf.variable_scope(private_scope, reuse=reuse_private):
+ net['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
+ net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
+ with tf.variable_scope(shared_scope, reuse=reuse_shared):
+ net['conv2'] = slim.conv2d(net['pool1'], 144, [3, 3], scope='conv2')
+ net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
+ net['conv3'] = slim.conv2d(net['pool2'], 256, [5, 5], scope='conv3')
+ net['pool3'] = slim.max_pool2d(net['conv3'], [2, 2], 2, scope='pool3')
+
+ net['fc3'] = slim.fully_connected(
+ slim.flatten(net['pool3']), 512, scope='fc3')
+ logits = slim.fully_connected(
+ net['fc3'], num_classes, activation_fn=None, scope='fc4')
+
+ return logits, net
+
+
+#########################
+# pose_mini task towers #
+#########################
+
+
+def pose_mini_tower(images,
+ num_classes=11,
+ is_training=False,
+ reuse_private=False,
+ private_scope='pose_mini',
+ reuse_shared=False,
+ shared_scope='task_model'):
+ """Task tower for the pose_mini dataset."""
+
+ with tf.variable_scope(private_scope, reuse=reuse_private):
+ net = slim.conv2d(images, 32, [5, 5], scope='conv1')
+ net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1')
+ with tf.variable_scope(shared_scope, reuse=reuse_shared):
+ net = slim.conv2d(net, 64, [5, 5], scope='conv2')
+ net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2')
+ net = slim.flatten(net)
+
+ net = slim.fully_connected(net, 128, scope='fc3')
+ net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
+ with tf.variable_scope('quaternion_prediction'):
+ quaternion_pred = slim.fully_connected(
+ net, 4, activation_fn=tf.tanh, scope='fc_q')
+ quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
+
+ logits = slim.fully_connected(
+ net, num_classes, activation_fn=None, scope='fc4')
+
+ return logits, quaternion_pred
+
+
+def doubling_cnn_class_and_quaternion(images,
+ num_private_layers=1,
+ num_classes=10,
+ is_training=False,
+ reuse_private=False,
+ private_scope='doubling_cnn',
+ reuse_shared=False,
+ shared_scope='task_model'):
+ """Alternate conv, pool while doubling filter count."""
+ net = images
+ depth = 32
+ layer_id = 1
+
+ with tf.variable_scope(private_scope, reuse=reuse_private):
+ while num_private_layers > 0 and net.shape.as_list()[1] > 5:
+ net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
+ net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
+ depth *= 2
+ layer_id += 1
+ num_private_layers -= 1
+
+ with tf.variable_scope(shared_scope, reuse=reuse_shared):
+ while net.shape.as_list()[1] > 5:
+ net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
+ net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
+ depth *= 2
+ layer_id += 1
+
+ net = slim.flatten(net)
+ net = slim.fully_connected(net, 100, scope='fc1')
+ net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
+ quaternion_pred = slim.fully_connected(
+ net, 4, activation_fn=tf.tanh, scope='fc_q')
+ quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
+
+ logits = slim.fully_connected(
+ net, num_classes, activation_fn=None, scope='fc_logits')
+
+ return logits, quaternion_pred
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_train.py b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ca072cceafa48769623381b8e564fe650f2a514
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_train.py
@@ -0,0 +1,409 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Trains the PixelDA model."""
+
+from functools import partial
+import os
+
+# Dependency imports
+
+import tensorflow as tf
+
+from domain_adaptation.datasets import dataset_factory
+from domain_adaptation.pixel_domain_adaptation import pixelda_losses
+from domain_adaptation.pixel_domain_adaptation import pixelda_model
+from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
+from domain_adaptation.pixel_domain_adaptation import pixelda_utils
+from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
+
+slim = tf.contrib.slim
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
+
+flags.DEFINE_integer(
+ 'ps_tasks', 0,
+ 'The number of parameter servers. If the value is 0, then the parameters '
+ 'are handled locally by the worker.')
+
+flags.DEFINE_integer(
+ 'task', 0,
+ 'The Task ID. This value is used when training with multiple workers to '
+ 'identify each worker.')
+
+flags.DEFINE_string('train_log_dir', '/tmp/pixelda/',
+ 'Directory where to write event logs.')
+
+flags.DEFINE_integer(
+ 'save_summaries_steps', 500,
+ 'The frequency with which summaries are saved, in seconds.')
+
+flags.DEFINE_integer('save_interval_secs', 300,
+ 'The frequency with which the model is saved, in seconds.')
+
+flags.DEFINE_boolean('summarize_gradients', False,
+ 'Whether to summarize model gradients')
+
+flags.DEFINE_integer(
+ 'print_loss_steps', 100,
+ 'The frequency with which the losses are printed, in steps.')
+
+flags.DEFINE_string('source_dataset', 'mnist', 'The name of the source dataset.'
+ ' If hparams="arch=dcgan", this flag is ignored.')
+
+flags.DEFINE_string('target_dataset', 'mnist_m',
+ 'The name of the target dataset.')
+
+flags.DEFINE_string('source_split_name', 'train',
+ 'Name of the train split for the source.')
+
+flags.DEFINE_string('target_split_name', 'train',
+ 'Name of the train split for the target.')
+
+flags.DEFINE_string('dataset_dir', '',
+ 'The directory where the datasets can be found.')
+
+flags.DEFINE_integer(
+ 'num_readers', 4,
+ 'The number of parallel readers that read data from the dataset.')
+
+flags.DEFINE_integer('num_preprocessing_threads', 4,
+ 'The number of threads used to create the batches.')
+
+# HParams
+
+flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
+
+
+def _get_vars_and_update_ops(hparams, scope):
+ """Returns the variables and update ops for a particular variable scope.
+
+ Args:
+ hparams: The hyperparameters struct.
+ scope: The variable scope.
+
+ Returns:
+ A tuple consisting of trainable variables and update ops.
+ """
+ is_trainable = lambda x: x in tf.trainable_variables()
+ var_list = filter(is_trainable, slim.get_model_variables(scope))
+ global_step = slim.get_or_create_global_step()
+
+ update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)
+
+ tf.logging.info('All variables for scope: %s',
+ slim.get_model_variables(scope))
+ tf.logging.info('Trainable variables for scope: %s', var_list)
+
+ return var_list, update_ops
+
+
+def _train(discriminator_train_op,
+ generator_train_op,
+ logdir,
+ master='',
+ is_chief=True,
+ scaffold=None,
+ hooks=None,
+ chief_only_hooks=None,
+ save_checkpoint_secs=600,
+ save_summaries_steps=100,
+ hparams=None):
+ """Runs the training loop.
+
+ Args:
+ discriminator_train_op: A `Tensor` that, when executed, will apply the
+ gradients and return the loss value for the discriminator.
+ generator_train_op: A `Tensor` that, when executed, will apply the
+ gradients and return the loss value for the generator.
+ logdir: The directory where the graph and checkpoints are saved.
+ master: The URL of the master.
+ is_chief: Specifies whether or not the training is being run by the primary
+ replica during replica training.
+ scaffold: An tf.train.Scaffold instance.
+ hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
+ training loop.
+ chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
+ inside the training loop for the chief trainer only.
+ save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
+ using a default checkpoint saver. If `save_checkpoint_secs` is set to
+ `None`, then the default checkpoint saver isn't used.
+ save_summaries_steps: The frequency, in number of global steps, that the
+ summaries are written to disk using a default summary saver. If
+ `save_summaries_steps` is set to `None`, then the default summary saver
+ isn't used.
+ hparams: The hparams struct.
+
+ Returns:
+ the value of the loss function after training.
+
+ Raises:
+ ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
+ `save_summaries_steps` are `None.
+ """
+ global_step = slim.get_or_create_global_step()
+
+ scaffold = scaffold or tf.train.Scaffold()
+
+ hooks = hooks or []
+
+ if is_chief:
+ session_creator = tf.train.ChiefSessionCreator(
+ scaffold=scaffold, checkpoint_dir=logdir, master=master)
+
+ if chief_only_hooks:
+ hooks.extend(chief_only_hooks)
+ hooks.append(tf.train.StepCounterHook(output_dir=logdir))
+
+ if save_summaries_steps:
+ if logdir is None:
+ raise ValueError(
+ 'logdir cannot be None when save_summaries_steps is None')
+ hooks.append(
+ tf.train.SummarySaverHook(
+ scaffold=scaffold,
+ save_steps=save_summaries_steps,
+ output_dir=logdir))
+
+ if save_checkpoint_secs:
+ if logdir is None:
+ raise ValueError(
+ 'logdir cannot be None when save_checkpoint_secs is None')
+ hooks.append(
+ tf.train.CheckpointSaverHook(
+ logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
+ else:
+ session_creator = tf.train.WorkerSessionCreator(
+ scaffold=scaffold, master=master)
+
+ with tf.train.MonitoredSession(
+ session_creator=session_creator, hooks=hooks) as session:
+ loss = None
+ while not session.should_stop():
+ # Run the domain classifier op X times.
+ for _ in range(hparams.discriminator_steps):
+ if session.should_stop():
+ return loss
+ loss, np_global_step = session.run(
+ [discriminator_train_op, global_step])
+ if np_global_step % FLAGS.print_loss_steps == 0:
+ tf.logging.info('Step %d: Discriminator Loss = %.2f', np_global_step,
+ loss)
+
+ # Run the generator op X times.
+ for _ in range(hparams.generator_steps):
+ if session.should_stop():
+ return loss
+ loss, np_global_step = session.run([generator_train_op, global_step])
+ if np_global_step % FLAGS.print_loss_steps == 0:
+ tf.logging.info('Step %d: Generator Loss = %.2f', np_global_step,
+ loss)
+ return loss
+
+
+def run_training(run_dir, checkpoint_dir, hparams):
+ """Runs the training loop.
+
+ Args:
+ run_dir: The directory where training specific logs are placed
+ checkpoint_dir: The directory where the checkpoints and log files are
+ stored.
+ hparams: The hyperparameters struct.
+
+ Raises:
+ ValueError: if hparams.arch is not recognized.
+ """
+ for path in [run_dir, checkpoint_dir]:
+ if not tf.gfile.Exists(path):
+ tf.gfile.MakeDirs(path)
+
+ # Serialize hparams to log dir
+ hparams_filename = os.path.join(checkpoint_dir, 'hparams.json')
+ with tf.gfile.FastGFile(hparams_filename, 'w') as f:
+ f.write(hparams.to_json())
+
+ with tf.Graph().as_default():
+ with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
+ global_step = slim.get_or_create_global_step()
+
+ #########################
+ # Preprocess the inputs #
+ #########################
+ target_dataset = dataset_factory.get_dataset(
+ FLAGS.target_dataset,
+ split_name='train',
+ dataset_dir=FLAGS.dataset_dir)
+ target_images, _ = dataset_factory.provide_batch(
+ FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
+ hparams.batch_size, FLAGS.num_preprocessing_threads)
+ num_target_classes = target_dataset.num_classes
+
+ if hparams.arch not in ['dcgan']:
+ source_dataset = dataset_factory.get_dataset(
+ FLAGS.source_dataset,
+ split_name='train',
+ dataset_dir=FLAGS.dataset_dir)
+ num_source_classes = source_dataset.num_classes
+ source_images, source_labels = dataset_factory.provide_batch(
+ FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
+ hparams.batch_size, FLAGS.num_preprocessing_threads)
+ # Data provider provides 1 hot labels, but we expect categorical.
+ source_labels['class'] = tf.argmax(source_labels['classes'], 1)
+ del source_labels['classes']
+ if num_source_classes != num_target_classes:
+ raise ValueError(
+ 'Source and Target datasets must have same number of classes. '
+ 'Are %d and %d' % (num_source_classes, num_target_classes))
+ else:
+ source_images = None
+ source_labels = None
+
+ ####################
+ # Define the model #
+ ####################
+ end_points = pixelda_model.create_model(
+ hparams,
+ target_images,
+ source_images=source_images,
+ source_labels=source_labels,
+ is_training=True,
+ num_classes=num_target_classes)
+
+ #################################
+ # Get the variables to optimize #
+ #################################
+ generator_vars, generator_update_ops = _get_vars_and_update_ops(
+ hparams, 'generator')
+ discriminator_vars, discriminator_update_ops = _get_vars_and_update_ops(
+ hparams, 'discriminator')
+
+ ########################
+ # Configure the losses #
+ ########################
+ generator_loss = pixelda_losses.g_step_loss(
+ source_images,
+ source_labels,
+ end_points,
+ hparams,
+ num_classes=num_target_classes)
+ discriminator_loss = pixelda_losses.d_step_loss(
+ end_points, source_labels, num_target_classes, hparams)
+
+ ###########################
+ # Create the training ops #
+ ###########################
+ learning_rate = hparams.learning_rate
+ if hparams.lr_decay_steps:
+ learning_rate = tf.train.exponential_decay(
+ learning_rate,
+ slim.get_or_create_global_step(),
+ decay_steps=hparams.lr_decay_steps,
+ decay_rate=hparams.lr_decay_rate,
+ staircase=True)
+ tf.summary.scalar('Learning_rate', learning_rate)
+
+
+ if hparams.discriminator_steps == 0:
+ discriminator_train_op = tf.no_op()
+ else:
+ discriminator_optimizer = tf.train.AdamOptimizer(
+ learning_rate, beta1=hparams.adam_beta1)
+
+ discriminator_train_op = slim.learning.create_train_op(
+ discriminator_loss,
+ discriminator_optimizer,
+ update_ops=discriminator_update_ops,
+ variables_to_train=discriminator_vars,
+ clip_gradient_norm=hparams.clip_gradient_norm,
+ summarize_gradients=FLAGS.summarize_gradients)
+
+ if hparams.generator_steps == 0:
+ generator_train_op = tf.no_op()
+ else:
+ generator_optimizer = tf.train.AdamOptimizer(
+ learning_rate, beta1=hparams.adam_beta1)
+ generator_train_op = slim.learning.create_train_op(
+ generator_loss,
+ generator_optimizer,
+ update_ops=generator_update_ops,
+ variables_to_train=generator_vars,
+ clip_gradient_norm=hparams.clip_gradient_norm,
+ summarize_gradients=FLAGS.summarize_gradients)
+
+ #############
+ # Summaries #
+ #############
+ pixelda_utils.summarize_model(end_points)
+ pixelda_utils.summarize_transferred_grid(
+ end_points['transferred_images'], source_images, name='Transferred')
+ if 'source_images_recon' in end_points:
+ pixelda_utils.summarize_transferred_grid(
+ end_points['source_images_recon'],
+ source_images,
+ name='Source Reconstruction')
+ pixelda_utils.summaries_color_distributions(end_points['transferred_images'],
+ 'Transferred')
+ pixelda_utils.summaries_color_distributions(target_images, 'Target')
+
+ if source_images is not None:
+ pixelda_utils.summarize_transferred(source_images,
+ end_points['transferred_images'])
+ pixelda_utils.summaries_color_distributions(source_images, 'Source')
+ pixelda_utils.summaries_color_distributions(
+ tf.abs(source_images - end_points['transferred_images']),
+ 'Abs(Source_minus_Transferred)')
+
+ number_of_steps = None
+ if hparams.num_training_examples:
+ # Want to control by amount of data seen, not # steps
+ number_of_steps = hparams.num_training_examples / hparams.batch_size
+
+ hooks = [tf.train.StepCounterHook(),]
+
+ chief_only_hooks = [
+ tf.train.CheckpointSaverHook(
+ saver=tf.train.Saver(),
+ checkpoint_dir=run_dir,
+ save_secs=FLAGS.save_interval_secs)
+ ]
+
+ if number_of_steps:
+ hooks.append(tf.train.StopAtStepHook(last_step=number_of_steps))
+
+ _train(
+ discriminator_train_op,
+ generator_train_op,
+ logdir=run_dir,
+ master=FLAGS.master,
+ is_chief=FLAGS.task == 0,
+ hooks=hooks,
+ chief_only_hooks=chief_only_hooks,
+ save_checkpoint_secs=None,
+ save_summaries_steps=FLAGS.save_summaries_steps,
+ hparams=hparams)
+
+def main(_):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ hparams = create_hparams(FLAGS.hparams)
+ run_training(
+ run_dir=FLAGS.train_log_dir,
+ checkpoint_dir=FLAGS.train_log_dir,
+ hparams=hparams)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_utils.py b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..28e8006f267f9bf7f13c3dff78625cc4cbd00185
--- /dev/null
+++ b/models/research/domain_adaptation/pixel_domain_adaptation/pixelda_utils.py
@@ -0,0 +1,195 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for PixelDA model."""
+import math
+
+# Dependency imports
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+
+def remove_depth(images):
+ """Takes a batch of images and remove depth channel if present."""
+ if images.shape.as_list()[-1] == 4:
+ return images[:, :, :, 0:3]
+ return images
+
+
+def image_grid(images, max_grid_size=4):
+ """Given images and N, return first N^2 images as an NxN image grid.
+
+ Args:
+ images: a `Tensor` of size [batch_size, height, width, channels]
+ max_grid_size: Maximum image grid height/width
+
+ Returns:
+ Single image batch, of dim [1, h*n, w*n, c]
+ """
+ images = remove_depth(images)
+ batch_size = images.shape.as_list()[0]
+ grid_size = min(int(math.sqrt(batch_size)), max_grid_size)
+ assert images.shape.as_list()[0] >= grid_size * grid_size
+
+ # If we have a depth channel
+ if images.shape.as_list()[-1] == 4:
+ images = images[:grid_size * grid_size, :, :, 0:3]
+ depth = tf.image.grayscale_to_rgb(images[:grid_size * grid_size, :, :, 3:4])
+
+ images = tf.reshape(images, [-1, images.shape.as_list()[2], 3])
+ split = tf.split(0, grid_size, images)
+ depth = tf.reshape(depth, [-1, images.shape.as_list()[2], 3])
+ depth_split = tf.split(0, grid_size, depth)
+ grid = tf.concat(split + depth_split, 1)
+ return tf.expand_dims(grid, 0)
+ else:
+ images = images[:grid_size * grid_size, :, :, :]
+ images = tf.reshape(
+ images, [-1, images.shape.as_list()[2],
+ images.shape.as_list()[3]])
+ split = tf.split(images, grid_size, 0)
+ grid = tf.concat(split, 1)
+ return tf.expand_dims(grid, 0)
+
+
+def source_and_output_image_grid(output_images,
+ source_images=None,
+ max_grid_size=4):
+ """Create NxN image grid for output, concatenate source grid if given.
+
+ Makes grid out of output_images and, if provided, source_images, and
+ concatenates them.
+
+ Args:
+ output_images: [batch_size, h, w, c] tensor of images
+ source_images: optional[batch_size, h, w, c] tensor of images
+ max_grid_size: Image grid height/width
+
+ Returns:
+ Single image batch, of dim [1, h*n, w*n, c]
+
+
+ """
+ output_grid = image_grid(output_images, max_grid_size=max_grid_size)
+ if source_images is not None:
+ source_grid = image_grid(source_images, max_grid_size=max_grid_size)
+ # Make sure they have the same # of channels before concat
+ # Assumes either 1 or 3 channels
+ if output_grid.shape.as_list()[-1] != source_grid.shape.as_list()[-1]:
+ if output_grid.shape.as_list()[-1] == 1:
+ output_grid = tf.tile(output_grid, [1, 1, 1, 3])
+ if source_grid.shape.as_list()[-1] == 1:
+ source_grid = tf.tile(source_grid, [1, 1, 1, 3])
+ output_grid = tf.concat([output_grid, source_grid], 1)
+ return output_grid
+
+
+def summarize_model(end_points):
+ """Summarizes the given model via its end_points.
+
+ Args:
+ end_points: A dictionary of end_point names to `Tensor`.
+ """
+ tf.summary.histogram('domain_logits_transferred',
+ tf.sigmoid(end_points['transferred_domain_logits']))
+
+ tf.summary.histogram('domain_logits_target',
+ tf.sigmoid(end_points['target_domain_logits']))
+
+
+def summarize_transferred_grid(transferred_images,
+ source_images=None,
+ name='Transferred'):
+ """Produces a visual grid summarization of the image transferrence.
+
+ Args:
+ transferred_images: A `Tensor` of size [batch_size, height, width, c].
+ source_images: A `Tensor` of size [batch_size, height, width, c].
+ name: Name to use in summary name
+ """
+ if source_images is not None:
+ grid = source_and_output_image_grid(transferred_images, source_images)
+ else:
+ grid = image_grid(transferred_images)
+ tf.summary.image('%s_Images_Grid' % name, grid, max_outputs=1)
+
+
+def summarize_transferred(source_images,
+ transferred_images,
+ max_images=20,
+ name='Transferred'):
+ """Produces a visual summary of the image transferrence.
+
+ This summary displays the source image, transferred image, and a grayscale
+ difference image which highlights the differences between input and output.
+
+ Args:
+ source_images: A `Tensor` of size [batch_size, height, width, channels].
+ transferred_images: A `Tensor` of size [batch_size, height, width, channels]
+ max_images: The number of images to show.
+ name: Name to use in summary name
+
+ Raises:
+ ValueError: If number of channels in source and target are incompatible
+ """
+ source_channels = source_images.shape.as_list()[-1]
+ transferred_channels = transferred_images.shape.as_list()[-1]
+ if source_channels < transferred_channels:
+ if source_channels != 1:
+ raise ValueError(
+ 'Source must be 1 channel or same # of channels as target')
+ source_images = tf.tile(source_images, [1, 1, 1, transferred_channels])
+ if transferred_channels < source_channels:
+ if transferred_channels != 1:
+ raise ValueError(
+ 'Target must be 1 channel or same # of channels as source')
+ transferred_images = tf.tile(transferred_images, [1, 1, 1, source_channels])
+ diffs = tf.abs(source_images - transferred_images)
+ diffs = tf.reduce_max(diffs, reduction_indices=[3], keep_dims=True)
+ diffs = tf.tile(diffs, [1, 1, 1, max(source_channels, transferred_channels)])
+
+ transition_images = tf.concat([
+ source_images,
+ transferred_images,
+ diffs,
+ ], 2)
+
+ tf.summary.image(
+ '%s_difference' % name, transition_images, max_outputs=max_images)
+
+
+def summaries_color_distributions(images, name):
+ """Produces a histogram of the color distributions of the images.
+
+ Args:
+ images: A `Tensor` of size [batch_size, height, width, 3].
+ name: The name of the images being summarized.
+ """
+ tf.summary.histogram('color_values/%s' % name, images)
+
+
+def summarize_images(images, name):
+ """Produces a visual summary of the given images.
+
+ Args:
+ images: A `Tensor` of size [batch_size, height, width, 3].
+ name: The name of the images being summarized.
+ """
+ grid = image_grid(images)
+ tf.summary.image('%s_Images' % name, grid, max_outputs=1)
diff --git a/models/research/efficient-hrl/README.md b/models/research/efficient-hrl/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6c454c687a3b75e9cf68d1f3737d74b464167e14
--- /dev/null
+++ b/models/research/efficient-hrl/README.md
@@ -0,0 +1,65 @@
+
+
+
+Code for performing Hierarchical RL based on the following publications:
+
+"Data-Efficient Hierarchical Reinforcement Learning" by
+Ofir Nachum, Shixiang (Shane) Gu, Honglak Lee, and Sergey Levine
+(https://arxiv.org/abs/1805.08296).
+
+"Near-Optimal Representation Learning for Hierarchical Reinforcement Learning"
+by Ofir Nachum, Shixiang (Shane) Gu, Honglak Lee, and Sergey Levine
+(https://arxiv.org/abs/1810.01257).
+
+
+Requirements:
+* TensorFlow (see http://www.tensorflow.org for how to install/upgrade)
+* Gin Config (see https://github.com/google/gin-config)
+* Tensorflow Agents (see https://github.com/tensorflow/agents)
+* OpenAI Gym (see http://gym.openai.com/docs, be sure to install MuJoCo as well)
+* NumPy (see http://www.numpy.org/)
+
+
+Quick Start:
+
+Run a training job based on the original HIRO paper on Ant Maze:
+
+```
+python scripts/local_train.py test1 hiro_orig ant_maze base_uvf suite
+```
+
+Run a continuous evaluation job for that experiment:
+
+```
+python scripts/local_eval.py test1 hiro_orig ant_maze base_uvf suite
+```
+
+To run the same experiment with online representation learning (the
+"Near-Optimal" paper), change `hiro_orig` to `hiro_repr`.
+You can also run with `hiro_xy` to run the same experiment with HIRO on only the
+xy coordinates of the agent.
+
+To run on other environments, change `ant_maze` to something else; e.g.,
+`ant_push_multi`, `ant_fall_multi`, etc. See `context/configs/*` for other options.
+
+
+Basic Code Guide:
+
+The code for training resides in train.py. The code trains a lower-level policy
+(a UVF agent in the code) and a higher-level policy (a MetaAgent in the code)
+concurrently. The higher-level policy communicates goals to the lower-level
+policy. In the code, this is called a context. Not only does the lower-level
+policy act with respect to a context (a higher-level specified goal), but the
+higher-level policy also acts with respect to an environment-specified context
+(corresponding to the navigation target location associated with the task).
+Therefore, in `context/configs/*` you will find both specifications for task setup
+as well as goal configurations. Most remaining hyperparameters used for
+training/evaluation may be found in `configs/*`.
+
+NOTE: Not all the code corresponding to the "Near-Optimal" paper is included.
+Namely, changes to low-level policy training proposed in the paper (discounting
+and auxiliary rewards) are not implemented here. Performance should not change
+significantly.
+
+
+Maintained by Ofir Nachum (ofirnachum).
diff --git a/models/research/efficient-hrl/agent.py b/models/research/efficient-hrl/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..0028ddffa0d37a0e80d2c990e6263a3d9b4ab948
--- /dev/null
+++ b/models/research/efficient-hrl/agent.py
@@ -0,0 +1,774 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A UVF agent.
+"""
+
+import tensorflow as tf
+import gin.tf
+from agents import ddpg_agent
+# pylint: disable=unused-import
+import cond_fn
+from utils import utils as uvf_utils
+from context import gin_imports
+# pylint: enable=unused-import
+slim = tf.contrib.slim
+
+
+@gin.configurable
+class UvfAgentCore(object):
+ """Defines basic functions for UVF agent. Must be inherited with an RL agent.
+
+ Used as lower-level agent.
+ """
+
+ def __init__(self,
+ observation_spec,
+ action_spec,
+ tf_env,
+ tf_context,
+ step_cond_fn=cond_fn.env_transition,
+ reset_episode_cond_fn=cond_fn.env_restart,
+ reset_env_cond_fn=cond_fn.false_fn,
+ metrics=None,
+ **base_agent_kwargs):
+ """Constructs a UVF agent.
+
+ Args:
+ observation_spec: A TensorSpec defining the observations.
+ action_spec: A BoundedTensorSpec defining the actions.
+ tf_env: A Tensorflow environment object.
+ tf_context: A Context class.
+ step_cond_fn: A function indicating whether to increment the num of steps.
+ reset_episode_cond_fn: A function indicating whether to restart the
+ episode, resampling the context.
+ reset_env_cond_fn: A function indicating whether to perform a manual reset
+ of the environment.
+ metrics: A list of functions that evaluate metrics of the agent.
+ **base_agent_kwargs: A dictionary of parameters for base RL Agent.
+ Raises:
+ ValueError: If 'dqda_clipping' is < 0.
+ """
+ self._step_cond_fn = step_cond_fn
+ self._reset_episode_cond_fn = reset_episode_cond_fn
+ self._reset_env_cond_fn = reset_env_cond_fn
+ self.metrics = metrics
+
+ # expose tf_context methods
+ self.tf_context = tf_context(tf_env=tf_env)
+ self.set_replay = self.tf_context.set_replay
+ self.sample_contexts = self.tf_context.sample_contexts
+ self.compute_rewards = self.tf_context.compute_rewards
+ self.gamma_index = self.tf_context.gamma_index
+ self.context_specs = self.tf_context.context_specs
+ self.context_as_action_specs = self.tf_context.context_as_action_specs
+ self.init_context_vars = self.tf_context.create_vars
+
+ self.env_observation_spec = observation_spec[0]
+ merged_observation_spec = (uvf_utils.merge_specs(
+ (self.env_observation_spec,) + self.context_specs),)
+ self._context_vars = dict()
+ self._action_vars = dict()
+
+ self.BASE_AGENT_CLASS.__init__(
+ self,
+ observation_spec=merged_observation_spec,
+ action_spec=action_spec,
+ **base_agent_kwargs
+ )
+
+ def set_meta_agent(self, agent=None):
+ self._meta_agent = agent
+
+ @property
+ def meta_agent(self):
+ return self._meta_agent
+
+ def actor_loss(self, states, actions, rewards, discounts,
+ next_states):
+ """Returns the next action for the state.
+
+ Args:
+ state: A [num_state_dims] tensor representing a state.
+ context: A list of [num_context_dims] tensor representing a context.
+ Returns:
+ A [num_action_dims] tensor representing the action.
+ """
+ return self.BASE_AGENT_CLASS.actor_loss(self, states)
+
+ def action(self, state, context=None):
+ """Returns the next action for the state.
+
+ Args:
+ state: A [num_state_dims] tensor representing a state.
+ context: A list of [num_context_dims] tensor representing a context.
+ Returns:
+ A [num_action_dims] tensor representing the action.
+ """
+ merged_state = self.merged_state(state, context)
+ return self.BASE_AGENT_CLASS.action(self, merged_state)
+
+ def actions(self, state, context=None):
+ """Returns the next action for the state.
+
+ Args:
+ state: A [-1, num_state_dims] tensor representing a state.
+ context: A list of [-1, num_context_dims] tensor representing a context.
+ Returns:
+ A [-1, num_action_dims] tensor representing the action.
+ """
+ merged_states = self.merged_states(state, context)
+ return self.BASE_AGENT_CLASS.actor_net(self, merged_states)
+
+ def log_probs(self, states, actions, state_reprs, contexts=None):
+ assert contexts is not None
+ batch_dims = [tf.shape(states)[0], tf.shape(states)[1]]
+ contexts = self.tf_context.context_multi_transition_fn(
+ contexts, states=tf.to_float(state_reprs))
+
+ flat_states = tf.reshape(states,
+ [batch_dims[0] * batch_dims[1], states.shape[-1]])
+ flat_contexts = [tf.reshape(tf.cast(context, states.dtype),
+ [batch_dims[0] * batch_dims[1], context.shape[-1]])
+ for context in contexts]
+ flat_pred_actions = self.actions(flat_states, flat_contexts)
+ pred_actions = tf.reshape(flat_pred_actions,
+ batch_dims + [flat_pred_actions.shape[-1]])
+
+ error = tf.square(actions - pred_actions)
+ spec_range = (self._action_spec.maximum - self._action_spec.minimum) / 2
+ normalized_error = tf.cast(error, tf.float64) / tf.constant(spec_range) ** 2
+ return -normalized_error
+
+ @gin.configurable('uvf_add_noise_fn')
+ def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
+ clip=True, global_step=None):
+ """Returns the action_fn with additive Gaussian noise.
+
+ Args:
+ action_fn: A callable(`state`, `context`) which returns a
+ [num_action_dims] tensor representing a action.
+ stddev: stddev for the Ornstein-Uhlenbeck noise.
+ debug: Print debug messages.
+ Returns:
+ A [num_action_dims] action tensor.
+ """
+ if global_step is not None:
+ stddev *= tf.maximum( # Decay exploration during training.
+ tf.train.exponential_decay(1.0, global_step, 1e6, 0.8), 0.5)
+ def noisy_action_fn(state, context=None):
+ """Noisy action fn."""
+ action = action_fn(state, context)
+ if debug:
+ action = uvf_utils.tf_print(
+ action, [action],
+ message='[add_noise_fn] pre-noise action',
+ first_n=100)
+ noise_dist = tf.distributions.Normal(tf.zeros_like(action),
+ tf.ones_like(action) * stddev)
+ noise = noise_dist.sample()
+ action += noise
+ if debug:
+ action = uvf_utils.tf_print(
+ action, [action],
+ message='[add_noise_fn] post-noise action',
+ first_n=100)
+ if clip:
+ action = uvf_utils.clip_to_spec(action, self._action_spec)
+ return action
+ return noisy_action_fn
+
+ def merged_state(self, state, context=None):
+ """Returns the merged state from the environment state and contexts.
+
+ Args:
+ state: A [num_state_dims] tensor representing a state.
+ context: A list of [num_context_dims] tensor representing a context.
+ If None, use the internal context.
+ Returns:
+ A [num_merged_state_dims] tensor representing the merged state.
+ """
+ if context is None:
+ context = list(self.context_vars)
+ state = tf.concat([state,] + context, axis=-1)
+ self._validate_states(self._batch_state(state))
+ return state
+
+ def merged_states(self, states, contexts=None):
+ """Returns the batch merged state from the batch env state and contexts.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ contexts: A list of [batch_size, num_context_dims] tensor
+ representing a batch of contexts. If None,
+ use the internal context.
+ Returns:
+ A [batch_size, num_merged_state_dims] tensor representing the batch
+ of merged states.
+ """
+ if contexts is None:
+ contexts = [tf.tile(tf.expand_dims(context, axis=0),
+ (tf.shape(states)[0], 1)) for
+ context in self.context_vars]
+ states = tf.concat([states,] + contexts, axis=-1)
+ self._validate_states(states)
+ return states
+
+ def unmerged_states(self, merged_states):
+ """Returns the batch state and contexts from the batch merged state.
+
+ Args:
+ merged_states: A [batch_size, num_merged_state_dims] tensor
+ representing a batch of merged states.
+ Returns:
+ A [batch_size, num_state_dims] tensor and a list of
+ [batch_size, num_context_dims] tensors representing the batch state
+ and contexts respectively.
+ """
+ self._validate_states(merged_states)
+ num_state_dims = self.env_observation_spec.shape.as_list()[0]
+ num_context_dims_list = [c.shape.as_list()[0] for c in self.context_specs]
+ states = merged_states[:, :num_state_dims]
+ contexts = []
+ i = num_state_dims
+ for num_context_dims in num_context_dims_list:
+ contexts.append(merged_states[:, i: i+num_context_dims])
+ i += num_context_dims
+ return states, contexts
+
+ def sample_random_actions(self, batch_size=1):
+ """Return random actions.
+
+ Args:
+ batch_size: Batch size.
+ Returns:
+ A [batch_size, num_action_dims] tensor representing the batch of actions.
+ """
+ actions = tf.concat(
+ [
+ tf.random_uniform(
+ shape=(batch_size, 1),
+ minval=self._action_spec.minimum[i],
+ maxval=self._action_spec.maximum[i])
+ for i in range(self._action_spec.shape[0].value)
+ ],
+ axis=1)
+ return actions
+
+ def clip_actions(self, actions):
+ """Clip actions to spec.
+
+ Args:
+ actions: A [batch_size, num_action_dims] tensor representing
+ the batch of actions.
+ Returns:
+ A [batch_size, num_action_dims] tensor representing the batch
+ of clipped actions.
+ """
+ actions = tf.concat(
+ [
+ tf.clip_by_value(
+ actions[:, i:i+1],
+ self._action_spec.minimum[i],
+ self._action_spec.maximum[i])
+ for i in range(self._action_spec.shape[0].value)
+ ],
+ axis=1)
+ return actions
+
+ def mix_contexts(self, contexts, insert_contexts, indices):
+ """Mix two contexts based on indices.
+
+ Args:
+ contexts: A list of [batch_size, num_context_dims] tensor representing
+ the batch of contexts.
+ insert_contexts: A list of [batch_size, num_context_dims] tensor
+ representing the batch of contexts to be inserted.
+ indices: A list of a list of integers denoting indices to replace.
+ Returns:
+ A list of resulting contexts.
+ """
+ if indices is None: indices = [[]] * len(contexts)
+ assert len(contexts) == len(indices)
+ assert all([spec.shape.ndims == 1 for spec in self.context_specs])
+ mix_contexts = []
+ for contexts_, insert_contexts_, indices_, spec in zip(
+ contexts, insert_contexts, indices, self.context_specs):
+ mix_contexts.append(
+ tf.concat(
+ [
+ insert_contexts_[:, i:i + 1] if i in indices_ else
+ contexts_[:, i:i + 1] for i in range(spec.shape.as_list()[0])
+ ],
+ axis=1))
+ return mix_contexts
+
+ def begin_episode_ops(self, mode, action_fn=None, state=None):
+ """Returns ops that reset agent at beginning of episodes.
+
+ Args:
+ mode: a string representing the mode=[train, explore, eval].
+ Returns:
+ A list of ops.
+ """
+ all_ops = []
+ for _, action_var in sorted(self._action_vars.items()):
+ sample_action = self.sample_random_actions(1)[0]
+ all_ops.append(tf.assign(action_var, sample_action))
+ all_ops += self.tf_context.reset(mode=mode, agent=self._meta_agent,
+ action_fn=action_fn, state=state)
+ return all_ops
+
+ def cond_begin_episode_op(self, cond, input_vars, mode, meta_action_fn):
+ """Returns op that resets agent at beginning of episodes.
+
+ A new episode is begun if the cond op evalues to `False`.
+
+ Args:
+ cond: a Boolean tensor variable.
+ input_vars: A list of tensor variables.
+ mode: a string representing the mode=[train, explore, eval].
+ Returns:
+ Conditional begin op.
+ """
+ (state, action, reward, next_state,
+ state_repr, next_state_repr) = input_vars
+ def continue_fn():
+ """Continue op fn."""
+ items = [state, action, reward, next_state,
+ state_repr, next_state_repr] + list(self.context_vars)
+ batch_items = [tf.expand_dims(item, 0) for item in items]
+ (states, actions, rewards, next_states,
+ state_reprs, next_state_reprs) = batch_items[:6]
+ context_reward = self.compute_rewards(
+ mode, state_reprs, actions, rewards, next_state_reprs,
+ batch_items[6:])[0][0]
+ context_reward = tf.cast(context_reward, dtype=reward.dtype)
+ if self.meta_agent is not None:
+ meta_action = tf.concat(self.context_vars, -1)
+ items = [state, meta_action, reward, next_state,
+ state_repr, next_state_repr] + list(self.meta_agent.context_vars)
+ batch_items = [tf.expand_dims(item, 0) for item in items]
+ (states, meta_actions, rewards, next_states,
+ state_reprs, next_state_reprs) = batch_items[:6]
+ meta_reward = self.meta_agent.compute_rewards(
+ mode, states, meta_actions, rewards,
+ next_states, batch_items[6:])[0][0]
+ meta_reward = tf.cast(meta_reward, dtype=reward.dtype)
+ else:
+ meta_reward = tf.constant(0, dtype=reward.dtype)
+
+ with tf.control_dependencies([context_reward, meta_reward]):
+ step_ops = self.tf_context.step(mode=mode, agent=self._meta_agent,
+ state=state,
+ next_state=next_state,
+ state_repr=state_repr,
+ next_state_repr=next_state_repr,
+ action_fn=meta_action_fn)
+ with tf.control_dependencies(step_ops):
+ context_reward, meta_reward = map(tf.identity, [context_reward, meta_reward])
+ return context_reward, meta_reward
+ def begin_episode_fn():
+ """Begin op fn."""
+ begin_ops = self.begin_episode_ops(mode=mode, action_fn=meta_action_fn, state=state)
+ with tf.control_dependencies(begin_ops):
+ return tf.zeros_like(reward), tf.zeros_like(reward)
+ with tf.control_dependencies(input_vars):
+ cond_begin_episode_op = tf.cond(cond, continue_fn, begin_episode_fn)
+ return cond_begin_episode_op
+
+ def get_env_base_wrapper(self, env_base, **begin_kwargs):
+ """Create a wrapper around env_base, with agent-specific begin/end_episode.
+
+ Args:
+ env_base: A python environment base.
+ **begin_kwargs: Keyword args for begin_episode_ops.
+ Returns:
+ An object with begin_episode() and end_episode().
+ """
+ begin_ops = self.begin_episode_ops(**begin_kwargs)
+ return uvf_utils.get_contextual_env_base(env_base, begin_ops)
+
+ def init_action_vars(self, name, i=None):
+ """Create and return a tensorflow Variable holding an action.
+
+ Args:
+ name: Name of the variables.
+ i: Integer id.
+ Returns:
+ A [num_action_dims] tensor.
+ """
+ if i is not None:
+ name += '_%d' % i
+ assert name not in self._action_vars, ('Conflict! %s is already '
+ 'initialized.') % name
+ self._action_vars[name] = tf.Variable(
+ self.sample_random_actions(1)[0], name='%s_action' % (name))
+ self._validate_actions(tf.expand_dims(self._action_vars[name], 0))
+ return self._action_vars[name]
+
+ @gin.configurable('uvf_critic_function')
+ def critic_function(self, critic_vals, states, critic_fn=None):
+ """Computes q values based on outputs from the critic net.
+
+ Args:
+ critic_vals: A tf.float32 [batch_size, ...] tensor representing outputs
+ from the critic net.
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ critic_fn: A callable that process outputs from critic_net and
+ outputs a [batch_size] tensor representing q values.
+ Returns:
+ A tf.float32 [batch_size] tensor representing q values.
+ """
+ if critic_fn is not None:
+ env_states, contexts = self.unmerged_states(states)
+ critic_vals = critic_fn(critic_vals, env_states, contexts)
+ critic_vals.shape.assert_has_rank(1)
+ return critic_vals
+
+ def get_action_vars(self, key):
+ return self._action_vars[key]
+
+ def get_context_vars(self, key):
+ return self.tf_context.context_vars[key]
+
+ def step_cond_fn(self, *args):
+ return self._step_cond_fn(self, *args)
+
+ def reset_episode_cond_fn(self, *args):
+ return self._reset_episode_cond_fn(self, *args)
+
+ def reset_env_cond_fn(self, *args):
+ return self._reset_env_cond_fn(self, *args)
+
+ @property
+ def context_vars(self):
+ return self.tf_context.vars
+
+
+@gin.configurable
+class MetaAgentCore(UvfAgentCore):
+ """Defines basic functions for UVF Meta-agent. Must be inherited with an RL agent.
+
+ Used as higher-level agent.
+ """
+
+ def __init__(self,
+ observation_spec,
+ action_spec,
+ tf_env,
+ tf_context,
+ sub_context,
+ step_cond_fn=cond_fn.env_transition,
+ reset_episode_cond_fn=cond_fn.env_restart,
+ reset_env_cond_fn=cond_fn.false_fn,
+ metrics=None,
+ actions_reg=0.,
+ k=2,
+ **base_agent_kwargs):
+ """Constructs a Meta agent.
+
+ Args:
+ observation_spec: A TensorSpec defining the observations.
+ action_spec: A BoundedTensorSpec defining the actions.
+ tf_env: A Tensorflow environment object.
+ tf_context: A Context class.
+ step_cond_fn: A function indicating whether to increment the num of steps.
+ reset_episode_cond_fn: A function indicating whether to restart the
+ episode, resampling the context.
+ reset_env_cond_fn: A function indicating whether to perform a manual reset
+ of the environment.
+ metrics: A list of functions that evaluate metrics of the agent.
+ **base_agent_kwargs: A dictionary of parameters for base RL Agent.
+ Raises:
+ ValueError: If 'dqda_clipping' is < 0.
+ """
+ self._step_cond_fn = step_cond_fn
+ self._reset_episode_cond_fn = reset_episode_cond_fn
+ self._reset_env_cond_fn = reset_env_cond_fn
+ self.metrics = metrics
+ self._actions_reg = actions_reg
+ self._k = k
+
+ # expose tf_context methods
+ self.tf_context = tf_context(tf_env=tf_env)
+ self.sub_context = sub_context(tf_env=tf_env)
+ self.set_replay = self.tf_context.set_replay
+ self.sample_contexts = self.tf_context.sample_contexts
+ self.compute_rewards = self.tf_context.compute_rewards
+ self.gamma_index = self.tf_context.gamma_index
+ self.context_specs = self.tf_context.context_specs
+ self.context_as_action_specs = self.tf_context.context_as_action_specs
+ self.sub_context_as_action_specs = self.sub_context.context_as_action_specs
+ self.init_context_vars = self.tf_context.create_vars
+
+ self.env_observation_spec = observation_spec[0]
+ merged_observation_spec = (uvf_utils.merge_specs(
+ (self.env_observation_spec,) + self.context_specs),)
+ self._context_vars = dict()
+ self._action_vars = dict()
+
+ assert len(self.context_as_action_specs) == 1
+ self.BASE_AGENT_CLASS.__init__(
+ self,
+ observation_spec=merged_observation_spec,
+ action_spec=self.sub_context_as_action_specs,
+ **base_agent_kwargs
+ )
+
+ @gin.configurable('meta_add_noise_fn')
+ def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
+ global_step=None):
+ noisy_action_fn = super(MetaAgentCore, self).add_noise_fn(
+ action_fn, stddev,
+ clip=True, global_step=global_step)
+ return noisy_action_fn
+
+ def actor_loss(self, states, actions, rewards, discounts,
+ next_states):
+ """Returns the next action for the state.
+
+ Args:
+ state: A [num_state_dims] tensor representing a state.
+ context: A list of [num_context_dims] tensor representing a context.
+ Returns:
+ A [num_action_dims] tensor representing the action.
+ """
+ actions = self.actor_net(states, stop_gradients=False)
+ regularizer = self._actions_reg * tf.reduce_mean(
+ tf.reduce_sum(tf.abs(actions[:, self._k:]), -1), 0)
+ loss = self.BASE_AGENT_CLASS.actor_loss(self, states)
+ return regularizer + loss
+
+
+@gin.configurable
+class UvfAgent(UvfAgentCore, ddpg_agent.TD3Agent):
+ """A DDPG agent with UVF.
+ """
+ BASE_AGENT_CLASS = ddpg_agent.TD3Agent
+ ACTION_TYPE = 'continuous'
+
+ def __init__(self, *args, **kwargs):
+ UvfAgentCore.__init__(self, *args, **kwargs)
+
+
+@gin.configurable
+class MetaAgent(MetaAgentCore, ddpg_agent.TD3Agent):
+ """A DDPG meta-agent.
+ """
+ BASE_AGENT_CLASS = ddpg_agent.TD3Agent
+ ACTION_TYPE = 'continuous'
+
+ def __init__(self, *args, **kwargs):
+ MetaAgentCore.__init__(self, *args, **kwargs)
+
+
+@gin.configurable()
+def state_preprocess_net(
+ states,
+ num_output_dims=2,
+ states_hidden_layers=(100,),
+ normalizer_fn=None,
+ activation_fn=tf.nn.relu,
+ zero_time=True,
+ images=False):
+ """Creates a simple feed forward net for embedding states.
+ """
+ with slim.arg_scope(
+ [slim.fully_connected],
+ activation_fn=activation_fn,
+ normalizer_fn=normalizer_fn,
+ weights_initializer=slim.variance_scaling_initializer(
+ factor=1.0/3.0, mode='FAN_IN', uniform=True)):
+
+ states_shape = tf.shape(states)
+ states_dtype = states.dtype
+ states = tf.to_float(states)
+ if images: # Zero-out x-y
+ states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
+ if zero_time:
+ states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
+ orig_states = states
+ embed = states
+ if states_hidden_layers:
+ embed = slim.stack(embed, slim.fully_connected, states_hidden_layers,
+ scope='states')
+
+ with slim.arg_scope([slim.fully_connected],
+ weights_regularizer=None,
+ weights_initializer=tf.random_uniform_initializer(
+ minval=-0.003, maxval=0.003)):
+ embed = slim.fully_connected(embed, num_output_dims,
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='value')
+
+ output = embed
+ output = tf.cast(output, states_dtype)
+ return output
+
+
+@gin.configurable()
+def action_embed_net(
+ actions,
+ states=None,
+ num_output_dims=2,
+ hidden_layers=(400, 300),
+ normalizer_fn=None,
+ activation_fn=tf.nn.relu,
+ zero_time=True,
+ images=False):
+ """Creates a simple feed forward net for embedding actions.
+ """
+ with slim.arg_scope(
+ [slim.fully_connected],
+ activation_fn=activation_fn,
+ normalizer_fn=normalizer_fn,
+ weights_initializer=slim.variance_scaling_initializer(
+ factor=1.0/3.0, mode='FAN_IN', uniform=True)):
+
+ actions = tf.to_float(actions)
+ if states is not None:
+ if images: # Zero-out x-y
+ states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
+ if zero_time:
+ states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
+ actions = tf.concat([actions, tf.to_float(states)], -1)
+
+ embed = actions
+ if hidden_layers:
+ embed = slim.stack(embed, slim.fully_connected, hidden_layers,
+ scope='hidden')
+
+ with slim.arg_scope([slim.fully_connected],
+ weights_regularizer=None,
+ weights_initializer=tf.random_uniform_initializer(
+ minval=-0.003, maxval=0.003)):
+ embed = slim.fully_connected(embed, num_output_dims,
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='value')
+ if num_output_dims == 1:
+ return embed[:, 0, ...]
+ else:
+ return embed
+
+
+def huber(x, kappa=0.1):
+ return (0.5 * tf.square(x) * tf.to_float(tf.abs(x) <= kappa) +
+ kappa * (tf.abs(x) - 0.5 * kappa) * tf.to_float(tf.abs(x) > kappa)
+ ) / kappa
+
+
+@gin.configurable()
+class StatePreprocess(object):
+ STATE_PREPROCESS_NET_SCOPE = 'state_process_net'
+ ACTION_EMBED_NET_SCOPE = 'action_embed_net'
+
+ def __init__(self, trainable=False,
+ state_preprocess_net=lambda states: states,
+ action_embed_net=lambda actions, *args, **kwargs: actions,
+ ndims=None):
+ self.trainable = trainable
+ self._scope = tf.get_variable_scope().name
+ self._ndims = ndims
+ self._state_preprocess_net = tf.make_template(
+ self.STATE_PREPROCESS_NET_SCOPE, state_preprocess_net,
+ create_scope_now_=True)
+ self._action_embed_net = tf.make_template(
+ self.ACTION_EMBED_NET_SCOPE, action_embed_net,
+ create_scope_now_=True)
+
+ def __call__(self, states):
+ batched = states.get_shape().ndims != 1
+ if not batched:
+ states = tf.expand_dims(states, 0)
+ embedded = self._state_preprocess_net(states)
+ if self._ndims is not None:
+ embedded = embedded[..., :self._ndims]
+ if not batched:
+ return embedded[0]
+ return embedded
+
+ def loss(self, states, next_states, low_actions, low_states):
+ batch_size = tf.shape(states)[0]
+ d = int(low_states.shape[1])
+ # Sample indices into meta-transition to train on.
+ probs = 0.99 ** tf.range(d, dtype=tf.float32)
+ probs *= tf.constant([1.0] * (d - 1) + [1.0 / (1 - 0.99)],
+ dtype=tf.float32)
+ probs /= tf.reduce_sum(probs)
+ index_dist = tf.distributions.Categorical(probs=probs, dtype=tf.int64)
+ indices = index_dist.sample(batch_size)
+ batch_size = tf.cast(batch_size, tf.int64)
+ next_indices = tf.concat(
+ [tf.range(batch_size, dtype=tf.int64)[:, None],
+ (1 + indices[:, None]) % d], -1)
+ new_next_states = tf.where(indices < d - 1,
+ tf.gather_nd(low_states, next_indices),
+ next_states)
+ next_states = new_next_states
+
+ embed1 = tf.to_float(self._state_preprocess_net(states))
+ embed2 = tf.to_float(self._state_preprocess_net(next_states))
+ action_embed = self._action_embed_net(
+ tf.layers.flatten(low_actions), states=states)
+
+ tau = 2.0
+ fn = lambda z: tau * tf.reduce_sum(huber(z), -1)
+ all_embed = tf.get_variable('all_embed', [1024, int(embed1.shape[-1])],
+ initializer=tf.zeros_initializer())
+ upd = all_embed.assign(tf.concat([all_embed[batch_size:], embed2], 0))
+ with tf.control_dependencies([upd]):
+ close = 1 * tf.reduce_mean(fn(embed1 + action_embed - embed2))
+ prior_log_probs = tf.reduce_logsumexp(
+ -fn((embed1 + action_embed)[:, None, :] - all_embed[None, :, :]),
+ axis=-1) - tf.log(tf.to_float(all_embed.shape[0]))
+ far = tf.reduce_mean(tf.exp(-fn((embed1 + action_embed)[1:] - embed2[:-1])
+ - tf.stop_gradient(prior_log_probs[1:])))
+ repr_log_probs = tf.stop_gradient(
+ -fn(embed1 + action_embed - embed2) - prior_log_probs) / tau
+ return close + far, repr_log_probs, indices
+
+ def get_trainable_vars(self):
+ return (
+ slim.get_trainable_variables(
+ uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) +
+ slim.get_trainable_variables(
+ uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE)))
+
+
+@gin.configurable()
+class InverseDynamics(object):
+ INVERSE_DYNAMICS_NET_SCOPE = 'inverse_dynamics'
+
+ def __init__(self, spec):
+ self._spec = spec
+
+ def sample(self, states, next_states, num_samples, orig_goals, sc=0.5):
+ goal_dim = orig_goals.shape[-1]
+ spec_range = (self._spec.maximum - self._spec.minimum) / 2 * tf.ones([goal_dim])
+ loc = tf.cast(next_states - states, tf.float32)[:, :goal_dim]
+ scale = sc * tf.tile(tf.reshape(spec_range, [1, goal_dim]),
+ [tf.shape(states)[0], 1])
+ dist = tf.distributions.Normal(loc, scale)
+ if num_samples == 1:
+ return dist.sample()
+ samples = tf.concat([dist.sample(num_samples - 2),
+ tf.expand_dims(loc, 0),
+ tf.expand_dims(orig_goals, 0)], 0)
+ return uvf_utils.clip_to_spec(samples, self._spec)
diff --git a/models/research/efficient-hrl/agents/__init__.py b/models/research/efficient-hrl/agents/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/research/efficient-hrl/agents/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/research/efficient-hrl/agents/circular_buffer.py b/models/research/efficient-hrl/agents/circular_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..72f90f0de89bf99956436e54a84bbcba903df6e7
--- /dev/null
+++ b/models/research/efficient-hrl/agents/circular_buffer.py
@@ -0,0 +1,289 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A circular buffer where each element is a list of tensors.
+
+Each element of the buffer is a list of tensors. An example use case is a replay
+buffer in reinforcement learning, where each element is a list of tensors
+representing the state, action, reward etc.
+
+New elements are added sequentially, and once the buffer is full, we
+start overwriting them in a circular fashion. Reading does not remove any
+elements, only adding new elements does.
+"""
+
+import collections
+import numpy as np
+import tensorflow as tf
+
+import gin.tf
+
+
+@gin.configurable
+class CircularBuffer(object):
+ """A circular buffer where each element is a list of tensors."""
+
+ def __init__(self, buffer_size=1000, scope='replay_buffer'):
+ """Circular buffer of list of tensors.
+
+ Args:
+ buffer_size: (integer) maximum number of tensor lists the buffer can hold.
+ scope: (string) variable scope for creating the variables.
+ """
+ self._buffer_size = np.int64(buffer_size)
+ self._scope = scope
+ self._tensors = collections.OrderedDict()
+ with tf.variable_scope(self._scope):
+ self._num_adds = tf.Variable(0, dtype=tf.int64, name='num_adds')
+ self._num_adds_cs = tf.CriticalSection(name='num_adds')
+
+ @property
+ def buffer_size(self):
+ return self._buffer_size
+
+ @property
+ def scope(self):
+ return self._scope
+
+ @property
+ def num_adds(self):
+ return self._num_adds
+
+ def _create_variables(self, tensors):
+ with tf.variable_scope(self._scope):
+ for name in tensors.keys():
+ tensor = tensors[name]
+ self._tensors[name] = tf.get_variable(
+ name='BufferVariable_' + name,
+ shape=[self._buffer_size] + tensor.get_shape().as_list(),
+ dtype=tensor.dtype,
+ trainable=False)
+
+ def _validate(self, tensors):
+ """Validate shapes of tensors."""
+ if len(tensors) != len(self._tensors):
+ raise ValueError('Expected tensors to have %d elements. Received %d '
+ 'instead.' % (len(self._tensors), len(tensors)))
+ if self._tensors.keys() != tensors.keys():
+ raise ValueError('The keys of tensors should be the always the same.'
+ 'Received %s instead %s.' %
+ (tensors.keys(), self._tensors.keys()))
+ for name, tensor in tensors.items():
+ if tensor.get_shape().as_list() != self._tensors[
+ name].get_shape().as_list()[1:]:
+ raise ValueError('Tensor %s has incorrect shape.' % name)
+ if not tensor.dtype.is_compatible_with(self._tensors[name].dtype):
+ raise ValueError(
+ 'Tensor %s has incorrect data type. Expected %s, received %s' %
+ (name, self._tensors[name].read_value().dtype, tensor.dtype))
+
+ def add(self, tensors):
+ """Adds an element (list/tuple/dict of tensors) to the buffer.
+
+ Args:
+ tensors: (list/tuple/dict of tensors) to be added to the buffer.
+ Returns:
+ An add operation that adds the input `tensors` to the buffer. Similar to
+ an enqueue_op.
+ Raises:
+ ValueError: If the shapes and data types of input `tensors' are not the
+ same across calls to the add function.
+ """
+ return self.maybe_add(tensors, True)
+
+ def maybe_add(self, tensors, condition):
+ """Adds an element (tensors) to the buffer based on the condition..
+
+ Args:
+ tensors: (list/tuple of tensors) to be added to the buffer.
+ condition: A boolean Tensor controlling whether the tensors would be added
+ to the buffer or not.
+ Returns:
+ An add operation that adds the input `tensors` to the buffer. Similar to
+ an maybe_enqueue_op.
+ Raises:
+ ValueError: If the shapes and data types of input `tensors' are not the
+ same across calls to the add function.
+ """
+ if not isinstance(tensors, dict):
+ names = [str(i) for i in range(len(tensors))]
+ tensors = collections.OrderedDict(zip(names, tensors))
+ if not isinstance(tensors, collections.OrderedDict):
+ tensors = collections.OrderedDict(
+ sorted(tensors.items(), key=lambda t: t[0]))
+ if not self._tensors:
+ self._create_variables(tensors)
+ else:
+ self._validate(tensors)
+
+ #@tf.critical_section(self._position_mutex)
+ def _increment_num_adds():
+ # Adding 0 to the num_adds variable is a trick to read the value of the
+ # variable and return a read-only tensor. Doing this in a critical
+ # section allows us to capture a snapshot of the variable that will
+ # not be affected by other threads updating num_adds.
+ return self._num_adds.assign_add(1) + 0
+ def _add():
+ num_adds_inc = self._num_adds_cs.execute(_increment_num_adds)
+ current_pos = tf.mod(num_adds_inc - 1, self._buffer_size)
+ update_ops = []
+ for name in self._tensors.keys():
+ update_ops.append(
+ tf.scatter_update(self._tensors[name], current_pos, tensors[name]))
+ return tf.group(*update_ops)
+
+ return tf.contrib.framework.smart_cond(condition, _add, tf.no_op)
+
+ def get_random_batch(self, batch_size, keys=None, num_steps=1):
+ """Samples a batch of tensors from the buffer with replacement.
+
+ Args:
+ batch_size: (integer) number of elements to sample.
+ keys: List of keys of tensors to retrieve. If None retrieve all.
+ num_steps: (integer) length of trajectories to return. If > 1 will return
+ a list of lists, where each internal list represents a trajectory of
+ length num_steps.
+ Returns:
+ A list of tensors, where each element in the list is a batch sampled from
+ one of the tensors in the buffer.
+ Raises:
+ ValueError: If get_random_batch is called before calling the add function.
+ tf.errors.InvalidArgumentError: If this operation is executed before any
+ items are added to the buffer.
+ """
+ if not self._tensors:
+ raise ValueError('The add function must be called before get_random_batch.')
+ if keys is None:
+ keys = self._tensors.keys()
+
+ latest_start_index = self.get_num_adds() - num_steps + 1
+ empty_buffer_assert = tf.Assert(
+ tf.greater(latest_start_index, 0),
+ ['Not enough elements have been added to the buffer.'])
+ with tf.control_dependencies([empty_buffer_assert]):
+ max_index = tf.minimum(self._buffer_size, latest_start_index)
+ indices = tf.random_uniform(
+ [batch_size],
+ minval=0,
+ maxval=max_index,
+ dtype=tf.int64)
+ if num_steps == 1:
+ return self.gather(indices, keys)
+ else:
+ return self.gather_nstep(num_steps, indices, keys)
+
+ def gather(self, indices, keys=None):
+ """Returns elements at the specified indices from the buffer.
+
+ Args:
+ indices: (list of integers or rank 1 int Tensor) indices in the buffer to
+ retrieve elements from.
+ keys: List of keys of tensors to retrieve. If None retrieve all.
+ Returns:
+ A list of tensors, where each element in the list is obtained by indexing
+ one of the tensors in the buffer.
+ Raises:
+ ValueError: If gather is called before calling the add function.
+ tf.errors.InvalidArgumentError: If indices are bigger than the number of
+ items in the buffer.
+ """
+ if not self._tensors:
+ raise ValueError('The add function must be called before calling gather.')
+ if keys is None:
+ keys = self._tensors.keys()
+ with tf.name_scope('Gather'):
+ index_bound_assert = tf.Assert(
+ tf.less(
+ tf.to_int64(tf.reduce_max(indices)),
+ tf.minimum(self.get_num_adds(), self._buffer_size)),
+ ['Index out of bounds.'])
+ with tf.control_dependencies([index_bound_assert]):
+ indices = tf.convert_to_tensor(indices)
+
+ batch = []
+ for key in keys:
+ batch.append(tf.gather(self._tensors[key], indices, name=key))
+ return batch
+
+ def gather_nstep(self, num_steps, indices, keys=None):
+ """Returns elements at the specified indices from the buffer.
+
+ Args:
+ num_steps: (integer) length of trajectories to return.
+ indices: (list of rank num_steps int Tensor) indices in the buffer to
+ retrieve elements from for multiple trajectories. Each Tensor in the
+ list represents the indices for a trajectory.
+ keys: List of keys of tensors to retrieve. If None retrieve all.
+ Returns:
+ A list of list-of-tensors, where each element in the list is obtained by
+ indexing one of the tensors in the buffer.
+ Raises:
+ ValueError: If gather is called before calling the add function.
+ tf.errors.InvalidArgumentError: If indices are bigger than the number of
+ items in the buffer.
+ """
+ if not self._tensors:
+ raise ValueError('The add function must be called before calling gather.')
+ if keys is None:
+ keys = self._tensors.keys()
+ with tf.name_scope('Gather'):
+ index_bound_assert = tf.Assert(
+ tf.less_equal(
+ tf.to_int64(tf.reduce_max(indices) + num_steps),
+ self.get_num_adds()),
+ ['Trajectory indices go out of bounds.'])
+ with tf.control_dependencies([index_bound_assert]):
+ indices = tf.map_fn(
+ lambda x: tf.mod(tf.range(x, x + num_steps), self._buffer_size),
+ indices,
+ dtype=tf.int64)
+
+ batch = []
+ for key in keys:
+
+ def SampleTrajectories(trajectory_indices, key=key,
+ num_steps=num_steps):
+ trajectory_indices.set_shape([num_steps])
+ return tf.gather(self._tensors[key], trajectory_indices, name=key)
+
+ batch.append(tf.map_fn(SampleTrajectories, indices,
+ dtype=self._tensors[key].dtype))
+ return batch
+
+ def get_position(self):
+ """Returns the position at which the last element was added.
+
+ Returns:
+ An int tensor representing the index at which the last element was added
+ to the buffer or -1 if no elements were added.
+ """
+ return tf.cond(self.get_num_adds() < 1,
+ lambda: self.get_num_adds() - 1,
+ lambda: tf.mod(self.get_num_adds() - 1, self._buffer_size))
+
+ def get_num_adds(self):
+ """Returns the number of additions to the buffer.
+
+ Returns:
+ An int tensor representing the number of elements that were added.
+ """
+ def num_adds():
+ return self._num_adds.value()
+
+ return self._num_adds_cs.execute(num_adds)
+
+ def get_num_tensors(self):
+ """Returns the number of tensors (slots) in the buffer."""
+ return len(self._tensors)
diff --git a/models/research/efficient-hrl/agents/ddpg_agent.py b/models/research/efficient-hrl/agents/ddpg_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..904eb6502717e30681d5a3d50437c31b2aa580b2
--- /dev/null
+++ b/models/research/efficient-hrl/agents/ddpg_agent.py
@@ -0,0 +1,739 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A DDPG/NAF agent.
+
+Implements the Deep Deterministic Policy Gradient (DDPG) algorithm from
+"Continuous control with deep reinforcement learning" - Lilicrap et al.
+https://arxiv.org/abs/1509.02971, and the Normalized Advantage Functions (NAF)
+algorithm "Continuous Deep Q-Learning with Model-based Acceleration" - Gu et al.
+https://arxiv.org/pdf/1603.00748.
+"""
+
+import tensorflow as tf
+slim = tf.contrib.slim
+import gin.tf
+from utils import utils
+from agents import ddpg_networks as networks
+
+
+@gin.configurable
+class DdpgAgent(object):
+ """An RL agent that learns using the DDPG algorithm.
+
+ Example usage:
+
+ def critic_net(states, actions):
+ ...
+ def actor_net(states, num_action_dims):
+ ...
+
+ Given a tensorflow environment tf_env,
+ (of type learning.deepmind.rl.environments.tensorflow.python.tfpyenvironment)
+
+ obs_spec = tf_env.observation_spec()
+ action_spec = tf_env.action_spec()
+
+ ddpg_agent = agent.DdpgAgent(obs_spec,
+ action_spec,
+ actor_net=actor_net,
+ critic_net=critic_net)
+
+ we can perform actions on the environment as follows:
+
+ state = tf_env.observations()[0]
+ action = ddpg_agent.actor_net(tf.expand_dims(state, 0))[0, :]
+ transition_type, reward, discount = tf_env.step([action])
+
+ Train:
+
+ critic_loss = ddpg_agent.critic_loss(states, actions, rewards, discounts,
+ next_states)
+ actor_loss = ddpg_agent.actor_loss(states)
+
+ critic_train_op = slim.learning.create_train_op(
+ critic_loss,
+ critic_optimizer,
+ variables_to_train=ddpg_agent.get_trainable_critic_vars(),
+ )
+
+ actor_train_op = slim.learning.create_train_op(
+ actor_loss,
+ actor_optimizer,
+ variables_to_train=ddpg_agent.get_trainable_actor_vars(),
+ )
+ """
+
+ ACTOR_NET_SCOPE = 'actor_net'
+ CRITIC_NET_SCOPE = 'critic_net'
+ TARGET_ACTOR_NET_SCOPE = 'target_actor_net'
+ TARGET_CRITIC_NET_SCOPE = 'target_critic_net'
+
+ def __init__(self,
+ observation_spec,
+ action_spec,
+ actor_net=networks.actor_net,
+ critic_net=networks.critic_net,
+ td_errors_loss=tf.losses.huber_loss,
+ dqda_clipping=0.,
+ actions_regularizer=0.,
+ target_q_clipping=None,
+ residual_phi=0.0,
+ debug_summaries=False):
+ """Constructs a DDPG agent.
+
+ Args:
+ observation_spec: A TensorSpec defining the observations.
+ action_spec: A BoundedTensorSpec defining the actions.
+ actor_net: A callable that creates the actor network. Must take the
+ following arguments: states, num_actions. Please see networks.actor_net
+ for an example.
+ critic_net: A callable that creates the critic network. Must take the
+ following arguments: states, actions. Please see networks.critic_net
+ for an example.
+ td_errors_loss: A callable defining the loss function for the critic
+ td error.
+ dqda_clipping: (float) clips the gradient dqda element-wise between
+ [-dqda_clipping, dqda_clipping]. Does not perform clipping if
+ dqda_clipping == 0.
+ actions_regularizer: A scalar, when positive penalizes the norm of the
+ actions. This can prevent saturation of actions for the actor_loss.
+ target_q_clipping: (tuple of floats) clips target q values within
+ (low, high) values when computing the critic loss.
+ residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
+ interpolates between Q-learning and residual gradient algorithm.
+ http://www.leemon.com/papers/1995b.pdf
+ debug_summaries: If True, add summaries to help debug behavior.
+ Raises:
+ ValueError: If 'dqda_clipping' is < 0.
+ """
+ self._observation_spec = observation_spec[0]
+ self._action_spec = action_spec[0]
+ self._state_shape = tf.TensorShape([None]).concatenate(
+ self._observation_spec.shape)
+ self._action_shape = tf.TensorShape([None]).concatenate(
+ self._action_spec.shape)
+ self._num_action_dims = self._action_spec.shape.num_elements()
+
+ self._scope = tf.get_variable_scope().name
+ self._actor_net = tf.make_template(
+ self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
+ self._critic_net = tf.make_template(
+ self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
+ self._target_actor_net = tf.make_template(
+ self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
+ self._target_critic_net = tf.make_template(
+ self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
+ self._td_errors_loss = td_errors_loss
+ if dqda_clipping < 0:
+ raise ValueError('dqda_clipping must be >= 0.')
+ self._dqda_clipping = dqda_clipping
+ self._actions_regularizer = actions_regularizer
+ self._target_q_clipping = target_q_clipping
+ self._residual_phi = residual_phi
+ self._debug_summaries = debug_summaries
+
+ def _batch_state(self, state):
+ """Convert state to a batched state.
+
+ Args:
+ state: Either a list/tuple with an state tensor [num_state_dims].
+ Returns:
+ A tensor [1, num_state_dims]
+ """
+ if isinstance(state, (tuple, list)):
+ state = state[0]
+ if state.get_shape().ndims == 1:
+ state = tf.expand_dims(state, 0)
+ return state
+
+ def action(self, state):
+ """Returns the next action for the state.
+
+ Args:
+ state: A [num_state_dims] tensor representing a state.
+ Returns:
+ A [num_action_dims] tensor representing the action.
+ """
+ return self.actor_net(self._batch_state(state), stop_gradients=True)[0, :]
+
+ @gin.configurable('ddpg_sample_action')
+ def sample_action(self, state, stddev=1.0):
+ """Returns the action for the state with additive noise.
+
+ Args:
+ state: A [num_state_dims] tensor representing a state.
+ stddev: stddev for the Ornstein-Uhlenbeck noise.
+ Returns:
+ A [num_action_dims] action tensor.
+ """
+ agent_action = self.action(state)
+ agent_action += tf.random_normal(tf.shape(agent_action)) * stddev
+ return utils.clip_to_spec(agent_action, self._action_spec)
+
+ def actor_net(self, states, stop_gradients=False):
+ """Returns the output of the actor network.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ stop_gradients: (boolean) if true, gradients cannot be propogated through
+ this operation.
+ Returns:
+ A [batch_size, num_action_dims] tensor of actions.
+ Raises:
+ ValueError: If `states` does not have the expected dimensions.
+ """
+ self._validate_states(states)
+ actions = self._actor_net(states, self._action_spec)
+ if stop_gradients:
+ actions = tf.stop_gradient(actions)
+ return actions
+
+ def critic_net(self, states, actions, for_critic_loss=False):
+ """Returns the output of the critic network.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] tensor representing a batch
+ of actions.
+ Returns:
+ q values: A [batch_size] tensor of q values.
+ Raises:
+ ValueError: If `states` or `actions' do not have the expected dimensions.
+ """
+ self._validate_states(states)
+ self._validate_actions(actions)
+ return self._critic_net(states, actions,
+ for_critic_loss=for_critic_loss)
+
+ def target_actor_net(self, states):
+ """Returns the output of the target actor network.
+
+ The target network is used to compute stable targets for training.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ Returns:
+ A [batch_size, num_action_dims] tensor of actions.
+ Raises:
+ ValueError: If `states` does not have the expected dimensions.
+ """
+ self._validate_states(states)
+ actions = self._target_actor_net(states, self._action_spec)
+ return tf.stop_gradient(actions)
+
+ def target_critic_net(self, states, actions, for_critic_loss=False):
+ """Returns the output of the target critic network.
+
+ The target network is used to compute stable targets for training.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] tensor representing a batch
+ of actions.
+ Returns:
+ q values: A [batch_size] tensor of q values.
+ Raises:
+ ValueError: If `states` or `actions' do not have the expected dimensions.
+ """
+ self._validate_states(states)
+ self._validate_actions(actions)
+ return tf.stop_gradient(
+ self._target_critic_net(states, actions,
+ for_critic_loss=for_critic_loss))
+
+ def value_net(self, states, for_critic_loss=False):
+ """Returns the output of the critic evaluated with the actor.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ Returns:
+ q values: A [batch_size] tensor of q values.
+ """
+ actions = self.actor_net(states)
+ return self.critic_net(states, actions,
+ for_critic_loss=for_critic_loss)
+
+ def target_value_net(self, states, for_critic_loss=False):
+ """Returns the output of the target critic evaluated with the target actor.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ Returns:
+ q values: A [batch_size] tensor of q values.
+ """
+ target_actions = self.target_actor_net(states)
+ return self.target_critic_net(states, target_actions,
+ for_critic_loss=for_critic_loss)
+
+ def critic_loss(self, states, actions, rewards, discounts,
+ next_states):
+ """Computes a loss for training the critic network.
+
+ The loss is the mean squared error between the Q value predictions of the
+ critic and Q values estimated using TD-lambda.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] tensor representing a batch
+ of actions.
+ rewards: A [batch_size, ...] tensor representing a batch of rewards,
+ broadcastable to the critic net output.
+ discounts: A [batch_size, ...] tensor representing a batch of discounts,
+ broadcastable to the critic net output.
+ next_states: A [batch_size, num_state_dims] tensor representing a batch
+ of next states.
+ Returns:
+ A rank-0 tensor representing the critic loss.
+ Raises:
+ ValueError: If any of the inputs do not have the expected dimensions, or
+ if their batch_sizes do not match.
+ """
+ self._validate_states(states)
+ self._validate_actions(actions)
+ self._validate_states(next_states)
+
+ target_q_values = self.target_value_net(next_states, for_critic_loss=True)
+ td_targets = target_q_values * discounts + rewards
+ if self._target_q_clipping is not None:
+ td_targets = tf.clip_by_value(td_targets, self._target_q_clipping[0],
+ self._target_q_clipping[1])
+ q_values = self.critic_net(states, actions, for_critic_loss=True)
+ td_errors = td_targets - q_values
+ if self._debug_summaries:
+ gen_debug_td_error_summaries(
+ target_q_values, q_values, td_targets, td_errors)
+
+ loss = self._td_errors_loss(td_targets, q_values)
+
+ if self._residual_phi > 0.0: # compute residual gradient loss
+ residual_q_values = self.value_net(next_states, for_critic_loss=True)
+ residual_td_targets = residual_q_values * discounts + rewards
+ if self._target_q_clipping is not None:
+ residual_td_targets = tf.clip_by_value(residual_td_targets,
+ self._target_q_clipping[0],
+ self._target_q_clipping[1])
+ residual_td_errors = residual_td_targets - q_values
+ residual_loss = self._td_errors_loss(
+ residual_td_targets, residual_q_values)
+ loss = (loss * (1.0 - self._residual_phi) +
+ residual_loss * self._residual_phi)
+ return loss
+
+ def actor_loss(self, states):
+ """Computes a loss for training the actor network.
+
+ Note that output does not represent an actual loss. It is called a loss only
+ in the sense that its gradient w.r.t. the actor network weights is the
+ correct gradient for training the actor network,
+ i.e. dloss/dweights = (dq/da)*(da/dweights)
+ which is the gradient used in Algorithm 1 of Lilicrap et al.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ Returns:
+ A rank-0 tensor representing the actor loss.
+ Raises:
+ ValueError: If `states` does not have the expected dimensions.
+ """
+ self._validate_states(states)
+ actions = self.actor_net(states, stop_gradients=False)
+ critic_values = self.critic_net(states, actions)
+ q_values = self.critic_function(critic_values, states)
+ dqda = tf.gradients([q_values], [actions])[0]
+ dqda_unclipped = dqda
+ if self._dqda_clipping > 0:
+ dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping)
+
+ actions_norm = tf.norm(actions)
+ if self._debug_summaries:
+ with tf.name_scope('dqda'):
+ tf.summary.scalar('actions_norm', actions_norm)
+ tf.summary.histogram('dqda', dqda)
+ tf.summary.histogram('dqda_unclipped', dqda_unclipped)
+ tf.summary.histogram('actions', actions)
+ for a in range(self._num_action_dims):
+ tf.summary.histogram('dqda_unclipped_%d' % a, dqda_unclipped[:, a])
+ tf.summary.histogram('dqda_%d' % a, dqda[:, a])
+
+ actions_norm *= self._actions_regularizer
+ return slim.losses.mean_squared_error(tf.stop_gradient(dqda + actions),
+ actions,
+ scope='actor_loss') + actions_norm
+
+ @gin.configurable('ddpg_critic_function')
+ def critic_function(self, critic_values, states, weights=None):
+ """Computes q values based on critic_net outputs, states, and weights.
+
+ Args:
+ critic_values: A tf.float32 [batch_size, ...] tensor representing outputs
+ from the critic net.
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ weights: A list or Numpy array or tensor with a shape broadcastable to
+ `critic_values`.
+ Returns:
+ A tf.float32 [batch_size] tensor representing q values.
+ """
+ del states # unused args
+ if weights is not None:
+ weights = tf.convert_to_tensor(weights, dtype=critic_values.dtype)
+ critic_values *= weights
+ if critic_values.shape.ndims > 1:
+ critic_values = tf.reduce_sum(critic_values,
+ range(1, critic_values.shape.ndims))
+ critic_values.shape.assert_has_rank(1)
+ return critic_values
+
+ @gin.configurable('ddpg_update_targets')
+ def update_targets(self, tau=1.0):
+ """Performs a soft update of the target network parameters.
+
+ For each weight w_s in the actor/critic networks, and its corresponding
+ weight w_t in the target actor/critic networks, a soft update is:
+ w_t = (1- tau) x w_t + tau x ws
+
+ Args:
+ tau: A float scalar in [0, 1]
+ Returns:
+ An operation that performs a soft update of the target network parameters.
+ Raises:
+ ValueError: If `tau` is not in [0, 1].
+ """
+ if tau < 0 or tau > 1:
+ raise ValueError('Input `tau` should be in [0, 1].')
+ update_actor = utils.soft_variables_update(
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
+ tau)
+ update_critic = utils.soft_variables_update(
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
+ tau)
+ return tf.group(update_actor, update_critic, name='update_targets')
+
+ def get_trainable_critic_vars(self):
+ """Returns a list of trainable variables in the critic network.
+
+ Returns:
+ A list of trainable variables in the critic network.
+ """
+ return slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
+
+ def get_trainable_actor_vars(self):
+ """Returns a list of trainable variables in the actor network.
+
+ Returns:
+ A list of trainable variables in the actor network.
+ """
+ return slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
+
+ def get_critic_vars(self):
+ """Returns a list of all variables in the critic network.
+
+ Returns:
+ A list of trainable variables in the critic network.
+ """
+ return slim.get_model_variables(
+ utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
+
+ def get_actor_vars(self):
+ """Returns a list of all variables in the actor network.
+
+ Returns:
+ A list of trainable variables in the actor network.
+ """
+ return slim.get_model_variables(
+ utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
+
+ def _validate_states(self, states):
+ """Raises a value error if `states` does not have the expected shape.
+
+ Args:
+ states: A tensor.
+ Raises:
+ ValueError: If states.shape or states.dtype are not compatible with
+ observation_spec.
+ """
+ states.shape.assert_is_compatible_with(self._state_shape)
+ if not states.dtype.is_compatible_with(self._observation_spec.dtype):
+ raise ValueError('states.dtype={} is not compatible with'
+ ' observation_spec.dtype={}'.format(
+ states.dtype, self._observation_spec.dtype))
+
+ def _validate_actions(self, actions):
+ """Raises a value error if `actions` does not have the expected shape.
+
+ Args:
+ actions: A tensor.
+ Raises:
+ ValueError: If actions.shape or actions.dtype are not compatible with
+ action_spec.
+ """
+ actions.shape.assert_is_compatible_with(self._action_shape)
+ if not actions.dtype.is_compatible_with(self._action_spec.dtype):
+ raise ValueError('actions.dtype={} is not compatible with'
+ ' action_spec.dtype={}'.format(
+ actions.dtype, self._action_spec.dtype))
+
+
+@gin.configurable
+class TD3Agent(DdpgAgent):
+ """An RL agent that learns using the TD3 algorithm."""
+
+ ACTOR_NET_SCOPE = 'actor_net'
+ CRITIC_NET_SCOPE = 'critic_net'
+ CRITIC_NET2_SCOPE = 'critic_net2'
+ TARGET_ACTOR_NET_SCOPE = 'target_actor_net'
+ TARGET_CRITIC_NET_SCOPE = 'target_critic_net'
+ TARGET_CRITIC_NET2_SCOPE = 'target_critic_net2'
+
+ def __init__(self,
+ observation_spec,
+ action_spec,
+ actor_net=networks.actor_net,
+ critic_net=networks.critic_net,
+ td_errors_loss=tf.losses.huber_loss,
+ dqda_clipping=0.,
+ actions_regularizer=0.,
+ target_q_clipping=None,
+ residual_phi=0.0,
+ debug_summaries=False):
+ """Constructs a TD3 agent.
+
+ Args:
+ observation_spec: A TensorSpec defining the observations.
+ action_spec: A BoundedTensorSpec defining the actions.
+ actor_net: A callable that creates the actor network. Must take the
+ following arguments: states, num_actions. Please see networks.actor_net
+ for an example.
+ critic_net: A callable that creates the critic network. Must take the
+ following arguments: states, actions. Please see networks.critic_net
+ for an example.
+ td_errors_loss: A callable defining the loss function for the critic
+ td error.
+ dqda_clipping: (float) clips the gradient dqda element-wise between
+ [-dqda_clipping, dqda_clipping]. Does not perform clipping if
+ dqda_clipping == 0.
+ actions_regularizer: A scalar, when positive penalizes the norm of the
+ actions. This can prevent saturation of actions for the actor_loss.
+ target_q_clipping: (tuple of floats) clips target q values within
+ (low, high) values when computing the critic loss.
+ residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
+ interpolates between Q-learning and residual gradient algorithm.
+ http://www.leemon.com/papers/1995b.pdf
+ debug_summaries: If True, add summaries to help debug behavior.
+ Raises:
+ ValueError: If 'dqda_clipping' is < 0.
+ """
+ self._observation_spec = observation_spec[0]
+ self._action_spec = action_spec[0]
+ self._state_shape = tf.TensorShape([None]).concatenate(
+ self._observation_spec.shape)
+ self._action_shape = tf.TensorShape([None]).concatenate(
+ self._action_spec.shape)
+ self._num_action_dims = self._action_spec.shape.num_elements()
+
+ self._scope = tf.get_variable_scope().name
+ self._actor_net = tf.make_template(
+ self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
+ self._critic_net = tf.make_template(
+ self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
+ self._critic_net2 = tf.make_template(
+ self.CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True)
+ self._target_actor_net = tf.make_template(
+ self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
+ self._target_critic_net = tf.make_template(
+ self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
+ self._target_critic_net2 = tf.make_template(
+ self.TARGET_CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True)
+ self._td_errors_loss = td_errors_loss
+ if dqda_clipping < 0:
+ raise ValueError('dqda_clipping must be >= 0.')
+ self._dqda_clipping = dqda_clipping
+ self._actions_regularizer = actions_regularizer
+ self._target_q_clipping = target_q_clipping
+ self._residual_phi = residual_phi
+ self._debug_summaries = debug_summaries
+
+ def get_trainable_critic_vars(self):
+ """Returns a list of trainable variables in the critic network.
+ NOTE: This gets the vars of both critic networks.
+
+ Returns:
+ A list of trainable variables in the critic network.
+ """
+ return (
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)))
+
+ def critic_net(self, states, actions, for_critic_loss=False):
+ """Returns the output of the critic network.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] tensor representing a batch
+ of actions.
+ Returns:
+ q values: A [batch_size] tensor of q values.
+ Raises:
+ ValueError: If `states` or `actions' do not have the expected dimensions.
+ """
+ values1 = self._critic_net(states, actions,
+ for_critic_loss=for_critic_loss)
+ values2 = self._critic_net2(states, actions,
+ for_critic_loss=for_critic_loss)
+ if for_critic_loss:
+ return values1, values2
+ return values1
+
+ def target_critic_net(self, states, actions, for_critic_loss=False):
+ """Returns the output of the target critic network.
+
+ The target network is used to compute stable targets for training.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] tensor representing a batch
+ of actions.
+ Returns:
+ q values: A [batch_size] tensor of q values.
+ Raises:
+ ValueError: If `states` or `actions' do not have the expected dimensions.
+ """
+ self._validate_states(states)
+ self._validate_actions(actions)
+ values1 = tf.stop_gradient(
+ self._target_critic_net(states, actions,
+ for_critic_loss=for_critic_loss))
+ values2 = tf.stop_gradient(
+ self._target_critic_net2(states, actions,
+ for_critic_loss=for_critic_loss))
+ if for_critic_loss:
+ return values1, values2
+ return values1
+
+ def value_net(self, states, for_critic_loss=False):
+ """Returns the output of the critic evaluated with the actor.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ Returns:
+ q values: A [batch_size] tensor of q values.
+ """
+ actions = self.actor_net(states)
+ return self.critic_net(states, actions,
+ for_critic_loss=for_critic_loss)
+
+ def target_value_net(self, states, for_critic_loss=False):
+ """Returns the output of the target critic evaluated with the target actor.
+
+ Args:
+ states: A [batch_size, num_state_dims] tensor representing a batch
+ of states.
+ Returns:
+ q values: A [batch_size] tensor of q values.
+ """
+ target_actions = self.target_actor_net(states)
+ noise = tf.clip_by_value(
+ tf.random_normal(tf.shape(target_actions), stddev=0.2), -0.5, 0.5)
+ values1, values2 = self.target_critic_net(
+ states, target_actions + noise,
+ for_critic_loss=for_critic_loss)
+ values = tf.minimum(values1, values2)
+ return values, values
+
+ @gin.configurable('td3_update_targets')
+ def update_targets(self, tau=1.0):
+ """Performs a soft update of the target network parameters.
+
+ For each weight w_s in the actor/critic networks, and its corresponding
+ weight w_t in the target actor/critic networks, a soft update is:
+ w_t = (1- tau) x w_t + tau x ws
+
+ Args:
+ tau: A float scalar in [0, 1]
+ Returns:
+ An operation that performs a soft update of the target network parameters.
+ Raises:
+ ValueError: If `tau` is not in [0, 1].
+ """
+ if tau < 0 or tau > 1:
+ raise ValueError('Input `tau` should be in [0, 1].')
+ update_actor = utils.soft_variables_update(
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
+ tau)
+ # NOTE: This updates both critic networks.
+ update_critic = utils.soft_variables_update(
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
+ slim.get_trainable_variables(
+ utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
+ tau)
+ return tf.group(update_actor, update_critic, name='update_targets')
+
+
+def gen_debug_td_error_summaries(
+ target_q_values, q_values, td_targets, td_errors):
+ """Generates debug summaries for critic given a set of batch samples.
+
+ Args:
+ target_q_values: set of predicted next stage values.
+ q_values: current predicted value for the critic network.
+ td_targets: discounted target_q_values with added next stage reward.
+ td_errors: the different between td_targets and q_values.
+ """
+ with tf.name_scope('td_errors'):
+ tf.summary.histogram('td_targets', td_targets)
+ tf.summary.histogram('q_values', q_values)
+ tf.summary.histogram('target_q_values', target_q_values)
+ tf.summary.histogram('td_errors', td_errors)
+ with tf.name_scope('td_targets'):
+ tf.summary.scalar('mean', tf.reduce_mean(td_targets))
+ tf.summary.scalar('max', tf.reduce_max(td_targets))
+ tf.summary.scalar('min', tf.reduce_min(td_targets))
+ with tf.name_scope('q_values'):
+ tf.summary.scalar('mean', tf.reduce_mean(q_values))
+ tf.summary.scalar('max', tf.reduce_max(q_values))
+ tf.summary.scalar('min', tf.reduce_min(q_values))
+ with tf.name_scope('target_q_values'):
+ tf.summary.scalar('mean', tf.reduce_mean(target_q_values))
+ tf.summary.scalar('max', tf.reduce_max(target_q_values))
+ tf.summary.scalar('min', tf.reduce_min(target_q_values))
+ with tf.name_scope('td_errors'):
+ tf.summary.scalar('mean', tf.reduce_mean(td_errors))
+ tf.summary.scalar('max', tf.reduce_max(td_errors))
+ tf.summary.scalar('min', tf.reduce_min(td_errors))
+ tf.summary.scalar('mean_abs', tf.reduce_mean(tf.abs(td_errors)))
diff --git a/models/research/efficient-hrl/agents/ddpg_networks.py b/models/research/efficient-hrl/agents/ddpg_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..63074dfb91cf950b602212936ab2560db818c3a4
--- /dev/null
+++ b/models/research/efficient-hrl/agents/ddpg_networks.py
@@ -0,0 +1,150 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Sample actor(policy) and critic(q) networks to use with DDPG/NAF agents.
+
+The DDPG networks are defined in "Section 7: Experiment Details" of
+"Continuous control with deep reinforcement learning" - Lilicrap et al.
+https://arxiv.org/abs/1509.02971
+
+The NAF critic network is based on "Section 4" of "Continuous deep Q-learning
+with model-based acceleration" - Gu et al. https://arxiv.org/pdf/1603.00748.
+"""
+
+import tensorflow as tf
+slim = tf.contrib.slim
+import gin.tf
+
+
+@gin.configurable('ddpg_critic_net')
+def critic_net(states, actions,
+ for_critic_loss=False,
+ num_reward_dims=1,
+ states_hidden_layers=(400,),
+ actions_hidden_layers=None,
+ joint_hidden_layers=(300,),
+ weight_decay=0.0001,
+ normalizer_fn=None,
+ activation_fn=tf.nn.relu,
+ zero_obs=False,
+ images=False):
+ """Creates a critic that returns q values for the given states and actions.
+
+ Args:
+ states: (castable to tf.float32) a [batch_size, num_state_dims] tensor
+ representing a batch of states.
+ actions: (castable to tf.float32) a [batch_size, num_action_dims] tensor
+ representing a batch of actions.
+ num_reward_dims: Number of reward dimensions.
+ states_hidden_layers: tuple of hidden layers units for states.
+ actions_hidden_layers: tuple of hidden layers units for actions.
+ joint_hidden_layers: tuple of hidden layers units after joining states
+ and actions using tf.concat().
+ weight_decay: Weight decay for l2 weights regularizer.
+ normalizer_fn: Normalizer function, i.e. slim.layer_norm,
+ activation_fn: Activation function, i.e. tf.nn.relu, slim.leaky_relu, ...
+ Returns:
+ A tf.float32 [batch_size] tensor of q values, or a tf.float32
+ [batch_size, num_reward_dims] tensor of vector q values if
+ num_reward_dims > 1.
+ """
+ with slim.arg_scope(
+ [slim.fully_connected],
+ activation_fn=activation_fn,
+ normalizer_fn=normalizer_fn,
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ weights_initializer=slim.variance_scaling_initializer(
+ factor=1.0/3.0, mode='FAN_IN', uniform=True)):
+
+ orig_states = tf.to_float(states)
+ #states = tf.to_float(states)
+ states = tf.concat([tf.to_float(states), tf.to_float(actions)], -1) #TD3
+ if images or zero_obs:
+ states *= tf.constant([0.0] * 2 + [1.0] * (states.shape[1] - 2)) #LALA
+ actions = tf.to_float(actions)
+ if states_hidden_layers:
+ states = slim.stack(states, slim.fully_connected, states_hidden_layers,
+ scope='states')
+ if actions_hidden_layers:
+ actions = slim.stack(actions, slim.fully_connected, actions_hidden_layers,
+ scope='actions')
+ joint = tf.concat([states, actions], 1)
+ if joint_hidden_layers:
+ joint = slim.stack(joint, slim.fully_connected, joint_hidden_layers,
+ scope='joint')
+ with slim.arg_scope([slim.fully_connected],
+ weights_regularizer=None,
+ weights_initializer=tf.random_uniform_initializer(
+ minval=-0.003, maxval=0.003)):
+ value = slim.fully_connected(joint, num_reward_dims,
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='q_value')
+ if num_reward_dims == 1:
+ value = tf.reshape(value, [-1])
+ if not for_critic_loss and num_reward_dims > 1:
+ value = tf.reduce_sum(
+ value * tf.abs(orig_states[:, -num_reward_dims:]), -1)
+ return value
+
+
+@gin.configurable('ddpg_actor_net')
+def actor_net(states, action_spec,
+ hidden_layers=(400, 300),
+ normalizer_fn=None,
+ activation_fn=tf.nn.relu,
+ zero_obs=False,
+ images=False):
+ """Creates an actor that returns actions for the given states.
+
+ Args:
+ states: (castable to tf.float32) a [batch_size, num_state_dims] tensor
+ representing a batch of states.
+ action_spec: (BoundedTensorSpec) A tensor spec indicating the shape
+ and range of actions.
+ hidden_layers: tuple of hidden layers units.
+ normalizer_fn: Normalizer function, i.e. slim.layer_norm,
+ activation_fn: Activation function, i.e. tf.nn.relu, slim.leaky_relu, ...
+ Returns:
+ A tf.float32 [batch_size, num_action_dims] tensor of actions.
+ """
+
+ with slim.arg_scope(
+ [slim.fully_connected],
+ activation_fn=activation_fn,
+ normalizer_fn=normalizer_fn,
+ weights_initializer=slim.variance_scaling_initializer(
+ factor=1.0/3.0, mode='FAN_IN', uniform=True)):
+
+ states = tf.to_float(states)
+ orig_states = states
+ if images or zero_obs: # Zero-out x, y position. Hacky.
+ states *= tf.constant([0.0] * 2 + [1.0] * (states.shape[1] - 2))
+ if hidden_layers:
+ states = slim.stack(states, slim.fully_connected, hidden_layers,
+ scope='states')
+ with slim.arg_scope([slim.fully_connected],
+ weights_initializer=tf.random_uniform_initializer(
+ minval=-0.003, maxval=0.003)):
+ actions = slim.fully_connected(states,
+ action_spec.shape.num_elements(),
+ scope='actions',
+ normalizer_fn=None,
+ activation_fn=tf.nn.tanh)
+ action_means = (action_spec.maximum + action_spec.minimum) / 2.0
+ action_magnitudes = (action_spec.maximum - action_spec.minimum) / 2.0
+ actions = action_means + action_magnitudes * actions
+
+ return actions
diff --git a/models/research/efficient-hrl/cond_fn.py b/models/research/efficient-hrl/cond_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd1a276e136bf69fa453c3b90c5907eecb1bda1e
--- /dev/null
+++ b/models/research/efficient-hrl/cond_fn.py
@@ -0,0 +1,244 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Defines many boolean functions indicating when to step and reset.
+"""
+
+import tensorflow as tf
+import gin.tf
+
+
+@gin.configurable
+def env_transition(agent, state, action, transition_type, environment_steps,
+ num_episodes):
+ """True if the transition_type is TRANSITION or FINAL_TRANSITION.
+
+ Args:
+ agent: RL agent.
+ state: A [num_state_dims] tensor representing a state.
+ action: Action performed.
+ transition_type: Type of transition after action
+ environment_steps: Number of steps performed by environment.
+ num_episodes: Number of episodes.
+ Returns:
+ cond: Returns an op that evaluates to true if the transition type is
+ not RESTARTING
+ """
+ del agent, state, action, num_episodes, environment_steps
+ cond = tf.logical_not(transition_type)
+ return cond
+
+
+@gin.configurable
+def env_restart(agent, state, action, transition_type, environment_steps,
+ num_episodes):
+ """True if the transition_type is RESTARTING.
+
+ Args:
+ agent: RL agent.
+ state: A [num_state_dims] tensor representing a state.
+ action: Action performed.
+ transition_type: Type of transition after action
+ environment_steps: Number of steps performed by environment.
+ num_episodes: Number of episodes.
+ Returns:
+ cond: Returns an op that evaluates to true if the transition type equals
+ RESTARTING.
+ """
+ del agent, state, action, num_episodes, environment_steps
+ cond = tf.identity(transition_type)
+ return cond
+
+
+@gin.configurable
+def every_n_steps(agent,
+ state,
+ action,
+ transition_type,
+ environment_steps,
+ num_episodes,
+ n=150):
+ """True once every n steps.
+
+ Args:
+ agent: RL agent.
+ state: A [num_state_dims] tensor representing a state.
+ action: Action performed.
+ transition_type: Type of transition after action
+ environment_steps: Number of steps performed by environment.
+ num_episodes: Number of episodes.
+ n: Return true once every n steps.
+ Returns:
+ cond: Returns an op that evaluates to true if environment_steps
+ equals 0 mod n. We increment the step before checking this condition, so
+ we do not need to add one to environment_steps.
+ """
+ del agent, state, action, transition_type, num_episodes
+ cond = tf.equal(tf.mod(environment_steps, n), 0)
+ return cond
+
+
+@gin.configurable
+def every_n_episodes(agent,
+ state,
+ action,
+ transition_type,
+ environment_steps,
+ num_episodes,
+ n=2,
+ steps_per_episode=None):
+ """True once every n episodes.
+
+ Specifically, evaluates to True on the 0th step of every nth episode.
+ Unlike environment_steps, num_episodes starts at 0, so we do want to add
+ one to ensure it does not reset on the first call.
+
+ Args:
+ agent: RL agent.
+ state: A [num_state_dims] tensor representing a state.
+ action: Action performed.
+ transition_type: Type of transition after action
+ environment_steps: Number of steps performed by environment.
+ num_episodes: Number of episodes.
+ n: Return true once every n episodes.
+ steps_per_episode: How many steps per episode. Needed to determine when a
+ new episode starts.
+ Returns:
+ cond: Returns an op that evaluates to true on the last step of the episode
+ (i.e. if num_episodes equals 0 mod n).
+ """
+ assert steps_per_episode is not None
+ del agent, action, transition_type
+ ant_fell = tf.logical_or(state[2] < 0.2, state[2] > 1.0)
+ cond = tf.logical_and(
+ tf.logical_or(
+ ant_fell,
+ tf.equal(tf.mod(num_episodes + 1, n), 0)),
+ tf.equal(tf.mod(environment_steps, steps_per_episode), 0))
+ return cond
+
+
+@gin.configurable
+def failed_reset_after_n_episodes(agent,
+ state,
+ action,
+ transition_type,
+ environment_steps,
+ num_episodes,
+ steps_per_episode=None,
+ reset_state=None,
+ max_dist=1.0,
+ epsilon=1e-10):
+ """Every n episodes, returns True if the reset agent fails to return.
+
+ Specifically, evaluates to True if the distance between the state and the
+ reset state is greater than max_dist at the end of the episode.
+
+ Args:
+ agent: RL agent.
+ state: A [num_state_dims] tensor representing a state.
+ action: Action performed.
+ transition_type: Type of transition after action
+ environment_steps: Number of steps performed by environment.
+ num_episodes: Number of episodes.
+ steps_per_episode: How many steps per episode. Needed to determine when a
+ new episode starts.
+ reset_state: State to which the reset controller should return.
+ max_dist: Agent is considered to have successfully reset if its distance
+ from the reset_state is less than max_dist.
+ epsilon: small offset to ensure non-negative/zero distance.
+ Returns:
+ cond: Returns an op that evaluates to true if num_episodes+1 equals 0
+ mod n. We add one to the num_episodes so the environment is not reset after
+ the 0th step.
+ """
+ assert steps_per_episode is not None
+ assert reset_state is not None
+ del agent, state, action, transition_type, num_episodes
+ dist = tf.sqrt(
+ tf.reduce_sum(tf.squared_difference(state, reset_state)) + epsilon)
+ cond = tf.logical_and(
+ tf.greater(dist, tf.constant(max_dist)),
+ tf.equal(tf.mod(environment_steps, steps_per_episode), 0))
+ return cond
+
+
+@gin.configurable
+def q_too_small(agent,
+ state,
+ action,
+ transition_type,
+ environment_steps,
+ num_episodes,
+ q_min=0.5):
+ """True of q is too small.
+
+ Args:
+ agent: RL agent.
+ state: A [num_state_dims] tensor representing a state.
+ action: Action performed.
+ transition_type: Type of transition after action
+ environment_steps: Number of steps performed by environment.
+ num_episodes: Number of episodes.
+ q_min: Returns true if the qval is less than q_min
+ Returns:
+ cond: Returns an op that evaluates to true if qval is less than q_min.
+ """
+ del transition_type, environment_steps, num_episodes
+ state_for_reset_agent = tf.stack(state[:-1], tf.constant([0], dtype=tf.float))
+ qval = agent.BASE_AGENT_CLASS.critic_net(
+ tf.expand_dims(state_for_reset_agent, 0), tf.expand_dims(action, 0))[0, :]
+ cond = tf.greater(tf.constant(q_min), qval)
+ return cond
+
+
+@gin.configurable
+def true_fn(agent, state, action, transition_type, environment_steps,
+ num_episodes):
+ """Returns an op that evaluates to true.
+
+ Args:
+ agent: RL agent.
+ state: A [num_state_dims] tensor representing a state.
+ action: Action performed.
+ transition_type: Type of transition after action
+ environment_steps: Number of steps performed by environment.
+ num_episodes: Number of episodes.
+ Returns:
+ cond: op that always evaluates to True.
+ """
+ del agent, state, action, transition_type, environment_steps, num_episodes
+ cond = tf.constant(True, dtype=tf.bool)
+ return cond
+
+
+@gin.configurable
+def false_fn(agent, state, action, transition_type, environment_steps,
+ num_episodes):
+ """Returns an op that evaluates to false.
+
+ Args:
+ agent: RL agent.
+ state: A [num_state_dims] tensor representing a state.
+ action: Action performed.
+ transition_type: Type of transition after action
+ environment_steps: Number of steps performed by environment.
+ num_episodes: Number of episodes.
+ Returns:
+ cond: op that always evaluates to False.
+ """
+ del agent, state, action, transition_type, environment_steps, num_episodes
+ cond = tf.constant(False, dtype=tf.bool)
+ return cond
diff --git a/models/research/efficient-hrl/configs/base_uvf.gin b/models/research/efficient-hrl/configs/base_uvf.gin
new file mode 100644
index 0000000000000000000000000000000000000000..2f3f47b67a3fb0a38ee35b7a4deaf54a2700a19a
--- /dev/null
+++ b/models/research/efficient-hrl/configs/base_uvf.gin
@@ -0,0 +1,68 @@
+#-*-Python-*-
+import gin.tf.external_configurables
+
+create_maze_env.top_down_view = %IMAGES
+## Create the agent
+AGENT_CLASS = @UvfAgent
+UvfAgent.tf_context = %CONTEXT
+UvfAgent.actor_net = @agent/ddpg_actor_net
+UvfAgent.critic_net = @agent/ddpg_critic_net
+UvfAgent.dqda_clipping = 0.0
+UvfAgent.td_errors_loss = @tf.losses.huber_loss
+UvfAgent.target_q_clipping = %TARGET_Q_CLIPPING
+
+# Create meta agent
+META_CLASS = @MetaAgent
+MetaAgent.tf_context = %META_CONTEXT
+MetaAgent.sub_context = %CONTEXT
+MetaAgent.actor_net = @meta/ddpg_actor_net
+MetaAgent.critic_net = @meta/ddpg_critic_net
+MetaAgent.dqda_clipping = 0.0
+MetaAgent.td_errors_loss = @tf.losses.huber_loss
+MetaAgent.target_q_clipping = %TARGET_Q_CLIPPING
+
+# Create state preprocess
+STATE_PREPROCESS_CLASS = @StatePreprocess
+StatePreprocess.ndims = %SUBGOAL_DIM
+state_preprocess_net.states_hidden_layers = (100, 100)
+state_preprocess_net.num_output_dims = %SUBGOAL_DIM
+state_preprocess_net.images = %IMAGES
+action_embed_net.num_output_dims = %SUBGOAL_DIM
+INVERSE_DYNAMICS_CLASS = @InverseDynamics
+
+# actor_net
+ACTOR_HIDDEN_SIZE_1 = 300
+ACTOR_HIDDEN_SIZE_2 = 300
+agent/ddpg_actor_net.hidden_layers = (%ACTOR_HIDDEN_SIZE_1, %ACTOR_HIDDEN_SIZE_2)
+agent/ddpg_actor_net.activation_fn = @tf.nn.relu
+agent/ddpg_actor_net.zero_obs = %ZERO_OBS
+agent/ddpg_actor_net.images = %IMAGES
+meta/ddpg_actor_net.hidden_layers = (%ACTOR_HIDDEN_SIZE_1, %ACTOR_HIDDEN_SIZE_2)
+meta/ddpg_actor_net.activation_fn = @tf.nn.relu
+meta/ddpg_actor_net.zero_obs = False
+meta/ddpg_actor_net.images = %IMAGES
+# critic_net
+CRITIC_HIDDEN_SIZE_1 = 300
+CRITIC_HIDDEN_SIZE_2 = 300
+agent/ddpg_critic_net.states_hidden_layers = (%CRITIC_HIDDEN_SIZE_1,)
+agent/ddpg_critic_net.actions_hidden_layers = None
+agent/ddpg_critic_net.joint_hidden_layers = (%CRITIC_HIDDEN_SIZE_2,)
+agent/ddpg_critic_net.weight_decay = 0.0
+agent/ddpg_critic_net.activation_fn = @tf.nn.relu
+agent/ddpg_critic_net.zero_obs = %ZERO_OBS
+agent/ddpg_critic_net.images = %IMAGES
+meta/ddpg_critic_net.states_hidden_layers = (%CRITIC_HIDDEN_SIZE_1,)
+meta/ddpg_critic_net.actions_hidden_layers = None
+meta/ddpg_critic_net.joint_hidden_layers = (%CRITIC_HIDDEN_SIZE_2,)
+meta/ddpg_critic_net.weight_decay = 0.0
+meta/ddpg_critic_net.activation_fn = @tf.nn.relu
+meta/ddpg_critic_net.zero_obs = False
+meta/ddpg_critic_net.images = %IMAGES
+
+tf.losses.huber_loss.delta = 1.0
+# Sample action
+uvf_add_noise_fn.stddev = 1.0
+meta_add_noise_fn.stddev = %META_EXPLORE_NOISE
+# Update targets
+ddpg_update_targets.tau = 0.001
+td3_update_targets.tau = 0.005
diff --git a/models/research/efficient-hrl/configs/eval_uvf.gin b/models/research/efficient-hrl/configs/eval_uvf.gin
new file mode 100644
index 0000000000000000000000000000000000000000..7a58241e06aa4a6140faa8a74b262729f1f5e4c1
--- /dev/null
+++ b/models/research/efficient-hrl/configs/eval_uvf.gin
@@ -0,0 +1,14 @@
+#-*-Python-*-
+# Config eval
+evaluate.environment = @create_maze_env()
+evaluate.agent_class = %AGENT_CLASS
+evaluate.meta_agent_class = %META_CLASS
+evaluate.state_preprocess_class = %STATE_PREPROCESS_CLASS
+evaluate.num_episodes_eval = 50
+evaluate.num_episodes_videos = 1
+evaluate.gamma = 1.0
+evaluate.eval_interval_secs = 1
+evaluate.generate_videos = False
+evaluate.generate_summaries = True
+evaluate.eval_modes = %EVAL_MODES
+evaluate.max_steps_per_episode = %RESET_EPISODE_PERIOD
diff --git a/models/research/efficient-hrl/configs/train_uvf.gin b/models/research/efficient-hrl/configs/train_uvf.gin
new file mode 100644
index 0000000000000000000000000000000000000000..8b02d7a6cb468f913ebaefe3fa8f519c3ad8fe4c
--- /dev/null
+++ b/models/research/efficient-hrl/configs/train_uvf.gin
@@ -0,0 +1,52 @@
+#-*-Python-*-
+# Create replay_buffer
+agent/CircularBuffer.buffer_size = 200000
+meta/CircularBuffer.buffer_size = 200000
+agent/CircularBuffer.scope = "agent"
+meta/CircularBuffer.scope = "meta"
+
+# Config train
+train_uvf.environment = @create_maze_env()
+train_uvf.agent_class = %AGENT_CLASS
+train_uvf.meta_agent_class = %META_CLASS
+train_uvf.state_preprocess_class = %STATE_PREPROCESS_CLASS
+train_uvf.inverse_dynamics_class = %INVERSE_DYNAMICS_CLASS
+train_uvf.replay_buffer = @agent/CircularBuffer()
+train_uvf.meta_replay_buffer = @meta/CircularBuffer()
+train_uvf.critic_optimizer = @critic/AdamOptimizer()
+train_uvf.actor_optimizer = @actor/AdamOptimizer()
+train_uvf.meta_critic_optimizer = @meta_critic/AdamOptimizer()
+train_uvf.meta_actor_optimizer = @meta_actor/AdamOptimizer()
+train_uvf.repr_optimizer = @repr/AdamOptimizer()
+train_uvf.num_episodes_train = 25000
+train_uvf.batch_size = 100
+train_uvf.initial_episodes = 5
+train_uvf.gamma = 0.99
+train_uvf.meta_gamma = 0.99
+train_uvf.reward_scale_factor = 1.0
+train_uvf.target_update_period = 2
+train_uvf.num_updates_per_observation = 1
+train_uvf.num_collect_per_update = 1
+train_uvf.num_collect_per_meta_update = 10
+train_uvf.debug_summaries = False
+train_uvf.log_every_n_steps = 1000
+train_uvf.save_policy_every_n_steps =100000
+
+# Config Optimizers
+critic/AdamOptimizer.learning_rate = 0.001
+critic/AdamOptimizer.beta1 = 0.9
+critic/AdamOptimizer.beta2 = 0.999
+actor/AdamOptimizer.learning_rate = 0.0001
+actor/AdamOptimizer.beta1 = 0.9
+actor/AdamOptimizer.beta2 = 0.999
+
+meta_critic/AdamOptimizer.learning_rate = 0.001
+meta_critic/AdamOptimizer.beta1 = 0.9
+meta_critic/AdamOptimizer.beta2 = 0.999
+meta_actor/AdamOptimizer.learning_rate = 0.0001
+meta_actor/AdamOptimizer.beta1 = 0.9
+meta_actor/AdamOptimizer.beta2 = 0.999
+
+repr/AdamOptimizer.learning_rate = 0.0001
+repr/AdamOptimizer.beta1 = 0.9
+repr/AdamOptimizer.beta2 = 0.999
diff --git a/models/research/efficient-hrl/context/__init__.py b/models/research/efficient-hrl/context/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/research/efficient-hrl/context/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/research/efficient-hrl/context/configs/ant_block.gin b/models/research/efficient-hrl/context/configs/ant_block.gin
new file mode 100644
index 0000000000000000000000000000000000000000..d5bd4f01e015611127bbc188b3ac9af3df6a288a
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_block.gin
@@ -0,0 +1,67 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntBlock"
+ZERO_OBS = False
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-4, -4), (20, 20))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval1", "eval2", "eval3"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [2]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval1": [@eval1/ConstantSampler],
+ "eval2": [@eval2/ConstantSampler],
+ "eval3": [@eval3/ConstantSampler],
+}
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [3, 4]
+task/negative_distance.relative_context = False
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+MetaAgent.k = %SUBGOAL_DIM
+
+eval1/ConstantSampler.value = [16, 0]
+eval2/ConstantSampler.value = [16, 16]
+eval3/ConstantSampler.value = [0, 16]
diff --git a/models/research/efficient-hrl/context/configs/ant_block_maze.gin b/models/research/efficient-hrl/context/configs/ant_block_maze.gin
new file mode 100644
index 0000000000000000000000000000000000000000..cebf775be129b51588092699fbeb314fb4f985d0
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_block_maze.gin
@@ -0,0 +1,67 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntBlockMaze"
+ZERO_OBS = False
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-4, -4), (12, 20))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval1", "eval2", "eval3"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [2]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval1": [@eval1/ConstantSampler],
+ "eval2": [@eval2/ConstantSampler],
+ "eval3": [@eval3/ConstantSampler],
+}
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [3, 4]
+task/negative_distance.relative_context = False
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+MetaAgent.k = %SUBGOAL_DIM
+
+eval1/ConstantSampler.value = [8, 0]
+eval2/ConstantSampler.value = [8, 16]
+eval3/ConstantSampler.value = [0, 16]
diff --git a/models/research/efficient-hrl/context/configs/ant_fall_multi.gin b/models/research/efficient-hrl/context/configs/ant_fall_multi.gin
new file mode 100644
index 0000000000000000000000000000000000000000..eb89ad0cb164ddb4c0c08ba9649d7f2e5d7a9944
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_fall_multi.gin
@@ -0,0 +1,62 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntFall"
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-4, -4, 0), (12, 28, 5))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval1"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [3]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval1": [@eval1/ConstantSampler],
+}
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1, 2]
+task/negative_distance.relative_context = False
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+MetaAgent.k = %SUBGOAL_DIM
+
+eval1/ConstantSampler.value = [0, 27, 4.5]
diff --git a/models/research/efficient-hrl/context/configs/ant_fall_multi_img.gin b/models/research/efficient-hrl/context/configs/ant_fall_multi_img.gin
new file mode 100644
index 0000000000000000000000000000000000000000..b54fb7c91961ab38febe597325c0d816a872be20
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_fall_multi_img.gin
@@ -0,0 +1,68 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntFall"
+IMAGES = True
+
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-4, -4, 0), (12, 28, 5))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval1"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [3]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval1": [@eval1/ConstantSampler],
+}
+meta/Context.context_transition_fn = @task/relative_context_transition_fn
+meta/Context.context_multi_transition_fn = @task/relative_context_multi_transition_fn
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1, 2]
+task/negative_distance.relative_context = True
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+task/relative_context_transition_fn.k = 3
+task/relative_context_multi_transition_fn.k = 3
+MetaAgent.k = %SUBGOAL_DIM
+
+eval1/ConstantSampler.value = [0, 27, 0]
diff --git a/models/research/efficient-hrl/context/configs/ant_fall_single.gin b/models/research/efficient-hrl/context/configs/ant_fall_single.gin
new file mode 100644
index 0000000000000000000000000000000000000000..56bbc070072182cbcda7580c87cf65a593e8a743
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_fall_single.gin
@@ -0,0 +1,62 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntFall"
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-4, -4, 0), (12, 28, 5))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval1"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [3]
+meta/Context.samplers = {
+ "train": [@eval1/ConstantSampler],
+ "explore": [@eval1/ConstantSampler],
+ "eval1": [@eval1/ConstantSampler],
+}
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1, 2]
+task/negative_distance.relative_context = False
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+MetaAgent.k = %SUBGOAL_DIM
+
+eval1/ConstantSampler.value = [0, 27, 4.5]
diff --git a/models/research/efficient-hrl/context/configs/ant_maze.gin b/models/research/efficient-hrl/context/configs/ant_maze.gin
new file mode 100644
index 0000000000000000000000000000000000000000..3a0b73e30d7054dc6d573669b7df728cff93226a
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_maze.gin
@@ -0,0 +1,66 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntMaze"
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-4, -4), (20, 20))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval1", "eval2", "eval3"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [2]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval1": [@eval1/ConstantSampler],
+ "eval2": [@eval2/ConstantSampler],
+ "eval3": [@eval3/ConstantSampler],
+}
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1]
+task/negative_distance.relative_context = False
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+MetaAgent.k = %SUBGOAL_DIM
+
+eval1/ConstantSampler.value = [16, 0]
+eval2/ConstantSampler.value = [16, 16]
+eval3/ConstantSampler.value = [0, 16]
diff --git a/models/research/efficient-hrl/context/configs/ant_maze_img.gin b/models/research/efficient-hrl/context/configs/ant_maze_img.gin
new file mode 100644
index 0000000000000000000000000000000000000000..ceed65a0884587d9cd64cdf162bf1b7e3495469d
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_maze_img.gin
@@ -0,0 +1,72 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntMaze"
+IMAGES = True
+
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-4, -4), (20, 20))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval1", "eval2", "eval3"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [2]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval1": [@eval1/ConstantSampler],
+ "eval2": [@eval2/ConstantSampler],
+ "eval3": [@eval3/ConstantSampler],
+}
+meta/Context.context_transition_fn = @task/relative_context_transition_fn
+meta/Context.context_multi_transition_fn = @task/relative_context_multi_transition_fn
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1]
+task/negative_distance.relative_context = True
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+task/relative_context_transition_fn.k = 2
+task/relative_context_multi_transition_fn.k = 2
+MetaAgent.k = %SUBGOAL_DIM
+
+eval1/ConstantSampler.value = [16, 0]
+eval2/ConstantSampler.value = [16, 16]
+eval3/ConstantSampler.value = [0, 16]
diff --git a/models/research/efficient-hrl/context/configs/ant_push_multi.gin b/models/research/efficient-hrl/context/configs/ant_push_multi.gin
new file mode 100644
index 0000000000000000000000000000000000000000..db9b4ed7bbe81fe38c9fbad10a43dde485a06802
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_push_multi.gin
@@ -0,0 +1,62 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntPush"
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-16, -4), (16, 20))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval2"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [2]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval2": [@eval2/ConstantSampler],
+}
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1]
+task/negative_distance.relative_context = False
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+MetaAgent.k = %SUBGOAL_DIM
+
+eval2/ConstantSampler.value = [0, 19]
diff --git a/models/research/efficient-hrl/context/configs/ant_push_multi_img.gin b/models/research/efficient-hrl/context/configs/ant_push_multi_img.gin
new file mode 100644
index 0000000000000000000000000000000000000000..abdc43402fca8a3e83438655bec26c06b8dfccbe
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_push_multi_img.gin
@@ -0,0 +1,68 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntPush"
+IMAGES = True
+
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-16, -4), (16, 20))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval2"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [2]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval2": [@eval2/ConstantSampler],
+}
+meta/Context.context_transition_fn = @task/relative_context_transition_fn
+meta/Context.context_multi_transition_fn = @task/relative_context_multi_transition_fn
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1]
+task/negative_distance.relative_context = True
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+task/relative_context_transition_fn.k = 2
+task/relative_context_multi_transition_fn.k = 2
+MetaAgent.k = %SUBGOAL_DIM
+
+eval2/ConstantSampler.value = [0, 19]
diff --git a/models/research/efficient-hrl/context/configs/ant_push_single.gin b/models/research/efficient-hrl/context/configs/ant_push_single.gin
new file mode 100644
index 0000000000000000000000000000000000000000..e85c5dfba4d04668cc5407c89aa42ca2044e12fd
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/ant_push_single.gin
@@ -0,0 +1,62 @@
+#-*-Python-*-
+create_maze_env.env_name = "AntPush"
+context_range = (%CONTEXT_RANGE_MIN, %CONTEXT_RANGE_MAX)
+meta_context_range = ((-16, -4), (16, 20))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval2"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [2]
+meta/Context.samplers = {
+ "train": [@eval2/ConstantSampler],
+ "explore": [@eval2/ConstantSampler],
+ "eval2": [@eval2/ConstantSampler],
+}
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1]
+task/negative_distance.relative_context = False
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+MetaAgent.k = %SUBGOAL_DIM
+
+eval2/ConstantSampler.value = [0, 19]
diff --git a/models/research/efficient-hrl/context/configs/default.gin b/models/research/efficient-hrl/context/configs/default.gin
new file mode 100644
index 0000000000000000000000000000000000000000..65f91e5292db86b62a9275fcc8929d46d779a677
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/default.gin
@@ -0,0 +1,12 @@
+#-*-Python-*-
+ENV_CONTEXT = None
+EVAL_MODES = ["eval"]
+TARGET_Q_CLIPPING = None
+RESET_EPISODE_PERIOD = None
+ZERO_OBS = False
+CONTEXT_RANGE_MIN = -10
+CONTEXT_RANGE_MAX = 10
+SUBGOAL_DIM = 2
+
+uvf/negative_distance.summarize = False
+uvf/negative_distance.relative_context = True
diff --git a/models/research/efficient-hrl/context/configs/hiro_orig.gin b/models/research/efficient-hrl/context/configs/hiro_orig.gin
new file mode 100644
index 0000000000000000000000000000000000000000..e39ba96be7b7d323ecf1a849dcea89d1468e87c3
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/hiro_orig.gin
@@ -0,0 +1,14 @@
+#-*-Python-*-
+ENV_CONTEXT = None
+EVAL_MODES = ["eval"]
+TARGET_Q_CLIPPING = None
+RESET_EPISODE_PERIOD = None
+ZERO_OBS = True
+IMAGES = False
+CONTEXT_RANGE_MIN = (-10, -10, -0.5, -1, -1, -1, -1, -0.5, -0.3, -0.5, -0.3, -0.5, -0.3, -0.5, -0.3)
+CONTEXT_RANGE_MAX = ( 10, 10, 0.5, 1, 1, 1, 1, 0.5, 0.3, 0.5, 0.3, 0.5, 0.3, 0.5, 0.3)
+SUBGOAL_DIM = 15
+META_EXPLORE_NOISE = 1.0
+
+uvf/negative_distance.summarize = False
+uvf/negative_distance.relative_context = True
diff --git a/models/research/efficient-hrl/context/configs/hiro_repr.gin b/models/research/efficient-hrl/context/configs/hiro_repr.gin
new file mode 100644
index 0000000000000000000000000000000000000000..a0a8057bd3cc834e5c1be33e73a9b7c6ae370a99
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/hiro_repr.gin
@@ -0,0 +1,18 @@
+#-*-Python-*-
+ENV_CONTEXT = None
+EVAL_MODES = ["eval"]
+TARGET_Q_CLIPPING = None
+RESET_EPISODE_PERIOD = None
+ZERO_OBS = False
+IMAGES = False
+CONTEXT_RANGE_MIN = -10
+CONTEXT_RANGE_MAX = 10
+SUBGOAL_DIM = 2
+META_EXPLORE_NOISE = 5.0
+
+StatePreprocess.trainable = True
+StatePreprocess.state_preprocess_net = @state_preprocess_net
+StatePreprocess.action_embed_net = @action_embed_net
+
+uvf/negative_distance.summarize = False
+uvf/negative_distance.relative_context = True
diff --git a/models/research/efficient-hrl/context/configs/hiro_xy.gin b/models/research/efficient-hrl/context/configs/hiro_xy.gin
new file mode 100644
index 0000000000000000000000000000000000000000..f35026c9e24246a87e69e16031981a13e22c283d
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/hiro_xy.gin
@@ -0,0 +1,14 @@
+#-*-Python-*-
+ENV_CONTEXT = None
+EVAL_MODES = ["eval"]
+TARGET_Q_CLIPPING = None
+RESET_EPISODE_PERIOD = None
+ZERO_OBS = False
+IMAGES = False
+CONTEXT_RANGE_MIN = -10
+CONTEXT_RANGE_MAX = 10
+SUBGOAL_DIM = 2
+META_EXPLORE_NOISE = 1.0
+
+uvf/negative_distance.summarize = False
+uvf/negative_distance.relative_context = True
diff --git a/models/research/efficient-hrl/context/configs/point_maze.gin b/models/research/efficient-hrl/context/configs/point_maze.gin
new file mode 100644
index 0000000000000000000000000000000000000000..0ea67d2d5fffedfdbc7b9df443acc3adaf98ec99
--- /dev/null
+++ b/models/research/efficient-hrl/context/configs/point_maze.gin
@@ -0,0 +1,73 @@
+#-*-Python-*-
+# NOTE: For best training, low-level exploration (uvf_add_noise_fn.stddev)
+# should be reduced to around 0.1.
+create_maze_env.env_name = "PointMaze"
+context_range_min = -10
+context_range_max = 10
+context_range = (%context_range_min, %context_range_max)
+meta_context_range = ((-2, -2), (10, 10))
+
+RESET_EPISODE_PERIOD = 500
+RESET_ENV_PERIOD = 1
+# End episode every N steps
+UvfAgent.reset_episode_cond_fn = @every_n_steps
+every_n_steps.n = %RESET_EPISODE_PERIOD
+train_uvf.max_steps_per_episode = %RESET_EPISODE_PERIOD
+# Do a manual reset every N episodes
+UvfAgent.reset_env_cond_fn = @every_n_episodes
+every_n_episodes.n = %RESET_ENV_PERIOD
+every_n_episodes.steps_per_episode = %RESET_EPISODE_PERIOD
+
+## Config defaults
+EVAL_MODES = ["eval1", "eval2", "eval3"]
+
+## Config agent
+CONTEXT = @agent/Context
+META_CONTEXT = @meta/Context
+
+## Config agent context
+agent/Context.context_ranges = [%context_range]
+agent/Context.context_shapes = [%SUBGOAL_DIM]
+agent/Context.meta_action_every_n = 10
+agent/Context.samplers = {
+ "train": [@train/DirectionSampler],
+ "explore": [@train/DirectionSampler],
+ "eval1": [@uvf_eval1/ConstantSampler],
+ "eval2": [@uvf_eval2/ConstantSampler],
+ "eval3": [@uvf_eval3/ConstantSampler],
+}
+
+agent/Context.context_transition_fn = @relative_context_transition_fn
+agent/Context.context_multi_transition_fn = @relative_context_multi_transition_fn
+
+agent/Context.reward_fn = @uvf/negative_distance
+
+## Config meta context
+meta/Context.context_ranges = [%meta_context_range]
+meta/Context.context_shapes = [2]
+meta/Context.samplers = {
+ "train": [@train/RandomSampler],
+ "explore": [@train/RandomSampler],
+ "eval1": [@eval1/ConstantSampler],
+ "eval2": [@eval2/ConstantSampler],
+ "eval3": [@eval3/ConstantSampler],
+}
+meta/Context.reward_fn = @task/negative_distance
+
+## Config rewards
+task/negative_distance.state_indices = [0, 1]
+task/negative_distance.relative_context = False
+task/negative_distance.diff = False
+task/negative_distance.offset = 0.0
+
+## Config samplers
+train/RandomSampler.context_range = %meta_context_range
+train/DirectionSampler.context_range = %context_range
+train/DirectionSampler.k = %SUBGOAL_DIM
+relative_context_transition_fn.k = %SUBGOAL_DIM
+relative_context_multi_transition_fn.k = %SUBGOAL_DIM
+MetaAgent.k = %SUBGOAL_DIM
+
+eval1/ConstantSampler.value = [8, 0]
+eval2/ConstantSampler.value = [8, 8]
+eval3/ConstantSampler.value = [0, 8]
diff --git a/models/research/efficient-hrl/context/context.py b/models/research/efficient-hrl/context/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..76be00b4966539b869e714225fbba124b9602c3a
--- /dev/null
+++ b/models/research/efficient-hrl/context/context.py
@@ -0,0 +1,467 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Context for Universal Value Function agents.
+
+A context specifies a list of contextual variables, each with
+ own sampling and reward computation methods.
+
+Examples of contextual variables include
+ goal states, reward combination vectors, etc.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf
+from tf_agents import specs
+import gin.tf
+from utils import utils as uvf_utils
+
+
+@gin.configurable
+class Context(object):
+ """Base context."""
+ VAR_NAME = 'action'
+
+ def __init__(self,
+ tf_env,
+ context_ranges=None,
+ context_shapes=None,
+ state_indices=None,
+ variable_indices=None,
+ gamma_index=None,
+ settable_context=False,
+ timers=None,
+ samplers=None,
+ reward_weights=None,
+ reward_fn=None,
+ random_sampler_mode='random',
+ normalizers=None,
+ context_transition_fn=None,
+ context_multi_transition_fn=None,
+ meta_action_every_n=None):
+ self._tf_env = tf_env
+ self.variable_indices = variable_indices
+ self.gamma_index = gamma_index
+ self._settable_context = settable_context
+ self.timers = timers
+ self._context_transition_fn = context_transition_fn
+ self._context_multi_transition_fn = context_multi_transition_fn
+ self._random_sampler_mode = random_sampler_mode
+
+ # assign specs
+ self._obs_spec = self._tf_env.observation_spec()
+ self._context_shapes = tuple([
+ shape if shape is not None else self._obs_spec.shape
+ for shape in context_shapes
+ ])
+ self.context_specs = tuple([
+ specs.TensorSpec(dtype=self._obs_spec.dtype, shape=shape)
+ for shape in self._context_shapes
+ ])
+ if context_ranges is not None:
+ self.context_ranges = context_ranges
+ else:
+ self.context_ranges = [None] * len(self._context_shapes)
+
+ self.context_as_action_specs = tuple([
+ specs.BoundedTensorSpec(
+ shape=shape,
+ dtype=(tf.float32 if self._obs_spec.dtype in
+ [tf.float32, tf.float64] else self._obs_spec.dtype),
+ minimum=context_range[0],
+ maximum=context_range[-1])
+ for shape, context_range in zip(self._context_shapes, self.context_ranges)
+ ])
+
+ if state_indices is not None:
+ self.state_indices = state_indices
+ else:
+ self.state_indices = [None] * len(self._context_shapes)
+ if self.variable_indices is not None and self.n != len(
+ self.variable_indices):
+ raise ValueError(
+ 'variable_indices (%s) must have the same length as contexts (%s).' %
+ (self.variable_indices, self.context_specs))
+ assert self.n == len(self.context_ranges)
+ assert self.n == len(self.state_indices)
+
+ # assign reward/sampler fns
+ self._sampler_fns = dict()
+ self._samplers = dict()
+ self._reward_fns = dict()
+
+ # assign reward fns
+ self._add_custom_reward_fns()
+ reward_weights = reward_weights or None
+ self._reward_fn = self._make_reward_fn(reward_fn, reward_weights)
+
+ # assign samplers
+ self._add_custom_sampler_fns()
+ for mode, sampler_fns in samplers.items():
+ self._make_sampler_fn(sampler_fns, mode)
+
+ # create normalizers
+ if normalizers is None:
+ self._normalizers = [None] * len(self.context_specs)
+ else:
+ self._normalizers = [
+ normalizer(tf.zeros(shape=spec.shape, dtype=spec.dtype))
+ if normalizer is not None else None
+ for normalizer, spec in zip(normalizers, self.context_specs)
+ ]
+ assert self.n == len(self._normalizers)
+
+ self.meta_action_every_n = meta_action_every_n
+
+ # create vars
+ self.context_vars = {}
+ self.timer_vars = {}
+ self.create_vars(self.VAR_NAME)
+ self.t = tf.Variable(
+ tf.zeros(shape=(), dtype=tf.int32), name='num_timer_steps')
+
+ def _add_custom_reward_fns(self):
+ pass
+
+ def _add_custom_sampler_fns(self):
+ pass
+
+ def sample_random_contexts(self, batch_size):
+ """Sample random batch contexts."""
+ assert self._random_sampler_mode is not None
+ return self.sample_contexts(self._random_sampler_mode, batch_size)[0]
+
+ def sample_contexts(self, mode, batch_size, state=None, next_state=None,
+ **kwargs):
+ """Sample a batch of contexts.
+
+ Args:
+ mode: A string representing the mode [`train`, `explore`, `eval`].
+ batch_size: Batch size.
+ Returns:
+ Two lists of [batch_size, num_context_dims] contexts.
+ """
+ contexts, next_contexts = self._sampler_fns[mode](
+ batch_size, state=state, next_state=next_state,
+ **kwargs)
+ self._validate_contexts(contexts)
+ self._validate_contexts(next_contexts)
+ return contexts, next_contexts
+
+ def compute_rewards(self, mode, states, actions, rewards, next_states,
+ contexts):
+ """Compute context-based rewards.
+
+ Args:
+ mode: A string representing the mode ['uvf', 'task'].
+ states: A [batch_size, num_state_dims] tensor.
+ actions: A [batch_size, num_action_dims] tensor.
+ rewards: A [batch_size] tensor representing unmodified rewards.
+ next_states: A [batch_size, num_state_dims] tensor.
+ contexts: A list of [batch_size, num_context_dims] tensors.
+ Returns:
+ A [batch_size] tensor representing rewards.
+ """
+ return self._reward_fn(states, actions, rewards, next_states,
+ contexts)
+
+ def _make_reward_fn(self, reward_fns_list, reward_weights):
+ """Returns a fn that computes rewards.
+
+ Args:
+ reward_fns_list: A fn or a list of reward fns.
+ mode: A string representing the operating mode.
+ reward_weights: A list of reward weights.
+ """
+ if not isinstance(reward_fns_list, (list, tuple)):
+ reward_fns_list = [reward_fns_list]
+ if reward_weights is None:
+ reward_weights = [1.0] * len(reward_fns_list)
+ assert len(reward_fns_list) == len(reward_weights)
+
+ reward_fns_list = [
+ self._custom_reward_fns[fn] if isinstance(fn, (str,)) else fn
+ for fn in reward_fns_list
+ ]
+
+ def reward_fn(*args, **kwargs):
+ """Returns rewards, discounts."""
+ reward_tuples = [
+ reward_fn(*args, **kwargs) for reward_fn in reward_fns_list
+ ]
+ rewards_list = [reward_tuple[0] for reward_tuple in reward_tuples]
+ discounts_list = [reward_tuple[1] for reward_tuple in reward_tuples]
+ ndims = max([r.shape.ndims for r in rewards_list])
+ if ndims > 1: # expand reward shapes to allow broadcasting
+ for i in range(len(rewards_list)):
+ for _ in range(rewards_list[i].shape.ndims - ndims):
+ rewards_list[i] = tf.expand_dims(rewards_list[i], axis=-1)
+ for _ in range(discounts_list[i].shape.ndims - ndims):
+ discounts_list[i] = tf.expand_dims(discounts_list[i], axis=-1)
+ rewards = tf.add_n(
+ [r * tf.to_float(w) for r, w in zip(rewards_list, reward_weights)])
+ discounts = discounts_list[0]
+ for d in discounts_list[1:]:
+ discounts *= d
+
+ return rewards, discounts
+
+ return reward_fn
+
+ def _make_sampler_fn(self, sampler_cls_list, mode):
+ """Returns a fn that samples a list of context vars.
+
+ Args:
+ sampler_cls_list: A list of sampler classes.
+ mode: A string representing the operating mode.
+ """
+ if not isinstance(sampler_cls_list, (list, tuple)):
+ sampler_cls_list = [sampler_cls_list]
+
+ self._samplers[mode] = []
+ sampler_fns = []
+ for spec, sampler in zip(self.context_specs, sampler_cls_list):
+ if isinstance(sampler, (str,)):
+ sampler_fn = self._custom_sampler_fns[sampler]
+ else:
+ sampler_fn = sampler(context_spec=spec)
+ self._samplers[mode].append(sampler_fn)
+ sampler_fns.append(sampler_fn)
+
+ def batch_sampler_fn(batch_size, state=None, next_state=None, **kwargs):
+ """Sampler fn."""
+ contexts_tuples = [
+ sampler(batch_size, state=state, next_state=next_state, **kwargs)
+ for sampler in sampler_fns]
+ contexts = [c[0] for c in contexts_tuples]
+ next_contexts = [c[1] for c in contexts_tuples]
+ contexts = [
+ normalizer.update_apply(c) if normalizer is not None else c
+ for normalizer, c in zip(self._normalizers, contexts)
+ ]
+ next_contexts = [
+ normalizer.apply(c) if normalizer is not None else c
+ for normalizer, c in zip(self._normalizers, next_contexts)
+ ]
+ return contexts, next_contexts
+
+ self._sampler_fns[mode] = batch_sampler_fn
+
+ def set_env_context_op(self, context, disable_unnormalizer=False):
+ """Returns a TensorFlow op that sets the environment context.
+
+ Args:
+ context: A list of context Tensor variables.
+ disable_unnormalizer: Disable unnormalization.
+ Returns:
+ A TensorFlow op that sets the environment context.
+ """
+ ret_val = np.array(1.0, dtype=np.float32)
+ if not self._settable_context:
+ return tf.identity(ret_val)
+
+ if not disable_unnormalizer:
+ context = [
+ normalizer.unapply(tf.expand_dims(c, 0))[0]
+ if normalizer is not None else c
+ for normalizer, c in zip(self._normalizers, context)
+ ]
+
+ def set_context_func(*env_context_values):
+ tf.logging.info('[set_env_context_op] Setting gym environment context.')
+ # pylint: disable=protected-access
+ self.gym_env.set_context(*env_context_values)
+ return ret_val
+ # pylint: enable=protected-access
+
+ with tf.name_scope('set_env_context'):
+ set_op = tf.py_func(set_context_func, context, tf.float32,
+ name='set_env_context_py_func')
+ set_op.set_shape([])
+ return set_op
+
+ def set_replay(self, replay):
+ """Set replay buffer for samplers.
+
+ Args:
+ replay: A replay buffer.
+ """
+ for _, samplers in self._samplers.items():
+ for sampler in samplers:
+ sampler.set_replay(replay)
+
+ def get_clip_fns(self):
+ """Returns a list of clip fns for contexts.
+
+ Returns:
+ A list of fns that clip context tensors.
+ """
+ clip_fns = []
+ for context_range in self.context_ranges:
+ def clip_fn(var_, range_=context_range):
+ """Clip a tensor."""
+ if range_ is None:
+ clipped_var = tf.identity(var_)
+ elif isinstance(range_[0], (int, long, float, list, np.ndarray)):
+ clipped_var = tf.clip_by_value(
+ var_,
+ range_[0],
+ range_[1],)
+ else: raise NotImplementedError(range_)
+ return clipped_var
+ clip_fns.append(clip_fn)
+ return clip_fns
+
+ def _validate_contexts(self, contexts):
+ """Validate if contexts have right specs.
+
+ Args:
+ contexts: A list of [batch_size, num_context_dim] tensors.
+ Raises:
+ ValueError: If shape or dtype mismatches that of spec.
+ """
+ for i, (context, spec) in enumerate(zip(contexts, self.context_specs)):
+ if context[0].shape != spec.shape:
+ raise ValueError('contexts[%d] has invalid shape %s wrt spec shape %s' %
+ (i, context[0].shape, spec.shape))
+ if context.dtype != spec.dtype:
+ raise ValueError('contexts[%d] has invalid dtype %s wrt spec dtype %s' %
+ (i, context.dtype, spec.dtype))
+
+ def context_multi_transition_fn(self, contexts, **kwargs):
+ """Returns multiple future contexts starting from a batch."""
+ assert self._context_multi_transition_fn
+ return self._context_multi_transition_fn(contexts, None, None, **kwargs)
+
+ def step(self, mode, agent=None, action_fn=None, **kwargs):
+ """Returns [next_contexts..., next_timer] list of ops.
+
+ Args:
+ mode: a string representing the mode=[train, explore, eval].
+ **kwargs: kwargs for context_transition_fn.
+ Returns:
+ a list of ops that set the context.
+ """
+ if agent is None:
+ ops = []
+ if self._context_transition_fn is not None:
+ def sampler_fn():
+ samples = self.sample_contexts(mode, 1)[0]
+ return [s[0] for s in samples]
+ values = self._context_transition_fn(self.vars, self.t, sampler_fn, **kwargs)
+ ops += [tf.assign(var, value) for var, value in zip(self.vars, values)]
+ ops.append(tf.assign_add(self.t, 1)) # increment timer
+ return ops
+ else:
+ ops = agent.tf_context.step(mode, **kwargs)
+ state = kwargs['state']
+ next_state = kwargs['next_state']
+ state_repr = kwargs['state_repr']
+ next_state_repr = kwargs['next_state_repr']
+ with tf.control_dependencies(ops): # Step high level context before computing low level one.
+ # Get the context transition function output.
+ values = self._context_transition_fn(self.vars, self.t, None,
+ state=state_repr,
+ next_state=next_state_repr)
+ # Select a new goal every C steps, otherwise use context transition.
+ low_level_context = [
+ tf.cond(tf.equal(self.t % self.meta_action_every_n, 0),
+ lambda: tf.cast(action_fn(next_state, context=None), tf.float32),
+ lambda: values)]
+ ops = [tf.assign(var, value)
+ for var, value in zip(self.vars, low_level_context)]
+ with tf.control_dependencies(ops):
+ return [tf.assign_add(self.t, 1)] # increment timer
+ return ops
+
+ def reset(self, mode, agent=None, action_fn=None, state=None):
+ """Returns ops that reset the context.
+
+ Args:
+ mode: a string representing the mode=[train, explore, eval].
+ Returns:
+ a list of ops that reset the context.
+ """
+ if agent is None:
+ values = self.sample_contexts(mode=mode, batch_size=1)[0]
+ if values is None:
+ return []
+ values = [value[0] for value in values]
+ values[0] = uvf_utils.tf_print(
+ values[0],
+ values,
+ message='context:reset, mode=%s' % mode,
+ first_n=10,
+ name='context:reset:%s' % mode)
+ all_ops = []
+ for _, context_vars in sorted(self.context_vars.items()):
+ ops = [tf.assign(var, value) for var, value in zip(context_vars, values)]
+ all_ops += ops
+ all_ops.append(self.set_env_context_op(values))
+ all_ops.append(tf.assign(self.t, 0)) # reset timer
+ return all_ops
+ else:
+ ops = agent.tf_context.reset(mode)
+ # NOTE: The code is currently written in such a way that the higher level
+ # policy does not provide a low-level context until the second
+ # observation. Insead, we just zero-out low-level contexts.
+ for key, context_vars in sorted(self.context_vars.items()):
+ ops += [tf.assign(var, tf.zeros_like(var)) for var, meta_var in
+ zip(context_vars, agent.tf_context.context_vars[key])]
+
+ ops.append(tf.assign(self.t, 0)) # reset timer
+ return ops
+
+ def create_vars(self, name, agent=None):
+ """Create tf variables for contexts.
+
+ Args:
+ name: Name of the variables.
+ Returns:
+ A list of [num_context_dims] tensors.
+ """
+ if agent is not None:
+ meta_vars = agent.create_vars(name)
+ else:
+ meta_vars = {}
+ assert name not in self.context_vars, ('Conflict! %s is already '
+ 'initialized.') % name
+ self.context_vars[name] = tuple([
+ tf.Variable(
+ tf.zeros(shape=spec.shape, dtype=spec.dtype),
+ name='%s_context_%d' % (name, i))
+ for i, spec in enumerate(self.context_specs)
+ ])
+ return self.context_vars[name], meta_vars
+
+ @property
+ def n(self):
+ return len(self.context_specs)
+
+ @property
+ def vars(self):
+ return self.context_vars[self.VAR_NAME]
+
+ # pylint: disable=protected-access
+ @property
+ def gym_env(self):
+ return self._tf_env.pyenv._gym_env
+
+ @property
+ def tf_env(self):
+ return self._tf_env
+ # pylint: enable=protected-access
diff --git a/models/research/efficient-hrl/context/context_transition_functions.py b/models/research/efficient-hrl/context/context_transition_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..70326debde4185ec9ee5300ebf06b94e8eb1f7ad
--- /dev/null
+++ b/models/research/efficient-hrl/context/context_transition_functions.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Context functions.
+
+Given the current contexts, timer and context sampler, returns new contexts
+ after an environment step. This can be used to define a high-level policy
+ that controls contexts as its actions.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+import gin.tf
+import utils as uvf_utils
+
+
+@gin.configurable
+def periodic_context_fn(contexts, timer, sampler_fn, period=1):
+ """Periodically samples contexts.
+
+ Args:
+ contexts: a list of [num_context_dims] tensor variables representing
+ current contexts.
+ timer: a scalar integer tensor variable holding the current time step.
+ sampler_fn: a sampler function that samples a list of [num_context_dims]
+ tensors.
+ period: (integer) period of update.
+ Returns:
+ a list of [num_context_dims] tensors.
+ """
+ contexts = list(contexts[:]) # create copy
+ return tf.cond(tf.mod(timer, period) == 0, sampler_fn, lambda: contexts)
+
+
+@gin.configurable
+def timer_context_fn(contexts,
+ timer,
+ sampler_fn,
+ period=1,
+ timer_index=-1,
+ debug=False):
+ """Samples contexts based on timer in contexts.
+
+ Args:
+ contexts: a list of [num_context_dims] tensor variables representing
+ current contexts.
+ timer: a scalar integer tensor variable holding the current time step.
+ sampler_fn: a sampler function that samples a list of [num_context_dims]
+ tensors.
+ period: (integer) period of update; actual period = `period` + 1.
+ timer_index: (integer) Index of context list that present timer.
+ debug: (boolean) Print debug messages.
+ Returns:
+ a list of [num_context_dims] tensors.
+ """
+ contexts = list(contexts[:]) # create copy
+ cond = tf.equal(contexts[timer_index][0], 0)
+ def reset():
+ """Sample context and reset the timer."""
+ new_contexts = sampler_fn()
+ new_contexts[timer_index] = tf.zeros_like(
+ contexts[timer_index]) + period
+ return new_contexts
+ def update():
+ """Decrement the timer."""
+ contexts[timer_index] -= 1
+ return contexts
+ values = tf.cond(cond, reset, update)
+ if debug:
+ values[0] = uvf_utils.tf_print(
+ values[0],
+ values + [timer],
+ 'timer_context_fn',
+ first_n=200,
+ name='timer_context_fn:contexts')
+ return values
+
+
+@gin.configurable
+def relative_context_transition_fn(
+ contexts, timer, sampler_fn,
+ k=2, state=None, next_state=None,
+ **kwargs):
+ """Contexts updated to be relative to next state.
+ """
+ contexts = list(contexts[:]) # create copy
+ assert len(contexts) == 1
+ new_contexts = [
+ tf.concat(
+ [contexts[0][:k] + state[:k] - next_state[:k],
+ contexts[0][k:]], -1)]
+ return new_contexts
+
+
+@gin.configurable
+def relative_context_multi_transition_fn(
+ contexts, timer, sampler_fn,
+ k=2, states=None,
+ **kwargs):
+ """Given contexts at first state and sequence of states, derives sequence of all contexts.
+ """
+ contexts = list(contexts[:]) # create copy
+ assert len(contexts) == 1
+ contexts = [
+ tf.concat(
+ [tf.expand_dims(contexts[0][:, :k] + states[:, 0, :k], 1) - states[:, :, :k],
+ contexts[0][:, None, k:] * tf.ones_like(states[:, :, :1])], -1)]
+ return contexts
diff --git a/models/research/efficient-hrl/context/gin_imports.py b/models/research/efficient-hrl/context/gin_imports.py
new file mode 100644
index 0000000000000000000000000000000000000000..94512cef8479ac2e9c36a941f4b197b6939d0814
--- /dev/null
+++ b/models/research/efficient-hrl/context/gin_imports.py
@@ -0,0 +1,25 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Import gin configurable modules.
+"""
+
+# pylint: disable=unused-import
+from context import context
+from context import context_transition_functions
+from context import gin_utils
+from context import rewards_functions
+from context import samplers
+# pylint: disable=unused-import
diff --git a/models/research/efficient-hrl/context/gin_utils.py b/models/research/efficient-hrl/context/gin_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab7c1b2d1dd1d7317071ad6e2c08d057d42ec2e1
--- /dev/null
+++ b/models/research/efficient-hrl/context/gin_utils.py
@@ -0,0 +1,45 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Gin configurable utility functions.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import gin.tf
+
+
+@gin.configurable
+def gin_sparse_array(size, values, indices, fill_value=0):
+ arr = np.zeros(size)
+ arr.fill(fill_value)
+ arr[indices] = values
+ return arr
+
+
+@gin.configurable
+def gin_sum(values):
+ result = values[0]
+ for value in values[1:]:
+ result += value
+ return result
+
+
+@gin.configurable
+def gin_range(n):
+ return range(n)
diff --git a/models/research/efficient-hrl/context/rewards_functions.py b/models/research/efficient-hrl/context/rewards_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab560a7f4290dfa1269e001339e0e2cdb116e761
--- /dev/null
+++ b/models/research/efficient-hrl/context/rewards_functions.py
@@ -0,0 +1,741 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Reward shaping functions used by Contexts.
+
+ Each reward function should take the following inputs and return new rewards,
+ and discounts.
+
+ new_rewards, discounts = reward_fn(states, actions, rewards,
+ next_states, contexts)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+import gin.tf
+
+
+def summarize_stats(stats):
+ """Summarize a dictionary of variables.
+
+ Args:
+ stats: a dictionary of {name: tensor} to compute stats over.
+ """
+ for name, stat in stats.items():
+ mean = tf.reduce_mean(stat)
+ tf.summary.scalar('mean_%s' % name, mean)
+ tf.summary.scalar('max_%s' % name, tf.reduce_max(stat))
+ tf.summary.scalar('min_%s' % name, tf.reduce_min(stat))
+ std = tf.sqrt(tf.reduce_mean(tf.square(stat)) - tf.square(mean) + 1e-10)
+ tf.summary.scalar('std_%s' % name, std)
+ tf.summary.histogram(name, stat)
+
+
+def index_states(states, indices):
+ """Return indexed states.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ indices: (a list of Numpy integer array) Indices of states dimensions
+ to be mapped.
+ Returns:
+ A [batch_size, num_indices] Tensor representing the batch of indexed states.
+ """
+ if indices is None:
+ return states
+ indices = tf.constant(indices, dtype=tf.int32)
+ return tf.gather(states, indices=indices, axis=1)
+
+
+def record_tensor(tensor, indices, stats, name='states'):
+ """Record specified tensor dimensions into stats.
+
+ Args:
+ tensor: A [batch_size, num_dims] Tensor.
+ indices: (a list of integers) Indices of dimensions to record.
+ stats: A dictionary holding stats.
+ name: (string) Name of tensor.
+ """
+ if indices is None:
+ indices = range(tensor.shape.as_list()[1])
+ for index in indices:
+ stats['%s_%02d' % (name, index)] = tensor[:, index]
+
+
+@gin.configurable
+def potential_rewards(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ gamma=1.0,
+ reward_fn=None):
+ """Return the potential-based rewards.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ gamma: Reward discount.
+ reward_fn: A reward function.
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del actions # unused args
+ gamma = tf.to_float(gamma)
+ rewards_tp1, discounts = reward_fn(None, None, rewards, next_states, contexts)
+ rewards, _ = reward_fn(None, None, rewards, states, contexts)
+ return -rewards + gamma * rewards_tp1, discounts
+
+
+@gin.configurable
+def timed_rewards(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ reward_fn=None,
+ dense=False,
+ timer_index=-1):
+ """Return the timed rewards.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ reward_fn: A reward function.
+ dense: (boolean) Provide dense rewards or sparse rewards at time = 0.
+ timer_index: (integer) The context list index that specifies timer.
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ assert contexts[timer_index].get_shape().as_list()[1] == 1
+ timers = contexts[timer_index][:, 0]
+ rewards, discounts = reward_fn(states, actions, rewards, next_states,
+ contexts)
+ terminates = tf.to_float(timers <= 0) # if terminate set 1, else set 0
+ for _ in range(rewards.shape.ndims - 1):
+ terminates = tf.expand_dims(terminates, axis=-1)
+ if not dense:
+ rewards *= terminates # if terminate, return rewards, else return 0
+ discounts *= (tf.to_float(1.0) - terminates)
+ return rewards, discounts
+
+
+@gin.configurable
+def reset_rewards(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ reset_index=0,
+ reset_state=None,
+ reset_reward_function=None,
+ include_forward_rewards=True,
+ include_reset_rewards=True):
+ """Returns the rewards for a forward/reset agent.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ reset_index: (integer) The context list index that specifies reset.
+ reset_state: Reset state.
+ reset_reward_function: Reward function for reset step.
+ include_forward_rewards: Include the rewards from the forward pass.
+ include_reset_rewards: Include the rewards from the reset pass.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ reset_state = tf.constant(
+ reset_state, dtype=next_states.dtype, shape=next_states.shape)
+ reset_states = tf.expand_dims(reset_state, 0)
+
+ def true_fn():
+ if include_reset_rewards:
+ return reset_reward_function(states, actions, rewards, next_states,
+ [reset_states] + contexts[1:])
+ else:
+ return tf.zeros_like(rewards), tf.ones_like(rewards)
+
+ def false_fn():
+ if include_forward_rewards:
+ return plain_rewards(states, actions, rewards, next_states, contexts)
+ else:
+ return tf.zeros_like(rewards), tf.ones_like(rewards)
+
+ rewards, discounts = tf.cond(
+ tf.cast(contexts[reset_index][0, 0], dtype=tf.bool), true_fn, false_fn)
+ return rewards, discounts
+
+
+@gin.configurable
+def tanh_similarity(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ mse_scale=1.0,
+ state_scales=1.0,
+ goal_scales=1.0,
+ summarize=False):
+ """Returns the similarity between next_states and contexts using tanh and mse.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ mse_scale: A float, to scale mse before tanh.
+ state_scales: multiplicative scale for (next) states. A scalar or 1D tensor,
+ must be broadcastable to number of state dimensions.
+ goal_scales: multiplicative scale for contexts. A scalar or 1D tensor,
+ must be broadcastable to number of goal dimensions.
+ summarize: (boolean) enable summary ops.
+
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del states, actions, rewards # Unused
+ mse = tf.reduce_mean(tf.squared_difference(next_states * state_scales,
+ contexts[0] * goal_scales), -1)
+ tanh = tf.tanh(mse_scale * mse)
+ if summarize:
+ with tf.name_scope('RewardFn/'):
+ tf.summary.scalar('mean_mse', tf.reduce_mean(mse))
+ tf.summary.histogram('mse', mse)
+ tf.summary.scalar('mean_tanh', tf.reduce_mean(tanh))
+ tf.summary.histogram('tanh', tanh)
+ rewards = tf.to_float(1 - tanh)
+ return rewards, tf.ones_like(rewards)
+
+
+@gin.configurable
+def negative_mse(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ state_scales=1.0,
+ goal_scales=1.0,
+ summarize=False):
+ """Returns the negative mean square error between next_states and contexts.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ state_scales: multiplicative scale for (next) states. A scalar or 1D tensor,
+ must be broadcastable to number of state dimensions.
+ goal_scales: multiplicative scale for contexts. A scalar or 1D tensor,
+ must be broadcastable to number of goal dimensions.
+ summarize: (boolean) enable summary ops.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del states, actions, rewards # Unused
+ mse = tf.reduce_mean(tf.squared_difference(next_states * state_scales,
+ contexts[0] * goal_scales), -1)
+ if summarize:
+ with tf.name_scope('RewardFn/'):
+ tf.summary.scalar('mean_mse', tf.reduce_mean(mse))
+ tf.summary.histogram('mse', mse)
+ rewards = tf.to_float(-mse)
+ return rewards, tf.ones_like(rewards)
+
+
+@gin.configurable
+def negative_distance(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ state_scales=1.0,
+ goal_scales=1.0,
+ reward_scales=1.0,
+ weight_index=None,
+ weight_vector=None,
+ summarize=False,
+ termination_epsilon=1e-4,
+ state_indices=None,
+ goal_indices=None,
+ vectorize=False,
+ relative_context=False,
+ diff=False,
+ norm='L2',
+ epsilon=1e-10,
+ bonus_epsilon=0., #5.,
+ offset=0.0):
+ """Returns the negative euclidean distance between next_states and contexts.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ state_scales: multiplicative scale for (next) states. A scalar or 1D tensor,
+ must be broadcastable to number of state dimensions.
+ goal_scales: multiplicative scale for goals. A scalar or 1D tensor,
+ must be broadcastable to number of goal dimensions.
+ reward_scales: multiplicative scale for rewards. A scalar or 1D tensor,
+ must be broadcastable to number of reward dimensions.
+ weight_index: (integer) The context list index that specifies weight.
+ weight_vector: (a number or a list or Numpy array) The weighting vector,
+ broadcastable to `next_states`.
+ summarize: (boolean) enable summary ops.
+ termination_epsilon: terminate if dist is less than this quantity.
+ state_indices: (a list of integers) list of state indices to select.
+ goal_indices: (a list of integers) list of goal indices to select.
+ vectorize: Return a vectorized form.
+ norm: L1 or L2.
+ epsilon: small offset to ensure non-negative/zero distance.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del actions, rewards # Unused
+ stats = {}
+ record_tensor(next_states, state_indices, stats, 'next_states')
+ states = index_states(states, state_indices)
+ next_states = index_states(next_states, state_indices)
+ goals = index_states(contexts[0], goal_indices)
+ if relative_context:
+ goals = states + goals
+ sq_dists = tf.squared_difference(next_states * state_scales,
+ goals * goal_scales)
+ old_sq_dists = tf.squared_difference(states * state_scales,
+ goals * goal_scales)
+ record_tensor(sq_dists, None, stats, 'sq_dists')
+ if weight_vector is not None:
+ sq_dists *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype)
+ old_sq_dists *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype)
+ if weight_index is not None:
+ #sq_dists *= contexts[weight_index]
+ weights = tf.abs(index_states(contexts[0], weight_index))
+ #weights /= tf.reduce_sum(weights, -1, keepdims=True)
+ sq_dists *= weights
+ old_sq_dists *= weights
+ if norm == 'L1':
+ dist = tf.sqrt(sq_dists + epsilon)
+ old_dist = tf.sqrt(old_sq_dists + epsilon)
+ if not vectorize:
+ dist = tf.reduce_sum(dist, -1)
+ old_dist = tf.reduce_sum(old_dist, -1)
+ elif norm == 'L2':
+ if vectorize:
+ dist = sq_dists
+ old_dist = old_sq_dists
+ else:
+ dist = tf.reduce_sum(sq_dists, -1)
+ old_dist = tf.reduce_sum(old_sq_dists, -1)
+ dist = tf.sqrt(dist + epsilon) # tf.gradients fails when tf.sqrt(-0.0)
+ old_dist = tf.sqrt(old_dist + epsilon) # tf.gradients fails when tf.sqrt(-0.0)
+ else:
+ raise NotImplementedError(norm)
+ discounts = dist > termination_epsilon
+ if summarize:
+ with tf.name_scope('RewardFn/'):
+ tf.summary.scalar('mean_dist', tf.reduce_mean(dist))
+ tf.summary.histogram('dist', dist)
+ summarize_stats(stats)
+ bonus = tf.to_float(dist < bonus_epsilon)
+ dist *= reward_scales
+ old_dist *= reward_scales
+ if diff:
+ return bonus + offset + tf.to_float(old_dist - dist), tf.to_float(discounts)
+ return bonus + offset + tf.to_float(-dist), tf.to_float(discounts)
+
+
+@gin.configurable
+def cosine_similarity(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ state_scales=1.0,
+ goal_scales=1.0,
+ reward_scales=1.0,
+ normalize_states=True,
+ normalize_goals=True,
+ weight_index=None,
+ weight_vector=None,
+ summarize=False,
+ state_indices=None,
+ goal_indices=None,
+ offset=0.0):
+ """Returns the cosine similarity between next_states - states and contexts.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ state_scales: multiplicative scale for (next) states. A scalar or 1D tensor,
+ must be broadcastable to number of state dimensions.
+ goal_scales: multiplicative scale for goals. A scalar or 1D tensor,
+ must be broadcastable to number of goal dimensions.
+ reward_scales: multiplicative scale for rewards. A scalar or 1D tensor,
+ must be broadcastable to number of reward dimensions.
+ weight_index: (integer) The context list index that specifies weight.
+ weight_vector: (a number or a list or Numpy array) The weighting vector,
+ broadcastable to `next_states`.
+ summarize: (boolean) enable summary ops.
+ termination_epsilon: terminate if dist is less than this quantity.
+ state_indices: (a list of integers) list of state indices to select.
+ goal_indices: (a list of integers) list of goal indices to select.
+ vectorize: Return a vectorized form.
+ norm: L1 or L2.
+ epsilon: small offset to ensure non-negative/zero distance.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del actions, rewards # Unused
+ stats = {}
+ record_tensor(next_states, state_indices, stats, 'next_states')
+ states = index_states(states, state_indices)
+ next_states = index_states(next_states, state_indices)
+ goals = index_states(contexts[0], goal_indices)
+
+ if weight_vector is not None:
+ goals *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype)
+ if weight_index is not None:
+ weights = tf.abs(index_states(contexts[0], weight_index))
+ goals *= weights
+
+ direction_vec = next_states - states
+ if normalize_states:
+ direction_vec = tf.nn.l2_normalize(direction_vec, -1)
+ goal_vec = goals
+ if normalize_goals:
+ goal_vec = tf.nn.l2_normalize(goal_vec, -1)
+
+ similarity = tf.reduce_sum(goal_vec * direction_vec, -1)
+ discounts = tf.ones_like(similarity)
+ return offset + tf.to_float(similarity), tf.to_float(discounts)
+
+
+@gin.configurable
+def diff_distance(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ state_scales=1.0,
+ goal_scales=1.0,
+ reward_scales=1.0,
+ weight_index=None,
+ weight_vector=None,
+ summarize=False,
+ termination_epsilon=1e-4,
+ state_indices=None,
+ goal_indices=None,
+ norm='L2',
+ epsilon=1e-10):
+ """Returns the difference in euclidean distance between states/next_states and contexts.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ state_scales: multiplicative scale for (next) states. A scalar or 1D tensor,
+ must be broadcastable to number of state dimensions.
+ goal_scales: multiplicative scale for goals. A scalar or 1D tensor,
+ must be broadcastable to number of goal dimensions.
+ reward_scales: multiplicative scale for rewards. A scalar or 1D tensor,
+ must be broadcastable to number of reward dimensions.
+ weight_index: (integer) The context list index that specifies weight.
+ weight_vector: (a number or a list or Numpy array) The weighting vector,
+ broadcastable to `next_states`.
+ summarize: (boolean) enable summary ops.
+ termination_epsilon: terminate if dist is less than this quantity.
+ state_indices: (a list of integers) list of state indices to select.
+ goal_indices: (a list of integers) list of goal indices to select.
+ vectorize: Return a vectorized form.
+ norm: L1 or L2.
+ epsilon: small offset to ensure non-negative/zero distance.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del actions, rewards # Unused
+ stats = {}
+ record_tensor(next_states, state_indices, stats, 'next_states')
+ next_states = index_states(next_states, state_indices)
+ states = index_states(states, state_indices)
+ goals = index_states(contexts[0], goal_indices)
+ next_sq_dists = tf.squared_difference(next_states * state_scales,
+ goals * goal_scales)
+ sq_dists = tf.squared_difference(states * state_scales,
+ goals * goal_scales)
+ record_tensor(sq_dists, None, stats, 'sq_dists')
+ if weight_vector is not None:
+ next_sq_dists *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype)
+ sq_dists *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype)
+ if weight_index is not None:
+ next_sq_dists *= contexts[weight_index]
+ sq_dists *= contexts[weight_index]
+ if norm == 'L1':
+ next_dist = tf.sqrt(next_sq_dists + epsilon)
+ dist = tf.sqrt(sq_dists + epsilon)
+ next_dist = tf.reduce_sum(next_dist, -1)
+ dist = tf.reduce_sum(dist, -1)
+ elif norm == 'L2':
+ next_dist = tf.reduce_sum(next_sq_dists, -1)
+ next_dist = tf.sqrt(next_dist + epsilon) # tf.gradients fails when tf.sqrt(-0.0)
+ dist = tf.reduce_sum(sq_dists, -1)
+ dist = tf.sqrt(dist + epsilon) # tf.gradients fails when tf.sqrt(-0.0)
+ else:
+ raise NotImplementedError(norm)
+ discounts = next_dist > termination_epsilon
+ if summarize:
+ with tf.name_scope('RewardFn/'):
+ tf.summary.scalar('mean_dist', tf.reduce_mean(dist))
+ tf.summary.histogram('dist', dist)
+ summarize_stats(stats)
+ diff = dist - next_dist
+ diff *= reward_scales
+ return tf.to_float(diff), tf.to_float(discounts)
+
+
+@gin.configurable
+def binary_indicator(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ termination_epsilon=1e-4,
+ offset=0,
+ epsilon=1e-10,
+ state_indices=None,
+ summarize=False):
+ """Returns 0/1 by checking if next_states and contexts overlap.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ termination_epsilon: terminate if dist is less than this quantity.
+ offset: Offset the rewards.
+ epsilon: small offset to ensure non-negative/zero distance.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del states, actions # unused args
+ next_states = index_states(next_states, state_indices)
+ dist = tf.reduce_sum(tf.squared_difference(next_states, contexts[0]), -1)
+ dist = tf.sqrt(dist + epsilon)
+ discounts = dist > termination_epsilon
+ rewards = tf.logical_not(discounts)
+ rewards = tf.to_float(rewards) + offset
+ return tf.to_float(rewards), tf.ones_like(tf.to_float(discounts)) #tf.to_float(discounts)
+
+
+@gin.configurable
+def plain_rewards(states, actions, rewards, next_states, contexts):
+ """Returns the given rewards.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del states, actions, next_states, contexts # Unused
+ return rewards, tf.ones_like(rewards)
+
+
+@gin.configurable
+def ctrl_rewards(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ reward_scales=1.0):
+ """Returns the negative control cost.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ reward_scales: multiplicative scale for rewards. A scalar or 1D tensor,
+ must be broadcastable to number of reward dimensions.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del states, rewards, contexts # Unused
+ if actions is None:
+ rewards = tf.to_float(tf.zeros(shape=next_states.shape[:1]))
+ else:
+ rewards = -tf.reduce_sum(tf.square(actions), axis=1)
+ rewards *= reward_scales
+ rewards = tf.to_float(rewards)
+ return rewards, tf.ones_like(rewards)
+
+
+@gin.configurable
+def diff_rewards(
+ states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ state_indices=None,
+ goal_index=0,):
+ """Returns (next_states - goals) as a batched vector reward."""
+ del states, rewards, actions # Unused
+ if state_indices is not None:
+ next_states = index_states(next_states, state_indices)
+ rewards = tf.to_float(next_states - contexts[goal_index])
+ return rewards, tf.ones_like(rewards)
+
+
+@gin.configurable
+def state_rewards(states,
+ actions,
+ rewards,
+ next_states,
+ contexts,
+ weight_index=None,
+ state_indices=None,
+ weight_vector=1.0,
+ offset_vector=0.0,
+ summarize=False):
+ """Returns the rewards that are linear mapping of next_states.
+
+ Args:
+ states: A [batch_size, num_state_dims] Tensor representing a batch
+ of states.
+ actions: A [batch_size, num_action_dims] Tensor representing a batch
+ of actions.
+ rewards: A [batch_size] Tensor representing a batch of rewards.
+ next_states: A [batch_size, num_state_dims] Tensor representing a batch
+ of next states.
+ contexts: A list of [batch_size, num_context_dims] Tensor representing
+ a batch of contexts.
+ weight_index: (integer) Index of contexts lists that specify weighting.
+ state_indices: (a list of Numpy integer array) Indices of states dimensions
+ to be mapped.
+ weight_vector: (a number or a list or Numpy array) The weighting vector,
+ broadcastable to `next_states`.
+ offset_vector: (a number or a list of Numpy array) The off vector.
+ summarize: (boolean) enable summary ops.
+
+ Returns:
+ A new tf.float32 [batch_size] rewards Tensor, and
+ tf.float32 [batch_size] discounts tensor.
+ """
+ del states, actions, rewards # unused args
+ stats = {}
+ record_tensor(next_states, state_indices, stats)
+ next_states = index_states(next_states, state_indices)
+ weight = tf.constant(
+ weight_vector, dtype=next_states.dtype, shape=next_states[0].shape)
+ weights = tf.expand_dims(weight, 0)
+ offset = tf.constant(
+ offset_vector, dtype=next_states.dtype, shape=next_states[0].shape)
+ offsets = tf.expand_dims(offset, 0)
+ if weight_index is not None:
+ weights *= contexts[weight_index]
+ rewards = tf.to_float(tf.reduce_sum(weights * (next_states+offsets), axis=1))
+ if summarize:
+ with tf.name_scope('RewardFn/'):
+ summarize_stats(stats)
+ return rewards, tf.ones_like(rewards)
diff --git a/models/research/efficient-hrl/context/samplers.py b/models/research/efficient-hrl/context/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..15a22df5eb3bcbd419b5a01bc299b5d5ac71ad91
--- /dev/null
+++ b/models/research/efficient-hrl/context/samplers.py
@@ -0,0 +1,445 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Samplers for Contexts.
+
+ Each sampler class should define __call__(batch_size).
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+slim = tf.contrib.slim
+import gin.tf
+
+
+@gin.configurable
+class BaseSampler(object):
+ """Base sampler."""
+
+ def __init__(self, context_spec, context_range=None, k=2, scope='sampler'):
+ """Construct a base sampler.
+
+ Args:
+ context_spec: A context spec.
+ context_range: A tuple of (minval, max), where minval, maxval are floats
+ or Numpy arrays with the same shape as the context.
+ scope: A string denoting scope.
+ """
+ self._context_spec = context_spec
+ self._context_range = context_range
+ self._k = k
+ self._scope = scope
+
+ def __call__(self, batch_size, **kwargs):
+ raise NotImplementedError
+
+ def set_replay(self, replay=None):
+ pass
+
+ def _validate_contexts(self, contexts):
+ """Validate if contexts have right spec.
+
+ Args:
+ contexts: A [batch_size, num_contexts_dim] tensor.
+ Raises:
+ ValueError: If shape or dtype mismatches that of spec.
+ """
+ if contexts[0].shape != self._context_spec.shape:
+ raise ValueError('contexts has invalid shape %s wrt spec shape %s' %
+ (contexts[0].shape, self._context_spec.shape))
+ if contexts.dtype != self._context_spec.dtype:
+ raise ValueError('contexts has invalid dtype %s wrt spec dtype %s' %
+ (contexts.dtype, self._context_spec.dtype))
+
+
+@gin.configurable
+class ZeroSampler(BaseSampler):
+ """Zero sampler."""
+
+ def __call__(self, batch_size, **kwargs):
+ """Sample a batch of context.
+
+ Args:
+ batch_size: Batch size.
+ Returns:
+ Two [batch_size, num_context_dims] tensors.
+ """
+ contexts = tf.zeros(
+ dtype=self._context_spec.dtype,
+ shape=[
+ batch_size,
+ ] + self._context_spec.shape.as_list())
+ return contexts, contexts
+
+
+@gin.configurable
+class BinarySampler(BaseSampler):
+ """Binary sampler."""
+
+ def __init__(self, probs=0.5, *args, **kwargs):
+ """Constructor."""
+ super(BinarySampler, self).__init__(*args, **kwargs)
+ self._probs = probs
+
+ def __call__(self, batch_size, **kwargs):
+ """Sample a batch of context."""
+ spec = self._context_spec
+ contexts = tf.random_uniform(
+ shape=[
+ batch_size,
+ ] + spec.shape.as_list(), dtype=tf.float32)
+ contexts = tf.cast(tf.greater(contexts, self._probs), dtype=spec.dtype)
+ return contexts, contexts
+
+
+@gin.configurable
+class RandomSampler(BaseSampler):
+ """Random sampler."""
+
+ def __call__(self, batch_size, **kwargs):
+ """Sample a batch of context.
+
+ Args:
+ batch_size: Batch size.
+ Returns:
+ Two [batch_size, num_context_dims] tensors.
+ """
+ spec = self._context_spec
+ context_range = self._context_range
+ if isinstance(context_range[0], (int, float)):
+ contexts = tf.random_uniform(
+ shape=[
+ batch_size,
+ ] + spec.shape.as_list(),
+ minval=context_range[0],
+ maxval=context_range[1],
+ dtype=spec.dtype)
+ elif isinstance(context_range[0], (list, tuple, np.ndarray)):
+ assert len(spec.shape.as_list()) == 1
+ assert spec.shape.as_list()[0] == len(context_range[0])
+ assert spec.shape.as_list()[0] == len(context_range[1])
+ contexts = tf.concat(
+ [
+ tf.random_uniform(
+ shape=[
+ batch_size, 1,
+ ] + spec.shape.as_list()[1:],
+ minval=context_range[0][i],
+ maxval=context_range[1][i],
+ dtype=spec.dtype) for i in range(spec.shape.as_list()[0])
+ ],
+ axis=1)
+ else: raise NotImplementedError(context_range)
+ self._validate_contexts(contexts)
+ state, next_state = kwargs['state'], kwargs['next_state']
+ if state is not None and next_state is not None:
+ pass
+ #contexts = tf.concat(
+ # [tf.random_normal(tf.shape(state[:, :self._k]), dtype=tf.float64) +
+ # tf.random_shuffle(state[:, :self._k]),
+ # contexts[:, self._k:]], 1)
+
+ return contexts, contexts
+
+
+@gin.configurable
+class ScheduledSampler(BaseSampler):
+ """Scheduled sampler."""
+
+ def __init__(self,
+ scope='default',
+ values=None,
+ scheduler='cycle',
+ scheduler_params=None,
+ *args, **kwargs):
+ """Construct sampler.
+
+ Args:
+ scope: Scope name.
+ values: A list of numbers or [num_context_dim] Numpy arrays
+ representing the values to cycle.
+ scheduler: scheduler type.
+ scheduler_params: scheduler parameters.
+ *args: arguments.
+ **kwargs: keyword arguments.
+ """
+ super(ScheduledSampler, self).__init__(*args, **kwargs)
+ self._scope = scope
+ self._values = values
+ self._scheduler = scheduler
+ self._scheduler_params = scheduler_params or {}
+ assert self._values is not None and len(
+ self._values), 'must provide non-empty values.'
+ self._n = len(self._values)
+ # TODO(shanegu): move variable creation outside. resolve tf.cond problem.
+ self._count = 0
+ self._i = tf.Variable(
+ tf.zeros(shape=(), dtype=tf.int32),
+ name='%s-scheduled_sampler_%d' % (self._scope, self._count))
+ self._values = tf.constant(self._values, dtype=self._context_spec.dtype)
+
+ def __call__(self, batch_size, **kwargs):
+ """Sample a batch of context.
+
+ Args:
+ batch_size: Batch size.
+ Returns:
+ Two [batch_size, num_context_dims] tensors.
+ """
+ spec = self._context_spec
+ next_op = self._next(self._i)
+ with tf.control_dependencies([next_op]):
+ value = self._values[self._i]
+ if value.get_shape().as_list():
+ values = tf.tile(
+ tf.expand_dims(value, 0), (batch_size,) + (1,) * spec.shape.ndims)
+ else:
+ values = value + tf.zeros(
+ shape=[
+ batch_size,
+ ] + spec.shape.as_list(), dtype=spec.dtype)
+ self._validate_contexts(values)
+ self._count += 1
+ return values, values
+
+ def _next(self, i):
+ """Return op that increments pointer to next value.
+
+ Args:
+ i: A tensorflow integer variable.
+ Returns:
+ Op that increments pointer.
+ """
+ if self._scheduler == 'cycle':
+ inc = ('inc' in self._scheduler_params and
+ self._scheduler_params['inc']) or 1
+ return tf.assign(i, tf.mod(i+inc, self._n))
+ else:
+ raise NotImplementedError(self._scheduler)
+
+
+@gin.configurable
+class ReplaySampler(BaseSampler):
+ """Replay sampler."""
+
+ def __init__(self,
+ prefetch_queue_capacity=2,
+ override_indices=None,
+ state_indices=None,
+ *args,
+ **kwargs):
+ """Construct sampler.
+
+ Args:
+ prefetch_queue_capacity: Capacity for prefetch queue.
+ override_indices: Override indices.
+ state_indices: Select certain indices from state dimension.
+ *args: arguments.
+ **kwargs: keyword arguments.
+ """
+ super(ReplaySampler, self).__init__(*args, **kwargs)
+ self._prefetch_queue_capacity = prefetch_queue_capacity
+ self._override_indices = override_indices
+ self._state_indices = state_indices
+
+ def set_replay(self, replay):
+ """Set replay.
+
+ Args:
+ replay: A replay buffer.
+ """
+ self._replay = replay
+
+ def __call__(self, batch_size, **kwargs):
+ """Sample a batch of context.
+
+ Args:
+ batch_size: Batch size.
+ Returns:
+ Two [batch_size, num_context_dims] tensors.
+ """
+ batch = self._replay.GetRandomBatch(batch_size)
+ next_states = batch[4]
+ if self._prefetch_queue_capacity > 0:
+ batch_queue = slim.prefetch_queue.prefetch_queue(
+ [next_states],
+ capacity=self._prefetch_queue_capacity,
+ name='%s/batch_context_queue' % self._scope)
+ next_states = batch_queue.dequeue()
+ if self._override_indices is not None:
+ assert self._context_range is not None and isinstance(
+ self._context_range[0], (int, long, float))
+ next_states = tf.concat(
+ [
+ tf.random_uniform(
+ shape=next_states[:, :1].shape,
+ minval=self._context_range[0],
+ maxval=self._context_range[1],
+ dtype=next_states.dtype)
+ if i in self._override_indices else next_states[:, i:i + 1]
+ for i in range(self._context_spec.shape.as_list()[0])
+ ],
+ axis=1)
+ if self._state_indices is not None:
+ next_states = tf.concat(
+ [
+ next_states[:, i:i + 1]
+ for i in range(self._context_spec.shape.as_list()[0])
+ ],
+ axis=1)
+ self._validate_contexts(next_states)
+ return next_states, next_states
+
+
+@gin.configurable
+class TimeSampler(BaseSampler):
+ """Time Sampler."""
+
+ def __init__(self, minval=0, maxval=1, timestep=-1, *args, **kwargs):
+ """Construct sampler.
+
+ Args:
+ minval: Min value integer.
+ maxval: Max value integer.
+ timestep: Time step between states and next_states.
+ *args: arguments.
+ **kwargs: keyword arguments.
+ """
+ super(TimeSampler, self).__init__(*args, **kwargs)
+ assert self._context_spec.shape.as_list() == [1]
+ self._minval = minval
+ self._maxval = maxval
+ self._timestep = timestep
+
+ def __call__(self, batch_size, **kwargs):
+ """Sample a batch of context.
+
+ Args:
+ batch_size: Batch size.
+ Returns:
+ Two [batch_size, num_context_dims] tensors.
+ """
+ if self._maxval == self._minval:
+ contexts = tf.constant(
+ self._maxval, shape=[batch_size, 1], dtype=tf.int32)
+ else:
+ contexts = tf.random_uniform(
+ shape=[batch_size, 1],
+ dtype=tf.int32,
+ maxval=self._maxval,
+ minval=self._minval)
+ next_contexts = tf.maximum(contexts + self._timestep, 0)
+
+ return tf.cast(
+ contexts, dtype=self._context_spec.dtype), tf.cast(
+ next_contexts, dtype=self._context_spec.dtype)
+
+
+@gin.configurable
+class ConstantSampler(BaseSampler):
+ """Constant sampler."""
+
+ def __init__(self, value=None, *args, **kwargs):
+ """Construct sampler.
+
+ Args:
+ value: A list or Numpy array for values of the constant.
+ *args: arguments.
+ **kwargs: keyword arguments.
+ """
+ super(ConstantSampler, self).__init__(*args, **kwargs)
+ self._value = value
+
+ def __call__(self, batch_size, **kwargs):
+ """Sample a batch of context.
+
+ Args:
+ batch_size: Batch size.
+ Returns:
+ Two [batch_size, num_context_dims] tensors.
+ """
+ spec = self._context_spec
+ value_ = tf.constant(self._value, shape=spec.shape, dtype=spec.dtype)
+ values = tf.tile(
+ tf.expand_dims(value_, 0), (batch_size,) + (1,) * spec.shape.ndims)
+ self._validate_contexts(values)
+ return values, values
+
+
+@gin.configurable
+class DirectionSampler(RandomSampler):
+ """Direction sampler."""
+
+ def __call__(self, batch_size, **kwargs):
+ """Sample a batch of context.
+
+ Args:
+ batch_size: Batch size.
+ Returns:
+ Two [batch_size, num_context_dims] tensors.
+ """
+ spec = self._context_spec
+ context_range = self._context_range
+ if isinstance(context_range[0], (int, float)):
+ contexts = tf.random_uniform(
+ shape=[
+ batch_size,
+ ] + spec.shape.as_list(),
+ minval=context_range[0],
+ maxval=context_range[1],
+ dtype=spec.dtype)
+ elif isinstance(context_range[0], (list, tuple, np.ndarray)):
+ assert len(spec.shape.as_list()) == 1
+ assert spec.shape.as_list()[0] == len(context_range[0])
+ assert spec.shape.as_list()[0] == len(context_range[1])
+ contexts = tf.concat(
+ [
+ tf.random_uniform(
+ shape=[
+ batch_size, 1,
+ ] + spec.shape.as_list()[1:],
+ minval=context_range[0][i],
+ maxval=context_range[1][i],
+ dtype=spec.dtype) for i in range(spec.shape.as_list()[0])
+ ],
+ axis=1)
+ else: raise NotImplementedError(context_range)
+ self._validate_contexts(contexts)
+ if 'sampler_fn' in kwargs:
+ other_contexts = kwargs['sampler_fn']()
+ else:
+ other_contexts = contexts
+ state, next_state = kwargs['state'], kwargs['next_state']
+ if state is not None and next_state is not None:
+ my_context_range = (np.array(context_range[1]) - np.array(context_range[0])) / 2 * np.ones(spec.shape.as_list())
+ contexts = tf.concat(
+ [0.1 * my_context_range[:self._k] *
+ tf.random_normal(tf.shape(state[:, :self._k]), dtype=state.dtype) +
+ tf.random_shuffle(state[:, :self._k]) - state[:, :self._k],
+ other_contexts[:, self._k:]], 1)
+ #contexts = tf.Print(contexts,
+ # [contexts, tf.reduce_max(contexts, 0),
+ # tf.reduce_min(state, 0), tf.reduce_max(state, 0)], 'contexts', summarize=15)
+ next_contexts = tf.concat( #LALA
+ [state[:, :self._k] + contexts[:, :self._k] - next_state[:, :self._k],
+ other_contexts[:, self._k:]], 1)
+ next_contexts = contexts #LALA cosine
+ else:
+ next_contexts = contexts
+ return tf.stop_gradient(contexts), tf.stop_gradient(next_contexts)
diff --git a/models/research/efficient-hrl/environments/__init__.py b/models/research/efficient-hrl/environments/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/research/efficient-hrl/environments/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/research/efficient-hrl/environments/ant.py b/models/research/efficient-hrl/environments/ant.py
new file mode 100644
index 0000000000000000000000000000000000000000..feab1eef4c5fac51a2e0f683de00731b893751c4
--- /dev/null
+++ b/models/research/efficient-hrl/environments/ant.py
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Wrapper for creating the ant environment in gym_mujoco."""
+
+import math
+import numpy as np
+import mujoco_py
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+
+
+def q_inv(a):
+ return [a[0], -a[1], -a[2], -a[3]]
+
+
+def q_mult(a, b): # multiply two quaternion
+ w = a[0] * b[0] - a[1] * b[1] - a[2] * b[2] - a[3] * b[3]
+ i = a[0] * b[1] + a[1] * b[0] + a[2] * b[3] - a[3] * b[2]
+ j = a[0] * b[2] - a[1] * b[3] + a[2] * b[0] + a[3] * b[1]
+ k = a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + a[3] * b[0]
+ return [w, i, j, k]
+
+
+class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+ FILE = "ant.xml"
+ ORI_IND = 3
+
+ def __init__(self, file_path=None, expose_all_qpos=True,
+ expose_body_coms=None, expose_body_comvels=None):
+ self._expose_all_qpos = expose_all_qpos
+ self._expose_body_coms = expose_body_coms
+ self._expose_body_comvels = expose_body_comvels
+ self._body_com_indices = {}
+ self._body_comvel_indices = {}
+
+ mujoco_env.MujocoEnv.__init__(self, file_path, 5)
+ utils.EzPickle.__init__(self)
+
+ @property
+ def physics(self):
+ # check mujoco version is greater than version 1.50 to call correct physics
+ # model containing PyMjData object for getting and setting position/velocity
+ # check https://github.com/openai/mujoco-py/issues/80 for updates to api
+ if mujoco_py.get_version() >= '1.50':
+ return self.sim
+ else:
+ return self.model
+
+ def _step(self, a):
+ return self.step(a)
+
+ def step(self, a):
+ xposbefore = self.get_body_com("torso")[0]
+ self.do_simulation(a, self.frame_skip)
+ xposafter = self.get_body_com("torso")[0]
+ forward_reward = (xposafter - xposbefore) / self.dt
+ ctrl_cost = .5 * np.square(a).sum()
+ survive_reward = 1.0
+ reward = forward_reward - ctrl_cost + survive_reward
+ state = self.state_vector()
+ done = False
+ ob = self._get_obs()
+ return ob, reward, done, dict(
+ reward_forward=forward_reward,
+ reward_ctrl=-ctrl_cost,
+ reward_survive=survive_reward)
+
+ def _get_obs(self):
+ # No cfrc observation
+ if self._expose_all_qpos:
+ obs = np.concatenate([
+ self.physics.data.qpos.flat[:15], # Ensures only ant obs.
+ self.physics.data.qvel.flat[:14],
+ ])
+ else:
+ obs = np.concatenate([
+ self.physics.data.qpos.flat[2:15],
+ self.physics.data.qvel.flat[:14],
+ ])
+
+ if self._expose_body_coms is not None:
+ for name in self._expose_body_coms:
+ com = self.get_body_com(name)
+ if name not in self._body_com_indices:
+ indices = range(len(obs), len(obs) + len(com))
+ self._body_com_indices[name] = indices
+ obs = np.concatenate([obs, com])
+
+ if self._expose_body_comvels is not None:
+ for name in self._expose_body_comvels:
+ comvel = self.get_body_comvel(name)
+ if name not in self._body_comvel_indices:
+ indices = range(len(obs), len(obs) + len(comvel))
+ self._body_comvel_indices[name] = indices
+ obs = np.concatenate([obs, comvel])
+ return obs
+
+ def reset_model(self):
+ qpos = self.init_qpos + self.np_random.uniform(
+ size=self.model.nq, low=-.1, high=.1)
+ qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
+
+ # Set everything other than ant to original position and 0 velocity.
+ qpos[15:] = self.init_qpos[15:]
+ qvel[14:] = 0.
+ self.set_state(qpos, qvel)
+ return self._get_obs()
+
+ def viewer_setup(self):
+ self.viewer.cam.distance = self.model.stat.extent * 0.5
+
+ def get_ori(self):
+ ori = [0, 1, 0, 0]
+ rot = self.physics.data.qpos[self.__class__.ORI_IND:self.__class__.ORI_IND + 4] # take the quaternion
+ ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
+ ori = math.atan2(ori[1], ori[0])
+ return ori
+
+ def set_xy(self, xy):
+ qpos = np.copy(self.physics.data.qpos)
+ qpos[0] = xy[0]
+ qpos[1] = xy[1]
+
+ qvel = self.physics.data.qvel
+ self.set_state(qpos, qvel)
+
+ def get_xy(self):
+ return self.physics.data.qpos[:2]
diff --git a/models/research/efficient-hrl/environments/ant_maze_env.py b/models/research/efficient-hrl/environments/ant_maze_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..69a10663f4d02901d295f2781eca2dd3e601e292
--- /dev/null
+++ b/models/research/efficient-hrl/environments/ant_maze_env.py
@@ -0,0 +1,21 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from environments.maze_env import MazeEnv
+from environments.ant import AntEnv
+
+
+class AntMazeEnv(MazeEnv):
+ MODEL_CLASS = AntEnv
diff --git a/models/research/efficient-hrl/environments/assets/ant.xml b/models/research/efficient-hrl/environments/assets/ant.xml
new file mode 100644
index 0000000000000000000000000000000000000000..5a49d7f52a0e577d64c47205ae32c00a9d23a2d9
--- /dev/null
+++ b/models/research/efficient-hrl/environments/assets/ant.xml
@@ -0,0 +1,81 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/models/research/efficient-hrl/environments/create_maze_env.py b/models/research/efficient-hrl/environments/create_maze_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6dc4f42190b137364700d3dcd970c3ad8b1b9ad
--- /dev/null
+++ b/models/research/efficient-hrl/environments/create_maze_env.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from environments.ant_maze_env import AntMazeEnv
+from environments.point_maze_env import PointMazeEnv
+
+import tensorflow as tf
+import gin.tf
+from tf_agents.environments import gym_wrapper
+from tf_agents.environments import tf_py_environment
+
+
+@gin.configurable
+def create_maze_env(env_name=None, top_down_view=False):
+ n_bins = 0
+ manual_collision = False
+ if env_name.startswith('Ego'):
+ n_bins = 8
+ env_name = env_name[3:]
+ if env_name.startswith('Ant'):
+ cls = AntMazeEnv
+ env_name = env_name[3:]
+ maze_size_scaling = 8
+ elif env_name.startswith('Point'):
+ cls = PointMazeEnv
+ manual_collision = True
+ env_name = env_name[5:]
+ maze_size_scaling = 4
+ else:
+ assert False, 'unknown env %s' % env_name
+
+ maze_id = None
+ observe_blocks = False
+ put_spin_near_agent = False
+ if env_name == 'Maze':
+ maze_id = 'Maze'
+ elif env_name == 'Push':
+ maze_id = 'Push'
+ elif env_name == 'Fall':
+ maze_id = 'Fall'
+ elif env_name == 'Block':
+ maze_id = 'Block'
+ put_spin_near_agent = True
+ observe_blocks = True
+ elif env_name == 'BlockMaze':
+ maze_id = 'BlockMaze'
+ put_spin_near_agent = True
+ observe_blocks = True
+ else:
+ raise ValueError('Unknown maze environment %s' % env_name)
+
+ gym_mujoco_kwargs = {
+ 'maze_id': maze_id,
+ 'n_bins': n_bins,
+ 'observe_blocks': observe_blocks,
+ 'put_spin_near_agent': put_spin_near_agent,
+ 'top_down_view': top_down_view,
+ 'manual_collision': manual_collision,
+ 'maze_size_scaling': maze_size_scaling
+ }
+ gym_env = cls(**gym_mujoco_kwargs)
+ gym_env.reset()
+ wrapped_env = gym_wrapper.GymWrapper(gym_env)
+ return wrapped_env
+
+
+class TFPyEnvironment(tf_py_environment.TFPyEnvironment):
+
+ def __init__(self, *args, **kwargs):
+ super(TFPyEnvironment, self).__init__(*args, **kwargs)
+
+ def start_collect(self):
+ pass
+
+ def current_obs(self):
+ time_step = self.current_time_step()
+ return time_step.observation[0] # For some reason, there is an extra dim.
+
+ def step(self, actions):
+ actions = tf.expand_dims(actions, 0)
+ next_step = super(TFPyEnvironment, self).step(actions)
+ return next_step.is_last()[0], next_step.reward[0], next_step.discount[0]
+
+ def reset(self):
+ return super(TFPyEnvironment, self).reset()
diff --git a/models/research/efficient-hrl/environments/maze_env.py b/models/research/efficient-hrl/environments/maze_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf7d1f2dc0a0d5883a7953c623b3419a02282206
--- /dev/null
+++ b/models/research/efficient-hrl/environments/maze_env.py
@@ -0,0 +1,499 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adapted from rllab maze_env.py."""
+
+import os
+import tempfile
+import xml.etree.ElementTree as ET
+import math
+import numpy as np
+import gym
+
+from environments import maze_env_utils
+
+# Directory that contains mujoco xml files.
+MODEL_DIR = 'environments/assets'
+
+
+class MazeEnv(gym.Env):
+ MODEL_CLASS = None
+
+ MAZE_HEIGHT = None
+ MAZE_SIZE_SCALING = None
+
+ def __init__(
+ self,
+ maze_id=None,
+ maze_height=0.5,
+ maze_size_scaling=8,
+ n_bins=0,
+ sensor_range=3.,
+ sensor_span=2 * math.pi,
+ observe_blocks=False,
+ put_spin_near_agent=False,
+ top_down_view=False,
+ manual_collision=False,
+ *args,
+ **kwargs):
+ self._maze_id = maze_id
+
+ model_cls = self.__class__.MODEL_CLASS
+ if model_cls is None:
+ raise "MODEL_CLASS unspecified!"
+ xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
+ tree = ET.parse(xml_path)
+ worldbody = tree.find(".//worldbody")
+
+ self.MAZE_HEIGHT = height = maze_height
+ self.MAZE_SIZE_SCALING = size_scaling = maze_size_scaling
+ self._n_bins = n_bins
+ self._sensor_range = sensor_range * size_scaling
+ self._sensor_span = sensor_span
+ self._observe_blocks = observe_blocks
+ self._put_spin_near_agent = put_spin_near_agent
+ self._top_down_view = top_down_view
+ self._manual_collision = manual_collision
+
+ self.MAZE_STRUCTURE = structure = maze_env_utils.construct_maze(maze_id=self._maze_id)
+ self.elevated = any(-1 in row for row in structure) # Elevate the maze to allow for falling.
+ self.blocks = any(
+ any(maze_env_utils.can_move(r) for r in row)
+ for row in structure) # Are there any movable blocks?
+
+ torso_x, torso_y = self._find_robot()
+ self._init_torso_x = torso_x
+ self._init_torso_y = torso_y
+ self._init_positions = [
+ (x - torso_x, y - torso_y)
+ for x, y in self._find_all_robots()]
+
+ self._xy_to_rowcol = lambda x, y: (2 + (y + size_scaling / 2) / size_scaling,
+ 2 + (x + size_scaling / 2) / size_scaling)
+ self._view = np.zeros([5, 5, 3]) # walls (immovable), chasms (fall), movable blocks
+
+ height_offset = 0.
+ if self.elevated:
+ # Increase initial z-pos of ant.
+ height_offset = height * size_scaling
+ torso = tree.find(".//body[@name='torso']")
+ torso.set('pos', '0 0 %.2f' % (0.75 + height_offset))
+ if self.blocks:
+ # If there are movable blocks, change simulation settings to perform
+ # better contact detection.
+ default = tree.find(".//default")
+ default.find('.//geom').set('solimp', '.995 .995 .01')
+
+ self.movable_blocks = []
+ for i in range(len(structure)):
+ for j in range(len(structure[0])):
+ struct = structure[i][j]
+ if struct == 'r' and self._put_spin_near_agent:
+ struct = maze_env_utils.Move.SpinXY
+ if self.elevated and struct not in [-1]:
+ # Create elevated platform.
+ ET.SubElement(
+ worldbody, "geom",
+ name="elevated_%d_%d" % (i, j),
+ pos="%f %f %f" % (j * size_scaling - torso_x,
+ i * size_scaling - torso_y,
+ height / 2 * size_scaling),
+ size="%f %f %f" % (0.5 * size_scaling,
+ 0.5 * size_scaling,
+ height / 2 * size_scaling),
+ type="box",
+ material="",
+ contype="1",
+ conaffinity="1",
+ rgba="0.9 0.9 0.9 1",
+ )
+ if struct == 1: # Unmovable block.
+ # Offset all coordinates so that robot starts at the origin.
+ ET.SubElement(
+ worldbody, "geom",
+ name="block_%d_%d" % (i, j),
+ pos="%f %f %f" % (j * size_scaling - torso_x,
+ i * size_scaling - torso_y,
+ height_offset +
+ height / 2 * size_scaling),
+ size="%f %f %f" % (0.5 * size_scaling,
+ 0.5 * size_scaling,
+ height / 2 * size_scaling),
+ type="box",
+ material="",
+ contype="1",
+ conaffinity="1",
+ rgba="0.4 0.4 0.4 1",
+ )
+ elif maze_env_utils.can_move(struct): # Movable block.
+ # The "falling" blocks are shrunk slightly and increased in mass to
+ # ensure that it can fall easily through a gap in the platform blocks.
+ name = "movable_%d_%d" % (i, j)
+ self.movable_blocks.append((name, struct))
+ falling = maze_env_utils.can_move_z(struct)
+ spinning = maze_env_utils.can_spin(struct)
+ x_offset = 0.25 * size_scaling if spinning else 0.0
+ y_offset = 0.0
+ shrink = 0.1 if spinning else 0.99 if falling else 1.0
+ height_shrink = 0.1 if spinning else 1.0
+ movable_body = ET.SubElement(
+ worldbody, "body",
+ name=name,
+ pos="%f %f %f" % (j * size_scaling - torso_x + x_offset,
+ i * size_scaling - torso_y + y_offset,
+ height_offset +
+ height / 2 * size_scaling * height_shrink),
+ )
+ ET.SubElement(
+ movable_body, "geom",
+ name="block_%d_%d" % (i, j),
+ pos="0 0 0",
+ size="%f %f %f" % (0.5 * size_scaling * shrink,
+ 0.5 * size_scaling * shrink,
+ height / 2 * size_scaling * height_shrink),
+ type="box",
+ material="",
+ mass="0.001" if falling else "0.0002",
+ contype="1",
+ conaffinity="1",
+ rgba="0.9 0.1 0.1 1"
+ )
+ if maze_env_utils.can_move_x(struct):
+ ET.SubElement(
+ movable_body, "joint",
+ armature="0",
+ axis="1 0 0",
+ damping="0.0",
+ limited="true" if falling else "false",
+ range="%f %f" % (-size_scaling, size_scaling),
+ margin="0.01",
+ name="movable_x_%d_%d" % (i, j),
+ pos="0 0 0",
+ type="slide"
+ )
+ if maze_env_utils.can_move_y(struct):
+ ET.SubElement(
+ movable_body, "joint",
+ armature="0",
+ axis="0 1 0",
+ damping="0.0",
+ limited="true" if falling else "false",
+ range="%f %f" % (-size_scaling, size_scaling),
+ margin="0.01",
+ name="movable_y_%d_%d" % (i, j),
+ pos="0 0 0",
+ type="slide"
+ )
+ if maze_env_utils.can_move_z(struct):
+ ET.SubElement(
+ movable_body, "joint",
+ armature="0",
+ axis="0 0 1",
+ damping="0.0",
+ limited="true",
+ range="%f 0" % (-height_offset),
+ margin="0.01",
+ name="movable_z_%d_%d" % (i, j),
+ pos="0 0 0",
+ type="slide"
+ )
+ if maze_env_utils.can_spin(struct):
+ ET.SubElement(
+ movable_body, "joint",
+ armature="0",
+ axis="0 0 1",
+ damping="0.0",
+ limited="false",
+ name="spinable_%d_%d" % (i, j),
+ pos="0 0 0",
+ type="ball"
+ )
+
+ torso = tree.find(".//body[@name='torso']")
+ geoms = torso.findall(".//geom")
+ for geom in geoms:
+ if 'name' not in geom.attrib:
+ raise Exception("Every geom of the torso must have a name "
+ "defined")
+
+ _, file_path = tempfile.mkstemp(text=True, suffix='.xml')
+ tree.write(file_path)
+
+ self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
+
+ def get_ori(self):
+ return self.wrapped_env.get_ori()
+
+ def get_top_down_view(self):
+ self._view = np.zeros_like(self._view)
+
+ def valid(row, col):
+ return self._view.shape[0] > row >= 0 and self._view.shape[1] > col >= 0
+
+ def update_view(x, y, d, row=None, col=None):
+ if row is None or col is None:
+ x = x - self._robot_x
+ y = y - self._robot_y
+ th = self._robot_ori
+
+ row, col = self._xy_to_rowcol(x, y)
+ update_view(x, y, d, row=row, col=col)
+ return
+
+ row, row_frac, col, col_frac = int(row), row % 1, int(col), col % 1
+ if row_frac < 0:
+ row_frac += 1
+ if col_frac < 0:
+ col_frac += 1
+
+ if valid(row, col):
+ self._view[row, col, d] += (
+ (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
+ (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
+ if valid(row - 1, col):
+ self._view[row - 1, col, d] += (
+ (max(0., 0.5 - row_frac)) *
+ (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
+ if valid(row + 1, col):
+ self._view[row + 1, col, d] += (
+ (max(0., row_frac - 0.5)) *
+ (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
+ if valid(row, col - 1):
+ self._view[row, col - 1, d] += (
+ (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
+ (max(0., 0.5 - col_frac)))
+ if valid(row, col + 1):
+ self._view[row, col + 1, d] += (
+ (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
+ (max(0., col_frac - 0.5)))
+ if valid(row - 1, col - 1):
+ self._view[row - 1, col - 1, d] += (
+ (max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
+ if valid(row - 1, col + 1):
+ self._view[row - 1, col + 1, d] += (
+ (max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
+ if valid(row + 1, col + 1):
+ self._view[row + 1, col + 1, d] += (
+ (max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
+ if valid(row + 1, col - 1):
+ self._view[row + 1, col - 1, d] += (
+ (max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
+
+ # Draw ant.
+ robot_x, robot_y = self.wrapped_env.get_body_com("torso")[:2]
+ self._robot_x = robot_x
+ self._robot_y = robot_y
+ self._robot_ori = self.get_ori()
+
+ structure = self.MAZE_STRUCTURE
+ size_scaling = self.MAZE_SIZE_SCALING
+ height = self.MAZE_HEIGHT
+
+ # Draw immovable blocks and chasms.
+ for i in range(len(structure)):
+ for j in range(len(structure[0])):
+ if structure[i][j] == 1: # Wall.
+ update_view(j * size_scaling - self._init_torso_x,
+ i * size_scaling - self._init_torso_y,
+ 0)
+ if structure[i][j] == -1: # Chasm.
+ update_view(j * size_scaling - self._init_torso_x,
+ i * size_scaling - self._init_torso_y,
+ 1)
+
+ # Draw movable blocks.
+ for block_name, block_type in self.movable_blocks:
+ block_x, block_y = self.wrapped_env.get_body_com(block_name)[:2]
+ update_view(block_x, block_y, 2)
+
+ return self._view
+
+ def get_range_sensor_obs(self):
+ """Returns egocentric range sensor observations of maze."""
+ robot_x, robot_y, robot_z = self.wrapped_env.get_body_com("torso")[:3]
+ ori = self.get_ori()
+
+ structure = self.MAZE_STRUCTURE
+ size_scaling = self.MAZE_SIZE_SCALING
+ height = self.MAZE_HEIGHT
+
+ segments = []
+ # Get line segments (corresponding to outer boundary) of each immovable
+ # block or drop-off.
+ for i in range(len(structure)):
+ for j in range(len(structure[0])):
+ if structure[i][j] in [1, -1]: # There's a wall or drop-off.
+ cx = j * size_scaling - self._init_torso_x
+ cy = i * size_scaling - self._init_torso_y
+ x1 = cx - 0.5 * size_scaling
+ x2 = cx + 0.5 * size_scaling
+ y1 = cy - 0.5 * size_scaling
+ y2 = cy + 0.5 * size_scaling
+ struct_segments = [
+ ((x1, y1), (x2, y1)),
+ ((x2, y1), (x2, y2)),
+ ((x2, y2), (x1, y2)),
+ ((x1, y2), (x1, y1)),
+ ]
+ for seg in struct_segments:
+ segments.append(dict(
+ segment=seg,
+ type=structure[i][j],
+ ))
+ # Get line segments (corresponding to outer boundary) of each movable
+ # block within the agent's z-view.
+ for block_name, block_type in self.movable_blocks:
+ block_x, block_y, block_z = self.wrapped_env.get_body_com(block_name)[:3]
+ if (block_z + height * size_scaling / 2 >= robot_z and
+ robot_z >= block_z - height * size_scaling / 2): # Block in view.
+ x1 = block_x - 0.5 * size_scaling
+ x2 = block_x + 0.5 * size_scaling
+ y1 = block_y - 0.5 * size_scaling
+ y2 = block_y + 0.5 * size_scaling
+ struct_segments = [
+ ((x1, y1), (x2, y1)),
+ ((x2, y1), (x2, y2)),
+ ((x2, y2), (x1, y2)),
+ ((x1, y2), (x1, y1)),
+ ]
+ for seg in struct_segments:
+ segments.append(dict(
+ segment=seg,
+ type=block_type,
+ ))
+
+ sensor_readings = np.zeros((self._n_bins, 3)) # 3 for wall, drop-off, block
+ for ray_idx in range(self._n_bins):
+ ray_ori = (ori - self._sensor_span * 0.5 +
+ (2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
+ ray_segments = []
+ # Get all segments that intersect with ray.
+ for seg in segments:
+ p = maze_env_utils.ray_segment_intersect(
+ ray=((robot_x, robot_y), ray_ori),
+ segment=seg["segment"])
+ if p is not None:
+ ray_segments.append(dict(
+ segment=seg["segment"],
+ type=seg["type"],
+ ray_ori=ray_ori,
+ distance=maze_env_utils.point_distance(p, (robot_x, robot_y)),
+ ))
+ if len(ray_segments) > 0:
+ # Find out which segment is intersected first.
+ first_seg = sorted(ray_segments, key=lambda x: x["distance"])[0]
+ seg_type = first_seg["type"]
+ idx = (0 if seg_type == 1 else # Wall.
+ 1 if seg_type == -1 else # Drop-off.
+ 2 if maze_env_utils.can_move(seg_type) else # Block.
+ None)
+ if first_seg["distance"] <= self._sensor_range:
+ sensor_readings[ray_idx][idx] = (self._sensor_range - first_seg["distance"]) / self._sensor_range
+
+ return sensor_readings
+
+ def _get_obs(self):
+ wrapped_obs = self.wrapped_env._get_obs()
+ if self._top_down_view:
+ view = [self.get_top_down_view().flat]
+ else:
+ view = []
+
+ if self._observe_blocks:
+ additional_obs = []
+ for block_name, block_type in self.movable_blocks:
+ additional_obs.append(self.wrapped_env.get_body_com(block_name))
+ wrapped_obs = np.concatenate([wrapped_obs[:3]] + additional_obs +
+ [wrapped_obs[3:]])
+
+ range_sensor_obs = self.get_range_sensor_obs()
+ return np.concatenate([wrapped_obs,
+ range_sensor_obs.flat] +
+ view + [[self.t * 0.001]])
+
+ def reset(self):
+ self.t = 0
+ self.trajectory = []
+ self.wrapped_env.reset()
+ if len(self._init_positions) > 1:
+ xy = random.choice(self._init_positions)
+ self.wrapped_env.set_xy(xy)
+ return self._get_obs()
+
+ @property
+ def viewer(self):
+ return self.wrapped_env.viewer
+
+ def render(self, *args, **kwargs):
+ return self.wrapped_env.render(*args, **kwargs)
+
+ @property
+ def observation_space(self):
+ shape = self._get_obs().shape
+ high = np.inf * np.ones(shape)
+ low = -high
+ return gym.spaces.Box(low, high)
+
+ @property
+ def action_space(self):
+ return self.wrapped_env.action_space
+
+ def _find_robot(self):
+ structure = self.MAZE_STRUCTURE
+ size_scaling = self.MAZE_SIZE_SCALING
+ for i in range(len(structure)):
+ for j in range(len(structure[0])):
+ if structure[i][j] == 'r':
+ return j * size_scaling, i * size_scaling
+ assert False, 'No robot in maze specification.'
+
+ def _find_all_robots(self):
+ structure = self.MAZE_STRUCTURE
+ size_scaling = self.MAZE_SIZE_SCALING
+ coords = []
+ for i in range(len(structure)):
+ for j in range(len(structure[0])):
+ if structure[i][j] == 'r':
+ coords.append((j * size_scaling, i * size_scaling))
+ return coords
+
+ def _is_in_collision(self, pos):
+ x, y = pos
+ structure = self.MAZE_STRUCTURE
+ size_scaling = self.MAZE_SIZE_SCALING
+ for i in range(len(structure)):
+ for j in range(len(structure[0])):
+ if structure[i][j] == 1:
+ minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
+ maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
+ miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
+ maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
+ if minx <= x <= maxx and miny <= y <= maxy:
+ return True
+ return False
+
+ def step(self, action):
+ self.t += 1
+ if self._manual_collision:
+ old_pos = self.wrapped_env.get_xy()
+ inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
+ new_pos = self.wrapped_env.get_xy()
+ if self._is_in_collision(new_pos):
+ self.wrapped_env.set_xy(old_pos)
+ else:
+ inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
+ next_obs = self._get_obs()
+ done = False
+ return next_obs, inner_reward, done, info
diff --git a/models/research/efficient-hrl/environments/maze_env_utils.py b/models/research/efficient-hrl/environments/maze_env_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f52509b65a89b43c1baa8e7448e0692ceefaaab
--- /dev/null
+++ b/models/research/efficient-hrl/environments/maze_env_utils.py
@@ -0,0 +1,164 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adapted from rllab maze_env_utils.py."""
+import numpy as np
+import math
+
+
+class Move(object):
+ X = 11
+ Y = 12
+ Z = 13
+ XY = 14
+ XZ = 15
+ YZ = 16
+ XYZ = 17
+ SpinXY = 18
+
+
+def can_move_x(movable):
+ return movable in [Move.X, Move.XY, Move.XZ, Move.XYZ,
+ Move.SpinXY]
+
+
+def can_move_y(movable):
+ return movable in [Move.Y, Move.XY, Move.YZ, Move.XYZ,
+ Move.SpinXY]
+
+
+def can_move_z(movable):
+ return movable in [Move.Z, Move.XZ, Move.YZ, Move.XYZ]
+
+
+def can_spin(movable):
+ return movable in [Move.SpinXY]
+
+
+def can_move(movable):
+ return can_move_x(movable) or can_move_y(movable) or can_move_z(movable)
+
+
+def construct_maze(maze_id='Maze'):
+ if maze_id == 'Maze':
+ structure = [
+ [1, 1, 1, 1, 1],
+ [1, 'r', 0, 0, 1],
+ [1, 1, 1, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1],
+ ]
+ elif maze_id == 'Push':
+ structure = [
+ [1, 1, 1, 1, 1],
+ [1, 0, 'r', 1, 1],
+ [1, 0, Move.XY, 0, 1],
+ [1, 1, 0, 1, 1],
+ [1, 1, 1, 1, 1],
+ ]
+ elif maze_id == 'Fall':
+ structure = [
+ [1, 1, 1, 1],
+ [1, 'r', 0, 1],
+ [1, 0, Move.YZ, 1],
+ [1, -1, -1, 1],
+ [1, 0, 0, 1],
+ [1, 1, 1, 1],
+ ]
+ elif maze_id == 'Block':
+ O = 'r'
+ structure = [
+ [1, 1, 1, 1, 1],
+ [1, O, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 0, 0, 0, 1],
+ [1, 1, 1, 1, 1],
+ ]
+ elif maze_id == 'BlockMaze':
+ O = 'r'
+ structure = [
+ [1, 1, 1, 1],
+ [1, O, 0, 1],
+ [1, 1, 0, 1],
+ [1, 0, 0, 1],
+ [1, 1, 1, 1],
+ ]
+ else:
+ raise NotImplementedError('The provided MazeId %s is not recognized' % maze_id)
+
+ return structure
+
+
+def line_intersect(pt1, pt2, ptA, ptB):
+ """
+ Taken from https://www.cs.hmc.edu/ACM/lectures/intersections.html
+
+ this returns the intersection of Line(pt1,pt2) and Line(ptA,ptB)
+ """
+
+ DET_TOLERANCE = 0.00000001
+
+ # the first line is pt1 + r*(pt2-pt1)
+ # in component form:
+ x1, y1 = pt1
+ x2, y2 = pt2
+ dx1 = x2 - x1
+ dy1 = y2 - y1
+
+ # the second line is ptA + s*(ptB-ptA)
+ x, y = ptA
+ xB, yB = ptB
+ dx = xB - x
+ dy = yB - y
+
+ DET = (-dx1 * dy + dy1 * dx)
+
+ if math.fabs(DET) < DET_TOLERANCE: return (0, 0, 0, 0, 0)
+
+ # now, the determinant should be OK
+ DETinv = 1.0 / DET
+
+ # find the scalar amount along the "self" segment
+ r = DETinv * (-dy * (x - x1) + dx * (y - y1))
+
+ # find the scalar amount along the input line
+ s = DETinv * (-dy1 * (x - x1) + dx1 * (y - y1))
+
+ # return the average of the two descriptions
+ xi = (x1 + r * dx1 + x + s * dx) / 2.0
+ yi = (y1 + r * dy1 + y + s * dy) / 2.0
+ return (xi, yi, 1, r, s)
+
+
+def ray_segment_intersect(ray, segment):
+ """
+ Check if the ray originated from (x, y) with direction theta intersects the line segment (x1, y1) -- (x2, y2),
+ and return the intersection point if there is one
+ """
+ (x, y), theta = ray
+ # (x1, y1), (x2, y2) = segment
+ pt1 = (x, y)
+ len = 1
+ pt2 = (x + len * math.cos(theta), y + len * math.sin(theta))
+ xo, yo, valid, r, s = line_intersect(pt1, pt2, *segment)
+ if valid and r >= 0 and 0 <= s <= 1:
+ return (xo, yo)
+ return None
+
+
+def point_distance(p1, p2):
+ x1, y1 = p1
+ x2, y2 = p2
+ return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
diff --git a/models/research/efficient-hrl/environments/point.py b/models/research/efficient-hrl/environments/point.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c2fc80bc824dbc81e228a122b9aaea054f73b74
--- /dev/null
+++ b/models/research/efficient-hrl/environments/point.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Wrapper for creating the ant environment in gym_mujoco."""
+
+import math
+import numpy as np
+import mujoco_py
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+
+
+class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+ FILE = "point.xml"
+ ORI_IND = 2
+
+ def __init__(self, file_path=None, expose_all_qpos=True):
+ self._expose_all_qpos = expose_all_qpos
+
+ mujoco_env.MujocoEnv.__init__(self, file_path, 1)
+ utils.EzPickle.__init__(self)
+
+ @property
+ def physics(self):
+ # check mujoco version is greater than version 1.50 to call correct physics
+ # model containing PyMjData object for getting and setting position/velocity
+ # check https://github.com/openai/mujoco-py/issues/80 for updates to api
+ if mujoco_py.get_version() >= '1.50':
+ return self.sim
+ else:
+ return self.model
+
+ def _step(self, a):
+ return self.step(a)
+
+ def step(self, action):
+ action[0] = 0.2 * action[0]
+ qpos = np.copy(self.physics.data.qpos)
+ qpos[2] += action[1]
+ ori = qpos[2]
+ # compute increment in each direction
+ dx = math.cos(ori) * action[0]
+ dy = math.sin(ori) * action[0]
+ # ensure that the robot is within reasonable range
+ qpos[0] = np.clip(qpos[0] + dx, -100, 100)
+ qpos[1] = np.clip(qpos[1] + dy, -100, 100)
+ qvel = self.physics.data.qvel
+ self.set_state(qpos, qvel)
+ for _ in range(0, self.frame_skip):
+ self.physics.step()
+ next_obs = self._get_obs()
+ reward = 0
+ done = False
+ info = {}
+ return next_obs, reward, done, info
+
+ def _get_obs(self):
+ if self._expose_all_qpos:
+ return np.concatenate([
+ self.physics.data.qpos.flat[:3], # Only point-relevant coords.
+ self.physics.data.qvel.flat[:3]])
+ return np.concatenate([
+ self.physics.data.qpos.flat[2:3],
+ self.physics.data.qvel.flat[:3]])
+
+ def reset_model(self):
+ qpos = self.init_qpos + self.np_random.uniform(
+ size=self.physics.model.nq, low=-.1, high=.1)
+ qvel = self.init_qvel + self.np_random.randn(self.physics.model.nv) * .1
+
+ # Set everything other than point to original position and 0 velocity.
+ qpos[3:] = self.init_qpos[3:]
+ qvel[3:] = 0.
+ self.set_state(qpos, qvel)
+ return self._get_obs()
+
+ def get_ori(self):
+ return self.physics.data.qpos[self.__class__.ORI_IND]
+
+ def set_xy(self, xy):
+ qpos = np.copy(self.physics.data.qpos)
+ qpos[0] = xy[0]
+ qpos[1] = xy[1]
+
+ qvel = self.physics.data.qvel
diff --git a/models/research/efficient-hrl/environments/point_maze_env.py b/models/research/efficient-hrl/environments/point_maze_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d6b819486370d609b87c232d92c4093aa906863
--- /dev/null
+++ b/models/research/efficient-hrl/environments/point_maze_env.py
@@ -0,0 +1,21 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from environments.maze_env import MazeEnv
+from environments.point import PointEnv
+
+
+class PointMazeEnv(MazeEnv):
+ MODEL_CLASS = PointEnv
diff --git a/models/research/efficient-hrl/eval.py b/models/research/efficient-hrl/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f5a4b20a53d920b4a9095c30de2c03698cd1b78
--- /dev/null
+++ b/models/research/efficient-hrl/eval.py
@@ -0,0 +1,460 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""Script for evaluating a UVF agent.
+
+To run locally: See run_eval.py
+
+To run on borg: See train_eval.borg
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow as tf
+slim = tf.contrib.slim
+import gin.tf
+# pylint: disable=unused-import
+import agent
+import train
+from utils import utils as uvf_utils
+from utils import eval_utils
+from environments import create_maze_env
+# pylint: enable=unused-import
+
+flags = tf.app.flags
+
+flags.DEFINE_string('eval_dir', None,
+ 'Directory for writing logs/summaries during eval.')
+flags.DEFINE_string('checkpoint_dir', None,
+ 'Directory containing checkpoints to eval.')
+FLAGS = flags.FLAGS
+
+
+def get_evaluate_checkpoint_fn(master, output_dir, eval_step_fns,
+ model_rollout_fn, gamma, max_steps_per_episode,
+ num_episodes_eval, num_episodes_videos,
+ tuner_hook, generate_videos,
+ generate_summaries, video_settings):
+ """Returns a function that evaluates a given checkpoint.
+
+ Args:
+ master: BNS name of the TensorFlow master
+ output_dir: The output directory to which the metric summaries are written.
+ eval_step_fns: A dictionary of a functions that return a list of
+ [state, action, reward, discount, transition_type] tensors,
+ indexed by summary tag name.
+ model_rollout_fn: Model rollout fn.
+ gamma: Discount factor for the reward.
+ max_steps_per_episode: Maximum steps to run each episode for.
+ num_episodes_eval: Number of episodes to evaluate and average reward over.
+ num_episodes_videos: Number of episodes to record for video.
+ tuner_hook: A callable(average reward, global step) that updates a Vizier
+ tuner trial.
+ generate_videos: Whether to generate videos of the agent in action.
+ generate_summaries: Whether to generate summaries.
+ video_settings: Settings for generating videos of the agent.
+
+ Returns:
+ A function that evaluates a checkpoint.
+ """
+ sess = tf.Session(master, graph=tf.get_default_graph())
+ sess.run(tf.global_variables_initializer())
+ sess.run(tf.local_variables_initializer())
+ summary_writer = tf.summary.FileWriter(output_dir)
+
+ def evaluate_checkpoint(checkpoint_path):
+ """Performs a one-time evaluation of the given checkpoint.
+
+ Args:
+ checkpoint_path: Checkpoint to evaluate.
+ Returns:
+ True if the evaluation process should stop
+ """
+ restore_fn = tf.contrib.framework.assign_from_checkpoint_fn(
+ checkpoint_path,
+ uvf_utils.get_all_vars(),
+ ignore_missing_vars=True,
+ reshape_variables=False)
+ assert restore_fn is not None, 'cannot restore %s' % checkpoint_path
+ restore_fn(sess)
+ global_step = sess.run(slim.get_global_step())
+ should_stop = False
+ max_reward = -1e10
+ max_meta_reward = -1e10
+
+ for eval_tag, (eval_step, env_base,) in sorted(eval_step_fns.items()):
+ if hasattr(env_base, 'set_sess'):
+ env_base.set_sess(sess) # set session
+
+ if generate_summaries:
+ tf.logging.info(
+ '[%s] Computing average reward over %d episodes at global step %d.',
+ eval_tag, num_episodes_eval, global_step)
+ (average_reward, last_reward,
+ average_meta_reward, last_meta_reward, average_success,
+ states, actions) = eval_utils.compute_average_reward(
+ sess, env_base, eval_step, gamma, max_steps_per_episode,
+ num_episodes_eval)
+ tf.logging.info('[%s] Average reward = %f', eval_tag, average_reward)
+ tf.logging.info('[%s] Last reward = %f', eval_tag, last_reward)
+ tf.logging.info('[%s] Average meta reward = %f', eval_tag, average_meta_reward)
+ tf.logging.info('[%s] Last meta reward = %f', eval_tag, last_meta_reward)
+ tf.logging.info('[%s] Average success = %f', eval_tag, average_success)
+ if model_rollout_fn is not None:
+ preds, model_losses = eval_utils.compute_model_loss(
+ sess, model_rollout_fn, states, actions)
+ for i, (pred, state, model_loss) in enumerate(
+ zip(preds, states, model_losses)):
+ tf.logging.info('[%s] Model rollout step %d: loss=%f', eval_tag, i,
+ model_loss)
+ tf.logging.info('[%s] Model rollout step %d: pred=%s', eval_tag, i,
+ str(pred.tolist()))
+ tf.logging.info('[%s] Model rollout step %d: state=%s', eval_tag, i,
+ str(state.tolist()))
+
+ # Report the eval stats to the tuner.
+ if average_reward > max_reward:
+ max_reward = average_reward
+ if average_meta_reward > max_meta_reward:
+ max_meta_reward = average_meta_reward
+
+ for (tag, value) in [('Reward/average_%s_reward', average_reward),
+ ('Reward/last_%s_reward', last_reward),
+ ('Reward/average_%s_meta_reward', average_meta_reward),
+ ('Reward/last_%s_meta_reward', last_meta_reward),
+ ('Reward/average_%s_success', average_success)]:
+ summary_str = tf.Summary(value=[
+ tf.Summary.Value(
+ tag=tag % eval_tag,
+ simple_value=value)
+ ])
+ summary_writer.add_summary(summary_str, global_step)
+ summary_writer.flush()
+
+ if generate_videos or should_stop:
+ # Do a manual reset before generating the video to see the initial
+ # pose of the robot, towards which the reset controller is moving.
+ if hasattr(env_base, '_gym_env'):
+ tf.logging.info('Resetting before recording video')
+ if hasattr(env_base._gym_env, 'reset_model'):
+ env_base._gym_env.reset_model() # pylint: disable=protected-access
+ else:
+ env_base._gym_env.wrapped_env.reset_model()
+ video_filename = os.path.join(output_dir, 'videos',
+ '%s_step_%d.mp4' % (eval_tag,
+ global_step))
+ eval_utils.capture_video(sess, eval_step, env_base,
+ max_steps_per_episode * num_episodes_videos,
+ video_filename, video_settings,
+ reset_every=max_steps_per_episode)
+
+ should_stop = should_stop or (generate_summaries and tuner_hook and
+ tuner_hook(max_reward, global_step))
+ return bool(should_stop)
+
+ return evaluate_checkpoint
+
+
+def get_model_rollout(uvf_agent, tf_env):
+ """Model rollout function."""
+ state_spec = tf_env.observation_spec()[0]
+ action_spec = tf_env.action_spec()[0]
+ state_ph = tf.placeholder(dtype=state_spec.dtype, shape=state_spec.shape)
+ action_ph = tf.placeholder(dtype=action_spec.dtype, shape=action_spec.shape)
+
+ merged_state = uvf_agent.merged_state(state_ph)
+ diff_value = uvf_agent.critic_net(tf.expand_dims(merged_state, 0),
+ tf.expand_dims(action_ph, 0))[0]
+ diff_value = tf.cast(diff_value, dtype=state_ph.dtype)
+ state_ph.shape.assert_is_compatible_with(diff_value.shape)
+ next_state = state_ph + diff_value
+
+ def model_rollout_fn(sess, state, action):
+ return sess.run(next_state, feed_dict={state_ph: state, action_ph: action})
+
+ return model_rollout_fn
+
+
+def get_eval_step(uvf_agent,
+ state_preprocess,
+ tf_env,
+ action_fn,
+ meta_action_fn,
+ environment_steps,
+ num_episodes,
+ mode='eval'):
+ """Get one-step policy/env stepping ops.
+
+ Args:
+ uvf_agent: A UVF agent.
+ tf_env: A TFEnvironment.
+ action_fn: A function to produce actions given current state.
+ meta_action_fn: A function to produce meta actions given current state.
+ environment_steps: A variable to count the number of steps in the tf_env.
+ num_episodes: A variable to count the number of episodes.
+ mode: a string representing the mode=[train, explore, eval].
+
+ Returns:
+ A collect_experience_op that excute an action and store into the
+ replay_buffer
+ """
+
+ tf_env.start_collect()
+ state = tf_env.current_obs()
+ action = action_fn(state, context=None)
+ state_repr = state_preprocess(state)
+
+ action_spec = tf_env.action_spec()
+ action_ph = tf.placeholder(dtype=action_spec.dtype, shape=action_spec.shape)
+ with tf.control_dependencies([state]):
+ transition_type, reward, discount = tf_env.step(action_ph)
+
+ def increment_step():
+ return environment_steps.assign_add(1)
+
+ def increment_episode():
+ return num_episodes.assign_add(1)
+
+ def no_op_int():
+ return tf.constant(0, dtype=tf.int64)
+
+ step_cond = uvf_agent.step_cond_fn(state, action,
+ transition_type,
+ environment_steps, num_episodes)
+ reset_episode_cond = uvf_agent.reset_episode_cond_fn(
+ state, action,
+ transition_type, environment_steps, num_episodes)
+ reset_env_cond = uvf_agent.reset_env_cond_fn(state, action,
+ transition_type,
+ environment_steps, num_episodes)
+
+ increment_step_op = tf.cond(step_cond, increment_step, no_op_int)
+ with tf.control_dependencies([increment_step_op]):
+ increment_episode_op = tf.cond(reset_episode_cond, increment_episode,
+ no_op_int)
+
+ with tf.control_dependencies([reward, discount]):
+ next_state = tf_env.current_obs()
+ next_state_repr = state_preprocess(next_state)
+
+ with tf.control_dependencies([increment_episode_op]):
+ post_reward, post_meta_reward = uvf_agent.cond_begin_episode_op(
+ tf.logical_not(reset_episode_cond),
+ [state, action_ph, reward, next_state,
+ state_repr, next_state_repr],
+ mode=mode, meta_action_fn=meta_action_fn)
+
+ # Important: do manual reset after getting the final reward from the
+ # unreset environment.
+ with tf.control_dependencies([post_reward, post_meta_reward]):
+ cond_reset_op = tf.cond(reset_env_cond,
+ tf_env.reset,
+ tf_env.current_time_step)
+
+ # Add a dummy control dependency to force the reset_op to run
+ with tf.control_dependencies(cond_reset_op):
+ post_reward, post_meta_reward = map(tf.identity, [post_reward, post_meta_reward])
+
+ eval_step = [next_state, action_ph, transition_type, post_reward, post_meta_reward, discount, uvf_agent.context_vars, state_repr]
+
+ if callable(action):
+ def step_fn(sess):
+ action_value = action(sess)
+ return sess.run(eval_step, feed_dict={action_ph: action_value})
+ else:
+ action = uvf_utils.clip_to_spec(action, action_spec)
+ def step_fn(sess):
+ action_value = sess.run(action)
+ return sess.run(eval_step, feed_dict={action_ph: action_value})
+
+ return step_fn
+
+
+@gin.configurable
+def evaluate(checkpoint_dir,
+ eval_dir,
+ environment=None,
+ num_bin_actions=3,
+ agent_class=None,
+ meta_agent_class=None,
+ state_preprocess_class=None,
+ gamma=1.0,
+ num_episodes_eval=10,
+ eval_interval_secs=60,
+ max_number_of_evaluations=None,
+ checkpoint_timeout=None,
+ timeout_fn=None,
+ tuner_hook=None,
+ generate_videos=False,
+ generate_summaries=True,
+ num_episodes_videos=5,
+ video_settings=None,
+ eval_modes=('eval',),
+ eval_model_rollout=False,
+ policy_save_dir='policy',
+ checkpoint_range=None,
+ checkpoint_path=None,
+ max_steps_per_episode=None,
+ evaluate_nohrl=False):
+ """Loads and repeatedly evaluates a checkpointed model at a set interval.
+
+ Args:
+ checkpoint_dir: The directory where the checkpoints reside.
+ eval_dir: Directory to save the evaluation summary results.
+ environment: A BaseEnvironment to evaluate.
+ num_bin_actions: Number of bins for discretizing continuous actions.
+ agent_class: An RL agent class.
+ meta_agent_class: A Meta agent class.
+ gamma: Discount factor for the reward.
+ num_episodes_eval: Number of episodes to evaluate and average reward over.
+ eval_interval_secs: The number of seconds between each evaluation run.
+ max_number_of_evaluations: The max number of evaluations. If None the
+ evaluation continues indefinitely.
+ checkpoint_timeout: The maximum amount of time to wait between checkpoints.
+ If left as `None`, then the process will wait indefinitely.
+ timeout_fn: Optional function to call after a timeout.
+ tuner_hook: A callable that takes the average reward and global step and
+ updates a Vizier tuner trial.
+ generate_videos: Whether to generate videos of the agent in action.
+ generate_summaries: Whether to generate summaries.
+ num_episodes_videos: Number of episodes to evaluate for generating videos.
+ video_settings: Settings for generating videos of the agent.
+ optimal action based on the critic.
+ eval_modes: A tuple of eval modes.
+ eval_model_rollout: Evaluate model rollout.
+ policy_save_dir: Optional sub-directory where the policies are
+ saved.
+ checkpoint_range: Optional. If provided, evaluate all checkpoints in
+ the range.
+ checkpoint_path: Optional sub-directory specifying which checkpoint to
+ evaluate. If None, will evaluate the most recent checkpoint.
+ """
+ tf_env = create_maze_env.TFPyEnvironment(environment)
+ observation_spec = [tf_env.observation_spec()]
+ action_spec = [tf_env.action_spec()]
+
+ assert max_steps_per_episode, 'max_steps_per_episode need to be set'
+
+ if agent_class.ACTION_TYPE == 'discrete':
+ assert False
+ else:
+ assert agent_class.ACTION_TYPE == 'continuous'
+
+ if meta_agent_class is not None:
+ assert agent_class.ACTION_TYPE == meta_agent_class.ACTION_TYPE
+ with tf.variable_scope('meta_agent'):
+ meta_agent = meta_agent_class(
+ observation_spec,
+ action_spec,
+ tf_env,
+ )
+ else:
+ meta_agent = None
+
+ with tf.variable_scope('uvf_agent'):
+ uvf_agent = agent_class(
+ observation_spec,
+ action_spec,
+ tf_env,
+ )
+ uvf_agent.set_meta_agent(agent=meta_agent)
+
+ with tf.variable_scope('state_preprocess'):
+ state_preprocess = state_preprocess_class()
+
+ # run both actor and critic once to ensure networks are initialized
+ # and gin configs will be saved
+ # pylint: disable=protected-access
+ temp_states = tf.expand_dims(
+ tf.zeros(
+ dtype=uvf_agent._observation_spec.dtype,
+ shape=uvf_agent._observation_spec.shape), 0)
+ # pylint: enable=protected-access
+ temp_actions = uvf_agent.actor_net(temp_states)
+ uvf_agent.critic_net(temp_states, temp_actions)
+
+ # create eval_step_fns for each action function
+ eval_step_fns = dict()
+ meta_agent = uvf_agent.meta_agent
+ for meta in [True] + [False] * evaluate_nohrl:
+ meta_tag = 'hrl' if meta else 'nohrl'
+ uvf_agent.set_meta_agent(meta_agent if meta else None)
+ for mode in eval_modes:
+ # wrap environment
+ wrapped_environment = uvf_agent.get_env_base_wrapper(
+ environment, mode=mode)
+ action_wrapper = lambda agent_: agent_.action
+ action_fn = action_wrapper(uvf_agent)
+ meta_action_fn = action_wrapper(meta_agent)
+ eval_step_fns['%s_%s' % (mode, meta_tag)] = (get_eval_step(
+ uvf_agent=uvf_agent,
+ state_preprocess=state_preprocess,
+ tf_env=tf_env,
+ action_fn=action_fn,
+ meta_action_fn=meta_action_fn,
+ environment_steps=tf.Variable(
+ 0, dtype=tf.int64, name='environment_steps'),
+ num_episodes=tf.Variable(0, dtype=tf.int64, name='num_episodes'),
+ mode=mode), wrapped_environment,)
+
+ model_rollout_fn = None
+ if eval_model_rollout:
+ model_rollout_fn = get_model_rollout(uvf_agent, tf_env)
+
+ tf.train.get_or_create_global_step()
+
+ if policy_save_dir:
+ checkpoint_dir = os.path.join(checkpoint_dir, policy_save_dir)
+
+ tf.logging.info('Evaluating policies at %s', checkpoint_dir)
+ tf.logging.info('Running episodes for max %d steps', max_steps_per_episode)
+
+ evaluate_checkpoint_fn = get_evaluate_checkpoint_fn(
+ '', eval_dir, eval_step_fns, model_rollout_fn, gamma,
+ max_steps_per_episode, num_episodes_eval, num_episodes_videos, tuner_hook,
+ generate_videos, generate_summaries, video_settings)
+
+ if checkpoint_path is not None:
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_path)
+ evaluate_checkpoint_fn(checkpoint_path)
+ elif checkpoint_range is not None:
+ model_files = tf.gfile.Glob(
+ os.path.join(checkpoint_dir, 'model.ckpt-*.index'))
+ tf.logging.info('Found %s policies at %s', len(model_files), checkpoint_dir)
+ model_files = {
+ int(f.split('model.ckpt-', 1)[1].split('.', 1)[0]):
+ os.path.splitext(f)[0]
+ for f in model_files
+ }
+ model_files = {
+ k: v
+ for k, v in model_files.items()
+ if k >= checkpoint_range[0] and k <= checkpoint_range[1]
+ }
+ tf.logging.info('Evaluating %d policies at %s',
+ len(model_files), checkpoint_dir)
+ for _, checkpoint_path in sorted(model_files.items()):
+ evaluate_checkpoint_fn(checkpoint_path)
+ else:
+ eval_utils.evaluate_checkpoint_repeatedly(
+ checkpoint_dir,
+ evaluate_checkpoint_fn,
+ eval_interval_secs=eval_interval_secs,
+ max_number_of_evaluations=max_number_of_evaluations,
+ checkpoint_timeout=checkpoint_timeout,
+ timeout_fn=timeout_fn)
diff --git a/models/research/efficient-hrl/run_env.py b/models/research/efficient-hrl/run_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..87fad542aea1dc0f9a39553b53d3da8978ca089f
--- /dev/null
+++ b/models/research/efficient-hrl/run_env.py
@@ -0,0 +1,129 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Random policy on an environment."""
+
+import tensorflow as tf
+import numpy as np
+import random
+
+from environments import create_maze_env
+
+app = tf.app
+flags = tf.flags
+logging = tf.logging
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('env', 'AntMaze', 'environment name: AntMaze, AntPush, or AntFall')
+flags.DEFINE_integer('episode_length', 500, 'episode length')
+flags.DEFINE_integer('num_episodes', 50, 'number of episodes')
+
+
+def get_goal_sample_fn(env_name):
+ if env_name == 'AntMaze':
+ # NOTE: When evaluating (i.e. the metrics shown in the paper,
+ # we use the commented out goal sampling function. The uncommented
+ # one is only used for training.
+ #return lambda: np.array([0., 16.])
+ return lambda: np.random.uniform((-4, -4), (20, 20))
+ elif env_name == 'AntPush':
+ return lambda: np.array([0., 19.])
+ elif env_name == 'AntFall':
+ return lambda: np.array([0., 27., 4.5])
+ else:
+ assert False, 'Unknown env'
+
+
+def get_reward_fn(env_name):
+ if env_name == 'AntMaze':
+ return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
+ elif env_name == 'AntPush':
+ return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
+ elif env_name == 'AntFall':
+ return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
+ else:
+ assert False, 'Unknown env'
+
+
+def success_fn(last_reward):
+ return last_reward > -5.0
+
+
+class EnvWithGoal(object):
+
+ def __init__(self, base_env, env_name):
+ self.base_env = base_env
+ self.goal_sample_fn = get_goal_sample_fn(env_name)
+ self.reward_fn = get_reward_fn(env_name)
+ self.goal = None
+
+ def reset(self):
+ obs = self.base_env.reset()
+ self.goal = self.goal_sample_fn()
+ return np.concatenate([obs, self.goal])
+
+ def step(self, a):
+ obs, _, done, info = self.base_env.step(a)
+ reward = self.reward_fn(obs, self.goal)
+ return np.concatenate([obs, self.goal]), reward, done, info
+
+ @property
+ def action_space(self):
+ return self.base_env.action_space
+
+
+def run_environment(env_name, episode_length, num_episodes):
+ env = EnvWithGoal(
+ create_maze_env.create_maze_env(env_name).gym,
+ env_name)
+
+ def action_fn(obs):
+ action_space = env.action_space
+ action_space_mean = (action_space.low + action_space.high) / 2.0
+ action_space_magn = (action_space.high - action_space.low) / 2.0
+ random_action = (action_space_mean +
+ action_space_magn *
+ np.random.uniform(low=-1.0, high=1.0,
+ size=action_space.shape))
+ return random_action
+
+ rewards = []
+ successes = []
+ for ep in range(num_episodes):
+ rewards.append(0.0)
+ successes.append(False)
+ obs = env.reset()
+ for _ in range(episode_length):
+ obs, reward, done, _ = env.step(action_fn(obs))
+ rewards[-1] += reward
+ successes[-1] = success_fn(reward)
+ if done:
+ break
+ logging.info('Episode %d reward: %.2f, Success: %d', ep + 1, rewards[-1], successes[-1])
+
+ logging.info('Average Reward over %d episodes: %.2f',
+ num_episodes, np.mean(rewards))
+ logging.info('Average Success over %d episodes: %.2f',
+ num_episodes, np.mean(successes))
+
+
+def main(unused_argv):
+ logging.set_verbosity(logging.INFO)
+ run_environment(FLAGS.env, FLAGS.episode_length, FLAGS.num_episodes)
+
+
+if __name__ == '__main__':
+ app.run()
diff --git a/models/research/efficient-hrl/run_eval.py b/models/research/efficient-hrl/run_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..12f12369c4c90762bdf3d7506b957588856bdd3f
--- /dev/null
+++ b/models/research/efficient-hrl/run_eval.py
@@ -0,0 +1,51 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""Script for evaluating a UVF agent.
+
+To run locally: See scripts/local_eval.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+import gin.tf
+# pylint: disable=unused-import
+import eval as eval_
+# pylint: enable=unused-import
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+
+def main(_):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ assert FLAGS.checkpoint_dir, "Flag 'checkpoint_dir' must be set."
+ assert FLAGS.eval_dir, "Flag 'eval_dir' must be set."
+
+ if FLAGS.config_file:
+ for config_file in FLAGS.config_file:
+ gin.parse_config_file(config_file)
+ if FLAGS.params:
+ gin.parse_config(FLAGS.params)
+
+ eval_.evaluate(FLAGS.checkpoint_dir, FLAGS.eval_dir)
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/efficient-hrl/run_train.py b/models/research/efficient-hrl/run_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d459d60b7f870bdcd81a48edc896158a9c6e4eb
--- /dev/null
+++ b/models/research/efficient-hrl/run_train.py
@@ -0,0 +1,49 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""Script for training an RL agent using the UVF algorithm.
+
+To run locally: See scripts/local_train.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+import gin.tf
+# pylint: enable=unused-import
+import train
+# pylint: disable=unused-import
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+
+def main(_):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ if FLAGS.config_file:
+ for config_file in FLAGS.config_file:
+ gin.parse_config_file(config_file)
+ if FLAGS.params:
+ gin.parse_config(FLAGS.params)
+
+ assert FLAGS.train_dir, "Flag 'train_dir' must be set."
+ return train.train_uvf(FLAGS.train_dir)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/efficient-hrl/scripts/local_eval.py b/models/research/efficient-hrl/scripts/local_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..89ef745a4086197b07cee5f98fabfcc29af6d145
--- /dev/null
+++ b/models/research/efficient-hrl/scripts/local_eval.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Script to run run_eval.py locally.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+from subprocess import call
+import sys
+
+CONFIGS_PATH = 'configs'
+CONTEXT_CONFIGS_PATH = 'context/configs'
+
+def main():
+ bb = './'
+ base_num_args = 6
+ if len(sys.argv) < base_num_args:
+ print(
+ "usage: python %s "
+ " [params...]"
+ % sys.argv[0])
+ sys.exit(0)
+ exp = sys.argv[1]
+ context_setting = sys.argv[2]
+ context = sys.argv[3]
+ agent = sys.argv[4]
+ assert sys.argv[5] in ["suite"], "args[5] must be `suite'"
+ suite = ""
+ binary = "python {bb}/run_eval{suite}.py ".format(bb=bb, suite=suite)
+
+ h = os.environ["HOME"]
+ ucp = CONFIGS_PATH
+ ccp = CONTEXT_CONFIGS_PATH
+ extra = ''
+ command_str = ("{binary} "
+ "--logtostderr "
+ "--checkpoint_dir={h}/tmp/{context_setting}/{context}/{agent}/{exp}/train "
+ "--eval_dir={h}/tmp/{context_setting}/{context}/{agent}/{exp}/eval "
+ "--config_file={ucp}/{agent}.gin "
+ "--config_file={ucp}/eval_{extra}uvf.gin "
+ "--config_file={ccp}/{context_setting}.gin "
+ "--config_file={ccp}/{context}.gin ").format(
+ h=h,
+ ucp=ucp,
+ ccp=ccp,
+ context_setting=context_setting,
+ context=context,
+ agent=agent,
+ extra=extra,
+ suite=suite,
+ exp=exp,
+ binary=binary)
+ for extra_arg in sys.argv[base_num_args:]:
+ command_str += "--params='%s' " % extra_arg
+
+ print(command_str)
+ call(command_str, shell=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/models/research/efficient-hrl/scripts/local_train.py b/models/research/efficient-hrl/scripts/local_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..718c88e8fedd381707ac944f5ee9243b636ac915
--- /dev/null
+++ b/models/research/efficient-hrl/scripts/local_train.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Script to run run_train.py locally.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+import random
+from subprocess import call
+import sys
+
+CONFIGS_PATH = './configs'
+CONTEXT_CONFIGS_PATH = './context/configs'
+
+def main():
+ bb = '.'
+ base_num_args = 6
+ if len(sys.argv) < base_num_args:
+ print(
+ "usage: python %s "
+ " [params...]"
+ % sys.argv[0])
+ sys.exit(0)
+ exp = sys.argv[1] # Name for experiment, e.g. 'test001'
+ context_setting = sys.argv[2] # Context setting, e.g. 'hiro_orig'
+ context = sys.argv[3] # Environment-specific context, e.g. 'ant_maze'
+ agent = sys.argv[4] # Agent settings, e.g. 'base_uvf'
+ assert sys.argv[5] in ["suite"], "args[5] must be `suite'"
+ suite = ""
+ binary = "python {bb}/run_train{suite}.py ".format(bb=bb, suite=suite)
+
+ h = os.environ["HOME"]
+ ucp = CONFIGS_PATH
+ ccp = CONTEXT_CONFIGS_PATH
+ extra = ''
+ port = random.randint(2000, 8000)
+ command_str = ("{binary} "
+ "--train_dir={h}/tmp/{context_setting}/{context}/{agent}/{exp}/train "
+ "--config_file={ucp}/{agent}.gin "
+ "--config_file={ucp}/train_{extra}uvf.gin "
+ "--config_file={ccp}/{context_setting}.gin "
+ "--config_file={ccp}/{context}.gin "
+ "--summarize_gradients=False "
+ "--save_interval_secs=60 "
+ "--save_summaries_secs=1 "
+ "--master=local "
+ "--alsologtostderr ").format(h=h, ucp=ucp,
+ context_setting=context_setting,
+ context=context, ccp=ccp,
+ suite=suite, agent=agent, extra=extra,
+ exp=exp, binary=binary,
+ port=port)
+ for extra_arg in sys.argv[base_num_args:]:
+ command_str += "--params='%s' " % extra_arg
+
+ print(command_str)
+ call(command_str, shell=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/models/research/efficient-hrl/train.py b/models/research/efficient-hrl/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a40e81dbec6c103563192a373661cda8b5ae5fbb
--- /dev/null
+++ b/models/research/efficient-hrl/train.py
@@ -0,0 +1,670 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""Script for training an RL agent using the UVF algorithm.
+
+To run locally: See run_train.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+import tensorflow as tf
+slim = tf.contrib.slim
+
+import gin.tf
+# pylint: disable=unused-import
+import train_utils
+import agent as agent_
+from agents import circular_buffer
+from utils import utils as uvf_utils
+from environments import create_maze_env
+# pylint: enable=unused-import
+
+
+flags = tf.app.flags
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('goal_sample_strategy', 'sample',
+ 'None, sample, FuN')
+
+LOAD_PATH = None
+
+
+def collect_experience(tf_env, agent, meta_agent, state_preprocess,
+ replay_buffer, meta_replay_buffer,
+ action_fn, meta_action_fn,
+ environment_steps, num_episodes, num_resets,
+ episode_rewards, episode_meta_rewards,
+ store_context,
+ disable_agent_reset):
+ """Collect experience in a tf_env into a replay_buffer using action_fn.
+
+ Args:
+ tf_env: A TFEnvironment.
+ agent: A UVF agent.
+ meta_agent: A Meta Agent.
+ replay_buffer: A Replay buffer to collect experience in.
+ meta_replay_buffer: A Replay buffer to collect meta agent experience in.
+ action_fn: A function to produce actions given current state.
+ meta_action_fn: A function to produce meta actions given current state.
+ environment_steps: A variable to count the number of steps in the tf_env.
+ num_episodes: A variable to count the number of episodes.
+ num_resets: A variable to count the number of resets.
+ store_context: A boolean to check if store context in replay.
+ disable_agent_reset: A boolean that disables agent from resetting.
+
+ Returns:
+ A collect_experience_op that excute an action and store into the
+ replay_buffers
+ """
+ tf_env.start_collect()
+ state = tf_env.current_obs()
+ state_repr = state_preprocess(state)
+ action = action_fn(state, context=None)
+
+ with tf.control_dependencies([state]):
+ transition_type, reward, discount = tf_env.step(action)
+
+ def increment_step():
+ return environment_steps.assign_add(1)
+
+ def increment_episode():
+ return num_episodes.assign_add(1)
+
+ def increment_reset():
+ return num_resets.assign_add(1)
+
+ def update_episode_rewards(context_reward, meta_reward, reset):
+ new_episode_rewards = tf.concat(
+ [episode_rewards[:1] + context_reward, episode_rewards[1:]], 0)
+ new_episode_meta_rewards = tf.concat(
+ [episode_meta_rewards[:1] + meta_reward,
+ episode_meta_rewards[1:]], 0)
+ return tf.group(
+ episode_rewards.assign(
+ tf.cond(reset,
+ lambda: tf.concat([[0.], episode_rewards[:-1]], 0),
+ lambda: new_episode_rewards)),
+ episode_meta_rewards.assign(
+ tf.cond(reset,
+ lambda: tf.concat([[0.], episode_meta_rewards[:-1]], 0),
+ lambda: new_episode_meta_rewards)))
+
+ def no_op_int():
+ return tf.constant(0, dtype=tf.int64)
+
+ step_cond = agent.step_cond_fn(state, action,
+ transition_type,
+ environment_steps, num_episodes)
+ reset_episode_cond = agent.reset_episode_cond_fn(
+ state, action,
+ transition_type, environment_steps, num_episodes)
+ reset_env_cond = agent.reset_env_cond_fn(state, action,
+ transition_type,
+ environment_steps, num_episodes)
+
+ increment_step_op = tf.cond(step_cond, increment_step, no_op_int)
+ increment_episode_op = tf.cond(reset_episode_cond, increment_episode,
+ no_op_int)
+ increment_reset_op = tf.cond(reset_env_cond, increment_reset, no_op_int)
+ increment_op = tf.group(increment_step_op, increment_episode_op,
+ increment_reset_op)
+
+ with tf.control_dependencies([increment_op, reward, discount]):
+ next_state = tf_env.current_obs()
+ next_state_repr = state_preprocess(next_state)
+ next_reset_episode_cond = tf.logical_or(
+ agent.reset_episode_cond_fn(
+ state, action,
+ transition_type, environment_steps, num_episodes),
+ tf.equal(discount, 0.0))
+
+ if store_context:
+ context = [tf.identity(var) + tf.zeros_like(var) for var in agent.context_vars]
+ meta_context = [tf.identity(var) + tf.zeros_like(var) for var in meta_agent.context_vars]
+ else:
+ context = []
+ meta_context = []
+ with tf.control_dependencies([next_state] + context + meta_context):
+ if disable_agent_reset:
+ collect_experience_ops = [tf.no_op()] # don't reset agent
+ else:
+ collect_experience_ops = agent.cond_begin_episode_op(
+ tf.logical_not(reset_episode_cond),
+ [state, action, reward, next_state,
+ state_repr, next_state_repr],
+ mode='explore', meta_action_fn=meta_action_fn)
+ context_reward, meta_reward = collect_experience_ops
+ collect_experience_ops = list(collect_experience_ops)
+ collect_experience_ops.append(
+ update_episode_rewards(tf.reduce_sum(context_reward), meta_reward,
+ reset_episode_cond))
+
+ meta_action_every_n = agent.tf_context.meta_action_every_n
+ with tf.control_dependencies(collect_experience_ops):
+ transition = [state, action, reward, discount, next_state]
+
+ meta_action = tf.to_float(
+ tf.concat(context, -1)) # Meta agent action is low-level context
+
+ meta_end = tf.logical_and( # End of meta-transition.
+ tf.equal(agent.tf_context.t % meta_action_every_n, 1),
+ agent.tf_context.t > 1)
+ with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
+ states_var = tf.get_variable('states_var',
+ [meta_action_every_n, state.shape[-1]],
+ state.dtype)
+ actions_var = tf.get_variable('actions_var',
+ [meta_action_every_n, action.shape[-1]],
+ action.dtype)
+ state_var = tf.get_variable('state_var', state.shape, state.dtype)
+ reward_var = tf.get_variable('reward_var', reward.shape, reward.dtype)
+ meta_action_var = tf.get_variable('meta_action_var',
+ meta_action.shape, meta_action.dtype)
+ meta_context_var = [
+ tf.get_variable('meta_context_var%d' % idx,
+ meta_context[idx].shape, meta_context[idx].dtype)
+ for idx in range(len(meta_context))]
+
+ actions_var_upd = tf.scatter_update(
+ actions_var, (agent.tf_context.t - 2) % meta_action_every_n, action)
+ with tf.control_dependencies([actions_var_upd]):
+ actions = tf.identity(actions_var) + tf.zeros_like(actions_var)
+ meta_reward = tf.identity(meta_reward) + tf.zeros_like(meta_reward)
+ meta_reward = tf.reshape(meta_reward, reward.shape)
+
+ reward = 0.1 * meta_reward
+ meta_transition = [state_var, meta_action_var,
+ reward_var + reward,
+ discount * (1 - tf.to_float(next_reset_episode_cond)),
+ next_state]
+ meta_transition.extend([states_var, actions])
+ if store_context: # store current and next context into replay
+ transition += context + list(agent.context_vars)
+ meta_transition += meta_context_var + list(meta_agent.context_vars)
+
+ meta_step_cond = tf.squeeze(tf.logical_and(step_cond, tf.logical_or(next_reset_episode_cond, meta_end)))
+
+ collect_experience_op = tf.group(
+ replay_buffer.maybe_add(transition, step_cond),
+ meta_replay_buffer.maybe_add(meta_transition, meta_step_cond),
+ )
+
+ with tf.control_dependencies([collect_experience_op]):
+ collect_experience_op = tf.cond(reset_env_cond,
+ tf_env.reset,
+ tf_env.current_time_step)
+
+ meta_period = tf.equal(agent.tf_context.t % meta_action_every_n, 1)
+ states_var_upd = tf.scatter_update(
+ states_var, (agent.tf_context.t - 1) % meta_action_every_n,
+ next_state)
+ state_var_upd = tf.assign(
+ state_var,
+ tf.cond(meta_period, lambda: next_state, lambda: state_var))
+ reward_var_upd = tf.assign(
+ reward_var,
+ tf.cond(meta_period,
+ lambda: tf.zeros_like(reward_var),
+ lambda: reward_var + reward))
+ meta_action = tf.to_float(tf.concat(agent.context_vars, -1))
+ meta_action_var_upd = tf.assign(
+ meta_action_var,
+ tf.cond(meta_period, lambda: meta_action, lambda: meta_action_var))
+ meta_context_var_upd = [
+ tf.assign(
+ meta_context_var[idx],
+ tf.cond(meta_period,
+ lambda: meta_agent.context_vars[idx],
+ lambda: meta_context_var[idx]))
+ for idx in range(len(meta_context))]
+
+ return tf.group(
+ collect_experience_op,
+ states_var_upd,
+ state_var_upd,
+ reward_var_upd,
+ meta_action_var_upd,
+ *meta_context_var_upd)
+
+
+def sample_best_meta_actions(state_reprs, next_state_reprs, prev_meta_actions,
+ low_states, low_actions, low_state_reprs,
+ inverse_dynamics, uvf_agent, k=10):
+ """Return meta-actions which approximately maximize low-level log-probs."""
+ sampled_actions = inverse_dynamics.sample(state_reprs, next_state_reprs, k, prev_meta_actions)
+ sampled_actions = tf.stop_gradient(sampled_actions)
+ sampled_log_probs = tf.reshape(uvf_agent.log_probs(
+ tf.tile(low_states, [k, 1, 1]),
+ tf.tile(low_actions, [k, 1, 1]),
+ tf.tile(low_state_reprs, [k, 1, 1]),
+ [tf.reshape(sampled_actions, [-1, sampled_actions.shape[-1]])]),
+ [k, low_states.shape[0],
+ low_states.shape[1], -1])
+ fitness = tf.reduce_sum(sampled_log_probs, [2, 3])
+ best_actions = tf.argmax(fitness, 0)
+ actions = tf.gather_nd(
+ sampled_actions,
+ tf.stack([best_actions,
+ tf.range(prev_meta_actions.shape[0], dtype=tf.int64)], -1))
+ return actions
+
+
+@gin.configurable
+def train_uvf(train_dir,
+ environment=None,
+ num_bin_actions=3,
+ agent_class=None,
+ meta_agent_class=None,
+ state_preprocess_class=None,
+ inverse_dynamics_class=None,
+ exp_action_wrapper=None,
+ replay_buffer=None,
+ meta_replay_buffer=None,
+ replay_num_steps=1,
+ meta_replay_num_steps=1,
+ critic_optimizer=None,
+ actor_optimizer=None,
+ meta_critic_optimizer=None,
+ meta_actor_optimizer=None,
+ repr_optimizer=None,
+ relabel_contexts=False,
+ meta_relabel_contexts=False,
+ batch_size=64,
+ repeat_size=0,
+ num_episodes_train=2000,
+ initial_episodes=2,
+ initial_steps=None,
+ num_updates_per_observation=1,
+ num_collect_per_update=1,
+ num_collect_per_meta_update=1,
+ gamma=1.0,
+ meta_gamma=1.0,
+ reward_scale_factor=1.0,
+ target_update_period=1,
+ should_stop_early=None,
+ clip_gradient_norm=0.0,
+ summarize_gradients=False,
+ debug_summaries=False,
+ log_every_n_steps=100,
+ prefetch_queue_capacity=2,
+ policy_save_dir='policy',
+ save_policy_every_n_steps=1000,
+ save_policy_interval_secs=0,
+ replay_context_ratio=0.0,
+ next_state_as_context_ratio=0.0,
+ state_index=0,
+ zero_timer_ratio=0.0,
+ timer_index=-1,
+ debug=False,
+ max_policies_to_save=None,
+ max_steps_per_episode=None,
+ load_path=LOAD_PATH):
+ """Train an agent."""
+ tf_env = create_maze_env.TFPyEnvironment(environment)
+ observation_spec = [tf_env.observation_spec()]
+ action_spec = [tf_env.action_spec()]
+
+ max_steps_per_episode = max_steps_per_episode or tf_env.pyenv.max_episode_steps
+
+ assert max_steps_per_episode, 'max_steps_per_episode need to be set'
+
+ if initial_steps is None:
+ initial_steps = initial_episodes * max_steps_per_episode
+
+ if agent_class.ACTION_TYPE == 'discrete':
+ assert False
+ else:
+ assert agent_class.ACTION_TYPE == 'continuous'
+
+ assert agent_class.ACTION_TYPE == meta_agent_class.ACTION_TYPE
+ with tf.variable_scope('meta_agent'):
+ meta_agent = meta_agent_class(
+ observation_spec,
+ action_spec,
+ tf_env,
+ debug_summaries=debug_summaries)
+ meta_agent.set_replay(replay=meta_replay_buffer)
+
+ with tf.variable_scope('uvf_agent'):
+ uvf_agent = agent_class(
+ observation_spec,
+ action_spec,
+ tf_env,
+ debug_summaries=debug_summaries)
+ uvf_agent.set_meta_agent(agent=meta_agent)
+ uvf_agent.set_replay(replay=replay_buffer)
+
+ with tf.variable_scope('state_preprocess'):
+ state_preprocess = state_preprocess_class()
+
+ with tf.variable_scope('inverse_dynamics'):
+ inverse_dynamics = inverse_dynamics_class(
+ meta_agent.sub_context_as_action_specs[0])
+
+ # Create counter variables
+ global_step = tf.contrib.framework.get_or_create_global_step()
+ num_episodes = tf.Variable(0, dtype=tf.int64, name='num_episodes')
+ num_resets = tf.Variable(0, dtype=tf.int64, name='num_resets')
+ num_updates = tf.Variable(0, dtype=tf.int64, name='num_updates')
+ num_meta_updates = tf.Variable(0, dtype=tf.int64, name='num_meta_updates')
+ episode_rewards = tf.Variable([0.] * 100, name='episode_rewards')
+ episode_meta_rewards = tf.Variable([0.] * 100, name='episode_meta_rewards')
+
+ # Create counter variables summaries
+ train_utils.create_counter_summaries([
+ ('environment_steps', global_step),
+ ('num_episodes', num_episodes),
+ ('num_resets', num_resets),
+ ('num_updates', num_updates),
+ ('num_meta_updates', num_meta_updates),
+ ('replay_buffer_adds', replay_buffer.get_num_adds()),
+ ('meta_replay_buffer_adds', meta_replay_buffer.get_num_adds()),
+ ])
+
+ tf.summary.scalar('avg_episode_rewards',
+ tf.reduce_mean(episode_rewards[1:]))
+ tf.summary.scalar('avg_episode_meta_rewards',
+ tf.reduce_mean(episode_meta_rewards[1:]))
+ tf.summary.histogram('episode_rewards', episode_rewards[1:])
+ tf.summary.histogram('episode_meta_rewards', episode_meta_rewards[1:])
+
+ # Create init ops
+ action_fn = uvf_agent.action
+ action_fn = uvf_agent.add_noise_fn(action_fn, global_step=None)
+ meta_action_fn = meta_agent.action
+ meta_action_fn = meta_agent.add_noise_fn(meta_action_fn, global_step=None)
+ meta_actions_fn = meta_agent.actions
+ meta_actions_fn = meta_agent.add_noise_fn(meta_actions_fn, global_step=None)
+ init_collect_experience_op = collect_experience(
+ tf_env,
+ uvf_agent,
+ meta_agent,
+ state_preprocess,
+ replay_buffer,
+ meta_replay_buffer,
+ action_fn,
+ meta_action_fn,
+ environment_steps=global_step,
+ num_episodes=num_episodes,
+ num_resets=num_resets,
+ episode_rewards=episode_rewards,
+ episode_meta_rewards=episode_meta_rewards,
+ store_context=True,
+ disable_agent_reset=False,
+ )
+
+ # Create train ops
+ collect_experience_op = collect_experience(
+ tf_env,
+ uvf_agent,
+ meta_agent,
+ state_preprocess,
+ replay_buffer,
+ meta_replay_buffer,
+ action_fn,
+ meta_action_fn,
+ environment_steps=global_step,
+ num_episodes=num_episodes,
+ num_resets=num_resets,
+ episode_rewards=episode_rewards,
+ episode_meta_rewards=episode_meta_rewards,
+ store_context=True,
+ disable_agent_reset=False,
+ )
+
+ train_op_list = []
+ repr_train_op = tf.constant(0.0)
+ for mode in ['meta', 'nometa']:
+ if mode == 'meta':
+ agent = meta_agent
+ buff = meta_replay_buffer
+ critic_opt = meta_critic_optimizer
+ actor_opt = meta_actor_optimizer
+ relabel = meta_relabel_contexts
+ num_steps = meta_replay_num_steps
+ my_gamma = meta_gamma,
+ n_updates = num_meta_updates
+ else:
+ agent = uvf_agent
+ buff = replay_buffer
+ critic_opt = critic_optimizer
+ actor_opt = actor_optimizer
+ relabel = relabel_contexts
+ num_steps = replay_num_steps
+ my_gamma = gamma
+ n_updates = num_updates
+
+ with tf.name_scope(mode):
+ batch = buff.get_random_batch(batch_size, num_steps=num_steps)
+ states, actions, rewards, discounts, next_states = batch[:5]
+ with tf.name_scope('Reward'):
+ tf.summary.scalar('average_step_reward', tf.reduce_mean(rewards))
+ rewards *= reward_scale_factor
+ batch_queue = slim.prefetch_queue.prefetch_queue(
+ [states, actions, rewards, discounts, next_states] + batch[5:],
+ capacity=prefetch_queue_capacity,
+ name='batch_queue')
+
+ batch_dequeue = batch_queue.dequeue()
+ if repeat_size > 0:
+ batch_dequeue = [
+ tf.tile(batch, (repeat_size+1,) + (1,) * (batch.shape.ndims - 1))
+ for batch in batch_dequeue
+ ]
+ batch_size *= (repeat_size + 1)
+ states, actions, rewards, discounts, next_states = batch_dequeue[:5]
+ if mode == 'meta':
+ low_states = batch_dequeue[5]
+ low_actions = batch_dequeue[6]
+ low_state_reprs = state_preprocess(low_states)
+ state_reprs = state_preprocess(states)
+ next_state_reprs = state_preprocess(next_states)
+
+ if mode == 'meta': # Re-label meta-action
+ prev_actions = actions
+ if FLAGS.goal_sample_strategy == 'None':
+ pass
+ elif FLAGS.goal_sample_strategy == 'FuN':
+ actions = inverse_dynamics.sample(state_reprs, next_state_reprs, 1, prev_actions, sc=0.1)
+ actions = tf.stop_gradient(actions)
+ elif FLAGS.goal_sample_strategy == 'sample':
+ actions = sample_best_meta_actions(state_reprs, next_state_reprs, prev_actions,
+ low_states, low_actions, low_state_reprs,
+ inverse_dynamics, uvf_agent, k=10)
+ else:
+ assert False
+
+ if state_preprocess.trainable and mode == 'meta':
+ # Representation learning is based on meta-transitions, but is trained
+ # along with low-level policy updates.
+ repr_loss, _, _ = state_preprocess.loss(states, next_states, low_actions, low_states)
+ repr_train_op = slim.learning.create_train_op(
+ repr_loss,
+ repr_optimizer,
+ global_step=None,
+ update_ops=None,
+ summarize_gradients=summarize_gradients,
+ clip_gradient_norm=clip_gradient_norm,
+ variables_to_train=state_preprocess.get_trainable_vars(),)
+
+ # Get contexts for training
+ contexts, next_contexts = agent.sample_contexts(
+ mode='train', batch_size=batch_size,
+ state=states, next_state=next_states,
+ )
+ if not relabel: # Re-label context (in the style of TDM or HER).
+ contexts, next_contexts = (
+ batch_dequeue[-2*len(contexts):-1*len(contexts)],
+ batch_dequeue[-1*len(contexts):])
+
+ merged_states = agent.merged_states(states, contexts)
+ merged_next_states = agent.merged_states(next_states, next_contexts)
+ if mode == 'nometa':
+ context_rewards, context_discounts = agent.compute_rewards(
+ 'train', state_reprs, actions, rewards, next_state_reprs, contexts)
+ elif mode == 'meta': # Meta-agent uses sum of rewards, not context-specific rewards.
+ _, context_discounts = agent.compute_rewards(
+ 'train', states, actions, rewards, next_states, contexts)
+ context_rewards = rewards
+
+ if agent.gamma_index is not None:
+ context_discounts *= tf.cast(
+ tf.reshape(contexts[agent.gamma_index], (-1,)),
+ dtype=context_discounts.dtype)
+ else: context_discounts *= my_gamma
+
+ critic_loss = agent.critic_loss(merged_states, actions,
+ context_rewards, context_discounts,
+ merged_next_states)
+
+ critic_loss = tf.reduce_mean(critic_loss)
+
+ actor_loss = agent.actor_loss(merged_states, actions,
+ context_rewards, context_discounts,
+ merged_next_states)
+ actor_loss *= tf.to_float( # Only update actor every N steps.
+ tf.equal(n_updates % target_update_period, 0))
+
+ critic_train_op = slim.learning.create_train_op(
+ critic_loss,
+ critic_opt,
+ global_step=n_updates,
+ update_ops=None,
+ summarize_gradients=summarize_gradients,
+ clip_gradient_norm=clip_gradient_norm,
+ variables_to_train=agent.get_trainable_critic_vars(),)
+ critic_train_op = uvf_utils.tf_print(
+ critic_train_op, [critic_train_op],
+ message='critic_loss',
+ print_freq=1000,
+ name='critic_loss')
+ train_op_list.append(critic_train_op)
+ if actor_loss is not None:
+ actor_train_op = slim.learning.create_train_op(
+ actor_loss,
+ actor_opt,
+ global_step=None,
+ update_ops=None,
+ summarize_gradients=summarize_gradients,
+ clip_gradient_norm=clip_gradient_norm,
+ variables_to_train=agent.get_trainable_actor_vars(),)
+ actor_train_op = uvf_utils.tf_print(
+ actor_train_op, [actor_train_op],
+ message='actor_loss',
+ print_freq=1000,
+ name='actor_loss')
+ train_op_list.append(actor_train_op)
+
+ assert len(train_op_list) == 4
+ # Update targets should happen after the networks have been updated.
+ with tf.control_dependencies(train_op_list[2:]):
+ update_targets_op = uvf_utils.periodically(
+ uvf_agent.update_targets, target_update_period, 'update_targets')
+ if meta_agent is not None:
+ with tf.control_dependencies(train_op_list[:2]):
+ update_meta_targets_op = uvf_utils.periodically(
+ meta_agent.update_targets, target_update_period, 'update_targets')
+
+ assert_op = tf.Assert( # Hack to get training to stop.
+ tf.less_equal(global_step, 200 + num_episodes_train * max_steps_per_episode),
+ [global_step])
+ with tf.control_dependencies([update_targets_op, assert_op]):
+ train_op = tf.add_n(train_op_list[2:], name='post_update_targets')
+ # Representation training steps on every low-level policy training step.
+ train_op += repr_train_op
+ with tf.control_dependencies([update_meta_targets_op, assert_op]):
+ meta_train_op = tf.add_n(train_op_list[:2],
+ name='post_update_meta_targets')
+
+ if debug_summaries:
+ train_.gen_debug_batch_summaries(batch)
+ slim.summaries.add_histogram_summaries(
+ uvf_agent.get_trainable_critic_vars(), 'critic_vars')
+ slim.summaries.add_histogram_summaries(
+ uvf_agent.get_trainable_actor_vars(), 'actor_vars')
+
+ train_ops = train_utils.TrainOps(train_op, meta_train_op,
+ collect_experience_op)
+
+ policy_save_path = os.path.join(train_dir, policy_save_dir, 'model.ckpt')
+ policy_vars = uvf_agent.get_actor_vars() + meta_agent.get_actor_vars() + [
+ global_step, num_episodes, num_resets
+ ] + list(uvf_agent.context_vars) + list(meta_agent.context_vars) + state_preprocess.get_trainable_vars()
+ # add critic vars, since some test evaluation depends on them
+ policy_vars += uvf_agent.get_trainable_critic_vars() + meta_agent.get_trainable_critic_vars()
+ policy_saver = tf.train.Saver(
+ policy_vars, max_to_keep=max_policies_to_save, sharded=False)
+
+ lowlevel_vars = (uvf_agent.get_actor_vars() +
+ uvf_agent.get_trainable_critic_vars() +
+ state_preprocess.get_trainable_vars())
+ lowlevel_saver = tf.train.Saver(lowlevel_vars)
+
+ def policy_save_fn(sess):
+ policy_saver.save(
+ sess, policy_save_path, global_step=global_step, write_meta_graph=False)
+ if save_policy_interval_secs > 0:
+ tf.logging.info(
+ 'Wait %d secs after save policy.' % save_policy_interval_secs)
+ time.sleep(save_policy_interval_secs)
+
+ train_step_fn = train_utils.TrainStep(
+ max_number_of_steps=num_episodes_train * max_steps_per_episode + 100,
+ num_updates_per_observation=num_updates_per_observation,
+ num_collect_per_update=num_collect_per_update,
+ num_collect_per_meta_update=num_collect_per_meta_update,
+ log_every_n_steps=log_every_n_steps,
+ policy_save_fn=policy_save_fn,
+ save_policy_every_n_steps=save_policy_every_n_steps,
+ should_stop_early=should_stop_early).train_step
+
+ local_init_op = tf.local_variables_initializer()
+ init_targets_op = tf.group(uvf_agent.update_targets(1.0),
+ meta_agent.update_targets(1.0))
+
+ def initialize_training_fn(sess):
+ """Initialize training function."""
+ sess.run(local_init_op)
+ sess.run(init_targets_op)
+ if load_path:
+ tf.logging.info('Restoring low-level from %s' % load_path)
+ lowlevel_saver.restore(sess, load_path)
+ global_step_value = sess.run(global_step)
+ assert global_step_value == 0, 'Global step should be zero.'
+ collect_experience_call = sess.make_callable(
+ init_collect_experience_op)
+
+ for _ in range(initial_steps):
+ collect_experience_call()
+
+ train_saver = tf.train.Saver(max_to_keep=2, sharded=True)
+ tf.logging.info('train dir: %s', train_dir)
+ return slim.learning.train(
+ train_ops,
+ train_dir,
+ train_step_fn=train_step_fn,
+ save_interval_secs=FLAGS.save_interval_secs,
+ saver=train_saver,
+ log_every_n_steps=0,
+ global_step=global_step,
+ master="",
+ is_chief=(FLAGS.task == 0),
+ save_summaries_secs=FLAGS.save_summaries_secs,
+ init_fn=initialize_training_fn)
diff --git a/models/research/efficient-hrl/train_utils.py b/models/research/efficient-hrl/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae23ef9f095ed1755c74579223b017239ccf1009
--- /dev/null
+++ b/models/research/efficient-hrl/train_utils.py
@@ -0,0 +1,175 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r""""""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+import os
+import time
+
+import tensorflow as tf
+
+import gin.tf
+
+flags = tf.app.flags
+
+
+flags.DEFINE_multi_string('config_file', None,
+ 'List of paths to the config files.')
+flags.DEFINE_multi_string('params', None,
+ 'Newline separated list of Gin parameter bindings.')
+
+flags.DEFINE_string('train_dir', None,
+ 'Directory for writing logs/summaries during training.')
+flags.DEFINE_string('master', 'local',
+ 'BNS name of the TensorFlow master to use.')
+flags.DEFINE_integer('task', 0, 'task id')
+flags.DEFINE_integer('save_interval_secs', 300, 'The frequency at which '
+ 'checkpoints are saved, in seconds.')
+flags.DEFINE_integer('save_summaries_secs', 30, 'The frequency at which '
+ 'summaries are saved, in seconds.')
+flags.DEFINE_boolean('summarize_gradients', False,
+ 'Whether to generate gradient summaries.')
+
+FLAGS = flags.FLAGS
+
+TrainOps = namedtuple('TrainOps',
+ ['train_op', 'meta_train_op', 'collect_experience_op'])
+
+
+class TrainStep(object):
+ """Handles training step."""
+
+ def __init__(self,
+ max_number_of_steps=0,
+ num_updates_per_observation=1,
+ num_collect_per_update=1,
+ num_collect_per_meta_update=1,
+ log_every_n_steps=1,
+ policy_save_fn=None,
+ save_policy_every_n_steps=0,
+ should_stop_early=None):
+ """Returns a function that is executed at each step of slim training.
+
+ Args:
+ max_number_of_steps: Optional maximum number of train steps to take.
+ num_updates_per_observation: Number of updates per observation.
+ log_every_n_steps: The frequency, in terms of global steps, that the loss
+ and global step and logged.
+ policy_save_fn: A tf.Saver().save function to save the policy.
+ save_policy_every_n_steps: How frequently to save the policy.
+ should_stop_early: Optional hook to report whether training should stop.
+ Raises:
+ ValueError: If policy_save_fn is not provided when
+ save_policy_every_n_steps > 0.
+ """
+ if save_policy_every_n_steps and policy_save_fn is None:
+ raise ValueError(
+ 'policy_save_fn is required when save_policy_every_n_steps > 0')
+ self.max_number_of_steps = max_number_of_steps
+ self.num_updates_per_observation = num_updates_per_observation
+ self.num_collect_per_update = num_collect_per_update
+ self.num_collect_per_meta_update = num_collect_per_meta_update
+ self.log_every_n_steps = log_every_n_steps
+ self.policy_save_fn = policy_save_fn
+ self.save_policy_every_n_steps = save_policy_every_n_steps
+ self.should_stop_early = should_stop_early
+ self.last_global_step_val = 0
+ self.train_op_fn = None
+ self.collect_and_train_fn = None
+ tf.logging.info('Training for %d max_number_of_steps',
+ self.max_number_of_steps)
+
+ def train_step(self, sess, train_ops, global_step, _):
+ """This function will be called at each step of training.
+
+ This represents one step of the DDPG algorithm and can include:
+ 1. collect a transition
+ 2. update the target network
+ 3. train the actor
+ 4. train the critic
+
+ Args:
+ sess: A Tensorflow session.
+ train_ops: A DdpgTrainOps tuple of train ops to run.
+ global_step: The global step.
+
+ Returns:
+ A scalar total loss.
+ A boolean should stop.
+ """
+ start_time = time.time()
+ if self.train_op_fn is None:
+ self.train_op_fn = sess.make_callable([train_ops.train_op, global_step])
+ self.meta_train_op_fn = sess.make_callable([train_ops.meta_train_op, global_step])
+ self.collect_fn = sess.make_callable([train_ops.collect_experience_op, global_step])
+ self.collect_and_train_fn = sess.make_callable(
+ [train_ops.train_op, global_step, train_ops.collect_experience_op])
+ self.collect_and_meta_train_fn = sess.make_callable(
+ [train_ops.meta_train_op, global_step, train_ops.collect_experience_op])
+ for _ in range(self.num_collect_per_update - 1):
+ self.collect_fn()
+ for _ in range(self.num_updates_per_observation - 1):
+ self.train_op_fn()
+
+ total_loss, global_step_val, _ = self.collect_and_train_fn()
+ if (global_step_val // self.num_collect_per_meta_update !=
+ self.last_global_step_val // self.num_collect_per_meta_update):
+ self.meta_train_op_fn()
+
+ time_elapsed = time.time() - start_time
+ should_stop = False
+ if self.max_number_of_steps:
+ should_stop = global_step_val >= self.max_number_of_steps
+ if global_step_val != self.last_global_step_val:
+ if (self.save_policy_every_n_steps and
+ global_step_val // self.save_policy_every_n_steps !=
+ self.last_global_step_val // self.save_policy_every_n_steps):
+ self.policy_save_fn(sess)
+
+ if (self.log_every_n_steps and
+ global_step_val % self.log_every_n_steps == 0):
+ tf.logging.info(
+ 'global step %d: loss = %.4f (%.3f sec/step) (%d steps/sec)',
+ global_step_val, total_loss, time_elapsed, 1 / time_elapsed)
+
+ self.last_global_step_val = global_step_val
+ stop_early = bool(self.should_stop_early and self.should_stop_early())
+ return total_loss, should_stop or stop_early
+
+
+def create_counter_summaries(counters):
+ """Add named summaries to counters, a list of tuples (name, counter)."""
+ if counters:
+ with tf.name_scope('Counters/'):
+ for name, counter in counters:
+ tf.summary.scalar(name, counter)
+
+
+def gen_debug_batch_summaries(batch):
+ """Generates summaries for the sampled replay batch."""
+ states, actions, rewards, _, next_states = batch
+ with tf.name_scope('batch'):
+ for s in range(states.get_shape()[-1]):
+ tf.summary.histogram('states_%d' % s, states[:, s])
+ for s in range(states.get_shape()[-1]):
+ tf.summary.histogram('next_states_%d' % s, next_states[:, s])
+ for a in range(actions.get_shape()[-1]):
+ tf.summary.histogram('actions_%d' % a, actions[:, a])
+ tf.summary.histogram('rewards', rewards)
diff --git a/models/research/efficient-hrl/utils/__init__.py b/models/research/efficient-hrl/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/research/efficient-hrl/utils/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/research/efficient-hrl/utils/eval_utils.py b/models/research/efficient-hrl/utils/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c88efc80fe1cc3399027cf71e310db85e3653df9
--- /dev/null
+++ b/models/research/efficient-hrl/utils/eval_utils.py
@@ -0,0 +1,151 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Evaluation utility functions.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf
+from collections import namedtuple
+logging = tf.logging
+import gin.tf
+
+
+@gin.configurable
+def evaluate_checkpoint_repeatedly(checkpoint_dir,
+ evaluate_checkpoint_fn,
+ eval_interval_secs=600,
+ max_number_of_evaluations=None,
+ checkpoint_timeout=None,
+ timeout_fn=None):
+ """Evaluates a checkpointed model at a set interval."""
+ if max_number_of_evaluations is not None and max_number_of_evaluations <= 0:
+ raise ValueError(
+ '`max_number_of_evaluations` must be either None or a positive number.')
+
+ number_of_evaluations = 0
+ for checkpoint_path in tf.contrib.training.checkpoints_iterator(
+ checkpoint_dir,
+ min_interval_secs=eval_interval_secs,
+ timeout=checkpoint_timeout,
+ timeout_fn=timeout_fn):
+ retries = 3
+ for _ in range(retries):
+ try:
+ should_stop = evaluate_checkpoint_fn(checkpoint_path)
+ break
+ except tf.errors.DataLossError as e:
+ logging.warn(
+ 'Encountered a DataLossError while evaluating a checkpoint. This '
+ 'can happen when reading a checkpoint before it is fully written. '
+ 'Retrying...'
+ )
+ time.sleep(2.0)
+
+
+def compute_model_loss(sess, model_rollout_fn, states, actions):
+ """Computes model loss."""
+ preds, losses = [], []
+ preds.append(states[0])
+ losses.append(0)
+ for state, action in zip(states[1:], actions[1:]):
+ pred = model_rollout_fn(sess, preds[-1], action)
+ loss = np.sqrt(np.sum((state - pred) ** 2))
+ preds.append(pred)
+ losses.append(loss)
+ return preds, losses
+
+
+def compute_average_reward(sess, env_base, step_fn, gamma, num_steps,
+ num_episodes):
+ """Computes the discounted reward for a given number of steps.
+
+ Args:
+ sess: The tensorflow session.
+ env_base: A python environment.
+ step_fn: A function that takes in `sess` and returns a list of
+ [state, action, reward, discount, transition_type] values.
+ gamma: discounting factor to apply to the reward.
+ num_steps: number of steps to compute the reward over.
+ num_episodes: number of episodes to average the reward over.
+ Returns:
+ average_reward: a scalar of discounted reward.
+ last_reward: last reward received.
+ """
+ average_reward = 0
+ average_last_reward = 0
+ average_meta_reward = 0
+ average_last_meta_reward = 0
+ average_success = 0.
+ states, actions = None, None
+ for i in range(num_episodes):
+ env_base.end_episode()
+ env_base.begin_episode()
+ (reward, last_reward, meta_reward, last_meta_reward,
+ states, actions) = compute_reward(
+ sess, step_fn, gamma, num_steps)
+ s_reward = last_meta_reward # Navigation
+ success = (s_reward > -5.0) # When using diff=False
+ logging.info('Episode = %d, reward = %s, meta_reward = %f, '
+ 'last_reward = %s, last meta_reward = %f, success = %s',
+ i, reward, meta_reward, last_reward, last_meta_reward,
+ success)
+ average_reward += reward
+ average_last_reward += last_reward
+ average_meta_reward += meta_reward
+ average_last_meta_reward += last_meta_reward
+ average_success += success
+ average_reward /= num_episodes
+ average_last_reward /= num_episodes
+ average_meta_reward /= num_episodes
+ average_last_meta_reward /= num_episodes
+ average_success /= num_episodes
+ return (average_reward, average_last_reward,
+ average_meta_reward, average_last_meta_reward,
+ average_success,
+ states, actions)
+
+
+def compute_reward(sess, step_fn, gamma, num_steps):
+ """Computes the discounted reward for a given number of steps.
+
+ Args:
+ sess: The tensorflow session.
+ step_fn: A function that takes in `sess` and returns a list of
+ [state, action, reward, discount, transition_type] values.
+ gamma: discounting factor to apply to the reward.
+ num_steps: number of steps to compute the reward over.
+ Returns:
+ reward: cumulative discounted reward.
+ last_reward: reward received at final step.
+ """
+
+ total_reward = 0
+ total_meta_reward = 0
+ gamma_step = 1
+ states = []
+ actions = []
+ for _ in range(num_steps):
+ state, action, transition_type, reward, meta_reward, discount, _, _ = step_fn(sess)
+ total_reward += reward * gamma_step * discount
+ total_meta_reward += meta_reward * gamma_step * discount
+ gamma_step *= gamma
+ states.append(state)
+ actions.append(action)
+ return (total_reward, reward, total_meta_reward, meta_reward,
+ states, actions)
diff --git a/models/research/efficient-hrl/utils/utils.py b/models/research/efficient-hrl/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e188316c33b006a34baaaf729a02cca2e13d92e8
--- /dev/null
+++ b/models/research/efficient-hrl/utils/utils.py
@@ -0,0 +1,318 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""TensorFlow utility functions.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from copy import deepcopy
+import tensorflow as tf
+from tf_agents import specs
+from tf_agents.utils import common
+
+_tf_print_counts = dict()
+_tf_print_running_sums = dict()
+_tf_print_running_counts = dict()
+_tf_print_ids = 0
+
+
+def get_contextual_env_base(env_base, begin_ops=None, end_ops=None):
+ """Wrap env_base with additional tf ops."""
+ # pylint: disable=protected-access
+ def init(self_, env_base):
+ self_._env_base = env_base
+ attribute_list = ["_render_mode", "_gym_env"]
+ for attribute in attribute_list:
+ if hasattr(env_base, attribute):
+ setattr(self_, attribute, getattr(env_base, attribute))
+ if hasattr(env_base, "physics"):
+ self_._physics = env_base.physics
+ elif hasattr(env_base, "gym"):
+ class Physics(object):
+ def render(self, *args, **kwargs):
+ return env_base.gym.render("rgb_array")
+ physics = Physics()
+ self_._physics = physics
+ self_.physics = physics
+ def set_sess(self_, sess):
+ self_._sess = sess
+ if hasattr(self_._env_base, "set_sess"):
+ self_._env_base.set_sess(sess)
+ def begin_episode(self_):
+ self_._env_base.reset()
+ if begin_ops is not None:
+ self_._sess.run(begin_ops)
+ def end_episode(self_):
+ self_._env_base.reset()
+ if end_ops is not None:
+ self_._sess.run(end_ops)
+ return type("ContextualEnvBase", (env_base.__class__,), dict(
+ __init__=init,
+ set_sess=set_sess,
+ begin_episode=begin_episode,
+ end_episode=end_episode,
+ ))(env_base)
+ # pylint: enable=protected-access
+
+
+def merge_specs(specs_):
+ """Merge TensorSpecs.
+
+ Args:
+ specs_: List of TensorSpecs to be merged.
+ Returns:
+ a TensorSpec: a merged TensorSpec.
+ """
+ shape = specs_[0].shape
+ dtype = specs_[0].dtype
+ name = specs_[0].name
+ for spec in specs_[1:]:
+ assert shape[1:] == spec.shape[1:], "incompatible shapes: %s, %s" % (
+ shape, spec.shape)
+ assert dtype == spec.dtype, "incompatible dtypes: %s, %s" % (
+ dtype, spec.dtype)
+ shape = merge_shapes((shape, spec.shape), axis=0)
+ return specs.TensorSpec(
+ shape=shape,
+ dtype=dtype,
+ name=name,
+ )
+
+
+def merge_shapes(shapes, axis=0):
+ """Merge TensorShapes.
+
+ Args:
+ shapes: List of TensorShapes to be merged.
+ axis: optional, the axis to merge shaped.
+ Returns:
+ a TensorShape: a merged TensorShape.
+ """
+ assert len(shapes) > 1
+ dims = deepcopy(shapes[0].dims)
+ for shape in shapes[1:]:
+ assert shapes[0].ndims == shape.ndims
+ dims[axis] += shape.dims[axis]
+ return tf.TensorShape(dims=dims)
+
+
+def get_all_vars(ignore_scopes=None):
+ """Get all tf variables in scope.
+
+ Args:
+ ignore_scopes: A list of scope names to ignore.
+ Returns:
+ A list of all tf variables in scope.
+ """
+ all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
+ all_vars = [var for var in all_vars if ignore_scopes is None or not
+ any(var.name.startswith(scope) for scope in ignore_scopes)]
+ return all_vars
+
+
+def clip(tensor, range_=None):
+ """Return a tf op which clips tensor according to range_.
+
+ Args:
+ tensor: A Tensor to be clipped.
+ range_: None, or a tuple representing (minval, maxval)
+ Returns:
+ A clipped Tensor.
+ """
+ if range_ is None:
+ return tf.identity(tensor)
+ elif isinstance(range_, (tuple, list)):
+ assert len(range_) == 2
+ return tf.clip_by_value(tensor, range_[0], range_[1])
+ else: raise NotImplementedError("Unacceptable range input: %r" % range_)
+
+
+def clip_to_bounds(value, minimum, maximum):
+ """Clips value to be between minimum and maximum.
+
+ Args:
+ value: (tensor) value to be clipped.
+ minimum: (numpy float array) minimum value to clip to.
+ maximum: (numpy float array) maximum value to clip to.
+ Returns:
+ clipped_value: (tensor) `value` clipped to between `minimum` and `maximum`.
+ """
+ value = tf.minimum(value, maximum)
+ return tf.maximum(value, minimum)
+
+
+clip_to_spec = common.clip_to_spec
+def _clip_to_spec(value, spec):
+ """Clips value to a given bounded tensor spec.
+
+ Args:
+ value: (tensor) value to be clipped.
+ spec: (BoundedTensorSpec) spec containing min. and max. values for clipping.
+ Returns:
+ clipped_value: (tensor) `value` clipped to be compatible with `spec`.
+ """
+ return clip_to_bounds(value, spec.minimum, spec.maximum)
+
+
+join_scope = common.join_scope
+def _join_scope(parent_scope, child_scope):
+ """Joins a parent and child scope using `/`, checking for empty/none.
+
+ Args:
+ parent_scope: (string) parent/prefix scope.
+ child_scope: (string) child/suffix scope.
+ Returns:
+ joined scope: (string) parent and child scopes joined by /.
+ """
+ if not parent_scope:
+ return child_scope
+ if not child_scope:
+ return parent_scope
+ return '/'.join([parent_scope, child_scope])
+
+
+def assign_vars(vars_, values):
+ """Returns the update ops for assigning a list of vars.
+
+ Args:
+ vars_: A list of variables.
+ values: A list of tensors representing new values.
+ Returns:
+ A list of update ops for the variables.
+ """
+ return [var.assign(value) for var, value in zip(vars_, values)]
+
+
+def identity_vars(vars_):
+ """Return the identity ops for a list of tensors.
+
+ Args:
+ vars_: A list of tensors.
+ Returns:
+ A list of identity ops.
+ """
+ return [tf.identity(var) for var in vars_]
+
+
+def tile(var, batch_size=1):
+ """Return tiled tensor.
+
+ Args:
+ var: A tensor representing the state.
+ batch_size: Batch size.
+ Returns:
+ A tensor with shape [batch_size,] + var.shape.
+ """
+ batch_var = tf.tile(
+ tf.expand_dims(var, 0),
+ (batch_size,) + (1,) * var.get_shape().ndims)
+ return batch_var
+
+
+def batch_list(vars_list):
+ """Batch a list of variables.
+
+ Args:
+ vars_list: A list of tensor variables.
+ Returns:
+ A list of tensor variables with additional first dimension.
+ """
+ return [tf.expand_dims(var, 0) for var in vars_list]
+
+
+def tf_print(op,
+ tensors,
+ message="",
+ first_n=-1,
+ name=None,
+ sub_messages=None,
+ print_freq=-1,
+ include_count=True):
+ """tf.Print, but to stdout."""
+ # TODO(shanegu): `name` is deprecated. Remove from the rest of codes.
+ global _tf_print_ids
+ _tf_print_ids += 1
+ name = _tf_print_ids
+ _tf_print_counts[name] = 0
+ if print_freq > 0:
+ _tf_print_running_sums[name] = [0 for _ in tensors]
+ _tf_print_running_counts[name] = 0
+ def print_message(*xs):
+ """print message fn."""
+ _tf_print_counts[name] += 1
+ if print_freq > 0:
+ for i, x in enumerate(xs):
+ _tf_print_running_sums[name][i] += x
+ _tf_print_running_counts[name] += 1
+ if (print_freq <= 0 or _tf_print_running_counts[name] >= print_freq) and (
+ first_n < 0 or _tf_print_counts[name] <= first_n):
+ for i, x in enumerate(xs):
+ if print_freq > 0:
+ del x
+ x = _tf_print_running_sums[name][i]/_tf_print_running_counts[name]
+ if sub_messages is None:
+ sub_message = str(i)
+ else:
+ sub_message = sub_messages[i]
+ log_message = "%s, %s" % (message, sub_message)
+ if include_count:
+ log_message += ", count=%d" % _tf_print_counts[name]
+ tf.logging.info("[%s]: %s" % (log_message, x))
+ if print_freq > 0:
+ for i, x in enumerate(xs):
+ _tf_print_running_sums[name][i] = 0
+ _tf_print_running_counts[name] = 0
+ return xs[0]
+
+ print_op = tf.py_func(print_message, tensors, tensors[0].dtype)
+ with tf.control_dependencies([print_op]):
+ op = tf.identity(op)
+ return op
+
+
+periodically = common.periodically
+def _periodically(body, period, name='periodically'):
+ """Periodically performs a tensorflow op."""
+ if period is None or period == 0:
+ return tf.no_op()
+
+ if period < 0:
+ raise ValueError("period cannot be less than 0.")
+
+ if period == 1:
+ return body()
+
+ with tf.variable_scope(None, default_name=name):
+ counter = tf.get_variable(
+ "counter",
+ shape=[],
+ dtype=tf.int64,
+ trainable=False,
+ initializer=tf.constant_initializer(period, dtype=tf.int64))
+
+ def _wrapped_body():
+ with tf.control_dependencies([body()]):
+ return counter.assign(1)
+
+ update = tf.cond(
+ tf.equal(counter, period), _wrapped_body,
+ lambda: counter.assign_add(1))
+
+ return update
+
+soft_variables_update = common.soft_variables_update
diff --git a/models/research/feelvos/CONTRIBUTING.md b/models/research/feelvos/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..939e5341e74dc2371c8b47f0e27b50581bed5f63
--- /dev/null
+++ b/models/research/feelvos/CONTRIBUTING.md
@@ -0,0 +1,28 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution;
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+## Community Guidelines
+
+This project follows [Google's Open Source Community
+Guidelines](https://opensource.google.com/conduct/).
diff --git a/models/research/feelvos/LICENSE b/models/research/feelvos/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7
--- /dev/null
+++ b/models/research/feelvos/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/models/research/feelvos/README.md b/models/research/feelvos/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..69017c8b19fc1427c47cbdfbdce408ffa92ec32c
--- /dev/null
+++ b/models/research/feelvos/README.md
@@ -0,0 +1,102 @@
+
+
+
+
+# FEELVOS: Fast End-to-End Embedding Learning for Video Object Segmentation
+
+FEELVOS is a fast model for video object segmentation which does not rely on fine-tuning on the
+first frame.
+
+For details, please refer to our paper. If you find the code useful, please
+also consider citing it.
+
+* FEELVOS:
+
+```
+@inproceedings{feelvos2019,
+ title={FEELVOS: Fast End-to-End Embedding Learning for Video Object Segmentation},
+ author={Paul Voigtlaender and Yuning Chai and Florian Schroff and Hartwig Adam and Bastian Leibe and Liang-Chieh Chen},
+ booktitle={CVPR},
+ year={2019}
+}
+```
+
+## Dependencies
+
+FEELVOS requires a good GPU with around 12 GB of memory and depends on the following libraries
+
+* TensorFlow
+* Pillow
+* Numpy
+* Scipy
+* Scikit Learn Image
+* tf Slim (which is included in the "tensorflow/models/research/" checkout)
+* DeepLab (which is included in the "tensorflow/models/research/" checkout)
+* correlation_cost (optional, see below)
+
+For detailed steps to install Tensorflow, follow the [Tensorflow installation
+instructions](https://www.tensorflow.org/install/). A typical user can install
+Tensorflow using the following command:
+
+```bash
+pip install tensorflow-gpu
+```
+
+The remaining libraries can also be installed with pip using:
+
+```bash
+pip install pillow scipy scikit-image
+```
+
+## Dependency on correlation_cost
+
+For fast cross-correlation, we use correlation cost as an external dependency. By default FEELVOS
+will use a slow and memory hungry fallback implementation without correlation_cost. If you care for
+performance, you should set up correlation_cost by following the instructions in
+correlation_cost/README and afterwards setting ```USE_CORRELATION_COST = True``` in
+utils/embedding_utils.py.
+
+## Pre-trained Models
+
+We provide 2 pre-trained FEELVOS models, both are based on Xception-65:
+
+* [Trained on DAVIS 2017](http://download.tensorflow.org/models/feelvos_davis17_trained.tar.gz)
+* [Trained on DAVIS 2017 and YouTube-VOS](http://download.tensorflow.org/models/feelvos_davis17_and_youtubevos_trained.tar.gz)
+
+Additionally, we provide a [DeepLab checkpoint for Xception-65 pre-trained on ImageNet and COCO](http://download.tensorflow.org/models/xception_65_coco_pretrained_2018_10_02.tar.gz),
+which can be used as an initialization for training FEELVOS.
+
+## Pre-computed Segmentation Masks
+
+We provide [pre-computed segmentation masks](http://download.tensorflow.org/models/feelvos_precomputed_masks.zip)
+for FEELVOS both for training with and without YouTube-VOS data for the following datasets:
+
+* DAVIS 2017 validation set
+* DAVIS 2017 test-dev set
+* YouTube-Objects dataset
+
+## Local Inference
+For a demo of local inference on DAVIS 2017 run
+
+```bash
+# From tensorflow/models/research/feelvos
+sh eval.sh
+```
+
+## Local Training
+For a demo of local training on DAVIS 2017 run
+
+```bash
+# From tensorflow/models/research/feelvos
+sh train.sh
+```
+
+## Contacts (Maintainers)
+* Paul Voigtlaender, github: [pvoigtlaender](https://github.com/pvoigtlaender)
+* Yuning Chai, github: [yuningchai](https://github.com/yuningchai)
+* Liang-Chieh Chen, github: [aquariusjay](https://github.com/aquariusjay)
+
+## License
+
+All the codes in feelvos folder is covered by the [LICENSE](https://github.com/tensorflow/models/blob/master/LICENSE)
+under tensorflow/models. Please refer to the LICENSE for details.
diff --git a/models/research/feelvos/__init__.py b/models/research/feelvos/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f1373443d0ff84fd90714e41dade400ab41a22c
--- /dev/null
+++ b/models/research/feelvos/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/research/feelvos/common.py b/models/research/feelvos/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..98f5a9ce348aea36efa4b3cc57048d3659f18895
--- /dev/null
+++ b/models/research/feelvos/common.py
@@ -0,0 +1,163 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Provides flags that are common to scripts.
+
+Common flags from train/vis_video.py are collected in this script.
+"""
+import tensorflow as tf
+
+from deeplab import common
+
+flags = tf.app.flags
+
+flags.DEFINE_enum(
+ 'classification_loss', 'softmax_with_attention',
+ ['softmax', 'triplet', 'softmax_with_attention'],
+ 'Type of loss function used for classifying pixels, can be either softmax, '
+ 'softmax_with_attention, or triplet.')
+
+flags.DEFINE_integer('k_nearest_neighbors', 1,
+ 'The number of nearest neighbors to use.')
+
+flags.DEFINE_integer('embedding_dimension', 100, 'The dimension used for the '
+ 'learned embedding')
+
+flags.DEFINE_boolean('use_softmax_feedback', True,
+ 'Whether to give the softmax predictions of the last '
+ 'frame as additional input to the segmentation head.')
+
+flags.DEFINE_boolean('sample_adjacent_and_consistent_query_frames', True,
+ 'If true, the query frames (all but the first frame '
+ 'which is the reference frame) will be sampled such '
+ 'that they are adjacent video frames and have the same '
+ 'crop coordinates and flip augmentation. Note that if '
+ 'use_softmax_feedback is True, this option will '
+ 'automatically be activated.')
+
+flags.DEFINE_integer('embedding_seg_feature_dimension', 256,
+ 'The dimensionality used in the segmentation head layers.')
+
+flags.DEFINE_integer('embedding_seg_n_layers', 4, 'The number of layers in the '
+ 'segmentation head.')
+
+flags.DEFINE_integer('embedding_seg_kernel_size', 7, 'The kernel size used in '
+ 'the segmentation head.')
+
+flags.DEFINE_multi_integer('embedding_seg_atrous_rates', [],
+ 'The atrous rates to use for the segmentation head.')
+
+flags.DEFINE_boolean('normalize_nearest_neighbor_distances', True,
+ 'Whether to normalize the nearest neighbor distances '
+ 'to [0,1] using sigmoid, scale and shift.')
+
+flags.DEFINE_boolean('also_attend_to_previous_frame', True, 'Whether to also '
+ 'use nearest neighbor attention with respect to the '
+ 'previous frame.')
+
+flags.DEFINE_bool('use_local_previous_frame_attention', True,
+ 'Whether to restrict the previous frame attention to a local '
+ 'search window. Only has an effect, if '
+ 'also_attend_to_previous_frame is True.')
+
+flags.DEFINE_integer('previous_frame_attention_window_size', 15,
+ 'The window size used for local previous frame attention,'
+ ' if use_local_previous_frame_attention is True.')
+
+flags.DEFINE_boolean('use_first_frame_matching', True, 'Whether to extract '
+ 'features by matching to the reference frame. This should '
+ 'always be true except for ablation experiments.')
+
+FLAGS = flags.FLAGS
+
+# Constants
+
+# Perform semantic segmentation predictions.
+OUTPUT_TYPE = common.OUTPUT_TYPE
+
+# Semantic segmentation item names.
+LABELS_CLASS = common.LABELS_CLASS
+IMAGE = common.IMAGE
+HEIGHT = common.HEIGHT
+WIDTH = common.WIDTH
+IMAGE_NAME = common.IMAGE_NAME
+SOURCE_ID = 'source_id'
+VIDEO_ID = 'video_id'
+LABEL = common.LABEL
+ORIGINAL_IMAGE = common.ORIGINAL_IMAGE
+PRECEDING_FRAME_LABEL = 'preceding_frame_label'
+
+# Test set name.
+TEST_SET = common.TEST_SET
+
+# Internal constants.
+OBJECT_LABEL = 'object_label'
+
+
+class VideoModelOptions(common.ModelOptions):
+ """Internal version of immutable class to hold model options."""
+
+ def __new__(cls,
+ outputs_to_num_classes,
+ crop_size=None,
+ atrous_rates=None,
+ output_stride=8):
+ """Constructor to set default values.
+
+ Args:
+ outputs_to_num_classes: A dictionary from output type to the number of
+ classes. For example, for the task of semantic segmentation with 21
+ semantic classes, we would have outputs_to_num_classes['semantic'] = 21.
+ crop_size: A tuple [crop_height, crop_width].
+ atrous_rates: A list of atrous convolution rates for ASPP.
+ output_stride: The ratio of input to output spatial resolution.
+
+ Returns:
+ A new VideoModelOptions instance.
+ """
+ self = super(VideoModelOptions, cls).__new__(
+ cls,
+ outputs_to_num_classes,
+ crop_size,
+ atrous_rates,
+ output_stride)
+ # Add internal flags.
+ self.classification_loss = FLAGS.classification_loss
+
+ return self
+
+
+def parse_decoder_output_stride():
+ """Parses decoder output stride.
+
+ FEELVOS assumes decoder_output_stride = 4. Thus, this function is created for
+ this particular purpose.
+
+ Returns:
+ An integer specifying the decoder_output_stride.
+
+ Raises:
+ ValueError: If decoder_output_stride is None or contains more than one
+ element.
+ """
+ if FLAGS.decoder_output_stride:
+ decoder_output_stride = [
+ int(x) for x in FLAGS.decoder_output_stride]
+ if len(decoder_output_stride) != 1:
+ raise ValueError('Expect decoder output stride has only one element.')
+ decoder_output_stride = decoder_output_stride[0]
+ else:
+ raise ValueError('Expect flag decoder output stride not to be None.')
+ return decoder_output_stride
diff --git a/models/research/feelvos/correlation_cost/README.md b/models/research/feelvos/correlation_cost/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6cdbe550c7fcf63191f6967dd99c72cf341302bc
--- /dev/null
+++ b/models/research/feelvos/correlation_cost/README.md
@@ -0,0 +1,36 @@
+# correlation_cost
+
+FEELVOS uses correlation_cost as an optional dependency to improve the speed and memory consumption
+of cross-correlation.
+
+## Installation
+
+Unfortunately we cannot provide the code for correlation_cost directly, so you
+will have to copy some files from this pull request
+https://github.com/tensorflow/tensorflow/pull/21392/. For your convenience we
+prepared scripts to download and adjust the code automatically.
+
+In the best case, all you need to do is run compile.sh with the path to your
+CUDA installation (tested only with CUDA 9).
+Note that the path should be to a folder containing the cuda folder, not to the
+cuda folder itself, e.g. if your cuda is in /usr/local/cuda-9.0, you can create
+a symlink /usr/local/cuda pointing to /usr/local/cuda-9.0 and then run
+
+```bash
+sh build.sh /usr/local/
+```
+
+This will
+
+* Download the code via ```sh get_code.sh ```
+* Apply minor adjustments to the code via ```sh fix_code.sh```
+* Clone the dependencies cub and thrust from github via ```sh clone_dependencies.sh```
+* Compile a shared library correlation_cost.so for correlation_cost via
+```sh compile.sh "${CUDA_DIR}"```
+
+Please review the licenses of correlation_cost, cub, and thrust.
+
+## Enabling correlation_cost
+If you managed to create the correlation_cost.so file, then set
+```USE_CORRELATION_COST = True``` in feelvos/utils/embedding_utils.py and try to run
+```sh eval.sh```.
diff --git a/models/research/feelvos/correlation_cost/build.sh b/models/research/feelvos/correlation_cost/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..37d9adb3147df07646a462fd170772393abf5642
--- /dev/null
+++ b/models/research/feelvos/correlation_cost/build.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to download and build the code for correlation_cost.
+#
+# Usage:
+# sh ./build.sh cuda_dir
+# Where cuda_dir points to a directory containing the cuda folder (not the cuda folder itself).
+#
+#
+
+if [ "$#" -ne 1 ]; then
+ echo "Illegal number of parameters, usage: ./build.sh cuda_dir"
+ echo "Where cuda_dir points to a directory containing the cuda folder (not the cuda folder itself)"
+ exit 1
+fi
+
+set -e
+set -x
+
+sh ./get_code.sh
+sh ./fix_code.sh
+sh ./clone_dependencies.sh
+sh ./compile.sh $1
diff --git a/models/research/feelvos/correlation_cost/clone_dependencies.sh b/models/research/feelvos/correlation_cost/clone_dependencies.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9174313f58a833a5ab547e21c63cdc87681cbc5d
--- /dev/null
+++ b/models/research/feelvos/correlation_cost/clone_dependencies.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to clone the dependencies, i.e. cub and thrust, of correlation_cost from github.
+#
+# Usage:
+# sh ./clone_dependencies.sh
+#
+#
+
+# Clone cub.
+if [ ! -d cub ] ; then
+ git clone https://github.com/dmlc/cub.git
+fi
+# Clone thrust.
+if [ ! -d thrust ] ; then
+ git clone https://github.com/thrust/thrust.git
+fi
diff --git a/models/research/feelvos/correlation_cost/compile.sh b/models/research/feelvos/correlation_cost/compile.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6025292dfa78b44dd6fcf2f1b349af936a43fcc7
--- /dev/null
+++ b/models/research/feelvos/correlation_cost/compile.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to compile the code for correlation_cost and create correlation_cost.so.
+#
+# Usage:
+# sh ./compile.sh cuda_dir
+# Where cuda_dir points to a directory containing the cuda folder (not the cuda folder itself).
+#
+#
+
+if [ "$#" -ne 1 ]; then
+ echo "Illegal number of parameters, usage: ./compile.sh cuda_dir"
+ exit 1
+fi
+CUDA_DIR=$1
+
+if [ ! -d "${CUDA_DIR}/cuda" ]; then
+ echo "cuda_dir must point to a directory containing the cuda folder, not to the cuda folder itself"
+ exit 1
+fi
+
+TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
+TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
+CUB_DIR=cub
+THRUST_DIR=thrust
+
+# Depending on the versions of your nvcc and gcc, the flag --expt-relaxed-constexpr might be required or should be removed.
+# If nvcc complains about a too new gcc version, you can point it to another gcc
+# version by using something like nvcc -ccbin /path/to/your/gcc6
+nvcc -std=c++11 --expt-relaxed-constexpr -I ./ -I ${CUB_DIR}/../ -I ${THRUST_DIR} -I ${CUDA_DIR}/ -c -o correlation_cost_op_gpu.o kernels/correlation_cost_op_gpu.cu.cc ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC
+
+g++ -std=c++11 -I ./ -L ${CUDA_DIR}/cuda/lib64 -shared -o correlation_cost.so ops/correlation_cost_op.cc kernels/correlation_cost_op.cc correlation_cost_op_gpu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]} -D GOOGLE_CUDA=1
diff --git a/models/research/feelvos/correlation_cost/fix_code.sh b/models/research/feelvos/correlation_cost/fix_code.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d4f285db3d745fc55a20bac57f97c6ca2fd8a5c4
--- /dev/null
+++ b/models/research/feelvos/correlation_cost/fix_code.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to modify the downloaded code.
+#
+# Usage:
+# sh ./fix_code.sh
+#
+#
+
+sed -i "s/tensorflow\/contrib\/correlation_cost\///g" kernels/correlation_cost_op_gpu.cu.cc
+sed -i "s/tensorflow\/contrib\/correlation_cost\///g" kernels/correlation_cost_op.cc
+sed -i "s/external\/cub_archive\//cub\//g" kernels/correlation_cost_op_gpu.cu.cc
+
+sed -i "s/from tensorflow.contrib.util import loader/import tensorflow as tf/g" python/ops/correlation_cost_op.py
+grep -v "from tensorflow" python/ops/correlation_cost_op.py | grep -v resource_loader.get_path_to_datafile > correlation_cost_op.py.tmp && mv correlation_cost_op.py.tmp python/ops/correlation_cost_op.py
+sed -i "s/array_ops/tf/g" python/ops/correlation_cost_op.py
+sed -i "s/ops/tf/g" python/ops/correlation_cost_op.py
+sed -i "s/loader.load_op_library(/tf.load_op_library('feelvos\/correlation_cost\/correlation_cost.so')/g" python/ops/correlation_cost_op.py
+sed -i "s/gen_correlation_cost_op/_correlation_cost_op_so/g" python/ops/correlation_cost_op.py
diff --git a/models/research/feelvos/correlation_cost/get_code.sh b/models/research/feelvos/correlation_cost/get_code.sh
new file mode 100644
index 0000000000000000000000000000000000000000..337142166ac4b61835417e807ef0a495532d749c
--- /dev/null
+++ b/models/research/feelvos/correlation_cost/get_code.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to download the code for correlation_cost.
+#
+# Usage:
+# sh ./get_code.sh
+#
+#
+
+mkdir -p kernels ops python/ops
+touch __init__.py
+touch python/__init__.py
+touch python/ops/__init__.py
+wget https://raw.githubusercontent.com/tensorflow/tensorflow/91b163b9bd8dd0f8c2631b4245a67dfd387536a6/tensorflow/contrib/correlation_cost/ops/correlation_cost_op.cc -O ops/correlation_cost_op.cc
+wget https://raw.githubusercontent.com/tensorflow/tensorflow/91b163b9bd8dd0f8c2631b4245a67dfd387536a6/tensorflow/contrib/correlation_cost/python/ops/correlation_cost_op.py -O python/ops/correlation_cost_op.py
+wget https://raw.githubusercontent.com/tensorflow/tensorflow/91b163b9bd8dd0f8c2631b4245a67dfd387536a6/tensorflow/contrib/correlation_cost/kernels/correlation_cost_op.cc -O kernels/correlation_cost_op.cc
+wget https://raw.githubusercontent.com/tensorflow/tensorflow/91b163b9bd8dd0f8c2631b4245a67dfd387536a6/tensorflow/contrib/correlation_cost/kernels/correlation_cost_op.h -O kernels/correlation_cost_op.h
+wget https://raw.githubusercontent.com/tensorflow/tensorflow/91b163b9bd8dd0f8c2631b4245a67dfd387536a6/tensorflow/contrib/correlation_cost/kernels/correlation_cost_op_gpu.cu.cc -O kernels/correlation_cost_op_gpu.cu.cc
diff --git a/models/research/feelvos/datasets/__init__.py b/models/research/feelvos/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f1373443d0ff84fd90714e41dade400ab41a22c
--- /dev/null
+++ b/models/research/feelvos/datasets/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/research/feelvos/datasets/build_davis2017_data.py b/models/research/feelvos/datasets/build_davis2017_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e093fc3b4531f5439957ea3608770441bd5ce4a
--- /dev/null
+++ b/models/research/feelvos/datasets/build_davis2017_data.py
@@ -0,0 +1,163 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Converts DAVIS 2017 data to TFRecord file format with SequenceExample protos.
+"""
+
+import io
+import math
+import os
+from StringIO import StringIO
+import numpy as np
+import PIL
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('data_folder', 'DAVIS2017/',
+ 'Folder containing the DAVIS 2017 data')
+
+tf.app.flags.DEFINE_string('imageset', 'val',
+ 'Which subset to use, either train or val')
+
+tf.app.flags.DEFINE_string(
+ 'output_dir', './tfrecord',
+ 'Path to save converted TFRecords of TensorFlow examples.')
+
+_NUM_SHARDS_TRAIN = 10
+_NUM_SHARDS_VAL = 1
+
+
+def read_image(path):
+ with open(path) as fid:
+ image_str = fid.read()
+ image = PIL.Image.open(io.BytesIO(image_str))
+ w, h = image.size
+ return image_str, (h, w)
+
+
+def read_annotation(path):
+ """Reads a single image annotation from a png image.
+
+ Args:
+ path: Path to the png image.
+
+ Returns:
+ png_string: The png encoded as string.
+ size: Tuple of (height, width).
+ """
+ with open(path) as fid:
+ x = np.array(PIL.Image.open(fid))
+ h, w = x.shape
+ im = PIL.Image.fromarray(x)
+
+ output = StringIO()
+ im.save(output, format='png')
+ png_string = output.getvalue()
+ output.close()
+
+ return png_string, (h, w)
+
+
+def process_video(key, input_dir, anno_dir):
+ """Creates a SequenceExample for the video.
+
+ Args:
+ key: Name of the video.
+ input_dir: Directory which contains the image files.
+ anno_dir: Directory which contains the annotation files.
+
+ Returns:
+ The created SequenceExample.
+ """
+ frame_names = sorted(tf.gfile.ListDirectory(input_dir))
+ anno_files = sorted(tf.gfile.ListDirectory(anno_dir))
+ assert len(frame_names) == len(anno_files)
+
+ sequence = tf.train.SequenceExample()
+ context = sequence.context.feature
+ features = sequence.feature_lists.feature_list
+
+ for i, name in enumerate(frame_names):
+ image_str, image_shape = read_image(
+ os.path.join(input_dir, name))
+ anno_str, anno_shape = read_annotation(
+ os.path.join(anno_dir, name[:-4] + '.png'))
+ image_encoded = features['image/encoded'].feature.add()
+ image_encoded.bytes_list.value.append(image_str)
+ segmentation_encoded = features['segmentation/object/encoded'].feature.add()
+ segmentation_encoded.bytes_list.value.append(anno_str)
+
+ np.testing.assert_array_equal(np.array(image_shape), np.array(anno_shape))
+
+ if i == 0:
+ first_shape = np.array(image_shape)
+ else:
+ np.testing.assert_array_equal(np.array(image_shape), first_shape)
+
+ context['video_id'].bytes_list.value.append(key.encode('ascii'))
+ context['clip/frames'].int64_list.value.append(len(frame_names))
+ context['image/format'].bytes_list.value.append('JPEG')
+ context['image/channels'].int64_list.value.append(3)
+ context['image/height'].int64_list.value.append(first_shape[0])
+ context['image/width'].int64_list.value.append(first_shape[1])
+ context['segmentation/object/format'].bytes_list.value.append('PNG')
+ context['segmentation/object/height'].int64_list.value.append(first_shape[0])
+ context['segmentation/object/width'].int64_list.value.append(first_shape[1])
+
+ return sequence
+
+
+def convert(data_folder, imageset, output_dir, num_shards):
+ """Converts the specified subset of DAVIS 2017 to TFRecord format.
+
+ Args:
+ data_folder: The path to the DAVIS 2017 data.
+ imageset: The subset to use, either train or val.
+ output_dir: Where to store the TFRecords.
+ num_shards: The number of shards used for storing the data.
+ """
+ sets_file = os.path.join(data_folder, 'ImageSets', '2017', imageset + '.txt')
+ vids = [x.strip() for x in open(sets_file).readlines()]
+ num_vids = len(vids)
+ num_vids_per_shard = int(math.ceil(num_vids) / float(num_shards))
+ for shard_id in range(num_shards):
+ output_filename = os.path.join(
+ output_dir,
+ '%s-%05d-of-%05d.tfrecord' % (imageset, shard_id, num_shards))
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+ start_idx = shard_id * num_vids_per_shard
+ end_idx = min((shard_id + 1) * num_vids_per_shard, num_vids)
+ for i in range(start_idx, end_idx):
+ print('Converting video %d/%d shard %d video %s' % (
+ i + 1, num_vids, shard_id, vids[i]))
+ img_dir = os.path.join(data_folder, 'JPEGImages', '480p', vids[i])
+ anno_dir = os.path.join(data_folder, 'Annotations', '480p', vids[i])
+ example = process_video(vids[i], img_dir, anno_dir)
+ tfrecord_writer.write(example.SerializeToString())
+
+
+def main(unused_argv):
+ imageset = FLAGS.imageset
+ assert imageset in ('train', 'val')
+ if imageset == 'train':
+ num_shards = _NUM_SHARDS_TRAIN
+ else:
+ num_shards = _NUM_SHARDS_VAL
+ convert(FLAGS.data_folder, FLAGS.imageset, FLAGS.output_dir, num_shards)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/feelvos/datasets/download_and_convert_davis17.sh b/models/research/feelvos/datasets/download_and_convert_davis17.sh
new file mode 100644
index 0000000000000000000000000000000000000000..011be61ba7586c8f3d141ccc00194d1c7ae56c3a
--- /dev/null
+++ b/models/research/feelvos/datasets/download_and_convert_davis17.sh
@@ -0,0 +1,77 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Script to download and preprocess the DAVIS 2017 dataset.
+#
+# Usage:
+# bash ./download_and_convert_davis17.sh
+
+# Exit immediately if a command exits with a non-zero status.
+set -e
+
+CURRENT_DIR=$(pwd)
+WORK_DIR="./davis17"
+mkdir -p "${WORK_DIR}"
+cd "${WORK_DIR}"
+
+# Helper function to download and unpack the DAVIS 2017 dataset.
+download_and_uncompress() {
+ local BASE_URL=${1}
+ local FILENAME=${2}
+
+ if [ ! -f "${FILENAME}" ]; then
+ echo "Downloading ${FILENAME} to ${WORK_DIR}"
+ wget -nd -c "${BASE_URL}/${FILENAME}"
+ echo "Uncompressing ${FILENAME}"
+ unzip "${FILENAME}"
+ fi
+}
+
+BASE_URL="https://data.vision.ee.ethz.ch/csergi/share/davis/"
+FILENAME="DAVIS-2017-trainval-480p.zip"
+
+download_and_uncompress "${BASE_URL}" "${FILENAME}"
+
+cd "${CURRENT_DIR}"
+
+# Root path for DAVIS 2017 dataset.
+DAVIS_ROOT="${WORK_DIR}/DAVIS"
+
+# Build TFRecords of the dataset.
+# First, create output directory for storing TFRecords.
+OUTPUT_DIR="${WORK_DIR}/tfrecord"
+mkdir -p "${OUTPUT_DIR}"
+
+IMAGE_FOLDER="${DAVIS_ROOT}/JPEGImages"
+LIST_FOLDER="${DAVIS_ROOT}/ImageSets/Segmentation"
+
+# Convert validation set.
+if [ ! -f "${OUTPUT_DIR}/val-00000-of-00001.tfrecord" ]; then
+ echo "Converting DAVIS 2017 dataset (val)..."
+ python ./build_davis2017_data.py \
+ --data_folder="${DAVIS_ROOT}" \
+ --imageset=val \
+ --output_dir="${OUTPUT_DIR}"
+fi
+
+# Convert training set.
+if [ ! -f "${OUTPUT_DIR}/train-00009-of-00010.tfrecord" ]; then
+ echo "Converting DAVIS 2017 dataset (train)..."
+ python ./build_davis2017_data.py \
+ --data_folder="${DAVIS_ROOT}" \
+ --imageset=train \
+ --output_dir="${OUTPUT_DIR}"
+fi
diff --git a/models/research/feelvos/datasets/tfsequence_example_decoder.py b/models/research/feelvos/datasets/tfsequence_example_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fa3e95d5b98eb00aa485371037b4ad6b0e7ece3
--- /dev/null
+++ b/models/research/feelvos/datasets/tfsequence_example_decoder.py
@@ -0,0 +1,118 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Contains the TFExampleDecoder.
+
+The TFExampleDecode is a DataDecoder used to decode TensorFlow Example protos.
+In order to do so each requested item must be paired with one or more Example
+features that are parsed to produce the Tensor-based manifestation of the item.
+"""
+
+import tensorflow as tf
+slim = tf.contrib.slim
+data_decoder = slim.data_decoder
+
+
+class TFSequenceExampleDecoder(data_decoder.DataDecoder):
+ """A decoder for TensorFlow SequenceExamples.
+
+ Decoding SequenceExample proto buffers is comprised of two stages:
+ (1) Example parsing and (2) tensor manipulation.
+
+ In the first stage, the tf.parse_single_sequence_example function is called
+ with a list of FixedLenFeatures and SparseLenFeatures. These instances tell TF
+ how to parse the example. The output of this stage is a set of tensors.
+
+ In the second stage, the resulting tensors are manipulated to provide the
+ requested 'item' tensors.
+
+ To perform this decoding operation, a SequenceExampleDecoder is given a list
+ of ItemHandlers. Each ItemHandler indicates the set of features for stage 1
+ and contains the instructions for post_processing its tensors for stage 2.
+ """
+
+ def __init__(self, keys_to_context_features, keys_to_sequence_features,
+ items_to_handlers):
+ """Constructs the decoder.
+
+ Args:
+ keys_to_context_features: a dictionary from TF-SequenceExample context
+ keys to either tf.VarLenFeature or tf.FixedLenFeature instances.
+ See tensorflow's parsing_ops.py.
+ keys_to_sequence_features: a dictionary from TF-SequenceExample sequence
+ keys to either tf.VarLenFeature or tf.FixedLenSequenceFeature instances.
+ See tensorflow's parsing_ops.py.
+ items_to_handlers: a dictionary from items (strings) to ItemHandler
+ instances. Note that the ItemHandler's are provided the keys that they
+ use to return the final item Tensors.
+
+ Raises:
+ ValueError: if the same key is present for context features and sequence
+ features.
+ """
+ unique_keys = set()
+ unique_keys.update(keys_to_context_features)
+ unique_keys.update(keys_to_sequence_features)
+ if len(unique_keys) != (
+ len(keys_to_context_features) + len(keys_to_sequence_features)):
+ # This situation is ambiguous in the decoder's keys_to_tensors variable.
+ raise ValueError('Context and sequence keys are not unique. \n'
+ ' Context keys: %s \n Sequence keys: %s' %
+ (list(keys_to_context_features.keys()),
+ list(keys_to_sequence_features.keys())))
+
+ self._keys_to_context_features = keys_to_context_features
+ self._keys_to_sequence_features = keys_to_sequence_features
+ self._items_to_handlers = items_to_handlers
+
+ def list_items(self):
+ """See base class."""
+ return self._items_to_handlers.keys()
+
+ def decode(self, serialized_example, items=None):
+ """Decodes the given serialized TF-SequenceExample.
+
+ Args:
+ serialized_example: a serialized TF-SequenceExample tensor.
+ items: the list of items to decode. These must be a subset of the item
+ keys in self._items_to_handlers. If `items` is left as None, then all
+ of the items in self._items_to_handlers are decoded.
+
+ Returns:
+ the decoded items, a list of tensor.
+ """
+
+ context, feature_list = tf.parse_single_sequence_example(
+ serialized_example, self._keys_to_context_features,
+ self._keys_to_sequence_features)
+
+ # Reshape non-sparse elements just once:
+ for k in self._keys_to_context_features:
+ v = self._keys_to_context_features[k]
+ if isinstance(v, tf.FixedLenFeature):
+ context[k] = tf.reshape(context[k], v.shape)
+
+ if not items:
+ items = self._items_to_handlers.keys()
+
+ outputs = []
+ for item in items:
+ handler = self._items_to_handlers[item]
+ keys_to_tensors = {
+ key: context[key] if key in context else feature_list[key]
+ for key in handler.keys
+ }
+ outputs.append(handler.tensors_to_item(keys_to_tensors))
+ return outputs
diff --git a/models/research/feelvos/datasets/video_dataset.py b/models/research/feelvos/datasets/video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..17b62e989af866df0232a0e6d921faee84fe1fa7
--- /dev/null
+++ b/models/research/feelvos/datasets/video_dataset.py
@@ -0,0 +1,196 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Provides data from video object segmentation datasets.
+
+This file provides both images and annotations (instance segmentations) for
+TensorFlow. Currently, we support the following datasets:
+
+1. DAVIS 2017 (https://davischallenge.org/davis2017/code.html).
+
+2. DAVIS 2016 (https://davischallenge.org/davis2016/code.html).
+
+3. YouTube-VOS (https://youtube-vos.org/dataset/download).
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os.path
+import tensorflow as tf
+from feelvos.datasets import tfsequence_example_decoder
+
+slim = tf.contrib.slim
+dataset = slim.dataset
+tfexample_decoder = slim.tfexample_decoder
+
+
+_ITEMS_TO_DESCRIPTIONS = {
+ 'image': 'A color image of varying height and width.',
+ 'labels_class': ('A semantic segmentation label whose size matches image.'
+ 'Its values range from 0 (background) to num_classes.'),
+}
+
+# Named tuple to describe the dataset properties.
+DatasetDescriptor = collections.namedtuple(
+ 'DatasetDescriptor',
+ ['splits_to_sizes', # Splits of the dataset into training, val, and test.
+ 'num_classes', # Number of semantic classes.
+ 'ignore_label', # Ignore label value.
+ ]
+)
+
+_DAVIS_2016_INFORMATION = DatasetDescriptor(
+ splits_to_sizes={'train': [30, 1830],
+ 'val': [20, 1376]},
+ num_classes=2,
+ ignore_label=255,
+)
+
+_DAVIS_2017_INFORMATION = DatasetDescriptor(
+ splits_to_sizes={'train': [60, 4219],
+ 'val': [30, 2023],
+ 'test-dev': [30, 2037]},
+ num_classes=None, # Number of instances per videos differ.
+ ignore_label=255,
+)
+
+_YOUTUBE_VOS_2018_INFORMATION = DatasetDescriptor(
+ # Leave these sizes as None to allow for different splits into
+ # training and validation sets.
+ splits_to_sizes={'train': [None, None],
+ 'val': [None, None]},
+ num_classes=None, # Number of instances per video differs.
+ ignore_label=255,
+)
+
+_DATASETS_INFORMATION = {
+ 'davis_2016': _DAVIS_2016_INFORMATION,
+ 'davis_2017': _DAVIS_2017_INFORMATION,
+ 'youtube_vos_2018': _YOUTUBE_VOS_2018_INFORMATION,
+}
+
+# Default file pattern of SSTable. Note we include '-' to avoid the confusion
+# between `train-` and `trainval-` sets.
+_FILE_PATTERN = '%s-*'
+
+
+def get_dataset(dataset_name,
+ split_name,
+ dataset_dir,
+ file_pattern=None,
+ data_type='tf_sequence_example',
+ decode_video_frames=False):
+ """Gets an instance of slim Dataset.
+
+ Args:
+ dataset_name: String, dataset name.
+ split_name: String, the train/val Split name.
+ dataset_dir: String, the directory of the dataset sources.
+ file_pattern: String, file pattern of SSTable.
+ data_type: String, data type. Currently supports 'tf_example' and
+ 'annotated_image'.
+ decode_video_frames: Boolean, decode the images or not. Not decoding it here
+ is useful if we subsample later
+
+ Returns:
+ An instance of slim Dataset.
+
+ Raises:
+ ValueError: If the dataset_name or split_name is not recognized, or if
+ the dataset_type is not supported.
+ """
+ if dataset_name not in _DATASETS_INFORMATION:
+ raise ValueError('The specified dataset is not supported yet.')
+
+ splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
+
+ if split_name not in splits_to_sizes:
+ raise ValueError('data split name %s not recognized' % split_name)
+
+ # Prepare the variables for different datasets.
+ num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
+ ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label
+
+ if file_pattern is None:
+ file_pattern = _FILE_PATTERN
+ file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
+ if data_type == 'tf_sequence_example':
+ keys_to_context_features = {
+ 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
+ 'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
+ 'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
+ 'segmentation/object/format': tf.FixedLenFeature(
+ (), tf.string, default_value='png'),
+ 'video_id': tf.FixedLenFeature((), tf.string, default_value='unknown')
+ }
+ label_name = 'class' if dataset_name == 'davis_2016' else 'object'
+ keys_to_sequence_features = {
+ 'image/encoded': tf.FixedLenSequenceFeature((), dtype=tf.string),
+ 'segmentation/{}/encoded'.format(label_name):
+ tf.FixedLenSequenceFeature((), tf.string),
+ 'segmentation/{}/encoded'.format(label_name):
+ tf.FixedLenSequenceFeature((), tf.string),
+ }
+ items_to_handlers = {
+ 'height': tfexample_decoder.Tensor('image/height'),
+ 'width': tfexample_decoder.Tensor('image/width'),
+ 'video_id': tfexample_decoder.Tensor('video_id')
+ }
+ if decode_video_frames:
+ decode_image_handler = tfexample_decoder.Image(
+ image_key='image/encoded',
+ format_key='image/format',
+ channels=3,
+ repeated=True)
+ items_to_handlers['image'] = decode_image_handler
+ decode_label_handler = tfexample_decoder.Image(
+ image_key='segmentation/{}/encoded'.format(label_name),
+ format_key='segmentation/{}/format'.format(label_name),
+ channels=1,
+ repeated=True)
+ items_to_handlers['labels_class'] = decode_label_handler
+ else:
+ items_to_handlers['image/encoded'] = tfexample_decoder.Tensor(
+ 'image/encoded')
+ items_to_handlers[
+ 'segmentation/object/encoded'] = tfexample_decoder.Tensor(
+ 'segmentation/{}/encoded'.format(label_name))
+ decoder = tfsequence_example_decoder.TFSequenceExampleDecoder(
+ keys_to_context_features, keys_to_sequence_features, items_to_handlers)
+ else:
+ raise ValueError('Unknown data type.')
+
+ size = splits_to_sizes[split_name]
+ if isinstance(size, collections.Sequence):
+ num_videos = size[0]
+ num_samples = size[1]
+ else:
+ num_videos = 0
+ num_samples = size
+
+ return dataset.Dataset(
+ data_sources=file_pattern,
+ reader=tf.TFRecordReader,
+ decoder=decoder,
+ num_samples=num_samples,
+ num_videos=num_videos,
+ items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
+ ignore_label=ignore_label,
+ num_classes=num_classes,
+ name=dataset_name,
+ multi_label=True)
diff --git a/models/research/feelvos/eval.sh b/models/research/feelvos/eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..96cb7f409a1e652ba8263f35c3786cb0cb77f5d1
--- /dev/null
+++ b/models/research/feelvos/eval.sh
@@ -0,0 +1,86 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to locally run inference on DAVIS 2017. Users could also
+# modify from this script for their use case. See train.sh for an example of
+# local training.
+#
+# Usage:
+# # From the tensorflow/models/research/feelvos directory.
+# sh ./eval.sh
+#
+#
+
+# Exit immediately if a command exits with a non-zero status.
+set -e
+
+# Move one-level up to tensorflow/models/research directory.
+cd ..
+
+# Update PYTHONPATH.
+export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim:`pwd`/feelvos
+
+# Set up the working environment.
+CURRENT_DIR=$(pwd)
+WORK_DIR="${CURRENT_DIR}/feelvos"
+
+# Run embedding_utils_test first to make sure the PYTHONPATH is correctly set.
+python "${WORK_DIR}"/utils/embedding_utils_test.py -v
+
+# Go to datasets folder and download and convert the DAVIS 2017 dataset.
+DATASET_DIR="datasets"
+cd "${WORK_DIR}/${DATASET_DIR}"
+sh download_and_convert_davis17.sh
+
+# Go to models folder and download and unpack the DAVIS 2017 trained model.
+MODELS_DIR="models"
+mkdir -p "${WORK_DIR}/${MODELS_DIR}"
+cd "${WORK_DIR}/${MODELS_DIR}"
+if [ ! -d "feelvos_davis17_trained" ]; then
+ wget http://download.tensorflow.org/models/feelvos_davis17_trained.tar.gz
+ tar -xvf feelvos_davis17_trained.tar.gz
+ echo "model_checkpoint_path: \"model.ckpt-200004\"" > feelvos_davis17_trained/checkpoint
+ rm feelvos_davis17_trained.tar.gz
+fi
+CHECKPOINT_DIR="${WORK_DIR}/${MODELS_DIR}/feelvos_davis17_trained/"
+
+# Go back to orignal directory.
+cd "${CURRENT_DIR}"
+
+# Set up the working directories.
+DAVIS_FOLDER="davis17"
+EXP_FOLDER="exp/eval_on_val_set"
+VIS_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${DAVIS_FOLDER}/${EXP_FOLDER}/eval"
+mkdir -p ${VIS_LOGDIR}
+
+DAVIS_DATASET="${WORK_DIR}/${DATASET_DIR}/${DAVIS_FOLDER}/tfrecord"
+
+python "${WORK_DIR}"/vis_video.py \
+ --dataset=davis_2017 \
+ --dataset_dir="${DAVIS_DATASET}" \
+ --vis_logdir="${VIS_LOGDIR}" \
+ --checkpoint_dir="${CHECKPOINT_DIR}" \
+ --logtostderr \
+ --atrous_rates=12 \
+ --atrous_rates=24 \
+ --atrous_rates=36 \
+ --decoder_output_stride=4 \
+ --model_variant=xception_65 \
+ --multi_grid=1 \
+ --multi_grid=1 \
+ --multi_grid=1 \
+ --output_stride=8 \
+ --save_segmentations
diff --git a/models/research/feelvos/input_preprocess.py b/models/research/feelvos/input_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..954c0b42ef2650b1c25ec8071933beee57e9bd69
--- /dev/null
+++ b/models/research/feelvos/input_preprocess.py
@@ -0,0 +1,280 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Prepare the data used for FEELVOS training/evaluation."""
+import tensorflow as tf
+
+from deeplab.core import feature_extractor
+from deeplab.core import preprocess_utils
+
+# The probability of flipping the images and labels
+# left-right during training
+_PROB_OF_FLIP = 0.5
+
+get_random_scale = preprocess_utils.get_random_scale
+randomly_scale_image_and_label = (
+ preprocess_utils.randomly_scale_image_and_label)
+
+
+def preprocess_image_and_label(image,
+ label,
+ crop_height,
+ crop_width,
+ min_resize_value=None,
+ max_resize_value=None,
+ resize_factor=None,
+ min_scale_factor=1.,
+ max_scale_factor=1.,
+ scale_factor_step_size=0,
+ ignore_label=255,
+ is_training=True,
+ model_variant=None):
+ """Preprocesses the image and label.
+
+ Args:
+ image: Input image.
+ label: Ground truth annotation label.
+ crop_height: The height value used to crop the image and label.
+ crop_width: The width value used to crop the image and label.
+ min_resize_value: Desired size of the smaller image side.
+ max_resize_value: Maximum allowed size of the larger image side.
+ resize_factor: Resized dimensions are multiple of factor plus one.
+ min_scale_factor: Minimum scale factor value.
+ max_scale_factor: Maximum scale factor value.
+ scale_factor_step_size: The step size from min scale factor to max scale
+ factor. The input is randomly scaled based on the value of
+ (min_scale_factor, max_scale_factor, scale_factor_step_size).
+ ignore_label: The label value which will be ignored for training and
+ evaluation.
+ is_training: If the preprocessing is used for training or not.
+ model_variant: Model variant (string) for choosing how to mean-subtract the
+ images. See feature_extractor.network_map for supported model variants.
+
+ Returns:
+ original_image: Original image (could be resized).
+ processed_image: Preprocessed image.
+ label: Preprocessed ground truth segmentation label.
+
+ Raises:
+ ValueError: Ground truth label not provided during training.
+ """
+ if is_training and label is None:
+ raise ValueError('During training, label must be provided.')
+ if model_variant is None:
+ tf.logging.warning('Default mean-subtraction is performed. Please specify '
+ 'a model_variant. See feature_extractor.network_map for '
+ 'supported model variants.')
+
+ # Keep reference to original image.
+ original_image = image
+
+ processed_image = tf.cast(image, tf.float32)
+
+ if label is not None:
+ label = tf.cast(label, tf.int32)
+
+ # Resize image and label to the desired range.
+ if min_resize_value is not None or max_resize_value is not None:
+ [processed_image, label] = (
+ preprocess_utils.resize_to_range(
+ image=processed_image,
+ label=label,
+ min_size=min_resize_value,
+ max_size=max_resize_value,
+ factor=resize_factor,
+ align_corners=True))
+ # The `original_image` becomes the resized image.
+ original_image = tf.identity(processed_image)
+
+ # Data augmentation by randomly scaling the inputs.
+ scale = get_random_scale(
+ min_scale_factor, max_scale_factor, scale_factor_step_size)
+ processed_image, label = randomly_scale_image_and_label(
+ processed_image, label, scale)
+
+ processed_image.set_shape([None, None, 3])
+
+ if crop_height is not None and crop_width is not None:
+ # Pad image and label to have dimensions >= [crop_height, crop_width].
+ image_shape = tf.shape(processed_image)
+ image_height = image_shape[0]
+ image_width = image_shape[1]
+
+ target_height = image_height + tf.maximum(crop_height - image_height, 0)
+ target_width = image_width + tf.maximum(crop_width - image_width, 0)
+
+ # Pad image with mean pixel value.
+ mean_pixel = tf.reshape(
+ feature_extractor.mean_pixel(model_variant), [1, 1, 3])
+ processed_image = preprocess_utils.pad_to_bounding_box(
+ processed_image, 0, 0, target_height, target_width, mean_pixel)
+
+ if label is not None:
+ label = preprocess_utils.pad_to_bounding_box(
+ label, 0, 0, target_height, target_width, ignore_label)
+
+ # Randomly crop the image and label.
+ if is_training and label is not None:
+ processed_image, label = preprocess_utils.random_crop(
+ [processed_image, label], crop_height, crop_width)
+
+ processed_image.set_shape([crop_height, crop_width, 3])
+
+ if label is not None:
+ label.set_shape([crop_height, crop_width, 1])
+
+ if is_training:
+ # Randomly left-right flip the image and label.
+ processed_image, label, _ = preprocess_utils.flip_dim(
+ [processed_image, label], _PROB_OF_FLIP, dim=1)
+
+ return original_image, processed_image, label
+
+
+def preprocess_images_and_labels_consistently(images,
+ labels,
+ crop_height,
+ crop_width,
+ min_resize_value=None,
+ max_resize_value=None,
+ resize_factor=None,
+ min_scale_factor=1.,
+ max_scale_factor=1.,
+ scale_factor_step_size=0,
+ ignore_label=255,
+ is_training=True,
+ model_variant=None):
+ """Preprocesses images and labels in a consistent way.
+
+ Similar to preprocess_image_and_label, but works on a list of images
+ and a list of labels and uses the same crop coordinates and either flips
+ all images and labels or none of them.
+
+ Args:
+ images: List of input images.
+ labels: List of ground truth annotation labels.
+ crop_height: The height value used to crop the image and label.
+ crop_width: The width value used to crop the image and label.
+ min_resize_value: Desired size of the smaller image side.
+ max_resize_value: Maximum allowed size of the larger image side.
+ resize_factor: Resized dimensions are multiple of factor plus one.
+ min_scale_factor: Minimum scale factor value.
+ max_scale_factor: Maximum scale factor value.
+ scale_factor_step_size: The step size from min scale factor to max scale
+ factor. The input is randomly scaled based on the value of
+ (min_scale_factor, max_scale_factor, scale_factor_step_size).
+ ignore_label: The label value which will be ignored for training and
+ evaluation.
+ is_training: If the preprocessing is used for training or not.
+ model_variant: Model variant (string) for choosing how to mean-subtract the
+ images. See feature_extractor.network_map for supported model variants.
+
+ Returns:
+ original_images: Original images (could be resized).
+ processed_images: Preprocessed images.
+ labels: Preprocessed ground truth segmentation labels.
+
+ Raises:
+ ValueError: Ground truth label not provided during training.
+ """
+ if is_training and labels is None:
+ raise ValueError('During training, labels must be provided.')
+ if model_variant is None:
+ tf.logging.warning('Default mean-subtraction is performed. Please specify '
+ 'a model_variant. See feature_extractor.network_map for '
+ 'supported model variants.')
+ if labels is not None:
+ assert len(images) == len(labels)
+ num_imgs = len(images)
+
+ # Keep reference to original images.
+ original_images = images
+
+ processed_images = [tf.cast(image, tf.float32) for image in images]
+
+ if labels is not None:
+ labels = [tf.cast(label, tf.int32) for label in labels]
+
+ # Resize images and labels to the desired range.
+ if min_resize_value is not None or max_resize_value is not None:
+ processed_images, labels = zip(*[
+ preprocess_utils.resize_to_range(
+ image=processed_image,
+ label=label,
+ min_size=min_resize_value,
+ max_size=max_resize_value,
+ factor=resize_factor,
+ align_corners=True) for processed_image, label
+ in zip(processed_images, labels)])
+ # The `original_images` becomes the resized images.
+ original_images = [tf.identity(processed_image)
+ for processed_image in processed_images]
+
+ # Data augmentation by randomly scaling the inputs.
+ scale = get_random_scale(
+ min_scale_factor, max_scale_factor, scale_factor_step_size)
+ processed_images, labels = zip(
+ *[randomly_scale_image_and_label(processed_image, label, scale)
+ for processed_image, label in zip(processed_images, labels)])
+
+ for processed_image in processed_images:
+ processed_image.set_shape([None, None, 3])
+
+ if crop_height is not None and crop_width is not None:
+ # Pad image and label to have dimensions >= [crop_height, crop_width].
+ image_shape = tf.shape(processed_images[0])
+ image_height = image_shape[0]
+ image_width = image_shape[1]
+
+ target_height = image_height + tf.maximum(crop_height - image_height, 0)
+ target_width = image_width + tf.maximum(crop_width - image_width, 0)
+
+ # Pad image with mean pixel value.
+ mean_pixel = tf.reshape(
+ feature_extractor.mean_pixel(model_variant), [1, 1, 3])
+ processed_images = [preprocess_utils.pad_to_bounding_box(
+ processed_image, 0, 0, target_height, target_width, mean_pixel)
+ for processed_image in processed_images]
+
+ if labels is not None:
+ labels = [preprocess_utils.pad_to_bounding_box(
+ label, 0, 0, target_height, target_width, ignore_label)
+ for label in labels]
+
+ # Randomly crop the images and labels.
+ if is_training and labels is not None:
+ cropped = preprocess_utils.random_crop(
+ processed_images + labels, crop_height, crop_width)
+ assert len(cropped) == 2 * num_imgs
+ processed_images = cropped[:num_imgs]
+ labels = cropped[num_imgs:]
+
+ for processed_image in processed_images:
+ processed_image.set_shape([crop_height, crop_width, 3])
+
+ if labels is not None:
+ for label in labels:
+ label.set_shape([crop_height, crop_width, 1])
+
+ if is_training:
+ # Randomly left-right flip the image and label.
+ res = preprocess_utils.flip_dim(
+ list(processed_images + labels), _PROB_OF_FLIP, dim=1)
+ maybe_flipped = res[:-1]
+ assert len(maybe_flipped) == 2 * num_imgs
+ processed_images = maybe_flipped[:num_imgs]
+ labels = maybe_flipped[num_imgs:]
+
+ return original_images, processed_images, labels
diff --git a/models/research/feelvos/model.py b/models/research/feelvos/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f145f91616958b7327d99bb55efb1b7b5016a223
--- /dev/null
+++ b/models/research/feelvos/model.py
@@ -0,0 +1,480 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""Provides DeepLab model definition and helper functions.
+
+DeepLab is a deep learning system for semantic image segmentation with
+the following features:
+
+(1) Atrous convolution to explicitly control the resolution at which
+feature responses are computed within Deep Convolutional Neural Networks.
+
+(2) Atrous spatial pyramid pooling (ASPP) to robustly segment objects at
+multiple scales with filters at multiple sampling rates and effective
+fields-of-views.
+
+(3) ASPP module augmented with image-level feature and batch normalization.
+
+(4) A simple yet effective decoder module to recover the object boundaries.
+
+See the following papers for more details:
+
+"Encoder-Decoder with Atrous Separable Convolution for Semantic Image
+Segmentation"
+Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam.
+(https://arxiv.org/abs1802.02611)
+
+"Rethinking Atrous Convolution for Semantic Image Segmentation,"
+Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam
+(https://arxiv.org/abs/1706.05587)
+
+"DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,
+Atrous Convolution, and Fully Connected CRFs",
+Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy,
+Alan L Yuille (* equal contribution)
+(https://arxiv.org/abs/1606.00915)
+
+"Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected
+CRFs"
+Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy,
+Alan L. Yuille (* equal contribution)
+(https://arxiv.org/abs/1412.7062)
+"""
+import collections
+import tensorflow as tf
+
+from deeplab import model
+from feelvos import common
+from feelvos.utils import embedding_utils
+from feelvos.utils import train_utils
+
+slim = tf.contrib.slim
+
+
+get_branch_logits = model.get_branch_logits
+get_extra_layer_scopes = model.get_extra_layer_scopes
+multi_scale_logits_v2 = model.multi_scale_logits
+refine_by_decoder = model.refine_by_decoder
+scale_dimension = model.scale_dimension
+split_separable_conv2d = model.split_separable_conv2d
+
+MERGED_LOGITS_SCOPE = model.MERGED_LOGITS_SCOPE
+IMAGE_POOLING_SCOPE = model.IMAGE_POOLING_SCOPE
+ASPP_SCOPE = model.ASPP_SCOPE
+CONCAT_PROJECTION_SCOPE = model.CONCAT_PROJECTION_SCOPE
+
+
+def predict_labels(images,
+ model_options,
+ image_pyramid=None,
+ reference_labels=None,
+ k_nearest_neighbors=1,
+ embedding_dimension=None,
+ use_softmax_feedback=False,
+ initial_softmax_feedback=None,
+ embedding_seg_feature_dimension=256,
+ embedding_seg_n_layers=4,
+ embedding_seg_kernel_size=7,
+ embedding_seg_atrous_rates=None,
+ also_return_softmax_probabilities=False,
+ num_frames_per_video=None,
+ normalize_nearest_neighbor_distances=False,
+ also_attend_to_previous_frame=False,
+ use_local_previous_frame_attention=False,
+ previous_frame_attention_window_size=9,
+ use_first_frame_matching=True,
+ also_return_embeddings=False,
+ ref_embeddings=None):
+ """Predicts segmentation labels.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ model_options: An InternalModelOptions instance to configure models.
+ image_pyramid: Input image scales for multi-scale feature extraction.
+ reference_labels: A tensor of size [batch, height, width, 1].
+ ground truth labels used to perform a nearest neighbor query
+ k_nearest_neighbors: Integer, the number of neighbors to use for nearest
+ neighbor queries.
+ embedding_dimension: Integer, the dimension used for the learned embedding.
+ use_softmax_feedback: Boolean, whether to give the softmax predictions of
+ the last frame as additional input to the segmentation head.
+ initial_softmax_feedback: Float32 tensor, or None. Can be used to
+ initialize the softmax predictions used for the feedback loop.
+ Typically only useful for inference. Only has an effect if
+ use_softmax_feedback is True.
+ embedding_seg_feature_dimension: Integer, the dimensionality used in the
+ segmentation head layers.
+ embedding_seg_n_layers: Integer, the number of layers in the segmentation
+ head.
+ embedding_seg_kernel_size: Integer, the kernel size used in the
+ segmentation head.
+ embedding_seg_atrous_rates: List of integers of length
+ embedding_seg_n_layers, the atrous rates to use for the segmentation head.
+ also_return_softmax_probabilities: Boolean, if true, additionally return
+ the softmax probabilities as second return value.
+ num_frames_per_video: Integer, the number of frames per video.
+ normalize_nearest_neighbor_distances: Boolean, whether to normalize the
+ nearest neighbor distances to [0,1] using sigmoid, scale and shift.
+ also_attend_to_previous_frame: Boolean, whether to also use nearest
+ neighbor attention with respect to the previous frame.
+ use_local_previous_frame_attention: Boolean, whether to restrict the
+ previous frame attention to a local search window.
+ Only has an effect, if also_attend_to_previous_frame is True.
+ previous_frame_attention_window_size: Integer, the window size used for
+ local previous frame attention, if use_local_previous_frame_attention
+ is True.
+ use_first_frame_matching: Boolean, whether to extract features by matching
+ to the reference frame. This should always be true except for ablation
+ experiments.
+ also_return_embeddings: Boolean, whether to return the embeddings as well.
+ ref_embeddings: Tuple of
+ (first_frame_embeddings, previous_frame_embeddings),
+ each of shape [batch, height, width, embedding_dimension], or None.
+
+ Returns:
+ A dictionary with keys specifying the output_type (e.g., semantic
+ prediction) and values storing Tensors representing predictions (argmax
+ over channels). Each prediction has size [batch, height, width].
+ If also_return_softmax_probabilities is True, the second return value are
+ the softmax probabilities.
+ If also_return_embeddings is True, it will also return an embeddings
+ tensor of shape [batch, height, width, embedding_dimension].
+
+ Raises:
+ ValueError: If classification_loss is not softmax, softmax_with_attention,
+ nor triplet.
+ """
+ if (model_options.classification_loss == 'triplet' and
+ reference_labels is None):
+ raise ValueError('Need reference_labels for triplet loss')
+
+ if model_options.classification_loss == 'softmax_with_attention':
+ if embedding_dimension is None:
+ raise ValueError('Need embedding_dimension for softmax_with_attention '
+ 'loss')
+ if reference_labels is None:
+ raise ValueError('Need reference_labels for softmax_with_attention loss')
+ res = (
+ multi_scale_logits_with_nearest_neighbor_matching(
+ images,
+ model_options=model_options,
+ image_pyramid=image_pyramid,
+ is_training=False,
+ reference_labels=reference_labels,
+ clone_batch_size=1,
+ num_frames_per_video=num_frames_per_video,
+ embedding_dimension=embedding_dimension,
+ max_neighbors_per_object=0,
+ k_nearest_neighbors=k_nearest_neighbors,
+ use_softmax_feedback=use_softmax_feedback,
+ initial_softmax_feedback=initial_softmax_feedback,
+ embedding_seg_feature_dimension=embedding_seg_feature_dimension,
+ embedding_seg_n_layers=embedding_seg_n_layers,
+ embedding_seg_kernel_size=embedding_seg_kernel_size,
+ embedding_seg_atrous_rates=embedding_seg_atrous_rates,
+ normalize_nearest_neighbor_distances=
+ normalize_nearest_neighbor_distances,
+ also_attend_to_previous_frame=also_attend_to_previous_frame,
+ use_local_previous_frame_attention=
+ use_local_previous_frame_attention,
+ previous_frame_attention_window_size=
+ previous_frame_attention_window_size,
+ use_first_frame_matching=use_first_frame_matching,
+ also_return_embeddings=also_return_embeddings,
+ ref_embeddings=ref_embeddings
+ ))
+ if also_return_embeddings:
+ outputs_to_scales_to_logits, embeddings = res
+ else:
+ outputs_to_scales_to_logits = res
+ embeddings = None
+ else:
+ outputs_to_scales_to_logits = multi_scale_logits_v2(
+ images,
+ model_options=model_options,
+ image_pyramid=image_pyramid,
+ is_training=False,
+ fine_tune_batch_norm=False)
+
+ predictions = {}
+ for output in sorted(outputs_to_scales_to_logits):
+ scales_to_logits = outputs_to_scales_to_logits[output]
+ original_logits = scales_to_logits[MERGED_LOGITS_SCOPE]
+ if isinstance(original_logits, list):
+ assert len(original_logits) == 1
+ original_logits = original_logits[0]
+ logits = tf.image.resize_bilinear(original_logits, tf.shape(images)[1:3],
+ align_corners=True)
+ if model_options.classification_loss in ('softmax',
+ 'softmax_with_attention'):
+ predictions[output] = tf.argmax(logits, 3)
+ elif model_options.classification_loss == 'triplet':
+ # to keep this fast, we do the nearest neighbor assignment on the
+ # resolution at which the embedding is extracted and scale the result up
+ # afterwards
+ embeddings = original_logits
+ reference_labels_logits_size = tf.squeeze(
+ tf.image.resize_nearest_neighbor(
+ reference_labels[tf.newaxis],
+ train_utils.resolve_shape(embeddings)[1:3],
+ align_corners=True), axis=0)
+ nn_labels = embedding_utils.assign_labels_by_nearest_neighbors(
+ embeddings[0], embeddings[1:], reference_labels_logits_size,
+ k_nearest_neighbors)
+ predictions[common.OUTPUT_TYPE] = tf.image.resize_nearest_neighbor(
+ nn_labels, tf.shape(images)[1:3], align_corners=True)
+ else:
+ raise ValueError(
+ 'Only support softmax, triplet, or softmax_with_attention for '
+ 'classification_loss.')
+
+ if also_return_embeddings:
+ assert also_return_softmax_probabilities
+ return predictions, tf.nn.softmax(original_logits, axis=-1), embeddings
+ elif also_return_softmax_probabilities:
+ return predictions, tf.nn.softmax(original_logits, axis=-1)
+ else:
+ return predictions
+
+
+def multi_scale_logits_with_nearest_neighbor_matching(
+ images,
+ model_options,
+ image_pyramid,
+ clone_batch_size,
+ reference_labels,
+ num_frames_per_video,
+ embedding_dimension,
+ max_neighbors_per_object,
+ weight_decay=0.0001,
+ is_training=False,
+ fine_tune_batch_norm=False,
+ k_nearest_neighbors=1,
+ use_softmax_feedback=False,
+ initial_softmax_feedback=None,
+ embedding_seg_feature_dimension=256,
+ embedding_seg_n_layers=4,
+ embedding_seg_kernel_size=7,
+ embedding_seg_atrous_rates=None,
+ normalize_nearest_neighbor_distances=False,
+ also_attend_to_previous_frame=False,
+ damage_initial_previous_frame_mask=False,
+ use_local_previous_frame_attention=False,
+ previous_frame_attention_window_size=9,
+ use_first_frame_matching=True,
+ also_return_embeddings=False,
+ ref_embeddings=None):
+ """Gets the logits for multi-scale inputs using nearest neighbor attention.
+
+ Adjusted version of multi_scale_logits_v2 to support nearest neighbor
+ attention and a variable number of classes for each element of the batch.
+ The returned logits are all downsampled (due to max-pooling layers)
+ for both training and evaluation.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ model_options: A ModelOptions instance to configure models.
+ image_pyramid: Input image scales for multi-scale feature extraction.
+ clone_batch_size: Integer, the number of videos on a batch.
+ reference_labels: The segmentation labels of the reference frame on which
+ attention is applied.
+ num_frames_per_video: Integer, the number of frames per video.
+ embedding_dimension: Integer, the dimension of the embedding.
+ max_neighbors_per_object: Integer, the maximum number of candidates
+ for the nearest neighbor query per object after subsampling.
+ Can be 0 for no subsampling.
+ weight_decay: The weight decay for model variables.
+ is_training: Is training or not.
+ fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
+ k_nearest_neighbors: Integer, the number of nearest neighbors to use.
+ use_softmax_feedback: Boolean, whether to give the softmax predictions of
+ the last frame as additional input to the segmentation head.
+ initial_softmax_feedback: List of Float32 tensors, or None.
+ Can be used to initialize the softmax predictions used for the feedback
+ loop. Only has an effect if use_softmax_feedback is True.
+ embedding_seg_feature_dimension: Integer, the dimensionality used in the
+ segmentation head layers.
+ embedding_seg_n_layers: Integer, the number of layers in the segmentation
+ head.
+ embedding_seg_kernel_size: Integer, the kernel size used in the
+ segmentation head.
+ embedding_seg_atrous_rates: List of integers of length
+ embedding_seg_n_layers, the atrous rates to use for the segmentation head.
+ normalize_nearest_neighbor_distances: Boolean, whether to normalize the
+ nearest neighbor distances to [0,1] using sigmoid, scale and shift.
+ also_attend_to_previous_frame: Boolean, whether to also use nearest
+ neighbor attention with respect to the previous frame.
+ damage_initial_previous_frame_mask: Boolean, whether to artificially damage
+ the initial previous frame mask. Only has an effect if
+ also_attend_to_previous_frame is True.
+ use_local_previous_frame_attention: Boolean, whether to restrict the
+ previous frame attention to a local search window.
+ Only has an effect, if also_attend_to_previous_frame is True.
+ previous_frame_attention_window_size: Integer, the window size used for
+ local previous frame attention, if use_local_previous_frame_attention
+ is True.
+ use_first_frame_matching: Boolean, whether to extract features by matching
+ to the reference frame. This should always be true except for ablation
+ experiments.
+ also_return_embeddings: Boolean, whether to return the embeddings as well.
+ ref_embeddings: Tuple of
+ (first_frame_embeddings, previous_frame_embeddings),
+ each of shape [batch, height, width, embedding_dimension], or None.
+
+ Returns:
+ outputs_to_scales_to_logits: A map of maps from output_type (e.g.,
+ semantic prediction) to a dictionary of multi-scale logits names to
+ logits. For each output_type, the dictionary has keys which
+ correspond to the scales and values which correspond to the logits.
+ For example, if `scales` equals [1.0, 1.5], then the keys would
+ include 'merged_logits', 'logits_1.00' and 'logits_1.50'.
+ If also_return_embeddings is True, it will also return an embeddings
+ tensor of shape [batch, height, width, embedding_dimension].
+
+ Raises:
+ ValueError: If model_options doesn't specify crop_size and its
+ add_image_level_feature = True, since add_image_level_feature requires
+ crop_size information.
+ """
+ # Setup default values.
+ if not image_pyramid:
+ image_pyramid = [1.0]
+ crop_height = (
+ model_options.crop_size[0]
+ if model_options.crop_size else tf.shape(images)[1])
+ crop_width = (
+ model_options.crop_size[1]
+ if model_options.crop_size else tf.shape(images)[2])
+
+ # Compute the height, width for the output logits.
+ if model_options.decoder_output_stride:
+ logits_output_stride = min(model_options.decoder_output_stride)
+ else:
+ logits_output_stride = model_options.output_stride
+ logits_height = scale_dimension(
+ crop_height,
+ max(1.0, max(image_pyramid)) / logits_output_stride)
+ logits_width = scale_dimension(
+ crop_width,
+ max(1.0, max(image_pyramid)) / logits_output_stride)
+
+ # Compute the logits for each scale in the image pyramid.
+ outputs_to_scales_to_logits = {
+ k: {}
+ for k in model_options.outputs_to_num_classes
+ }
+
+ for image_scale in image_pyramid:
+ if image_scale != 1.0:
+ scaled_height = scale_dimension(crop_height, image_scale)
+ scaled_width = scale_dimension(crop_width, image_scale)
+ scaled_crop_size = [scaled_height, scaled_width]
+ scaled_images = tf.image.resize_bilinear(
+ images, scaled_crop_size, align_corners=True)
+ scaled_reference_labels = tf.image.resize_nearest_neighbor(
+ reference_labels, scaled_crop_size, align_corners=True
+ )
+ if model_options.crop_size is None:
+ scaled_crop_size = None
+ if model_options.crop_size:
+ scaled_images.set_shape([None, scaled_height, scaled_width, 3])
+ else:
+ scaled_crop_size = model_options.crop_size
+ scaled_images = images
+ scaled_reference_labels = reference_labels
+
+ updated_options = model_options._replace(crop_size=scaled_crop_size)
+ res = embedding_utils.get_logits_with_matching(
+ scaled_images,
+ updated_options,
+ weight_decay=weight_decay,
+ reuse=tf.AUTO_REUSE,
+ is_training=is_training,
+ fine_tune_batch_norm=fine_tune_batch_norm,
+ reference_labels=scaled_reference_labels,
+ batch_size=clone_batch_size,
+ num_frames_per_video=num_frames_per_video,
+ embedding_dimension=embedding_dimension,
+ max_neighbors_per_object=max_neighbors_per_object,
+ k_nearest_neighbors=k_nearest_neighbors,
+ use_softmax_feedback=use_softmax_feedback,
+ initial_softmax_feedback=initial_softmax_feedback,
+ embedding_seg_feature_dimension=embedding_seg_feature_dimension,
+ embedding_seg_n_layers=embedding_seg_n_layers,
+ embedding_seg_kernel_size=embedding_seg_kernel_size,
+ embedding_seg_atrous_rates=embedding_seg_atrous_rates,
+ normalize_nearest_neighbor_distances=
+ normalize_nearest_neighbor_distances,
+ also_attend_to_previous_frame=also_attend_to_previous_frame,
+ damage_initial_previous_frame_mask=damage_initial_previous_frame_mask,
+ use_local_previous_frame_attention=use_local_previous_frame_attention,
+ previous_frame_attention_window_size=
+ previous_frame_attention_window_size,
+ use_first_frame_matching=use_first_frame_matching,
+ also_return_embeddings=also_return_embeddings,
+ ref_embeddings=ref_embeddings
+ )
+ if also_return_embeddings:
+ outputs_to_logits, embeddings = res
+ else:
+ outputs_to_logits = res
+ embeddings = None
+
+ # Resize the logits to have the same dimension before merging.
+ for output in sorted(outputs_to_logits):
+ if isinstance(outputs_to_logits[output], collections.Sequence):
+ outputs_to_logits[output] = [tf.image.resize_bilinear(
+ x, [logits_height, logits_width], align_corners=True)
+ for x in outputs_to_logits[output]]
+ else:
+ outputs_to_logits[output] = tf.image.resize_bilinear(
+ outputs_to_logits[output], [logits_height, logits_width],
+ align_corners=True)
+
+ # Return when only one input scale.
+ if len(image_pyramid) == 1:
+ for output in sorted(model_options.outputs_to_num_classes):
+ outputs_to_scales_to_logits[output][
+ MERGED_LOGITS_SCOPE] = outputs_to_logits[output]
+ if also_return_embeddings:
+ return outputs_to_scales_to_logits, embeddings
+ else:
+ return outputs_to_scales_to_logits
+
+ # Save logits to the output map.
+ for output in sorted(model_options.outputs_to_num_classes):
+ outputs_to_scales_to_logits[output][
+ 'logits_%.2f' % image_scale] = outputs_to_logits[output]
+
+ # Merge the logits from all the multi-scale inputs.
+ for output in sorted(model_options.outputs_to_num_classes):
+ # Concatenate the multi-scale logits for each output type.
+ all_logits = [
+ [tf.expand_dims(l, axis=4)]
+ for logits in outputs_to_scales_to_logits[output].values()
+ for l in logits
+ ]
+ transposed = map(list, zip(*all_logits))
+ all_logits = [tf.concat(t, 4) for t in transposed]
+ merge_fn = (
+ tf.reduce_max
+ if model_options.merge_method == 'max' else tf.reduce_mean)
+ outputs_to_scales_to_logits[output][MERGED_LOGITS_SCOPE] = [merge_fn(
+ l, axis=4) for l in all_logits]
+
+ if also_return_embeddings:
+ return outputs_to_scales_to_logits, embeddings
+ else:
+ return outputs_to_scales_to_logits
diff --git a/models/research/feelvos/train.py b/models/research/feelvos/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c085722749bcfde5aeff15cdbec336e5efe451
--- /dev/null
+++ b/models/research/feelvos/train.py
@@ -0,0 +1,630 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Training script for the FEELVOS model.
+
+See model.py for more details and usage.
+"""
+import six
+import tensorflow as tf
+
+from feelvos import common
+from feelvos import model
+from feelvos.datasets import video_dataset
+from feelvos.utils import embedding_utils
+from feelvos.utils import train_utils
+from feelvos.utils import video_input_generator
+from deployment import model_deploy
+
+slim = tf.contrib.slim
+prefetch_queue = slim.prefetch_queue
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+# Settings for multi-GPUs/multi-replicas training.
+
+flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.')
+
+flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')
+
+flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')
+
+flags.DEFINE_integer('startup_delay_steps', 15,
+ 'Number of training steps between replicas startup.')
+
+flags.DEFINE_integer('num_ps_tasks', 0,
+ 'The number of parameter servers. If the value is 0, then '
+ 'the parameters are handled locally by the worker.')
+
+flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
+
+flags.DEFINE_integer('task', 0, 'The task ID.')
+
+# Settings for logging.
+
+flags.DEFINE_string('train_logdir', None,
+ 'Where the checkpoint and logs are stored.')
+
+flags.DEFINE_integer('log_steps', 10,
+ 'Display logging information at every log_steps.')
+
+flags.DEFINE_integer('save_interval_secs', 1200,
+ 'How often, in seconds, we save the model to disk.')
+
+flags.DEFINE_integer('save_summaries_secs', 600,
+ 'How often, in seconds, we compute the summaries.')
+
+# Settings for training strategy.
+
+flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
+ 'Learning rate policy for training.')
+
+flags.DEFINE_float('base_learning_rate', 0.0007,
+ 'The base learning rate for model training.')
+
+flags.DEFINE_float('learning_rate_decay_factor', 0.1,
+ 'The rate to decay the base learning rate.')
+
+flags.DEFINE_integer('learning_rate_decay_step', 2000,
+ 'Decay the base learning rate at a fixed step.')
+
+flags.DEFINE_float('learning_power', 0.9,
+ 'The power value used in the poly learning policy.')
+
+flags.DEFINE_integer('training_number_of_steps', 200000,
+ 'The number of steps used for training')
+
+flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')
+
+flags.DEFINE_integer('train_batch_size', 6,
+ 'The number of images in each batch during training.')
+
+flags.DEFINE_integer('train_num_frames_per_video', 3,
+ 'The number of frames used per video during training')
+
+flags.DEFINE_float('weight_decay', 0.00004,
+ 'The value of the weight decay for training.')
+
+flags.DEFINE_multi_integer('train_crop_size', [465, 465],
+ 'Image crop size [height, width] during training.')
+
+flags.DEFINE_float('last_layer_gradient_multiplier', 1.0,
+ 'The gradient multiplier for last layers, which is used to '
+ 'boost the gradient of last layers if the value > 1.')
+
+flags.DEFINE_boolean('upsample_logits', True,
+ 'Upsample logits during training.')
+
+flags.DEFINE_integer('batch_capacity_factor', 16, 'Batch capacity factor.')
+
+flags.DEFINE_integer('num_readers', 1, 'Number of readers for data provider.')
+
+flags.DEFINE_integer('batch_num_threads', 1, 'Batch number of threads.')
+
+flags.DEFINE_integer('prefetch_queue_capacity_factor', 32,
+ 'Prefetch queue capacity factor.')
+
+flags.DEFINE_integer('prefetch_queue_num_threads', 1,
+ 'Prefetch queue number of threads.')
+
+flags.DEFINE_integer('train_max_neighbors_per_object', 1024,
+ 'The maximum number of candidates for the nearest '
+ 'neighbor query per object after subsampling')
+
+# Settings for fine-tuning the network.
+
+flags.DEFINE_string('tf_initial_checkpoint', None,
+ 'The initial checkpoint in tensorflow format.')
+
+flags.DEFINE_boolean('initialize_last_layer', False,
+ 'Initialize the last layer.')
+
+flags.DEFINE_boolean('last_layers_contain_logits_only', False,
+ 'Only consider logits as last layers or not.')
+
+flags.DEFINE_integer('slow_start_step', 0,
+ 'Training model with small learning rate for few steps.')
+
+flags.DEFINE_float('slow_start_learning_rate', 1e-4,
+ 'Learning rate employed during slow start.')
+
+flags.DEFINE_boolean('fine_tune_batch_norm', False,
+ 'Fine tune the batch norm parameters or not.')
+
+flags.DEFINE_float('min_scale_factor', 1.,
+ 'Mininum scale factor for data augmentation.')
+
+flags.DEFINE_float('max_scale_factor', 1.3,
+ 'Maximum scale factor for data augmentation.')
+
+flags.DEFINE_float('scale_factor_step_size', 0,
+ 'Scale factor step size for data augmentation.')
+
+flags.DEFINE_multi_integer('atrous_rates', None,
+ 'Atrous rates for atrous spatial pyramid pooling.')
+
+flags.DEFINE_integer('output_stride', 8,
+ 'The ratio of input to output spatial resolution.')
+
+flags.DEFINE_boolean('sample_only_first_frame_for_finetuning', False,
+ 'Whether to only sample the first frame during '
+ 'fine-tuning. This should be False when using lucid data, '
+ 'but True when fine-tuning on the first frame only. Only '
+ 'has an effect if first_frame_finetuning is True.')
+
+flags.DEFINE_multi_integer('first_frame_finetuning', [0],
+ 'Whether to only sample the first frame for '
+ 'fine-tuning.')
+
+# Dataset settings.
+
+flags.DEFINE_multi_string('dataset', [], 'Name of the segmentation datasets.')
+
+flags.DEFINE_multi_float('dataset_sampling_probabilities', [],
+ 'A list of probabilities to sample each of the '
+ 'datasets.')
+
+flags.DEFINE_string('train_split', 'train',
+ 'Which split of the dataset to be used for training')
+
+flags.DEFINE_multi_string('dataset_dir', [], 'Where the datasets reside.')
+
+flags.DEFINE_multi_integer('three_frame_dataset', [0],
+ 'Whether the dataset has exactly three frames per '
+ 'video of which the first is to be used as reference'
+ ' and the two others are consecutive frames to be '
+ 'used as query frames.'
+ 'Set true for pascal lucid data.')
+
+flags.DEFINE_boolean('damage_initial_previous_frame_mask', False,
+ 'Whether to artificially damage the initial previous '
+ 'frame mask. Only has an effect if '
+ 'also_attend_to_previous_frame is True.')
+
+flags.DEFINE_float('top_k_percent_pixels', 0.15, 'Float in [0.0, 1.0].'
+ 'When its value < 1.0, only compute the loss for the top k'
+ 'percent pixels (e.g., the top 20% pixels). This is useful'
+ 'for hard pixel mining.')
+
+flags.DEFINE_integer('hard_example_mining_step', 100000,
+ 'The training step in which the hard exampling mining '
+ 'kicks off. Note that we gradually reduce the mining '
+ 'percent to the top_k_percent_pixels. For example, if '
+ 'hard_example_mining_step=100K and '
+ 'top_k_percent_pixels=0.25, then mining percent will '
+ 'gradually reduce from 100% to 25% until 100K steps '
+ 'after which we only mine top 25% pixels. Only has an '
+ 'effect if top_k_percent_pixels < 1.0')
+
+
+def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes,
+ ignore_label):
+ """Builds a clone of DeepLab.
+
+ Args:
+ inputs_queue_or_samples: A prefetch queue for images and labels, or
+ directly a dict of the samples.
+ outputs_to_num_classes: A map from output type to the number of classes.
+ For example, for the task of semantic segmentation with 21 semantic
+ classes, we would have outputs_to_num_classes['semantic'] = 21.
+ ignore_label: Ignore label.
+
+ Returns:
+ A map of maps from output_type (e.g., semantic prediction) to a
+ dictionary of multi-scale logits names to logits. For each output_type,
+ the dictionary has keys which correspond to the scales and values which
+ correspond to the logits. For example, if `scales` equals [1.0, 1.5],
+ then the keys would include 'merged_logits', 'logits_1.00' and
+ 'logits_1.50'.
+
+ Raises:
+ ValueError: If classification_loss is not softmax, softmax_with_attention,
+ or triplet.
+ """
+ if hasattr(inputs_queue_or_samples, 'dequeue'):
+ samples = inputs_queue_or_samples.dequeue()
+ else:
+ samples = inputs_queue_or_samples
+ train_crop_size = (None if 0 in FLAGS.train_crop_size else
+ FLAGS.train_crop_size)
+
+ model_options = common.VideoModelOptions(
+ outputs_to_num_classes=outputs_to_num_classes,
+ crop_size=train_crop_size,
+ atrous_rates=FLAGS.atrous_rates,
+ output_stride=FLAGS.output_stride)
+
+ if model_options.classification_loss == 'softmax_with_attention':
+ clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones
+
+ # Create summaries of ground truth labels.
+ for n in range(clone_batch_size):
+ tf.summary.image(
+ 'gt_label_%d' % n,
+ tf.cast(samples[common.LABEL][
+ n * FLAGS.train_num_frames_per_video:
+ (n + 1) * FLAGS.train_num_frames_per_video],
+ tf.uint8) * 32, max_outputs=FLAGS.train_num_frames_per_video)
+
+ if common.PRECEDING_FRAME_LABEL in samples:
+ preceding_frame_label = samples[common.PRECEDING_FRAME_LABEL]
+ init_softmax = []
+ for n in range(clone_batch_size):
+ init_softmax_n = embedding_utils.create_initial_softmax_from_labels(
+ preceding_frame_label[n, tf.newaxis],
+ samples[common.LABEL][n * FLAGS.train_num_frames_per_video,
+ tf.newaxis],
+ common.parse_decoder_output_stride(),
+ reduce_labels=True)
+ init_softmax_n = tf.squeeze(init_softmax_n, axis=0)
+ init_softmax.append(init_softmax_n)
+ tf.summary.image('preceding_frame_label',
+ tf.cast(preceding_frame_label[n, tf.newaxis],
+ tf.uint8) * 32)
+ else:
+ init_softmax = None
+
+ outputs_to_scales_to_logits = (
+ model.multi_scale_logits_with_nearest_neighbor_matching(
+ samples[common.IMAGE],
+ model_options=model_options,
+ image_pyramid=FLAGS.image_pyramid,
+ weight_decay=FLAGS.weight_decay,
+ is_training=True,
+ fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
+ reference_labels=samples[common.LABEL],
+ clone_batch_size=FLAGS.train_batch_size // FLAGS.num_clones,
+ num_frames_per_video=FLAGS.train_num_frames_per_video,
+ embedding_dimension=FLAGS.embedding_dimension,
+ max_neighbors_per_object=FLAGS.train_max_neighbors_per_object,
+ k_nearest_neighbors=FLAGS.k_nearest_neighbors,
+ use_softmax_feedback=FLAGS.use_softmax_feedback,
+ initial_softmax_feedback=init_softmax,
+ embedding_seg_feature_dimension=
+ FLAGS.embedding_seg_feature_dimension,
+ embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
+ embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
+ embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
+ normalize_nearest_neighbor_distances=
+ FLAGS.normalize_nearest_neighbor_distances,
+ also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
+ damage_initial_previous_frame_mask=
+ FLAGS.damage_initial_previous_frame_mask,
+ use_local_previous_frame_attention=
+ FLAGS.use_local_previous_frame_attention,
+ previous_frame_attention_window_size=
+ FLAGS.previous_frame_attention_window_size,
+ use_first_frame_matching=FLAGS.use_first_frame_matching
+ ))
+ else:
+ outputs_to_scales_to_logits = model.multi_scale_logits_v2(
+ samples[common.IMAGE],
+ model_options=model_options,
+ image_pyramid=FLAGS.image_pyramid,
+ weight_decay=FLAGS.weight_decay,
+ is_training=True,
+ fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
+
+ if model_options.classification_loss == 'softmax':
+ for output, num_classes in six.iteritems(outputs_to_num_classes):
+ train_utils.add_softmax_cross_entropy_loss_for_each_scale(
+ outputs_to_scales_to_logits[output],
+ samples[common.LABEL],
+ num_classes,
+ ignore_label,
+ loss_weight=1.0,
+ upsample_logits=FLAGS.upsample_logits,
+ scope=output)
+ elif model_options.classification_loss == 'triplet':
+ for output, _ in six.iteritems(outputs_to_num_classes):
+ train_utils.add_triplet_loss_for_each_scale(
+ FLAGS.train_batch_size // FLAGS.num_clones,
+ FLAGS.train_num_frames_per_video,
+ FLAGS.embedding_dimension, outputs_to_scales_to_logits[output],
+ samples[common.LABEL], scope=output)
+ elif model_options.classification_loss == 'softmax_with_attention':
+ labels = samples[common.LABEL]
+ batch_size = FLAGS.train_batch_size // FLAGS.num_clones
+ num_frames_per_video = FLAGS.train_num_frames_per_video
+ h, w = train_utils.resolve_shape(labels)[1:3]
+ labels = tf.reshape(labels, tf.stack(
+ [batch_size, num_frames_per_video, h, w, 1]))
+ # Strip the reference labels off.
+ if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
+ n_ref_frames = 2
+ else:
+ n_ref_frames = 1
+ labels = labels[:, n_ref_frames:]
+ # Merge batch and time dimensions.
+ labels = tf.reshape(labels, tf.stack(
+ [batch_size * (num_frames_per_video - n_ref_frames), h, w, 1]))
+
+ for output, num_classes in six.iteritems(outputs_to_num_classes):
+ train_utils.add_dynamic_softmax_cross_entropy_loss_for_each_scale(
+ outputs_to_scales_to_logits[output],
+ labels,
+ ignore_label,
+ loss_weight=1.0,
+ upsample_logits=FLAGS.upsample_logits,
+ scope=output,
+ top_k_percent_pixels=FLAGS.top_k_percent_pixels,
+ hard_example_mining_step=FLAGS.hard_example_mining_step)
+ else:
+ raise ValueError('Only support softmax, softmax_with_attention'
+ ' or triplet for classification_loss.')
+
+ return outputs_to_scales_to_logits
+
+
+def main(unused_argv):
+ # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
+ config = model_deploy.DeploymentConfig(
+ num_clones=FLAGS.num_clones,
+ clone_on_cpu=FLAGS.clone_on_cpu,
+ replica_id=FLAGS.task,
+ num_replicas=FLAGS.num_replicas,
+ num_ps_tasks=FLAGS.num_ps_tasks)
+
+ with tf.Graph().as_default():
+ with tf.device(config.inputs_device()):
+ train_crop_size = (None if 0 in FLAGS.train_crop_size else
+ FLAGS.train_crop_size)
+ assert FLAGS.dataset
+ assert len(FLAGS.dataset) == len(FLAGS.dataset_dir)
+ if len(FLAGS.first_frame_finetuning) == 1:
+ first_frame_finetuning = (list(FLAGS.first_frame_finetuning)
+ * len(FLAGS.dataset))
+ else:
+ first_frame_finetuning = FLAGS.first_frame_finetuning
+ if len(FLAGS.three_frame_dataset) == 1:
+ three_frame_dataset = (list(FLAGS.three_frame_dataset)
+ * len(FLAGS.dataset))
+ else:
+ three_frame_dataset = FLAGS.three_frame_dataset
+ assert len(FLAGS.dataset) == len(first_frame_finetuning)
+ assert len(FLAGS.dataset) == len(three_frame_dataset)
+ datasets, samples_list = zip(
+ *[_get_dataset_and_samples(config, train_crop_size, dataset,
+ dataset_dir, bool(first_frame_finetuning_),
+ bool(three_frame_dataset_))
+ for dataset, dataset_dir, first_frame_finetuning_,
+ three_frame_dataset_ in zip(FLAGS.dataset, FLAGS.dataset_dir,
+ first_frame_finetuning,
+ three_frame_dataset)])
+ # Note that this way of doing things is wasteful since it will evaluate
+ # all branches but just use one of them. But let's do it anyway for now,
+ # since it's easy and will probably be fast enough.
+ dataset = datasets[0]
+ if len(samples_list) == 1:
+ samples = samples_list[0]
+ else:
+ probabilities = FLAGS.dataset_sampling_probabilities
+ if probabilities:
+ assert len(probabilities) == len(samples_list)
+ else:
+ # Default to uniform probabilities.
+ probabilities = [1.0 / len(samples_list) for _ in samples_list]
+ probabilities = tf.constant(probabilities)
+ logits = tf.log(probabilities[tf.newaxis])
+ rand_idx = tf.squeeze(tf.multinomial(logits, 1, output_dtype=tf.int32),
+ axis=[0, 1])
+
+ def wrap(x):
+ def f():
+ return x
+ return f
+
+ samples = tf.case({tf.equal(rand_idx, idx): wrap(s)
+ for idx, s in enumerate(samples_list)},
+ exclusive=True)
+
+ # Prefetch_queue requires the shape to be known at graph creation time.
+ # So we only use it if we crop to a fixed size.
+ if train_crop_size is None:
+ inputs_queue = samples
+ else:
+ inputs_queue = prefetch_queue.prefetch_queue(
+ samples,
+ capacity=FLAGS.prefetch_queue_capacity_factor*config.num_clones,
+ num_threads=FLAGS.prefetch_queue_num_threads)
+
+ # Create the global step on the device storing the variables.
+ with tf.device(config.variables_device()):
+ global_step = tf.train.get_or_create_global_step()
+
+ # Define the model and create clones.
+ model_fn = _build_deeplab
+ if FLAGS.classification_loss == 'triplet':
+ embedding_dim = FLAGS.embedding_dimension
+ output_type_to_dim = {'embedding': embedding_dim}
+ else:
+ output_type_to_dim = {common.OUTPUT_TYPE: dataset.num_classes}
+ model_args = (inputs_queue, output_type_to_dim, dataset.ignore_label)
+ clones = model_deploy.create_clones(config, model_fn, args=model_args)
+
+ # Gather update_ops from the first clone. These contain, for example,
+ # the updates for the batch_norm variables created by model_fn.
+ first_clone_scope = config.clone_scope(0)
+ update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
+
+ # Gather initial summaries.
+ summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
+
+ # Add summaries for model variables.
+ for model_var in tf.contrib.framework.get_model_variables():
+ summaries.add(tf.summary.histogram(model_var.op.name, model_var))
+
+ # Add summaries for losses.
+ for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
+ summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
+
+ # Build the optimizer based on the device specification.
+ with tf.device(config.optimizer_device()):
+ learning_rate = train_utils.get_model_learning_rate(
+ FLAGS.learning_policy,
+ FLAGS.base_learning_rate,
+ FLAGS.learning_rate_decay_step,
+ FLAGS.learning_rate_decay_factor,
+ FLAGS.training_number_of_steps,
+ FLAGS.learning_power,
+ FLAGS.slow_start_step,
+ FLAGS.slow_start_learning_rate)
+ optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
+ summaries.add(tf.summary.scalar('learning_rate', learning_rate))
+
+ startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
+
+ with tf.device(config.variables_device()):
+ total_loss, grads_and_vars = model_deploy.optimize_clones(
+ clones, optimizer)
+ total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
+ summaries.add(tf.summary.scalar('total_loss', total_loss))
+
+ # Modify the gradients for biases and last layer variables.
+ last_layers = model.get_extra_layer_scopes(
+ FLAGS.last_layers_contain_logits_only)
+ grad_mult = train_utils.get_model_gradient_multipliers(
+ last_layers, FLAGS.last_layer_gradient_multiplier)
+ if grad_mult:
+ grads_and_vars = slim.learning.multiply_gradients(grads_and_vars,
+ grad_mult)
+
+ with tf.name_scope('grad_clipping'):
+ grads_and_vars = slim.learning.clip_gradient_norms(grads_and_vars, 5.0)
+
+ # Create histogram summaries for the gradients.
+ # We have too many summaries for mldash, so disable this one for now.
+ # for grad, var in grads_and_vars:
+ # summaries.add(tf.summary.histogram(
+ # var.name.replace(':0', '_0') + '/gradient', grad))
+
+ # Create gradient update op.
+ grad_updates = optimizer.apply_gradients(grads_and_vars,
+ global_step=global_step)
+ update_ops.append(grad_updates)
+ update_op = tf.group(*update_ops)
+ with tf.control_dependencies([update_op]):
+ train_tensor = tf.identity(total_loss, name='train_op')
+
+ # Add the summaries from the first clone. These contain the summaries
+ # created by model_fn and either optimize_clones() or _gather_clone_loss().
+ summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
+ first_clone_scope))
+
+ # Merge all summaries together.
+ summary_op = tf.summary.merge(list(summaries))
+
+ # Soft placement allows placing on CPU ops without GPU implementation.
+ session_config = tf.ConfigProto(allow_soft_placement=True,
+ log_device_placement=False)
+
+ # Start the training.
+ slim.learning.train(
+ train_tensor,
+ logdir=FLAGS.train_logdir,
+ log_every_n_steps=FLAGS.log_steps,
+ master=FLAGS.master,
+ number_of_steps=FLAGS.training_number_of_steps,
+ is_chief=(FLAGS.task == 0),
+ session_config=session_config,
+ startup_delay_steps=startup_delay_steps,
+ init_fn=train_utils.get_model_init_fn(FLAGS.train_logdir,
+ FLAGS.tf_initial_checkpoint,
+ FLAGS.initialize_last_layer,
+ last_layers,
+ ignore_missing_vars=True),
+ summary_op=summary_op,
+ save_summaries_secs=FLAGS.save_summaries_secs,
+ save_interval_secs=FLAGS.save_interval_secs)
+
+
+def _get_dataset_and_samples(config, train_crop_size, dataset_name,
+ dataset_dir, first_frame_finetuning,
+ three_frame_dataset):
+ """Creates dataset object and samples dict of tensor.
+
+ Args:
+ config: A DeploymentConfig.
+ train_crop_size: Integer, the crop size used for training.
+ dataset_name: String, the name of the dataset.
+ dataset_dir: String, the directory of the dataset.
+ first_frame_finetuning: Boolean, whether the used dataset is a dataset
+ for first frame fine-tuning.
+ three_frame_dataset: Boolean, whether the dataset has exactly three frames
+ per video of which the first is to be used as reference and the two
+ others are consecutive frames to be used as query frames.
+
+ Returns:
+ dataset: An instance of slim Dataset.
+ samples: A dictionary of tensors for semantic segmentation.
+ """
+
+ # Split the batch across GPUs.
+ assert FLAGS.train_batch_size % config.num_clones == 0, (
+ 'Training batch size not divisble by number of clones (GPUs).')
+
+ clone_batch_size = FLAGS.train_batch_size / config.num_clones
+
+ if first_frame_finetuning:
+ train_split = 'val'
+ else:
+ train_split = FLAGS.train_split
+
+ data_type = 'tf_sequence_example'
+ # Get dataset-dependent information.
+ dataset = video_dataset.get_dataset(
+ dataset_name,
+ train_split,
+ dataset_dir=dataset_dir,
+ data_type=data_type)
+
+ tf.gfile.MakeDirs(FLAGS.train_logdir)
+ tf.logging.info('Training on %s set', train_split)
+
+ samples = video_input_generator.get(
+ dataset,
+ FLAGS.train_num_frames_per_video,
+ train_crop_size,
+ clone_batch_size,
+ num_readers=FLAGS.num_readers,
+ num_threads=FLAGS.batch_num_threads,
+ min_resize_value=FLAGS.min_resize_value,
+ max_resize_value=FLAGS.max_resize_value,
+ resize_factor=FLAGS.resize_factor,
+ min_scale_factor=FLAGS.min_scale_factor,
+ max_scale_factor=FLAGS.max_scale_factor,
+ scale_factor_step_size=FLAGS.scale_factor_step_size,
+ dataset_split=FLAGS.train_split,
+ is_training=True,
+ model_variant=FLAGS.model_variant,
+ batch_capacity_factor=FLAGS.batch_capacity_factor,
+ decoder_output_stride=common.parse_decoder_output_stride(),
+ first_frame_finetuning=first_frame_finetuning,
+ sample_only_first_frame_for_finetuning=
+ FLAGS.sample_only_first_frame_for_finetuning,
+ sample_adjacent_and_consistent_query_frames=
+ FLAGS.sample_adjacent_and_consistent_query_frames or
+ FLAGS.use_softmax_feedback,
+ remap_labels_to_reference_frame=True,
+ three_frame_dataset=three_frame_dataset,
+ add_prev_frame_label=not FLAGS.also_attend_to_previous_frame
+ )
+ return dataset, samples
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('train_logdir')
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run()
diff --git a/models/research/feelvos/train.sh b/models/research/feelvos/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..63b7ea19d4c53dea932322c3885abb9a95237e0c
--- /dev/null
+++ b/models/research/feelvos/train.sh
@@ -0,0 +1,92 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script is used to run local training on DAVIS 2017. Users could also
+# modify from this script for their use case. See eval.sh for an example of
+# local inference with a pre-trained model.
+#
+# Note that this script runs local training with a single GPU and a smaller crop
+# and batch size, while in the paper, we trained our models with 16 GPUS with
+# --num_clones=2, --train_batch_size=6, --num_replicas=8,
+# --training_number_of_steps=200000, --train_crop_size=465,
+# --train_crop_size=465.
+#
+# Usage:
+# # From the tensorflow/models/research/feelvos directory.
+# sh ./train.sh
+#
+#
+
+# Exit immediately if a command exits with a non-zero status.
+set -e
+
+# Move one-level up to tensorflow/models/research directory.
+cd ..
+
+# Update PYTHONPATH.
+export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim:`pwd`/feelvos
+
+# Set up the working environment.
+CURRENT_DIR=$(pwd)
+WORK_DIR="${CURRENT_DIR}/feelvos"
+
+# Set up the working directories.
+DATASET_DIR="datasets"
+DAVIS_FOLDER="davis17"
+DAVIS_DATASET="${WORK_DIR}/${DATASET_DIR}/${DAVIS_FOLDER}/tfrecord"
+EXP_FOLDER="exp/train"
+TRAIN_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${DAVIS_FOLDER}/${EXP_FOLDER}/train"
+mkdir -p ${TRAIN_LOGDIR}
+
+# Go to datasets folder and download and convert the DAVIS 2017 dataset.
+DATASET_DIR="datasets"
+cd "${WORK_DIR}/${DATASET_DIR}"
+sh download_and_convert_davis17.sh
+
+# Go to models folder and download and unpack the COCO pre-trained model.
+MODELS_DIR="models"
+mkdir -p "${WORK_DIR}/${MODELS_DIR}"
+cd "${WORK_DIR}/${MODELS_DIR}"
+if [ ! -d "xception_65_coco_pretrained" ]; then
+ wget http://download.tensorflow.org/models/xception_65_coco_pretrained_2018_10_02.tar.gz
+ tar -xvf xception_65_coco_pretrained_2018_10_02.tar.gz
+ rm xception_65_coco_pretrained_2018_10_02.tar.gz
+fi
+INIT_CKPT="${WORK_DIR}/${MODELS_DIR}/xception_65_coco_pretrained/x65-b2u1s2p-d48-2-3x256-sc-cr300k_init.ckpt"
+
+# Go back to orignal directory.
+cd "${CURRENT_DIR}"
+
+python "${WORK_DIR}"/train.py \
+ --dataset=davis_2017 \
+ --dataset_dir="${DAVIS_DATASET}" \
+ --train_logdir="${TRAIN_LOGDIR}" \
+ --tf_initial_checkpoint="${INIT_CKPT}" \
+ --logtostderr \
+ --atrous_rates=6 \
+ --atrous_rates=12 \
+ --atrous_rates=18 \
+ --decoder_output_stride=4 \
+ --model_variant=xception_65 \
+ --multi_grid=1 \
+ --multi_grid=1 \
+ --multi_grid=1 \
+ --output_stride=16 \
+ --weight_decay=0.00004 \
+ --num_clones=1 \
+ --train_batch_size=1 \
+ --train_crop_size=300 \
+ --train_crop_size=300
diff --git a/models/research/feelvos/utils/__init__.py b/models/research/feelvos/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f1373443d0ff84fd90714e41dade400ab41a22c
--- /dev/null
+++ b/models/research/feelvos/utils/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/models/research/feelvos/utils/embedding_utils.py b/models/research/feelvos/utils/embedding_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..233c70d9327d08251537c58821dd8405b42f0fe7
--- /dev/null
+++ b/models/research/feelvos/utils/embedding_utils.py
@@ -0,0 +1,1082 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities for the instance embedding for segmentation."""
+
+import numpy as np
+import tensorflow as tf
+from deeplab import model
+from deeplab.core import preprocess_utils
+from feelvos.utils import mask_damaging
+
+slim = tf.contrib.slim
+resolve_shape = preprocess_utils.resolve_shape
+WRONG_LABEL_PADDING_DISTANCE = 1e20
+
+# With correlation_cost local matching will be much faster. But we provide a
+# slow fallback for convenience.
+USE_CORRELATION_COST = False
+if USE_CORRELATION_COST:
+ # pylint: disable=g-import-not-at-top
+ from correlation_cost.python.ops import correlation_cost_op
+
+
+def pairwise_distances(x, y):
+ """Computes pairwise squared l2 distances between tensors x and y.
+
+ Args:
+ x: Tensor of shape [n, feature_dim].
+ y: Tensor of shape [m, feature_dim].
+
+ Returns:
+ Float32 distances tensor of shape [n, m].
+ """
+ # d[i,j] = (x[i] - y[j]) * (x[i] - y[j])'
+ # = sum(x[i]^2, 1) + sum(y[j]^2, 1) - 2 * x[i] * y[j]'
+ xs = tf.reduce_sum(x * x, axis=1)[:, tf.newaxis]
+ ys = tf.reduce_sum(y * y, axis=1)[tf.newaxis, :]
+ d = xs + ys - 2 * tf.matmul(x, y, transpose_b=True)
+ return d
+
+
+def pairwise_distances2(x, y):
+ """Computes pairwise squared l2 distances between tensors x and y.
+
+ Naive implementation, high memory use. Could be useful to test the more
+ efficient implementation.
+
+ Args:
+ x: Tensor of shape [n, feature_dim].
+ y: Tensor of shape [m, feature_dim].
+
+ Returns:
+ distances of shape [n, m].
+ """
+ return tf.reduce_sum(tf.squared_difference(
+ x[:, tf.newaxis], y[tf.newaxis, :]), axis=-1)
+
+
+def cross_correlate(x, y, max_distance=9):
+ """Efficiently computes the cross correlation of x and y.
+
+ Optimized implementation using correlation_cost.
+ Note that we do not normalize by the feature dimension.
+
+ Args:
+ x: Float32 tensor of shape [height, width, feature_dim].
+ y: Float32 tensor of shape [height, width, feature_dim].
+ max_distance: Integer, the maximum distance in pixel coordinates
+ per dimension which is considered to be in the search window.
+
+ Returns:
+ Float32 tensor of shape [height, width, (2 * max_distance + 1) ** 2].
+ """
+ with tf.name_scope('cross_correlation'):
+ corr = correlation_cost_op.correlation_cost(
+ x[tf.newaxis], y[tf.newaxis], kernel_size=1,
+ max_displacement=max_distance, stride_1=1, stride_2=1,
+ pad=max_distance)
+ corr = tf.squeeze(corr, axis=0)
+ # This correlation implementation takes the mean over the feature_dim,
+ # but we want sum here, so multiply by feature_dim.
+ feature_dim = resolve_shape(x)[-1]
+ corr *= feature_dim
+ return corr
+
+
+def local_pairwise_distances(x, y, max_distance=9):
+ """Computes pairwise squared l2 distances using a local search window.
+
+ Optimized implementation using correlation_cost.
+
+ Args:
+ x: Float32 tensor of shape [height, width, feature_dim].
+ y: Float32 tensor of shape [height, width, feature_dim].
+ max_distance: Integer, the maximum distance in pixel coordinates
+ per dimension which is considered to be in the search window.
+
+ Returns:
+ Float32 distances tensor of shape
+ [height, width, (2 * max_distance + 1) ** 2].
+ """
+ with tf.name_scope('local_pairwise_distances'):
+ # d[i,j] = (x[i] - y[j]) * (x[i] - y[j])'
+ # = sum(x[i]^2, -1) + sum(y[j]^2, -1) - 2 * x[i] * y[j]'
+ corr = cross_correlate(x, y, max_distance=max_distance)
+ xs = tf.reduce_sum(x * x, axis=2)[..., tf.newaxis]
+ ys = tf.reduce_sum(y * y, axis=2)[..., tf.newaxis]
+ ones_ys = tf.ones_like(ys)
+ ys = cross_correlate(ones_ys, ys, max_distance=max_distance)
+ d = xs + ys - 2 * corr
+ # Boundary should be set to Inf.
+ boundary = tf.equal(
+ cross_correlate(ones_ys, ones_ys, max_distance=max_distance), 0)
+ d = tf.where(boundary, tf.fill(tf.shape(d), tf.constant(np.float('inf'))),
+ d)
+ return d
+
+
+def local_pairwise_distances2(x, y, max_distance=9):
+ """Computes pairwise squared l2 distances using a local search window.
+
+ Naive implementation using map_fn.
+ Used as a slow fallback for when correlation_cost is not available.
+
+ Args:
+ x: Float32 tensor of shape [height, width, feature_dim].
+ y: Float32 tensor of shape [height, width, feature_dim].
+ max_distance: Integer, the maximum distance in pixel coordinates
+ per dimension which is considered to be in the search window.
+
+ Returns:
+ Float32 distances tensor of shape
+ [height, width, (2 * max_distance + 1) ** 2].
+ """
+ with tf.name_scope('local_pairwise_distances2'):
+ padding_val = 1e20
+ padded_y = tf.pad(y, [[max_distance, max_distance],
+ [max_distance, max_distance], [0, 0]],
+ constant_values=padding_val)
+ height, width, _ = resolve_shape(x)
+ dists = []
+ for y_start in range(2 * max_distance + 1):
+ y_end = y_start + height
+ y_slice = padded_y[y_start:y_end]
+ for x_start in range(2 * max_distance + 1):
+ x_end = x_start + width
+ offset_y = y_slice[:, x_start:x_end]
+ dist = tf.reduce_sum(tf.squared_difference(x, offset_y), axis=2)
+ dists.append(dist)
+ dists = tf.stack(dists, axis=2)
+ return dists
+
+
+def majority_vote(labels):
+ """Performs a label majority vote along axis 1.
+
+ Second try, hopefully this time more efficient.
+ We assume that the labels are contiguous starting from 0.
+ It will also work for non-contiguous labels, but be inefficient.
+
+ Args:
+ labels: Int tensor of shape [n, k]
+
+ Returns:
+ The majority of labels along axis 1
+ """
+ max_label = tf.reduce_max(labels)
+ one_hot = tf.one_hot(labels, depth=max_label + 1)
+ summed = tf.reduce_sum(one_hot, axis=1)
+ majority = tf.argmax(summed, axis=1)
+ return majority
+
+
+def assign_labels_by_nearest_neighbors(reference_embeddings, query_embeddings,
+ reference_labels, k=1):
+ """Segments by nearest neighbor query wrt the reference frame.
+
+ Args:
+ reference_embeddings: Tensor of shape [height, width, embedding_dim],
+ the embedding vectors for the reference frame
+ query_embeddings: Tensor of shape [n_query_images, height, width,
+ embedding_dim], the embedding vectors for the query frames
+ reference_labels: Tensor of shape [height, width, 1], the class labels of
+ the reference frame
+ k: Integer, the number of nearest neighbors to use
+
+ Returns:
+ The labels of the nearest neighbors as [n_query_frames, height, width, 1]
+ tensor
+
+ Raises:
+ ValueError: If k < 1.
+ """
+ if k < 1:
+ raise ValueError('k must be at least 1')
+ dists = flattened_pairwise_distances(reference_embeddings, query_embeddings)
+ if k == 1:
+ nn_indices = tf.argmin(dists, axis=1)[..., tf.newaxis]
+ else:
+ _, nn_indices = tf.nn.top_k(-dists, k, sorted=False)
+ reference_labels = tf.reshape(reference_labels, [-1])
+ nn_labels = tf.gather(reference_labels, nn_indices)
+ if k == 1:
+ nn_labels = tf.squeeze(nn_labels, axis=1)
+ else:
+ nn_labels = majority_vote(nn_labels)
+ height = tf.shape(reference_embeddings)[0]
+ width = tf.shape(reference_embeddings)[1]
+ n_query_frames = query_embeddings.shape[0]
+ nn_labels = tf.reshape(nn_labels, [n_query_frames, height, width, 1])
+ return nn_labels
+
+
+def flattened_pairwise_distances(reference_embeddings, query_embeddings):
+ """Calculates flattened tensor of pairwise distances between ref and query.
+
+ Args:
+ reference_embeddings: Tensor of shape [..., embedding_dim],
+ the embedding vectors for the reference frame
+ query_embeddings: Tensor of shape [n_query_images, height, width,
+ embedding_dim], the embedding vectors for the query frames.
+
+ Returns:
+ A distance tensor of shape [reference_embeddings.size / embedding_dim,
+ query_embeddings.size / embedding_dim]
+ """
+ embedding_dim = resolve_shape(query_embeddings)[-1]
+ reference_embeddings = tf.reshape(reference_embeddings, [-1, embedding_dim])
+ first_dim = -1
+ query_embeddings = tf.reshape(query_embeddings, [first_dim, embedding_dim])
+ dists = pairwise_distances(query_embeddings, reference_embeddings)
+ return dists
+
+
+def nearest_neighbor_features_per_object(
+ reference_embeddings, query_embeddings, reference_labels,
+ max_neighbors_per_object, k_nearest_neighbors, gt_ids=None, n_chunks=100):
+ """Calculates the distance to the nearest neighbor per object.
+
+ For every pixel of query_embeddings calculate the distance to the
+ nearest neighbor in the (possibly subsampled) reference_embeddings per object.
+
+ Args:
+ reference_embeddings: Tensor of shape [height, width, embedding_dim],
+ the embedding vectors for the reference frame.
+ query_embeddings: Tensor of shape [n_query_images, height, width,
+ embedding_dim], the embedding vectors for the query frames.
+ reference_labels: Tensor of shape [height, width, 1], the class labels of
+ the reference frame.
+ max_neighbors_per_object: Integer, the maximum number of candidates
+ for the nearest neighbor query per object after subsampling,
+ or 0 for no subsampling.
+ k_nearest_neighbors: Integer, the number of nearest neighbors to use.
+ gt_ids: Int tensor of shape [n_objs] of the sorted unique ground truth
+ ids in the first frame. If None, it will be derived from
+ reference_labels.
+ n_chunks: Integer, the number of chunks to use to save memory
+ (set to 1 for no chunking).
+
+ Returns:
+ nn_features: A float32 tensor of nearest neighbor features of shape
+ [n_query_images, height, width, n_objects, feature_dim].
+ gt_ids: An int32 tensor of the unique sorted object ids present
+ in the reference labels.
+ """
+ with tf.name_scope('nn_features_per_object'):
+ reference_labels_flat = tf.reshape(reference_labels, [-1])
+ if gt_ids is None:
+ ref_obj_ids, _ = tf.unique(reference_labels_flat)
+ ref_obj_ids = tf.contrib.framework.sort(ref_obj_ids)
+ gt_ids = ref_obj_ids
+ embedding_dim = resolve_shape(reference_embeddings)[-1]
+ reference_embeddings_flat = tf.reshape(reference_embeddings,
+ [-1, embedding_dim])
+
+ reference_embeddings_flat, reference_labels_flat = (
+ subsample_reference_embeddings_and_labels(reference_embeddings_flat,
+ reference_labels_flat,
+ gt_ids,
+ max_neighbors_per_object))
+ shape = resolve_shape(query_embeddings)
+ query_embeddings_flat = tf.reshape(query_embeddings, [-1, embedding_dim])
+ nn_features = _nearest_neighbor_features_per_object_in_chunks(
+ reference_embeddings_flat, query_embeddings_flat, reference_labels_flat,
+ gt_ids, k_nearest_neighbors, n_chunks)
+ nn_features_dim = resolve_shape(nn_features)[-1]
+ nn_features_reshaped = tf.reshape(nn_features,
+ tf.stack(shape[:3] + [tf.size(gt_ids),
+ nn_features_dim]))
+ return nn_features_reshaped, gt_ids
+
+
+def _nearest_neighbor_features_per_object_in_chunks(
+ reference_embeddings_flat, query_embeddings_flat, reference_labels_flat,
+ ref_obj_ids, k_nearest_neighbors, n_chunks):
+ """Calculates the nearest neighbor features per object in chunks to save mem.
+
+ Uses chunking to bound the memory use.
+
+ Args:
+ reference_embeddings_flat: Tensor of shape [n, embedding_dim],
+ the embedding vectors for the reference frame.
+ query_embeddings_flat: Tensor of shape [m, embedding_dim], the embedding
+ vectors for the query frames.
+ reference_labels_flat: Tensor of shape [n], the class labels of the
+ reference frame.
+ ref_obj_ids: int tensor of unique object ids in the reference labels.
+ k_nearest_neighbors: Integer, the number of nearest neighbors to use.
+ n_chunks: Integer, the number of chunks to use to save memory
+ (set to 1 for no chunking).
+
+ Returns:
+ nn_features: A float32 tensor of nearest neighbor features of shape
+ [m, n_objects, feature_dim].
+ """
+ chunk_size = tf.cast(tf.ceil(tf.cast(tf.shape(query_embeddings_flat)[0],
+ tf.float32) / n_chunks), tf.int32)
+ wrong_label_mask = tf.not_equal(reference_labels_flat,
+ ref_obj_ids[:, tf.newaxis])
+ all_features = []
+ for n in range(n_chunks):
+ if n_chunks == 1:
+ query_embeddings_flat_chunk = query_embeddings_flat
+ else:
+ chunk_start = n * chunk_size
+ chunk_end = (n + 1) * chunk_size
+ query_embeddings_flat_chunk = query_embeddings_flat[chunk_start:chunk_end]
+ # Use control dependencies to make sure that the chunks are not processed
+ # in parallel which would prevent any peak memory savings.
+ with tf.control_dependencies(all_features):
+ features = _nn_features_per_object_for_chunk(
+ reference_embeddings_flat, query_embeddings_flat_chunk,
+ wrong_label_mask, k_nearest_neighbors
+ )
+ all_features.append(features)
+ if n_chunks == 1:
+ nn_features = all_features[0]
+ else:
+ nn_features = tf.concat(all_features, axis=0)
+ return nn_features
+
+
+def _nn_features_per_object_for_chunk(
+ reference_embeddings, query_embeddings, wrong_label_mask,
+ k_nearest_neighbors):
+ """Extracts features for each object using nearest neighbor attention.
+
+ Args:
+ reference_embeddings: Tensor of shape [n_chunk, embedding_dim],
+ the embedding vectors for the reference frame.
+ query_embeddings: Tensor of shape [m_chunk, embedding_dim], the embedding
+ vectors for the query frames.
+ wrong_label_mask:
+ k_nearest_neighbors: Integer, the number of nearest neighbors to use.
+
+ Returns:
+ nn_features: A float32 tensor of nearest neighbor features of shape
+ [m_chunk, n_objects, feature_dim].
+ """
+ reference_embeddings_key = reference_embeddings
+ query_embeddings_key = query_embeddings
+ dists = flattened_pairwise_distances(reference_embeddings_key,
+ query_embeddings_key)
+ dists = (dists[:, tf.newaxis, :] +
+ tf.cast(wrong_label_mask[tf.newaxis, :, :], tf.float32) *
+ WRONG_LABEL_PADDING_DISTANCE)
+ if k_nearest_neighbors == 1:
+ features = tf.reduce_min(dists, axis=2, keepdims=True)
+ else:
+ # Find the closest k and combine them according to attention_feature_type
+ dists, _ = tf.nn.top_k(-dists, k=k_nearest_neighbors)
+ dists = -dists
+ # If not enough real neighbors were found, pad with the farthest real
+ # neighbor.
+ valid_mask = tf.less(dists, WRONG_LABEL_PADDING_DISTANCE)
+ masked_dists = dists * tf.cast(valid_mask, tf.float32)
+ pad_dist = tf.tile(tf.reduce_max(masked_dists, axis=2)[..., tf.newaxis],
+ multiples=[1, 1, k_nearest_neighbors])
+ dists = tf.where(valid_mask, dists, pad_dist)
+ # take mean of distances
+ features = tf.reduce_mean(dists, axis=2, keepdims=True)
+ return features
+
+
+def create_embedding_segmentation_features(features, feature_dimension,
+ n_layers, kernel_size, reuse,
+ atrous_rates=None):
+ """Extracts features which can be used to estimate the final segmentation.
+
+ Args:
+ features: input features of shape [batch, height, width, features]
+ feature_dimension: Integer, the dimensionality used in the segmentation
+ head layers.
+ n_layers: Integer, the number of layers in the segmentation head.
+ kernel_size: Integer, the kernel size used in the segmentation head.
+ reuse: reuse mode for the variable_scope.
+ atrous_rates: List of integers of length n_layers, the atrous rates to use.
+
+ Returns:
+ Features to be used to estimate the segmentation labels of shape
+ [batch, height, width, embedding_seg_feat_dim].
+ """
+ if atrous_rates is None or not atrous_rates:
+ atrous_rates = [1 for _ in range(n_layers)]
+ assert len(atrous_rates) == n_layers
+ with tf.variable_scope('embedding_seg', reuse=reuse):
+ for n in range(n_layers):
+ features = model.split_separable_conv2d(
+ features, feature_dimension, kernel_size=kernel_size,
+ rate=atrous_rates[n], scope='split_separable_conv2d_{}'.format(n))
+ return features
+
+
+def add_image_summaries(images, nn_features, logits, batch_size,
+ prev_frame_nn_features=None):
+ """Adds image summaries of input images, attention features and logits.
+
+ Args:
+ images: Image tensor of shape [batch, height, width, channels].
+ nn_features: Nearest neighbor attention features of shape
+ [batch_size, height, width, n_objects, 1].
+ logits: Float32 tensor of logits.
+ batch_size: Integer, the number of videos per clone per mini-batch.
+ prev_frame_nn_features: Nearest neighbor attention features wrt. the
+ last frame of shape [batch_size, height, width, n_objects, 1].
+ Can be None.
+ """
+ # Separate reference and query images.
+ reshaped_images = tf.reshape(images, tf.stack(
+ [batch_size, -1] + resolve_shape(images)[1:]))
+ reference_images = reshaped_images[:, 0]
+ query_images = reshaped_images[:, 1:]
+ query_images_reshaped = tf.reshape(query_images, tf.stack(
+ [-1] + resolve_shape(images)[1:]))
+ tf.summary.image('ref_images', reference_images, max_outputs=batch_size)
+ tf.summary.image('query_images', query_images_reshaped, max_outputs=10)
+ predictions = tf.cast(
+ tf.argmax(logits, axis=-1), tf.uint8)[..., tf.newaxis]
+ # Scale up so that we can actually see something.
+ tf.summary.image('predictions', predictions * 32, max_outputs=10)
+ # We currently only show the first dimension of the features for background
+ # and the first foreground object.
+ tf.summary.image('nn_fg_features', nn_features[..., 0:1, 0],
+ max_outputs=batch_size)
+ if prev_frame_nn_features is not None:
+ tf.summary.image('nn_fg_features_prev', prev_frame_nn_features[..., 0:1, 0],
+ max_outputs=batch_size)
+ tf.summary.image('nn_bg_features', nn_features[..., 1:2, 0],
+ max_outputs=batch_size)
+ if prev_frame_nn_features is not None:
+ tf.summary.image('nn_bg_features_prev',
+ prev_frame_nn_features[..., 1:2, 0],
+ max_outputs=batch_size)
+
+
+def get_embeddings(images, model_options, embedding_dimension):
+ """Extracts embedding vectors for images. Should only be used for inference.
+
+ Args:
+ images: A tensor of shape [batch, height, width, channels].
+ model_options: A ModelOptions instance to configure models.
+ embedding_dimension: Integer, the dimension of the embedding.
+
+ Returns:
+ embeddings: A tensor of shape [batch, height, width, embedding_dimension].
+ """
+ features, end_points = model.extract_features(
+ images,
+ model_options,
+ is_training=False)
+
+ if model_options.decoder_output_stride is not None:
+ decoder_output_stride = min(model_options.decoder_output_stride)
+ if model_options.crop_size is None:
+ height = tf.shape(images)[1]
+ width = tf.shape(images)[2]
+ else:
+ height, width = model_options.crop_size
+ features = model.refine_by_decoder(
+ features,
+ end_points,
+ crop_size=[height, width],
+ decoder_output_stride=[decoder_output_stride],
+ decoder_use_separable_conv=model_options.decoder_use_separable_conv,
+ model_variant=model_options.model_variant,
+ is_training=False)
+
+ with tf.variable_scope('embedding'):
+ embeddings = split_separable_conv2d_with_identity_initializer(
+ features, embedding_dimension, scope='split_separable_conv2d')
+ return embeddings
+
+
+def get_logits_with_matching(images,
+ model_options,
+ weight_decay=0.0001,
+ reuse=None,
+ is_training=False,
+ fine_tune_batch_norm=False,
+ reference_labels=None,
+ batch_size=None,
+ num_frames_per_video=None,
+ embedding_dimension=None,
+ max_neighbors_per_object=0,
+ k_nearest_neighbors=1,
+ use_softmax_feedback=True,
+ initial_softmax_feedback=None,
+ embedding_seg_feature_dimension=256,
+ embedding_seg_n_layers=4,
+ embedding_seg_kernel_size=7,
+ embedding_seg_atrous_rates=None,
+ normalize_nearest_neighbor_distances=True,
+ also_attend_to_previous_frame=True,
+ damage_initial_previous_frame_mask=False,
+ use_local_previous_frame_attention=True,
+ previous_frame_attention_window_size=15,
+ use_first_frame_matching=True,
+ also_return_embeddings=False,
+ ref_embeddings=None):
+ """Gets the logits by atrous/image spatial pyramid pooling using attention.
+
+ Args:
+ images: A tensor of size [batch, height, width, channels].
+ model_options: A ModelOptions instance to configure models.
+ weight_decay: The weight decay for model variables.
+ reuse: Reuse the model variables or not.
+ is_training: Is training or not.
+ fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
+ reference_labels: The segmentation labels of the reference frame on which
+ attention is applied.
+ batch_size: Integer, the number of videos on a batch
+ num_frames_per_video: Integer, the number of frames per video
+ embedding_dimension: Integer, the dimension of the embedding
+ max_neighbors_per_object: Integer, the maximum number of candidates
+ for the nearest neighbor query per object after subsampling.
+ Can be 0 for no subsampling.
+ k_nearest_neighbors: Integer, the number of nearest neighbors to use.
+ use_softmax_feedback: Boolean, whether to give the softmax predictions of
+ the last frame as additional input to the segmentation head.
+ initial_softmax_feedback: List of Float32 tensors, or None. Can be used to
+ initialize the softmax predictions used for the feedback loop.
+ Only has an effect if use_softmax_feedback is True.
+ embedding_seg_feature_dimension: Integer, the dimensionality used in the
+ segmentation head layers.
+ embedding_seg_n_layers: Integer, the number of layers in the segmentation
+ head.
+ embedding_seg_kernel_size: Integer, the kernel size used in the
+ segmentation head.
+ embedding_seg_atrous_rates: List of integers of length
+ embedding_seg_n_layers, the atrous rates to use for the segmentation head.
+ normalize_nearest_neighbor_distances: Boolean, whether to normalize the
+ nearest neighbor distances to [0,1] using sigmoid, scale and shift.
+ also_attend_to_previous_frame: Boolean, whether to also use nearest
+ neighbor attention with respect to the previous frame.
+ damage_initial_previous_frame_mask: Boolean, whether to artificially damage
+ the initial previous frame mask. Only has an effect if
+ also_attend_to_previous_frame is True.
+ use_local_previous_frame_attention: Boolean, whether to restrict the
+ previous frame attention to a local search window.
+ Only has an effect, if also_attend_to_previous_frame is True.
+ previous_frame_attention_window_size: Integer, the window size used for
+ local previous frame attention, if use_local_previous_frame_attention
+ is True.
+ use_first_frame_matching: Boolean, whether to extract features by matching
+ to the reference frame. This should always be true except for ablation
+ experiments.
+ also_return_embeddings: Boolean, whether to return the embeddings as well.
+ ref_embeddings: Tuple of
+ (first_frame_embeddings, previous_frame_embeddings),
+ each of shape [batch, height, width, embedding_dimension], or None.
+ Returns:
+ outputs_to_logits: A map from output_type to logits.
+ If also_return_embeddings is True, it will also return an embeddings
+ tensor of shape [batch, height, width, embedding_dimension].
+ """
+ features, end_points = model.extract_features(
+ images,
+ model_options,
+ weight_decay=weight_decay,
+ reuse=reuse,
+ is_training=is_training,
+ fine_tune_batch_norm=fine_tune_batch_norm)
+
+ if model_options.decoder_output_stride:
+ decoder_output_stride = min(model_options.decoder_output_stride)
+ if model_options.crop_size is None:
+ height = tf.shape(images)[1]
+ width = tf.shape(images)[2]
+ else:
+ height, width = model_options.crop_size
+ decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride)
+ decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride)
+ features = model.refine_by_decoder(
+ features,
+ end_points,
+ crop_size=[height, width],
+ decoder_output_stride=[decoder_output_stride],
+ decoder_use_separable_conv=model_options.decoder_use_separable_conv,
+ model_variant=model_options.model_variant,
+ weight_decay=weight_decay,
+ reuse=reuse,
+ is_training=is_training,
+ fine_tune_batch_norm=fine_tune_batch_norm)
+
+ with tf.variable_scope('embedding', reuse=reuse):
+ embeddings = split_separable_conv2d_with_identity_initializer(
+ features, embedding_dimension, scope='split_separable_conv2d')
+ embeddings = tf.identity(embeddings, name='embeddings')
+ scaled_reference_labels = tf.image.resize_nearest_neighbor(
+ reference_labels,
+ resolve_shape(embeddings, 4)[1:3],
+ align_corners=True)
+ h, w = decoder_height, decoder_width
+ if num_frames_per_video is None:
+ num_frames_per_video = tf.size(embeddings) // (
+ batch_size * h * w * embedding_dimension)
+ new_labels_shape = tf.stack([batch_size, -1, h, w, 1])
+ reshaped_reference_labels = tf.reshape(scaled_reference_labels,
+ new_labels_shape)
+ new_embeddings_shape = tf.stack([batch_size,
+ num_frames_per_video, h, w,
+ embedding_dimension])
+ reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
+ all_nn_features = []
+ all_ref_obj_ids = []
+ # To keep things simple, we do all this separate for each sequence for now.
+ for n in range(batch_size):
+ embedding = reshaped_embeddings[n]
+ if ref_embeddings is None:
+ n_chunks = 100
+ reference_embedding = embedding[0]
+ if also_attend_to_previous_frame or use_softmax_feedback:
+ queries_embedding = embedding[2:]
+ else:
+ queries_embedding = embedding[1:]
+ else:
+ if USE_CORRELATION_COST:
+ n_chunks = 20
+ else:
+ n_chunks = 500
+ reference_embedding = ref_embeddings[0][n]
+ queries_embedding = embedding
+ reference_labels = reshaped_reference_labels[n][0]
+ nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object(
+ reference_embedding, queries_embedding, reference_labels,
+ max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks)
+ if normalize_nearest_neighbor_distances:
+ nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2
+ all_nn_features.append(nn_features_n)
+ all_ref_obj_ids.append(ref_obj_ids)
+
+ feat_dim = resolve_shape(features)[-1]
+ features = tf.reshape(features, tf.stack(
+ [batch_size, num_frames_per_video, h, w, feat_dim]))
+ if ref_embeddings is None:
+ # Strip the features for the reference frame.
+ if also_attend_to_previous_frame or use_softmax_feedback:
+ features = features[:, 2:]
+ else:
+ features = features[:, 1:]
+
+ # To keep things simple, we do all this separate for each sequence for now.
+ outputs_to_logits = {output: [] for
+ output in model_options.outputs_to_num_classes}
+ for n in range(batch_size):
+ features_n = features[n]
+ nn_features_n = all_nn_features[n]
+ nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4])
+ n_objs = tf.shape(nn_features_n_tr)[0]
+ # Repeat features for every object.
+ features_n_tiled = tf.tile(features_n[tf.newaxis],
+ multiples=[n_objs, 1, 1, 1, 1])
+ prev_frame_labels = None
+ if also_attend_to_previous_frame:
+ prev_frame_labels = reshaped_reference_labels[n, 1]
+ if is_training and damage_initial_previous_frame_mask:
+ # Damage the previous frame masks.
+ prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels,
+ dilate=False)
+ tf.summary.image('prev_frame_labels',
+ tf.cast(prev_frame_labels[tf.newaxis],
+ tf.uint8) * 32)
+ initial_softmax_feedback_n = create_initial_softmax_from_labels(
+ prev_frame_labels, reshaped_reference_labels[n][0],
+ decoder_output_stride=None, reduce_labels=True)
+ elif initial_softmax_feedback is not None:
+ initial_softmax_feedback_n = initial_softmax_feedback[n]
+ else:
+ initial_softmax_feedback_n = None
+ if initial_softmax_feedback_n is None:
+ last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32)
+ else:
+ last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[
+ ..., tf.newaxis]
+ assert len(model_options.outputs_to_num_classes) == 1
+ output = model_options.outputs_to_num_classes.keys()[0]
+ logits = []
+ n_ref_frames = 1
+ prev_frame_nn_features_n = None
+ if also_attend_to_previous_frame or use_softmax_feedback:
+ n_ref_frames += 1
+ if ref_embeddings is not None:
+ n_ref_frames = 0
+ for t in range(num_frames_per_video - n_ref_frames):
+ to_concat = [features_n_tiled[:, t]]
+ if use_first_frame_matching:
+ to_concat.append(nn_features_n_tr[:, t])
+ if use_softmax_feedback:
+ to_concat.append(last_softmax)
+ if also_attend_to_previous_frame:
+ assert normalize_nearest_neighbor_distances, (
+ 'previous frame attention currently only works when normalized '
+ 'distances are used')
+ embedding = reshaped_embeddings[n]
+ if ref_embeddings is None:
+ last_frame_embedding = embedding[t + 1]
+ query_embeddings = embedding[t + 2, tf.newaxis]
+ else:
+ last_frame_embedding = ref_embeddings[1][0]
+ query_embeddings = embedding
+ if use_local_previous_frame_attention:
+ assert query_embeddings.shape[0] == 1
+ prev_frame_nn_features_n = (
+ local_previous_frame_nearest_neighbor_features_per_object(
+ last_frame_embedding,
+ query_embeddings[0],
+ prev_frame_labels,
+ all_ref_obj_ids[n],
+ max_distance=previous_frame_attention_window_size)
+ )
+ else:
+ prev_frame_nn_features_n, _ = (
+ nearest_neighbor_features_per_object(
+ last_frame_embedding, query_embeddings, prev_frame_labels,
+ max_neighbors_per_object, k_nearest_neighbors,
+ gt_ids=all_ref_obj_ids[n]))
+ prev_frame_nn_features_n = (tf.nn.sigmoid(
+ prev_frame_nn_features_n) - 0.5) * 2
+ prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n,
+ axis=0)
+ prev_frame_nn_features_n_tr = tf.transpose(
+ prev_frame_nn_features_n_sq, [2, 0, 1, 3])
+ to_concat.append(prev_frame_nn_features_n_tr)
+ features_n_concat_t = tf.concat(to_concat, axis=-1)
+ embedding_seg_features_n_t = (
+ create_embedding_segmentation_features(
+ features_n_concat_t, embedding_seg_feature_dimension,
+ embedding_seg_n_layers, embedding_seg_kernel_size,
+ reuse or n > 0, atrous_rates=embedding_seg_atrous_rates))
+ logits_t = model.get_branch_logits(
+ embedding_seg_features_n_t,
+ 1,
+ model_options.atrous_rates,
+ aspp_with_batch_norm=model_options.aspp_with_batch_norm,
+ kernel_size=model_options.logits_kernel_size,
+ weight_decay=weight_decay,
+ reuse=reuse or n > 0 or t > 0,
+ scope_suffix=output
+ )
+ logits.append(logits_t)
+ prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0),
+ [2, 0, 1])
+ last_softmax = tf.nn.softmax(logits_t, axis=0)
+ logits = tf.stack(logits, axis=1)
+ logits_shape = tf.stack(
+ [n_objs, num_frames_per_video - n_ref_frames] +
+ resolve_shape(logits)[2:-1])
+ logits_reshaped = tf.reshape(logits, logits_shape)
+ logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0])
+ outputs_to_logits[output].append(logits_transposed)
+
+ add_image_summaries(
+ images[n * num_frames_per_video: (n+1) * num_frames_per_video],
+ nn_features_n,
+ logits_transposed,
+ batch_size=1,
+ prev_frame_nn_features=prev_frame_nn_features_n)
+ if also_return_embeddings:
+ return outputs_to_logits, embeddings
+ else:
+ return outputs_to_logits
+
+
+def subsample_reference_embeddings_and_labels(
+ reference_embeddings_flat, reference_labels_flat, ref_obj_ids,
+ max_neighbors_per_object):
+ """Subsamples the reference embedding vectors and labels.
+
+ After subsampling, at most max_neighbors_per_object items will remain per
+ class.
+
+ Args:
+ reference_embeddings_flat: Tensor of shape [n, embedding_dim],
+ the embedding vectors for the reference frame.
+ reference_labels_flat: Tensor of shape [n, 1],
+ the class labels of the reference frame.
+ ref_obj_ids: An int32 tensor of the unique object ids present
+ in the reference labels.
+ max_neighbors_per_object: Integer, the maximum number of candidates
+ for the nearest neighbor query per object after subsampling,
+ or 0 for no subsampling.
+
+ Returns:
+ reference_embeddings_flat: Tensor of shape [n_sub, embedding_dim],
+ the subsampled embedding vectors for the reference frame.
+ reference_labels_flat: Tensor of shape [n_sub, 1],
+ the class labels of the reference frame.
+ """
+ if max_neighbors_per_object == 0:
+ return reference_embeddings_flat, reference_labels_flat
+ same_label_mask = tf.equal(reference_labels_flat[tf.newaxis, :],
+ ref_obj_ids[:, tf.newaxis])
+ max_neighbors_per_object_repeated = tf.tile(
+ tf.constant(max_neighbors_per_object)[tf.newaxis],
+ multiples=[tf.size(ref_obj_ids)])
+ # Somehow map_fn on GPU caused trouble sometimes, so let's put it on CPU
+ # for now.
+ with tf.device('cpu:0'):
+ subsampled_indices = tf.map_fn(_create_subsampling_mask,
+ (same_label_mask,
+ max_neighbors_per_object_repeated),
+ dtype=tf.int64,
+ name='subsample_labels_map_fn',
+ parallel_iterations=1)
+ mask = tf.not_equal(subsampled_indices, tf.constant(-1, dtype=tf.int64))
+ masked_indices = tf.boolean_mask(subsampled_indices, mask)
+ reference_embeddings_flat = tf.gather(reference_embeddings_flat,
+ masked_indices)
+ reference_labels_flat = tf.gather(reference_labels_flat, masked_indices)
+ return reference_embeddings_flat, reference_labels_flat
+
+
+def _create_subsampling_mask(args):
+ """Creates boolean mask which can be used to subsample the labels.
+
+ Args:
+ args: tuple of (label_mask, max_neighbors_per_object), where label_mask
+ is the mask to be subsampled and max_neighbors_per_object is a int scalar,
+ the maximum number of neighbors to be retained after subsampling.
+
+ Returns:
+ The boolean mask for subsampling the labels.
+ """
+ label_mask, max_neighbors_per_object = args
+ indices = tf.squeeze(tf.where(label_mask), axis=1)
+ shuffled_indices = tf.random_shuffle(indices)
+ subsampled_indices = shuffled_indices[:max_neighbors_per_object]
+ n_pad = max_neighbors_per_object - tf.size(subsampled_indices)
+ padded_label = -1
+ padding = tf.fill((n_pad,), tf.constant(padded_label, dtype=tf.int64))
+ padded = tf.concat([subsampled_indices, padding], axis=0)
+ return padded
+
+
+def conv2d_identity_initializer(scale=1.0, mean=0, stddev=3e-2):
+ """Creates an identity initializer for TensorFlow conv2d.
+
+ We add a small amount of normal noise to the initialization matrix.
+ Code copied from lcchen@.
+
+ Args:
+ scale: The scale coefficient for the identity weight matrix.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
+ truncated normal distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the truncated normal distribution.
+
+ Returns:
+ An identity initializer function for TensorFlow conv2d.
+ """
+ def _initializer(shape, dtype=tf.float32, partition_info=None):
+ """Returns the identity matrix scaled by `scale`.
+
+ Args:
+ shape: A tuple of int32 numbers indicating the shape of the initializing
+ matrix.
+ dtype: The data type of the initializing matrix.
+ partition_info: (Optional) variable_scope._PartitionInfo object holding
+ additional information about how the variable is partitioned. This input
+ is not used in our case, but is required by TensorFlow.
+
+ Returns:
+ A identity matrix.
+
+ Raises:
+ ValueError: If len(shape) != 4, or shape[0] != shape[1], or shape[0] is
+ not odd, or shape[1] is not odd..
+ """
+ del partition_info
+ if len(shape) != 4:
+ raise ValueError('Expect shape length to be 4.')
+ if shape[0] != shape[1]:
+ raise ValueError('Expect shape[0] = shape[1].')
+ if shape[0] % 2 != 1:
+ raise ValueError('Expect shape[0] to be odd value.')
+ if shape[1] % 2 != 1:
+ raise ValueError('Expect shape[1] to be odd value.')
+ weights = np.zeros(shape, dtype=np.float32)
+ center_y = shape[0] / 2
+ center_x = shape[1] / 2
+ min_channel = min(shape[2], shape[3])
+ for i in range(min_channel):
+ weights[center_y, center_x, i, i] = scale
+ return tf.constant(weights, dtype=dtype) + tf.truncated_normal(
+ shape, mean=mean, stddev=stddev, dtype=dtype)
+
+ return _initializer
+
+
+def split_separable_conv2d_with_identity_initializer(
+ inputs,
+ filters,
+ kernel_size=3,
+ rate=1,
+ weight_decay=0.00004,
+ scope=None):
+ """Splits a separable conv2d into depthwise and pointwise conv2d.
+
+ This operation differs from `tf.layers.separable_conv2d` as this operation
+ applies activation function between depthwise and pointwise conv2d.
+
+ Args:
+ inputs: Input tensor with shape [batch, height, width, channels].
+ filters: Number of filters in the 1x1 pointwise convolution.
+ kernel_size: A list of length 2: [kernel_height, kernel_width] of
+ of the filters. Can be an int if both values are the same.
+ rate: Atrous convolution rate for the depthwise convolution.
+ weight_decay: The weight decay to use for regularizing the model.
+ scope: Optional scope for the operation.
+
+ Returns:
+ Computed features after split separable conv2d.
+ """
+ initializer = conv2d_identity_initializer()
+ outputs = slim.separable_conv2d(
+ inputs,
+ None,
+ kernel_size=kernel_size,
+ depth_multiplier=1,
+ rate=rate,
+ weights_initializer=initializer,
+ weights_regularizer=None,
+ scope=scope + '_depthwise')
+ return slim.conv2d(
+ outputs,
+ filters,
+ 1,
+ weights_initializer=initializer,
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ scope=scope + '_pointwise')
+
+
+def create_initial_softmax_from_labels(last_frame_labels, reference_labels,
+ decoder_output_stride, reduce_labels):
+ """Creates initial softmax predictions from last frame labels.
+
+ Args:
+ last_frame_labels: last frame labels of shape [1, height, width, 1].
+ reference_labels: reference frame labels of shape [1, height, width, 1].
+ decoder_output_stride: Integer, the stride of the decoder. Can be None, in
+ this case it's assumed that the last_frame_labels and reference_labels
+ are already scaled to the decoder output resolution.
+ reduce_labels: Boolean, whether to reduce the depth of the softmax one_hot
+ encoding to the actual number of labels present in the reference frame
+ (otherwise the depth will be the highest label index + 1).
+
+ Returns:
+ init_softmax: the initial softmax predictions.
+ """
+ if decoder_output_stride is None:
+ labels_output_size = last_frame_labels
+ reference_labels_output_size = reference_labels
+ else:
+ h = tf.shape(last_frame_labels)[1]
+ w = tf.shape(last_frame_labels)[2]
+ h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
+ w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
+ labels_output_size = tf.image.resize_nearest_neighbor(
+ last_frame_labels, [h_sub, w_sub], align_corners=True)
+ reference_labels_output_size = tf.image.resize_nearest_neighbor(
+ reference_labels, [h_sub, w_sub], align_corners=True)
+ if reduce_labels:
+ unique_labels, _ = tf.unique(tf.reshape(reference_labels_output_size, [-1]))
+ depth = tf.size(unique_labels)
+ else:
+ depth = tf.reduce_max(reference_labels_output_size) + 1
+ one_hot_assertion = tf.assert_less(tf.reduce_max(labels_output_size), depth)
+ with tf.control_dependencies([one_hot_assertion]):
+ init_softmax = tf.one_hot(tf.squeeze(labels_output_size,
+ axis=-1),
+ depth=depth,
+ dtype=tf.float32)
+ return init_softmax
+
+
+def local_previous_frame_nearest_neighbor_features_per_object(
+ prev_frame_embedding, query_embedding, prev_frame_labels,
+ gt_ids, max_distance=9):
+ """Computes nearest neighbor features while only allowing local matches.
+
+ Args:
+ prev_frame_embedding: Tensor of shape [height, width, embedding_dim],
+ the embedding vectors for the last frame.
+ query_embedding: Tensor of shape [height, width, embedding_dim],
+ the embedding vectors for the query frames.
+ prev_frame_labels: Tensor of shape [height, width, 1], the class labels of
+ the previous frame.
+ gt_ids: Int Tensor of shape [n_objs] of the sorted unique ground truth
+ ids in the first frame.
+ max_distance: Integer, the maximum distance allowed for local matching.
+
+ Returns:
+ nn_features: A float32 np.array of nearest neighbor features of shape
+ [1, height, width, n_objects, 1].
+ """
+ with tf.name_scope(
+ 'local_previous_frame_nearest_neighbor_features_per_object'):
+ if USE_CORRELATION_COST:
+ tf.logging.info('Using correlation_cost.')
+ d = local_pairwise_distances(query_embedding, prev_frame_embedding,
+ max_distance=max_distance)
+ else:
+ # Slow fallback in case correlation_cost is not available.
+ tf.logging.warn('correlation cost is not available, using slow fallback '
+ 'implementation.')
+ d = local_pairwise_distances2(query_embedding, prev_frame_embedding,
+ max_distance=max_distance)
+ d = (tf.nn.sigmoid(d) - 0.5) * 2
+ height = tf.shape(prev_frame_embedding)[0]
+ width = tf.shape(prev_frame_embedding)[1]
+
+ # Create offset versions of the mask.
+ if USE_CORRELATION_COST:
+ # New, faster code with cross-correlation via correlation_cost.
+ # Due to padding we have to add 1 to the labels.
+ offset_labels = correlation_cost_op.correlation_cost(
+ tf.ones((1, height, width, 1)),
+ tf.cast(prev_frame_labels + 1, tf.float32)[tf.newaxis],
+ kernel_size=1,
+ max_displacement=max_distance, stride_1=1, stride_2=1,
+ pad=max_distance)
+ offset_labels = tf.squeeze(offset_labels, axis=0)[..., tf.newaxis]
+ # Subtract the 1 again and round.
+ offset_labels = tf.round(offset_labels - 1)
+ offset_masks = tf.equal(
+ offset_labels,
+ tf.cast(gt_ids, tf.float32)[tf.newaxis, tf.newaxis, tf.newaxis, :])
+ else:
+ # Slower code, without dependency to correlation_cost
+ masks = tf.equal(prev_frame_labels, gt_ids[tf.newaxis, tf.newaxis])
+ padded_masks = tf.pad(masks,
+ [[max_distance, max_distance],
+ [max_distance, max_distance],
+ [0, 0]])
+ offset_masks = []
+ for y_start in range(2 * max_distance + 1):
+ y_end = y_start + height
+ masks_slice = padded_masks[y_start:y_end]
+ for x_start in range(2 * max_distance + 1):
+ x_end = x_start + width
+ offset_mask = masks_slice[:, x_start:x_end]
+ offset_masks.append(offset_mask)
+ offset_masks = tf.stack(offset_masks, axis=2)
+
+ pad = tf.ones((height, width, (2 * max_distance + 1) ** 2, tf.size(gt_ids)))
+ d_tiled = tf.tile(d[..., tf.newaxis], multiples=(1, 1, 1, tf.size(gt_ids)))
+ d_masked = tf.where(offset_masks, d_tiled, pad)
+ dists = tf.reduce_min(d_masked, axis=2)
+ dists = tf.reshape(dists, (1, height, width, tf.size(gt_ids), 1))
+ return dists
diff --git a/models/research/feelvos/utils/embedding_utils_test.py b/models/research/feelvos/utils/embedding_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddebd7b4e7fcc9402887ebf59d247fea815d6cda
--- /dev/null
+++ b/models/research/feelvos/utils/embedding_utils_test.py
@@ -0,0 +1,213 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for embedding utils."""
+
+import unittest
+import numpy as np
+import tensorflow as tf
+from feelvos.utils import embedding_utils
+
+if embedding_utils.USE_CORRELATION_COST:
+ # pylint: disable=g-import-not-at-top
+ from correlation_cost.python.ops import correlation_cost_op
+
+
+class EmbeddingUtilsTest(tf.test.TestCase):
+
+ def test_pairwise_distances(self):
+ x = np.arange(100, dtype=np.float32).reshape(20, 5)
+ y = np.arange(100, 200, dtype=np.float32).reshape(20, 5)
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g) as sess:
+ x = tf.constant(x)
+ y = tf.constant(y)
+ d1 = embedding_utils.pairwise_distances(x, y)
+ d2 = embedding_utils.pairwise_distances2(x, y)
+ d1_val, d2_val = sess.run([d1, d2])
+ self.assertAllClose(d1_val, d2_val)
+
+ @unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
+ 'depends on correlation_cost')
+ def test_correlation_cost_one_dimensional(self):
+ a = np.array([[[[1.0], [2.0]], [[3.0], [4.0]]]])
+ b = np.array([[[[2.0], [1.0]], [[4.0], [3.0]]]])
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g) as sess:
+ c = correlation_cost_op.correlation_cost(
+ a, b, kernel_size=1, max_displacement=1, stride_1=1, stride_2=1,
+ pad=1)
+ c = tf.squeeze(c, axis=0)
+ c_val = sess.run(c)
+ self.assertAllEqual(c_val.shape, (2, 2, 9))
+ for y in range(2):
+ for x in range(2):
+ for dy in range(-1, 2):
+ for dx in range(-1, 2):
+ a_slice = a[0, y, x, 0]
+ if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
+ b_slice = 0
+ else:
+ b_slice = b[0, y + dy, x + dx, 0]
+ expected = a_slice * b_slice
+ dy0 = dy + 1
+ dx0 = dx + 1
+ self.assertAlmostEqual(c_val[y, x, 3 * dy0 + dx0], expected)
+
+ @unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
+ 'depends on correlation_cost')
+ def test_correlation_cost_two_dimensional(self):
+ a = np.array([[[[1.0, -5.0], [7.0, 2.0]], [[1.0, 3.0], [3.0, 4.0]]]])
+ b = np.array([[[[2.0, 1.0], [0.0, -9.0]], [[4.0, 3.0], [3.0, 1.0]]]])
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g) as sess:
+ c = correlation_cost_op.correlation_cost(
+ a, b, kernel_size=1, max_displacement=1, stride_1=1, stride_2=1,
+ pad=1)
+ c = tf.squeeze(c, axis=0)
+ c_val = sess.run(c)
+ self.assertAllEqual(c_val.shape, (2, 2, 9))
+ for y in range(2):
+ for x in range(2):
+ for dy in range(-1, 2):
+ for dx in range(-1, 2):
+ a_slice = a[0, y, x, :]
+ if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
+ b_slice = 0
+ else:
+ b_slice = b[0, y + dy, x + dx, :]
+ expected = (a_slice * b_slice).mean()
+ dy0 = dy + 1
+ dx0 = dx + 1
+ self.assertAlmostEqual(c_val[y, x, 3 * dy0 + dx0], expected)
+
+ @unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
+ 'depends on correlation_cost')
+ def test_local_pairwise_distances_one_dimensional(self):
+ a = np.array([[[1.0], [2.0]], [[3.0], [4.0]]])
+ b = np.array([[[2.0], [1.0]], [[4.0], [3.0]]])
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g) as sess:
+ a_tf = tf.constant(a, dtype=tf.float32)
+ b_tf = tf.constant(b, dtype=tf.float32)
+ d = embedding_utils.local_pairwise_distances(a_tf, b_tf,
+ max_distance=1)
+ d_val = sess.run(d)
+ for y in range(2):
+ for x in range(2):
+ for dy in range(-1, 2):
+ for dx in range(-1, 2):
+ a_slice = a[y, x, 0]
+ if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
+ expected = np.float('inf')
+ else:
+ b_slice = b[y + dy, x + dx, 0]
+ expected = (a_slice - b_slice) ** 2
+ dy0 = dy + 1
+ dx0 = dx + 1
+ self.assertAlmostEqual(d_val[y, x, 3 * dy0 + dx0], expected)
+
+ @unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
+ 'depends on correlation_cost')
+ def test_local_pairwise_distances_shape(self):
+ a = np.zeros((4, 5, 2))
+ b = np.zeros((4, 5, 2))
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g) as sess:
+ a_tf = tf.constant(a, dtype=tf.float32)
+ b_tf = tf.constant(b, dtype=tf.float32)
+ d = embedding_utils.local_pairwise_distances(a_tf, b_tf, max_distance=4)
+ d_val = sess.run(d)
+ self.assertAllEqual(d_val.shape, (4, 5, 81))
+
+ @unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
+ 'depends on correlation_cost')
+ def test_local_pairwise_distances_two_dimensional(self):
+ a = np.array([[[1.0, -5.0], [7.0, 2.0]], [[1.0, 3.0], [3.0, 4.0]]])
+ b = np.array([[[2.0, 1.0], [0.0, -9.0]], [[4.0, 3.0], [3.0, 1.0]]])
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g) as sess:
+ a_tf = tf.constant(a, dtype=tf.float32)
+ b_tf = tf.constant(b, dtype=tf.float32)
+ d = embedding_utils.local_pairwise_distances(a_tf, b_tf,
+ max_distance=1)
+ d_val = sess.run(d)
+ for y in range(2):
+ for x in range(2):
+ for dy in range(-1, 2):
+ for dx in range(-1, 2):
+ a_slice = a[y, x, :]
+ if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
+ expected = np.float('inf')
+ else:
+ b_slice = b[y + dy, x + dx, :]
+ expected = ((a_slice - b_slice) ** 2).sum()
+ dy0 = dy + 1
+ dx0 = dx + 1
+ self.assertAlmostEqual(d_val[y, x, 3 * dy0 + dx0], expected)
+
+ @unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
+ 'depends on correlation_cost')
+ def test_local_previous_frame_nearest_neighbor_features_per_object(self):
+ prev_frame_embedding = np.array([[[1.0, -5.0], [7.0, 2.0]],
+ [[1.0, 3.0], [3.0, 4.0]]]) / 10
+ query_embedding = np.array([[[2.0, 1.0], [0.0, -9.0]],
+ [[4.0, 3.0], [3.0, 1.0]]]) / 10
+ prev_frame_labels = np.array([[[0], [1]], [[1], [0]]])
+ gt_ids = np.array([0, 1])
+ g = tf.Graph()
+ with g.as_default():
+ with self.test_session(graph=g) as sess:
+ prev_frame_embedding_tf = tf.constant(prev_frame_embedding,
+ dtype=tf.float32)
+ query_embedding_tf = tf.constant(query_embedding, dtype=tf.float32)
+ embu = embedding_utils
+ dists = (
+ embu.local_previous_frame_nearest_neighbor_features_per_object(
+ prev_frame_embedding_tf, query_embedding_tf,
+ prev_frame_labels, gt_ids, max_distance=1))
+ dists = tf.squeeze(dists, axis=4)
+ dists = tf.squeeze(dists, axis=0)
+ dists_val = sess.run(dists)
+ for obj_id in gt_ids:
+ for y in range(2):
+ for x in range(2):
+ curr_min = 1.0
+ for dy in range(-1, 2):
+ for dx in range(-1, 2):
+ # Attention: here we shift the prev frame embedding,
+ # not the query.
+ if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
+ continue
+ if prev_frame_labels[y + dy, x + dx, 0] != obj_id:
+ continue
+ prev_frame_slice = prev_frame_embedding[y + dy, x + dx, :]
+ query_frame_slice = query_embedding[y, x, :]
+ v_unnorm = ((prev_frame_slice - query_frame_slice) ** 2).sum()
+ v = ((1.0 / (1.0 + np.exp(-v_unnorm))) - 0.5) * 2
+ curr_min = min(curr_min, v)
+ expected = curr_min
+ self.assertAlmostEqual(dists_val[y, x, obj_id], expected,
+ places=5)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/feelvos/utils/eval_utils.py b/models/research/feelvos/utils/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..517ec0d788eb3a6ec48246e10920dd4b55332bf5
--- /dev/null
+++ b/models/research/feelvos/utils/eval_utils.py
@@ -0,0 +1,153 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utility functions for evaluations."""
+
+import numpy as np
+import PIL
+import tensorflow as tf
+
+pascal_colormap = [
+ 0, 0, 0,
+ 0.5020, 0, 0,
+ 0, 0.5020, 0,
+ 0.5020, 0.5020, 0,
+ 0, 0, 0.5020,
+ 0.5020, 0, 0.5020,
+ 0, 0.5020, 0.5020,
+ 0.5020, 0.5020, 0.5020,
+ 0.2510, 0, 0,
+ 0.7529, 0, 0,
+ 0.2510, 0.5020, 0,
+ 0.7529, 0.5020, 0,
+ 0.2510, 0, 0.5020,
+ 0.7529, 0, 0.5020,
+ 0.2510, 0.5020, 0.5020,
+ 0.7529, 0.5020, 0.5020,
+ 0, 0.2510, 0,
+ 0.5020, 0.2510, 0,
+ 0, 0.7529, 0,
+ 0.5020, 0.7529, 0,
+ 0, 0.2510, 0.5020,
+ 0.5020, 0.2510, 0.5020,
+ 0, 0.7529, 0.5020,
+ 0.5020, 0.7529, 0.5020,
+ 0.2510, 0.2510, 0]
+
+
+def save_segmentation_with_colormap(filename, img):
+ """Saves a segmentation with the pascal colormap as expected for DAVIS eval.
+
+ Args:
+ filename: Where to store the segmentation.
+ img: A numpy array of the segmentation to be saved.
+ """
+ if img.shape[-1] == 1:
+ img = img[..., 0]
+
+ # Save with colormap.
+ colormap = (np.array(pascal_colormap) * 255).round().astype('uint8')
+ colormap_image = PIL.Image.new('P', (16, 16))
+ colormap_image.putpalette(colormap)
+ pil_image = PIL.Image.fromarray(img.astype('uint8'))
+ pil_image_with_colormap = pil_image.quantize(palette=colormap_image)
+ with tf.gfile.GFile(filename, 'w') as f:
+ pil_image_with_colormap.save(f)
+
+
+def save_embeddings(filename, embeddings):
+ with tf.gfile.GFile(filename, 'w') as f:
+ np.save(f, embeddings)
+
+
+def calculate_iou(pred_labels, ref_labels):
+ """Calculates the intersection over union for binary segmentation.
+
+ Args:
+ pred_labels: predicted segmentation labels.
+ ref_labels: reference segmentation labels.
+
+ Returns:
+ The IoU between pred_labels and ref_labels
+ """
+ if ref_labels.any():
+ i = np.logical_and(pred_labels, ref_labels).sum()
+ u = np.logical_or(pred_labels, ref_labels).sum()
+ return i.astype('float') / u
+ else:
+ if pred_labels.any():
+ return 0.0
+ else:
+ return 1.0
+
+
+def calculate_multi_object_miou_tf(pred_labels, ref_labels):
+ """Calculates the mIoU for a batch of predicted and reference labels.
+
+ Args:
+ pred_labels: Int32 tensor of shape [batch, height, width, 1].
+ ref_labels: Int32 tensor of shape [batch, height, width, 1].
+
+ Returns:
+ The mIoU between pred_labels and ref_labels as float32 scalar tensor.
+ """
+
+ def calculate_multi_object_miou(pred_labels_, ref_labels_):
+ """Calculates the mIoU for predicted and reference labels in numpy.
+
+ Args:
+ pred_labels_: int32 np.array of shape [batch, height, width, 1].
+ ref_labels_: int32 np.array of shape [batch, height, width, 1].
+
+ Returns:
+ The mIoU between pred_labels_ and ref_labels_.
+ """
+ assert len(pred_labels_.shape) == 4
+ assert pred_labels_.shape[3] == 1
+ assert pred_labels_.shape == ref_labels_.shape
+ ious = []
+ for pred_label, ref_label in zip(pred_labels_, ref_labels_):
+ ids = np.setdiff1d(np.unique(ref_label), [0])
+ if ids.size == 0:
+ continue
+ for id_ in ids:
+ iou = calculate_iou(pred_label == id_, ref_label == id_)
+ ious.append(iou)
+ if ious:
+ return np.cast['float32'](np.mean(ious))
+ else:
+ return np.cast['float32'](1.0)
+
+ miou = tf.py_func(calculate_multi_object_miou, [pred_labels, ref_labels],
+ tf.float32, name='calculate_multi_object_miou')
+ miou.set_shape(())
+ return miou
+
+
+def calculate_multi_object_ious(pred_labels, ref_labels, label_set):
+ """Calculates the intersection over union for binary segmentation.
+
+ Args:
+ pred_labels: predicted segmentation labels.
+ ref_labels: reference segmentation labels.
+ label_set: int np.array of object ids.
+
+ Returns:
+ float np.array of IoUs between pred_labels and ref_labels
+ for each object in label_set.
+ """
+ # Background should not be included as object label.
+ return np.array([calculate_iou(pred_labels == label, ref_labels == label)
+ for label in label_set if label != 0])
diff --git a/models/research/feelvos/utils/mask_damaging.py b/models/research/feelvos/utils/mask_damaging.py
new file mode 100644
index 0000000000000000000000000000000000000000..74f3cdab5a0e4374f0cd238544a9a582fd0eef92
--- /dev/null
+++ b/models/research/feelvos/utils/mask_damaging.py
@@ -0,0 +1,176 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities for artificially damaging segmentation masks."""
+
+import numpy as np
+from scipy.ndimage import interpolation
+from skimage import morphology
+from skimage import transform
+import tensorflow as tf
+
+
+def damage_masks(labels, shift=True, scale=True, rotate=True, dilate=True):
+ """Damages segmentation masks by random transformations.
+
+ Args:
+ labels: Int32 labels tensor of shape (height, width, 1).
+ shift: Boolean, whether to damage the masks by shifting.
+ scale: Boolean, whether to damage the masks by scaling.
+ rotate: Boolean, whether to damage the masks by rotation.
+ dilate: Boolean, whether to damage the masks by dilation.
+
+ Returns:
+ The damaged version of labels.
+ """
+ def _damage_masks_np(labels_):
+ return damage_masks_np(labels_, shift, scale, rotate, dilate)
+ damaged_masks = tf.py_func(_damage_masks_np, [labels], tf.int32,
+ name='damage_masks')
+ damaged_masks.set_shape(labels.get_shape())
+ return damaged_masks
+
+
+def damage_masks_np(labels, shift=True, scale=True, rotate=True, dilate=True):
+ """Performs the actual mask damaging in numpy.
+
+ Args:
+ labels: Int32 numpy array of shape (height, width, 1).
+ shift: Boolean, whether to damage the masks by shifting.
+ scale: Boolean, whether to damage the masks by scaling.
+ rotate: Boolean, whether to damage the masks by rotation.
+ dilate: Boolean, whether to damage the masks by dilation.
+
+ Returns:
+ The damaged version of labels.
+ """
+ unique_labels = np.unique(labels)
+ unique_labels = np.setdiff1d(unique_labels, [0])
+ # Shuffle to get random depth ordering when combining together.
+ np.random.shuffle(unique_labels)
+ damaged_labels = np.zeros_like(labels)
+ for l in unique_labels:
+ obj_mask = (labels == l)
+ damaged_obj_mask = _damage_single_object_mask(obj_mask, shift, scale,
+ rotate, dilate)
+ damaged_labels[damaged_obj_mask] = l
+ return damaged_labels
+
+
+def _damage_single_object_mask(mask, shift, scale, rotate, dilate):
+ """Performs mask damaging in numpy for a single object.
+
+ Args:
+ mask: Boolean numpy array of shape(height, width, 1).
+ shift: Boolean, whether to damage the masks by shifting.
+ scale: Boolean, whether to damage the masks by scaling.
+ rotate: Boolean, whether to damage the masks by rotation.
+ dilate: Boolean, whether to damage the masks by dilation.
+
+ Returns:
+ The damaged version of mask.
+ """
+ # For now we just do shifting and scaling. Better would be Affine or thin
+ # spline plate transformations.
+ if shift:
+ mask = _shift_mask(mask)
+ if scale:
+ mask = _scale_mask(mask)
+ if rotate:
+ mask = _rotate_mask(mask)
+ if dilate:
+ mask = _dilate_mask(mask)
+ return mask
+
+
+def _shift_mask(mask, max_shift_factor=0.05):
+ """Damages a mask for a single object by randomly shifting it in numpy.
+
+ Args:
+ mask: Boolean numpy array of shape(height, width, 1).
+ max_shift_factor: Float scalar, the maximum factor for random shifting.
+
+ Returns:
+ The shifted version of mask.
+ """
+ nzy, nzx, _ = mask.nonzero()
+ h = nzy.max() - nzy.min()
+ w = nzx.max() - nzx.min()
+ size = np.sqrt(h * w)
+ offset = np.random.uniform(-size * max_shift_factor, size * max_shift_factor,
+ 2)
+ shifted_mask = interpolation.shift(np.squeeze(mask, axis=2),
+ offset, order=0).astype('bool')[...,
+ np.newaxis]
+ return shifted_mask
+
+
+def _scale_mask(mask, scale_amount=0.025):
+ """Damages a mask for a single object by randomly scaling it in numpy.
+
+ Args:
+ mask: Boolean numpy array of shape(height, width, 1).
+ scale_amount: Float scalar, the maximum factor for random scaling.
+
+ Returns:
+ The scaled version of mask.
+ """
+ nzy, nzx, _ = mask.nonzero()
+ cy = 0.5 * (nzy.max() - nzy.min())
+ cx = 0.5 * (nzx.max() - nzx.min())
+ scale_factor = np.random.uniform(1.0 - scale_amount, 1.0 + scale_amount)
+ shift = transform.SimilarityTransform(translation=[-cx, -cy])
+ inv_shift = transform.SimilarityTransform(translation=[cx, cy])
+ s = transform.SimilarityTransform(scale=[scale_factor, scale_factor])
+ m = (shift + (s + inv_shift)).inverse
+ scaled_mask = transform.warp(mask, m) > 0.5
+ return scaled_mask
+
+
+def _rotate_mask(mask, max_rot_degrees=3.0):
+ """Damages a mask for a single object by randomly rotating it in numpy.
+
+ Args:
+ mask: Boolean numpy array of shape(height, width, 1).
+ max_rot_degrees: Float scalar, the maximum number of degrees to rotate.
+
+ Returns:
+ The scaled version of mask.
+ """
+ cy = 0.5 * mask.shape[0]
+ cx = 0.5 * mask.shape[1]
+ rot_degrees = np.random.uniform(-max_rot_degrees, max_rot_degrees)
+ shift = transform.SimilarityTransform(translation=[-cx, -cy])
+ inv_shift = transform.SimilarityTransform(translation=[cx, cy])
+ r = transform.SimilarityTransform(rotation=np.deg2rad(rot_degrees))
+ m = (shift + (r + inv_shift)).inverse
+ scaled_mask = transform.warp(mask, m) > 0.5
+ return scaled_mask
+
+
+def _dilate_mask(mask, dilation_radius=5):
+ """Damages a mask for a single object by dilating it in numpy.
+
+ Args:
+ mask: Boolean numpy array of shape(height, width, 1).
+ dilation_radius: Integer, the radius of the used disk structure element.
+
+ Returns:
+ The dilated version of mask.
+ """
+ disk = morphology.disk(dilation_radius, dtype=np.bool)
+ dilated_mask = morphology.binary_dilation(
+ np.squeeze(mask, axis=2), selem=disk)[..., np.newaxis]
+ return dilated_mask
diff --git a/models/research/feelvos/utils/train_utils.py b/models/research/feelvos/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a04cd33645931c5c795bef8559c0d3f5c4c23c
--- /dev/null
+++ b/models/research/feelvos/utils/train_utils.py
@@ -0,0 +1,269 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utility functions for training."""
+import collections
+import six
+import tensorflow as tf
+
+from deeplab.core import preprocess_utils
+from deeplab.utils import train_utils
+from feelvos.utils import embedding_utils
+from feelvos.utils import eval_utils
+
+slim = tf.contrib.slim
+add_softmax_cross_entropy_loss_for_each_scale = (
+ train_utils.add_softmax_cross_entropy_loss_for_each_scale)
+get_model_gradient_multipliers = train_utils.get_model_gradient_multipliers
+get_model_learning_rate = train_utils.get_model_learning_rate
+resolve_shape = preprocess_utils.resolve_shape
+
+
+def add_triplet_loss_for_each_scale(batch_size, num_frames_per_video,
+ embedding_dim, scales_to_embeddings,
+ labels, scope):
+ """Adds triplet loss for logits of each scale.
+
+ Args:
+ batch_size: Int, the number of video chunks sampled per batch
+ num_frames_per_video: Int, the number of frames per video.
+ embedding_dim: Int, the dimension of the learned embedding
+ scales_to_embeddings: A map from embedding names for different scales to
+ embeddings. The embeddings have shape [batch, embeddings_height,
+ embeddings_width, embedding_dim].
+ labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
+ scope: String, the scope for the loss.
+
+ Raises:
+ ValueError: labels is None.
+ """
+ if labels is None:
+ raise ValueError('No label for triplet loss.')
+ for scale, embeddings in scales_to_embeddings.iteritems():
+ loss_scope = None
+ if scope:
+ loss_scope = '%s_%s' % (scope, scale)
+ # Label is downsampled to the same size as logits.
+ scaled_labels = tf.image.resize_nearest_neighbor(
+ labels,
+ resolve_shape(embeddings, 4)[1:3],
+ align_corners=True)
+ # Reshape from [batch * num_frames, ...] to [batch, num_frames, ...].
+ h = tf.shape(embeddings)[1]
+ w = tf.shape(embeddings)[2]
+ new_labels_shape = tf.stack([batch_size, num_frames_per_video, h, w, 1])
+ reshaped_labels = tf.reshape(scaled_labels, new_labels_shape)
+ new_embeddings_shape = tf.stack([batch_size, num_frames_per_video, h, w,
+ -1])
+ reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
+
+ with tf.name_scope(loss_scope):
+ total_loss = tf.constant(0, dtype=tf.float32)
+ for n in range(batch_size):
+ embedding = reshaped_embeddings[n]
+ label = reshaped_labels[n]
+ n_pixels = h * w
+ n_anchors_used = 256
+ sampled_anchor_indices = tf.random_shuffle(tf.range(n_pixels))[
+ :n_anchors_used]
+ anchors_pool = tf.reshape(embedding[0], [-1, embedding_dim])
+ anchors_pool_classes = tf.reshape(label[0], [-1])
+ anchors = tf.gather(anchors_pool, sampled_anchor_indices)
+ anchor_classes = tf.gather(anchors_pool_classes, sampled_anchor_indices)
+
+ pos_neg_pool = tf.reshape(embedding[1:], [-1, embedding_dim])
+ pos_neg_pool_classes = tf.reshape(label[1:], [-1])
+ dists = embedding_utils.pairwise_distances(anchors, pos_neg_pool)
+ pos_mask = tf.equal(anchor_classes[:, tf.newaxis],
+ pos_neg_pool_classes[tf.newaxis, :])
+ neg_mask = tf.logical_not(pos_mask)
+ pos_mask_f = tf.cast(pos_mask, tf.float32)
+ neg_mask_f = tf.cast(neg_mask, tf.float32)
+ pos_dists = pos_mask_f * dists + 1e20 * neg_mask_f
+ neg_dists = neg_mask_f * dists + 1e20 * pos_mask_f
+ pos_dists_min = tf.reduce_min(pos_dists, axis=1)
+ neg_dists_min = tf.reduce_min(neg_dists, axis=1)
+ margin = 1.0
+ loss = tf.nn.relu(pos_dists_min - neg_dists_min + margin)
+ # Handle case that no positive is present (per anchor).
+ any_pos = tf.reduce_any(pos_mask, axis=1)
+ loss *= tf.cast(any_pos, tf.float32)
+ # Average over anchors
+ loss = tf.reduce_mean(loss, axis=0)
+ total_loss += loss
+ total_loss /= batch_size
+ # Scale the loss up a bit.
+ total_loss *= 3.0
+ tf.add_to_collection(tf.GraphKeys.LOSSES, total_loss)
+
+
+def add_dynamic_softmax_cross_entropy_loss_for_each_scale(
+ scales_to_logits, labels, ignore_label, loss_weight=1.0,
+ upsample_logits=True, scope=None, top_k_percent_pixels=1.0,
+ hard_example_mining_step=100000):
+ """Adds softmax cross entropy loss per scale for logits with varying classes.
+
+ Also adds summaries for mIoU.
+
+ Args:
+ scales_to_logits: A map from logits names for different scales to logits.
+ The logits are a list of length batch_size of tensors of shape
+ [time, logits_height, logits_width, num_classes].
+ labels: Groundtruth labels with shape [batch_size * time, image_height,
+ image_width, 1].
+ ignore_label: Integer, label to ignore.
+ loss_weight: Float, loss weight.
+ upsample_logits: Boolean, upsample logits or not.
+ scope: String, the scope for the loss.
+ top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
+ value < 1.0, only compute the loss for the top k percent pixels (e.g.,
+ the top 20% pixels). This is useful for hard pixel mining.
+ hard_example_mining_step: An integer, the training step in which the
+ hard exampling mining kicks off. Note that we gradually reduce the
+ mining percent to the top_k_percent_pixels. For example, if
+ hard_example_mining_step=100K and top_k_percent_pixels=0.25, then
+ mining percent will gradually reduce from 100% to 25% until 100K steps
+ after which we only mine top 25% pixels.
+
+ Raises:
+ ValueError: Label or logits is None.
+ """
+ if labels is None:
+ raise ValueError('No label for softmax cross entropy loss.')
+
+ if top_k_percent_pixels < 0 or top_k_percent_pixels > 1:
+ raise ValueError('Unexpected value of top_k_percent_pixels.')
+
+ for scale, logits in six.iteritems(scales_to_logits):
+ loss_scope = None
+ if scope:
+ loss_scope = '%s_%s' % (scope, scale)
+
+ if upsample_logits:
+ # Label is not downsampled, and instead we upsample logits.
+ assert isinstance(logits, collections.Sequence)
+ logits = [tf.image.resize_bilinear(
+ x,
+ preprocess_utils.resolve_shape(labels, 4)[1:3],
+ align_corners=True) for x in logits]
+ scaled_labels = labels
+ else:
+ # Label is downsampled to the same size as logits.
+ assert isinstance(logits, collections.Sequence)
+ scaled_labels = tf.image.resize_nearest_neighbor(
+ labels,
+ preprocess_utils.resolve_shape(logits[0], 4)[1:3],
+ align_corners=True)
+
+ batch_size = len(logits)
+ num_time = preprocess_utils.resolve_shape(logits[0])[0]
+ reshaped_labels = tf.reshape(
+ scaled_labels, ([batch_size, num_time] +
+ preprocess_utils.resolve_shape(scaled_labels)[1:]))
+ for n, logits_n in enumerate(logits):
+ labels_n = reshaped_labels[n]
+ labels_n = tf.reshape(labels_n, shape=[-1])
+ not_ignore_mask = tf.to_float(tf.not_equal(labels_n,
+ ignore_label)) * loss_weight
+ num_classes_n = tf.shape(logits_n)[-1]
+ one_hot_labels = slim.one_hot_encoding(
+ labels_n, num_classes_n, on_value=1.0, off_value=0.0)
+ logits_n_flat = tf.reshape(logits_n, shape=[-1, num_classes_n])
+ if top_k_percent_pixels == 1.0:
+ tf.losses.softmax_cross_entropy(
+ one_hot_labels,
+ logits_n_flat,
+ weights=not_ignore_mask,
+ scope=loss_scope)
+ else:
+ # Only compute the loss for top k percent pixels.
+ # First, compute the loss for all pixels. Note we do not put the loss
+ # to loss_collection and set reduction = None to keep the shape.
+ num_pixels = tf.to_float(tf.shape(logits_n_flat)[0])
+ pixel_losses = tf.losses.softmax_cross_entropy(
+ one_hot_labels,
+ logits_n_flat,
+ weights=not_ignore_mask,
+ scope='pixel_losses',
+ loss_collection=None,
+ reduction=tf.losses.Reduction.NONE)
+ # Compute the top_k_percent pixels based on current training step.
+ if hard_example_mining_step == 0:
+ # Directly focus on the top_k pixels.
+ top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
+ else:
+ # Gradually reduce the mining percent to top_k_percent_pixels.
+ global_step = tf.to_float(tf.train.get_or_create_global_step())
+ ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
+ top_k_pixels = tf.to_int32(
+ (ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
+ _, top_k_indices = tf.nn.top_k(pixel_losses,
+ k=top_k_pixels,
+ sorted=True,
+ name='top_k_percent_pixels')
+ # Compute the loss for the top k percent pixels.
+ tf.losses.softmax_cross_entropy(
+ tf.gather(one_hot_labels, top_k_indices),
+ tf.gather(logits_n_flat, top_k_indices),
+ weights=tf.gather(not_ignore_mask, top_k_indices),
+ scope=loss_scope)
+
+ pred_n = tf.argmax(logits_n, axis=-1, output_type=tf.int32)[
+ ..., tf.newaxis]
+ labels_n = labels[n * num_time: (n + 1) * num_time]
+ miou = eval_utils.calculate_multi_object_miou_tf(pred_n, labels_n)
+ tf.summary.scalar('miou', miou)
+
+
+def get_model_init_fn(train_logdir,
+ tf_initial_checkpoint,
+ initialize_last_layer,
+ last_layers,
+ ignore_missing_vars=False):
+ """Gets the function initializing model variables from a checkpoint.
+
+ Args:
+ train_logdir: Log directory for training.
+ tf_initial_checkpoint: TensorFlow checkpoint for initialization.
+ initialize_last_layer: Initialize last layer or not.
+ last_layers: Last layers of the model.
+ ignore_missing_vars: Ignore missing variables in the checkpoint.
+
+ Returns:
+ Initialization function.
+ """
+ if tf_initial_checkpoint is None:
+ tf.logging.info('Not initializing the model from a checkpoint.')
+ return None
+
+ if tf.train.latest_checkpoint(train_logdir):
+ tf.logging.info('Ignoring initialization; other checkpoint exists')
+ return None
+
+ tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
+
+ # Variables that will not be restored.
+ exclude_list = ['global_step']
+ if not initialize_last_layer:
+ exclude_list.extend(last_layers)
+
+ variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)
+
+ if variables_to_restore:
+ return slim.assign_from_checkpoint_fn(
+ tf_initial_checkpoint,
+ variables_to_restore,
+ ignore_missing_vars=ignore_missing_vars)
+ return None
diff --git a/models/research/feelvos/utils/video_input_generator.py b/models/research/feelvos/utils/video_input_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0135e50110c677865217c8a3f13d1d1d891f0b2
--- /dev/null
+++ b/models/research/feelvos/utils/video_input_generator.py
@@ -0,0 +1,558 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Wrapper for providing semantic segmentation video data."""
+
+import tensorflow as tf
+from feelvos import input_preprocess
+from feelvos import model
+from feelvos.utils import mask_damaging
+from feelvos.utils import train_utils
+
+slim = tf.contrib.slim
+dataset_data_provider = slim.dataset_data_provider
+
+
+MIN_LABEL_COUNT = 10
+
+
+def decode_image_sequence(tensor, image_format='jpeg', shape=None,
+ channels=3, raw_dtype=tf.uint8):
+ """Decodes a sequence of images.
+
+ Args:
+ tensor: the tensor of strings to decode, shape: [num_images]
+ image_format: a string (possibly tensor) with the format of the image.
+ Options include 'jpeg', 'png', and 'raw'.
+ shape: a list or tensor of the decoded image shape for a single image.
+ channels: if 'shape' is None, the third dimension of the image is set to
+ this value.
+ raw_dtype: if the image is encoded as raw bytes, this is the method of
+ decoding the bytes into values.
+ Returns:
+ The decoded images with shape [time, height, width, channels].
+ """
+ handler = slim.tfexample_decoder.Image(
+ shape=shape, channels=channels, dtype=raw_dtype, repeated=True)
+ return handler.tensors_to_item({'image/encoded': tensor,
+ 'image/format': image_format})
+
+
+def _get_data(data_provider, dataset_split, video_frames_are_decoded):
+ """Gets data from data provider.
+
+ Args:
+ data_provider: An object of slim.data_provider.
+ dataset_split: Dataset split.
+ video_frames_are_decoded: Boolean, whether the video frames are already
+ decoded
+
+ Returns:
+ image: Image Tensor.
+ label: Label Tensor storing segmentation annotations.
+ object_label: An integer refers to object_label according to labelmap. If
+ the example has more than one object_label, take the first one.
+ image_name: Image name.
+ height: Image height.
+ width: Image width.
+ video_id: String tensor representing the name of the video.
+
+ Raises:
+ ValueError: Failed to find label.
+ """
+
+ if video_frames_are_decoded:
+ image, = data_provider.get(['image'])
+ else:
+ image, = data_provider.get(['image/encoded'])
+
+ # Some datasets do not contain image_name.
+ if 'image_name' in data_provider.list_items():
+ image_name, = data_provider.get(['image_name'])
+ else:
+ image_name = tf.constant('')
+
+ height, width = data_provider.get(['height', 'width'])
+
+ label = None
+ if dataset_split != 'test':
+ if video_frames_are_decoded:
+ if 'labels_class' not in data_provider.list_items():
+ raise ValueError('Failed to find labels.')
+ label, = data_provider.get(['labels_class'])
+ else:
+ key = 'segmentation/object/encoded'
+ if key not in data_provider.list_items():
+ raise ValueError('Failed to find labels.')
+ label, = data_provider.get([key])
+
+ object_label = None
+ video_id, = data_provider.get(['video_id'])
+
+ return image, label, object_label, image_name, height, width, video_id
+
+
+def _has_foreground_and_background_in_first_frame(label, subsampling_factor):
+ """Checks if the labels have foreground and background in the first frame.
+
+ Args:
+ label: Label tensor of shape [num_frames, height, width, 1].
+ subsampling_factor: Integer, the subsampling factor.
+
+ Returns:
+ Boolean, whether the labels have foreground and background in the first
+ frame.
+ """
+ h, w = train_utils.resolve_shape(label)[1:3]
+ label_downscaled = tf.squeeze(
+ tf.image.resize_nearest_neighbor(label[0, tf.newaxis],
+ [h // subsampling_factor,
+ w // subsampling_factor],
+ align_corners=True),
+ axis=0)
+ is_bg = tf.equal(label_downscaled, 0)
+ is_fg = tf.logical_not(is_bg)
+ # Just using reduce_any was not robust enough, so lets make sure the count
+ # is above MIN_LABEL_COUNT.
+ fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
+ bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
+ has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
+ has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
+ return tf.logical_and(has_bg, has_fg)
+
+
+def _has_foreground_and_background_in_first_frame_2(label,
+ decoder_output_stride):
+ """Checks if the labels have foreground and background in the first frame.
+
+ Second attempt, this time we use the actual output dimension for resizing.
+
+ Args:
+ label: Label tensor of shape [num_frames, height, width, 1].
+ decoder_output_stride: Integer, the stride of the decoder output.
+
+ Returns:
+ Boolean, whether the labels have foreground and background in the first
+ frame.
+ """
+ h, w = train_utils.resolve_shape(label)[1:3]
+ h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
+ w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
+ label_downscaled = tf.squeeze(
+ tf.image.resize_nearest_neighbor(label[0, tf.newaxis], [h_sub, w_sub],
+ align_corners=True), axis=0)
+ is_bg = tf.equal(label_downscaled, 0)
+ is_fg = tf.logical_not(is_bg)
+ # Just using reduce_any was not robust enough, so lets make sure the count
+ # is above MIN_LABEL_COUNT.
+ fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
+ bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
+ has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
+ has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
+ return tf.logical_and(has_bg, has_fg)
+
+
+def _has_enough_pixels_of_each_object_in_first_frame(
+ label, decoder_output_stride):
+ """Checks if for each object (incl. background) enough pixels are visible.
+
+ During test time, we will usually not see a reference frame in which only
+ very few pixels of one object are visible. These cases can be problematic
+ during training, especially if more than the 1-nearest neighbor is used.
+ That's why this function can be used to detect and filter these cases.
+
+ Args:
+ label: Label tensor of shape [num_frames, height, width, 1].
+ decoder_output_stride: Integer, the stride of the decoder output.
+
+ Returns:
+ Boolean, whether the labels have enough pixels of each object in the first
+ frame.
+ """
+ h, w = train_utils.resolve_shape(label)[1:3]
+ h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
+ w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
+ label_downscaled = tf.squeeze(
+ tf.image.resize_nearest_neighbor(label[0, tf.newaxis], [h_sub, w_sub],
+ align_corners=True), axis=0)
+ _, _, counts = tf.unique_with_counts(
+ tf.reshape(label_downscaled, [-1]))
+ has_enough_pixels_per_object = tf.reduce_all(
+ tf.greater_equal(counts, MIN_LABEL_COUNT))
+ return has_enough_pixels_per_object
+
+
+def get(dataset,
+ num_frames_per_video,
+ crop_size,
+ batch_size,
+ min_resize_value=None,
+ max_resize_value=None,
+ resize_factor=None,
+ min_scale_factor=1.,
+ max_scale_factor=1.,
+ scale_factor_step_size=0,
+ preprocess_image_and_label=True,
+ num_readers=1,
+ num_threads=1,
+ dataset_split=None,
+ is_training=True,
+ model_variant=None,
+ batch_capacity_factor=32,
+ video_frames_are_decoded=False,
+ decoder_output_stride=None,
+ first_frame_finetuning=False,
+ sample_only_first_frame_for_finetuning=False,
+ sample_adjacent_and_consistent_query_frames=False,
+ remap_labels_to_reference_frame=True,
+ generate_prev_frame_mask_by_mask_damaging=False,
+ three_frame_dataset=False,
+ add_prev_frame_label=True):
+ """Gets the dataset split for semantic segmentation.
+
+ This functions gets the dataset split for semantic segmentation. In
+ particular, it is a wrapper of (1) dataset_data_provider which returns the raw
+ dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
+ Tensorflow operation of batching the preprocessed data. Then, the output could
+ be directly used by training, evaluation or visualization.
+
+ Args:
+ dataset: An instance of slim Dataset.
+ num_frames_per_video: The number of frames used per video
+ crop_size: Image crop size [height, width].
+ batch_size: Batch size.
+ min_resize_value: Desired size of the smaller image side.
+ max_resize_value: Maximum allowed size of the larger image side.
+ resize_factor: Resized dimensions are multiple of factor plus one.
+ min_scale_factor: Minimum scale factor value.
+ max_scale_factor: Maximum scale factor value.
+ scale_factor_step_size: The step size from min scale factor to max scale
+ factor. The input is randomly scaled based on the value of
+ (min_scale_factor, max_scale_factor, scale_factor_step_size).
+ preprocess_image_and_label: Boolean variable specifies if preprocessing of
+ image and label will be performed or not.
+ num_readers: Number of readers for data provider.
+ num_threads: Number of threads for batching data.
+ dataset_split: Dataset split.
+ is_training: Is training or not.
+ model_variant: Model variant (string) for choosing how to mean-subtract the
+ images. See feature_extractor.network_map for supported model variants.
+ batch_capacity_factor: Batch capacity factor affecting the training queue
+ batch capacity.
+ video_frames_are_decoded: Boolean, whether the video frames are already
+ decoded
+ decoder_output_stride: Integer, the stride of the decoder output.
+ first_frame_finetuning: Boolean, whether to only sample the first frame
+ for fine-tuning.
+ sample_only_first_frame_for_finetuning: Boolean, whether to only sample the
+ first frame during fine-tuning. This should be False when using lucid or
+ wonderland data, but true when fine-tuning on the first frame only.
+ Only has an effect if first_frame_finetuning is True.
+ sample_adjacent_and_consistent_query_frames: Boolean, if true, the query
+ frames (all but the first frame which is the reference frame) will be
+ sampled such that they are adjacent video frames and have the same
+ crop coordinates and flip augmentation.
+ remap_labels_to_reference_frame: Boolean, whether to remap the labels of
+ the query frames to match the labels of the (downscaled) reference frame.
+ If a query frame contains a label which is not present in the reference,
+ it will be mapped to background.
+ generate_prev_frame_mask_by_mask_damaging: Boolean, whether to generate
+ the masks used as guidance from the previous frame by damaging the
+ ground truth mask.
+ three_frame_dataset: Boolean, whether the dataset has exactly three frames
+ per video of which the first is to be used as reference and the two
+ others are consecutive frames to be used as query frames.
+ add_prev_frame_label: Boolean, whether to sample one more frame before the
+ first query frame to obtain a previous frame label. Only has an effect,
+ if sample_adjacent_and_consistent_query_frames is True and
+ generate_prev_frame_mask_by_mask_damaging is False.
+
+ Returns:
+ A dictionary of batched Tensors for semantic segmentation.
+
+ Raises:
+ ValueError: dataset_split is None, or Failed to find labels.
+ """
+ if dataset_split is None:
+ raise ValueError('Unknown dataset split.')
+ if model_variant is None:
+ tf.logging.warning('Please specify a model_variant. See '
+ 'feature_extractor.network_map for supported model '
+ 'variants.')
+
+ data_provider = dataset_data_provider.DatasetDataProvider(
+ dataset,
+ num_readers=num_readers,
+ num_epochs=None if is_training else 1,
+ shuffle=is_training)
+ image, label, object_label, image_name, height, width, video_id = _get_data(
+ data_provider, dataset_split, video_frames_are_decoded)
+
+ sampling_is_valid = tf.constant(True)
+ if num_frames_per_video is not None:
+ total_num_frames = tf.shape(image)[0]
+ if first_frame_finetuning or three_frame_dataset:
+ if sample_only_first_frame_for_finetuning:
+ assert not sample_adjacent_and_consistent_query_frames, (
+ 'this option does not make sense for sampling only first frame.')
+ # Sample the first frame num_frames_per_video times.
+ sel_indices = tf.tile(tf.constant(0, dtype=tf.int32)[tf.newaxis],
+ multiples=[num_frames_per_video])
+ else:
+ if sample_adjacent_and_consistent_query_frames:
+ if add_prev_frame_label:
+ num_frames_per_video += 1
+ # Since this is first frame fine-tuning, we'll for now assume that
+ # each sequence has exactly 3 images: the ref frame and 2 adjacent
+ # query frames.
+ assert num_frames_per_video == 3
+ with tf.control_dependencies([tf.assert_equal(total_num_frames, 3)]):
+ sel_indices = tf.constant([1, 2], dtype=tf.int32)
+ else:
+ # Sample num_frames_per_video - 1 query frames which are not the
+ # first frame.
+ sel_indices = tf.random_shuffle(
+ tf.range(1, total_num_frames))[:(num_frames_per_video - 1)]
+ # Concat first frame as reference frame to the front.
+ sel_indices = tf.concat([tf.constant(0, dtype=tf.int32)[tf.newaxis],
+ sel_indices], axis=0)
+ else:
+ if sample_adjacent_and_consistent_query_frames:
+ if add_prev_frame_label:
+ # Sample one more frame which we can use to provide initial softmax
+ # feedback.
+ num_frames_per_video += 1
+ ref_idx = tf.random_shuffle(tf.range(total_num_frames))[0]
+ sampling_is_valid = tf.greater_equal(total_num_frames,
+ num_frames_per_video)
+ def sample_query_start_idx():
+ return tf.random_shuffle(
+ tf.range(total_num_frames - num_frames_per_video + 1))[0]
+ query_start_idx = tf.cond(sampling_is_valid, sample_query_start_idx,
+ lambda: tf.constant(0, dtype=tf.int32))
+ def sample_sel_indices():
+ return tf.concat(
+ [ref_idx[tf.newaxis],
+ tf.range(
+ query_start_idx,
+ query_start_idx + (num_frames_per_video - 1))], axis=0)
+ sel_indices = tf.cond(
+ sampling_is_valid, sample_sel_indices,
+ lambda: tf.zeros((num_frames_per_video,), dtype=tf.int32))
+ else:
+ # Randomly sample some frames from the video.
+ sel_indices = tf.random_shuffle(
+ tf.range(total_num_frames))[:num_frames_per_video]
+ image = tf.gather(image, sel_indices, axis=0)
+ if not video_frames_are_decoded:
+ image = decode_image_sequence(image)
+
+ if label is not None:
+ if num_frames_per_video is not None:
+ label = tf.gather(label, sel_indices, axis=0)
+ if not video_frames_are_decoded:
+ label = decode_image_sequence(label, image_format='png', channels=1)
+
+ # Sometimes, label is saved as [num_frames_per_video, height, width] or
+ # [num_frames_per_video, height, width, 1]. We change it to be
+ # [num_frames_per_video, height, width, 1].
+ if label.shape.ndims == 3:
+ label = tf.expand_dims(label, 3)
+ elif label.shape.ndims == 4 and label.shape.dims[3] == 1:
+ pass
+ else:
+ raise ValueError('Input label shape must be '
+ '[num_frames_per_video, height, width],'
+ ' or [num_frames, height, width, 1]. '
+ 'Got {}'.format(label.shape.ndims))
+ label.set_shape([None, None, None, 1])
+
+ # Add size of first dimension since tf can't figure it out automatically.
+ image.set_shape((num_frames_per_video, None, None, None))
+ if label is not None:
+ label.set_shape((num_frames_per_video, None, None, None))
+
+ preceding_frame_label = None
+ if preprocess_image_and_label:
+ if num_frames_per_video is None:
+ raise ValueError('num_frame_per_video must be specified for preproc.')
+ original_images = []
+ images = []
+ labels = []
+ if sample_adjacent_and_consistent_query_frames:
+ num_frames_individual_preproc = 1
+ else:
+ num_frames_individual_preproc = num_frames_per_video
+ for frame_idx in range(num_frames_individual_preproc):
+ original_image_t, image_t, label_t = (
+ input_preprocess.preprocess_image_and_label(
+ image[frame_idx],
+ label[frame_idx],
+ crop_height=crop_size[0] if crop_size is not None else None,
+ crop_width=crop_size[1] if crop_size is not None else None,
+ min_resize_value=min_resize_value,
+ max_resize_value=max_resize_value,
+ resize_factor=resize_factor,
+ min_scale_factor=min_scale_factor,
+ max_scale_factor=max_scale_factor,
+ scale_factor_step_size=scale_factor_step_size,
+ ignore_label=dataset.ignore_label,
+ is_training=is_training,
+ model_variant=model_variant))
+ original_images.append(original_image_t)
+ images.append(image_t)
+ labels.append(label_t)
+ if sample_adjacent_and_consistent_query_frames:
+ imgs_for_preproc = [image[frame_idx] for frame_idx in
+ range(1, num_frames_per_video)]
+ labels_for_preproc = [label[frame_idx] for frame_idx in
+ range(1, num_frames_per_video)]
+ original_image_rest, image_rest, label_rest = (
+ input_preprocess.preprocess_images_and_labels_consistently(
+ imgs_for_preproc,
+ labels_for_preproc,
+ crop_height=crop_size[0] if crop_size is not None else None,
+ crop_width=crop_size[1] if crop_size is not None else None,
+ min_resize_value=min_resize_value,
+ max_resize_value=max_resize_value,
+ resize_factor=resize_factor,
+ min_scale_factor=min_scale_factor,
+ max_scale_factor=max_scale_factor,
+ scale_factor_step_size=scale_factor_step_size,
+ ignore_label=dataset.ignore_label,
+ is_training=is_training,
+ model_variant=model_variant))
+ original_images.extend(original_image_rest)
+ images.extend(image_rest)
+ labels.extend(label_rest)
+ assert len(original_images) == num_frames_per_video
+ assert len(images) == num_frames_per_video
+ assert len(labels) == num_frames_per_video
+
+ if remap_labels_to_reference_frame:
+ # Remap labels to indices into the labels of the (downscaled) reference
+ # frame, or 0, i.e. background, for labels which are not present
+ # in the reference.
+ reference_labels = labels[0][tf.newaxis]
+ h, w = train_utils.resolve_shape(reference_labels)[1:3]
+ embedding_height = model.scale_dimension(
+ h, 1.0 / decoder_output_stride)
+ embedding_width = model.scale_dimension(
+ w, 1.0 / decoder_output_stride)
+ reference_labels_embedding_size = tf.squeeze(
+ tf.image.resize_nearest_neighbor(
+ reference_labels, tf.stack([embedding_height, embedding_width]),
+ align_corners=True),
+ axis=0)
+ # Get sorted unique labels in the reference frame.
+ labels_in_ref_frame, _ = tf.unique(
+ tf.reshape(reference_labels_embedding_size, [-1]))
+ labels_in_ref_frame = tf.contrib.framework.sort(labels_in_ref_frame)
+ for idx in range(1, len(labels)):
+ ref_label_mask = tf.equal(
+ labels[idx],
+ labels_in_ref_frame[tf.newaxis, tf.newaxis, :])
+ remapped = tf.argmax(tf.cast(ref_label_mask, tf.uint8), axis=-1,
+ output_type=tf.int32)
+ # Set to 0 if label is not present
+ is_in_ref = tf.reduce_any(ref_label_mask, axis=-1)
+ remapped *= tf.cast(is_in_ref, tf.int32)
+ labels[idx] = remapped[..., tf.newaxis]
+
+ if sample_adjacent_and_consistent_query_frames:
+ if first_frame_finetuning and generate_prev_frame_mask_by_mask_damaging:
+ preceding_frame_label = mask_damaging.damage_masks(labels[1])
+ elif add_prev_frame_label:
+ # Discard the image of the additional frame and take the label as
+ # initialization for softmax feedback.
+ original_images = [original_images[0]] + original_images[2:]
+ preceding_frame_label = labels[1]
+ images = [images[0]] + images[2:]
+ labels = [labels[0]] + labels[2:]
+ num_frames_per_video -= 1
+
+ original_image = tf.stack(original_images, axis=0)
+ image = tf.stack(images, axis=0)
+ label = tf.stack(labels, axis=0)
+ else:
+ if label is not None:
+ # Need to set label shape due to batching.
+ label.set_shape([num_frames_per_video,
+ None if crop_size is None else crop_size[0],
+ None if crop_size is None else crop_size[1],
+ 1])
+ original_image = tf.to_float(tf.zeros_like(label))
+ if crop_size is None:
+ height = tf.shape(image)[1]
+ width = tf.shape(image)[2]
+ else:
+ height = crop_size[0]
+ width = crop_size[1]
+
+ sample = {'image': image,
+ 'image_name': image_name,
+ 'height': height,
+ 'width': width,
+ 'video_id': video_id}
+ if label is not None:
+ sample['label'] = label
+
+ if object_label is not None:
+ sample['object_label'] = object_label
+
+ if preceding_frame_label is not None:
+ sample['preceding_frame_label'] = preceding_frame_label
+
+ if not is_training:
+ # Original image is only used during visualization.
+ sample['original_image'] = original_image
+
+ if is_training:
+ if first_frame_finetuning:
+ keep_input = tf.constant(True)
+ else:
+ keep_input = tf.logical_and(sampling_is_valid, tf.logical_and(
+ _has_enough_pixels_of_each_object_in_first_frame(
+ label, decoder_output_stride),
+ _has_foreground_and_background_in_first_frame_2(
+ label, decoder_output_stride)))
+
+ batched = tf.train.maybe_batch(sample,
+ keep_input=keep_input,
+ batch_size=batch_size,
+ num_threads=num_threads,
+ capacity=batch_capacity_factor * batch_size,
+ dynamic_pad=True)
+ else:
+ batched = tf.train.batch(sample,
+ batch_size=batch_size,
+ num_threads=num_threads,
+ capacity=batch_capacity_factor * batch_size,
+ dynamic_pad=True)
+
+ # Flatten from [batch, num_frames_per_video, ...] to
+ # batch * num_frames_per_video, ...].
+ cropped_height = train_utils.resolve_shape(batched['image'])[2]
+ cropped_width = train_utils.resolve_shape(batched['image'])[3]
+ if num_frames_per_video is None:
+ first_dim = -1
+ else:
+ first_dim = batch_size * num_frames_per_video
+ batched['image'] = tf.reshape(batched['image'],
+ [first_dim, cropped_height, cropped_width, 3])
+ if label is not None:
+ batched['label'] = tf.reshape(batched['label'],
+ [first_dim, cropped_height, cropped_width, 1])
+ return batched
diff --git a/models/research/feelvos/vis_video.py b/models/research/feelvos/vis_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..211bccf52acdef83aca298285fc473748126de02
--- /dev/null
+++ b/models/research/feelvos/vis_video.py
@@ -0,0 +1,500 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Segmentation results evaluation and visualization for videos using attention.
+"""
+
+import math
+import os
+import time
+import numpy as np
+
+import tensorflow as tf
+
+from feelvos import common
+from feelvos import model
+from feelvos.datasets import video_dataset
+from feelvos.utils import embedding_utils
+from feelvos.utils import eval_utils
+from feelvos.utils import video_input_generator
+
+
+slim = tf.contrib.slim
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_integer('eval_interval_secs', 60 * 5,
+ 'How often (in seconds) to run evaluation.')
+
+flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
+
+flags.DEFINE_integer('vis_batch_size', 1,
+ 'The number of images in each batch during evaluation.')
+
+flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.')
+
+flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
+
+flags.DEFINE_integer('output_stride', 8,
+ 'The ratio of input to output spatial resolution.')
+
+flags.DEFINE_string('dataset', 'davis_2016',
+ 'Name of the segmentation dataset.')
+
+flags.DEFINE_string('vis_split', 'val',
+ 'Which split of the dataset used for visualizing results')
+
+flags.DEFINE_string(
+ 'dataset_dir',
+ '/cns/is-d/home/lcchen/data/pascal_voc_seg/example_sstables',
+ 'Where the dataset resides.')
+
+flags.DEFINE_integer('num_vis_examples', -1,
+ 'Number of examples for visualization. If -1, use all '
+ 'samples in the vis data.')
+
+flags.DEFINE_multi_integer('atrous_rates', None,
+ 'Atrous rates for atrous spatial pyramid pooling.')
+
+flags.DEFINE_bool('save_segmentations', False, 'Whether to save the '
+ 'segmentation masks as '
+ 'png images. Might be slow '
+ 'on cns.')
+
+flags.DEFINE_bool('save_embeddings', False, 'Whether to save the embeddings as'
+ 'pickle. Might be slow on cns.')
+
+flags.DEFINE_bool('eval_once_and_quit', False,
+ 'Whether to just run the eval a single time and quit '
+ 'afterwards. Otherwise, the eval is run in a loop with '
+ 'new checkpoints.')
+
+flags.DEFINE_boolean('first_frame_finetuning', False,
+ 'Whether to only sample the first frame for fine-tuning.')
+
+# the folder where segmentations are saved.
+_SEGMENTATION_SAVE_FOLDER = 'segmentation'
+_EMBEDDINGS_SAVE_FOLDER = 'embeddings'
+
+
+def _process_seq_data(segmentation_dir, embeddings_dir, seq_name,
+ predicted_labels, gt_labels, embeddings):
+ """Calculates the sequence IoU and optionally save the segmentation masks.
+
+ Args:
+ segmentation_dir: Directory in which the segmentation results are stored.
+ embeddings_dir: Directory in which the embeddings are stored.
+ seq_name: String, the name of the sequence.
+ predicted_labels: Int64 np.array of shape [n_frames, height, width].
+ gt_labels: Ground truth labels, Int64 np.array of shape
+ [n_frames, height, width].
+ embeddings: Float32 np.array of embeddings of shape
+ [n_frames, decoder_height, decoder_width, embedding_dim], or None.
+
+ Returns:
+ The IoU for the sequence (float).
+ """
+ sequence_dir = os.path.join(segmentation_dir, seq_name)
+ tf.gfile.MakeDirs(sequence_dir)
+ embeddings_seq_dir = os.path.join(embeddings_dir, seq_name)
+ tf.gfile.MakeDirs(embeddings_seq_dir)
+ label_set = np.unique(gt_labels[0])
+ ious = []
+ assert len(predicted_labels) == len(gt_labels)
+ if embeddings is not None:
+ assert len(predicted_labels) == len(embeddings)
+ for t, (predicted_label, gt_label) in enumerate(
+ zip(predicted_labels, gt_labels)):
+ if FLAGS.save_segmentations:
+ seg_filename = os.path.join(segmentation_dir, seq_name, '%05d.png' % t)
+ eval_utils.save_segmentation_with_colormap(seg_filename, predicted_label)
+ if FLAGS.save_embeddings:
+ embedding_filename = os.path.join(embeddings_dir, seq_name,
+ '%05d.npy' % t)
+ assert embeddings is not None
+ eval_utils.save_embeddings(embedding_filename, embeddings[t])
+ object_ious_t = eval_utils.calculate_multi_object_ious(
+ predicted_label, gt_label, label_set)
+ ious.append(object_ious_t)
+ # First and last frame are excluded in DAVIS eval.
+ seq_ious = np.mean(ious[1:-1], axis=0)
+ tf.logging.info('seq ious: %s %s', seq_name, seq_ious)
+ return seq_ious
+
+
+def create_predictions(samples, reference_labels, first_frame_img,
+ model_options):
+ """Predicts segmentation labels for each frame of the video.
+
+ Slower version than create_predictions_fast, but does support more options.
+
+ Args:
+ samples: Dictionary of input samples.
+ reference_labels: Int tensor of shape [1, height, width, 1].
+ first_frame_img: Float32 tensor of shape [height, width, 3].
+ model_options: An InternalModelOptions instance to configure models.
+
+ Returns:
+ predicted_labels: Int tensor of shape [time, height, width] of
+ predicted labels for each frame.
+ all_embeddings: Float32 tensor of shape
+ [time, height, width, embedding_dim], or None.
+ """
+
+ def predict(args, imgs):
+ """Predicts segmentation labels and softmax probabilities for each image.
+
+ Args:
+ args: A tuple of (predictions, softmax_probabilities), where predictions
+ is an int tensor of shape [1, h, w] and softmax_probabilities is a
+ float32 tensor of shape [1, h_decoder, w_decoder, n_objects].
+ imgs: Either a one-tuple of the image to predict labels for of shape
+ [h, w, 3], or pair of previous frame and current frame image.
+
+ Returns:
+ predictions: The predicted labels as int tensor of shape [1, h, w].
+ softmax_probabilities: The softmax probabilities of shape
+ [1, h_decoder, w_decoder, n_objects].
+ """
+ if FLAGS.save_embeddings:
+ last_frame_predictions, last_softmax_probabilities, _ = args
+ else:
+ last_frame_predictions, last_softmax_probabilities = args
+
+ if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
+ ref_labels_to_use = tf.concat(
+ [reference_labels, last_frame_predictions[..., tf.newaxis]],
+ axis=0)
+ else:
+ ref_labels_to_use = reference_labels
+
+ predictions, softmax_probabilities = model.predict_labels(
+ tf.stack((first_frame_img,) + imgs),
+ model_options=model_options,
+ image_pyramid=FLAGS.image_pyramid,
+ embedding_dimension=FLAGS.embedding_dimension,
+ reference_labels=ref_labels_to_use,
+ k_nearest_neighbors=FLAGS.k_nearest_neighbors,
+ use_softmax_feedback=FLAGS.use_softmax_feedback,
+ initial_softmax_feedback=last_softmax_probabilities,
+ embedding_seg_feature_dimension=
+ FLAGS.embedding_seg_feature_dimension,
+ embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
+ embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
+ embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
+ also_return_softmax_probabilities=True,
+ num_frames_per_video=
+ (3 if FLAGS.also_attend_to_previous_frame or
+ FLAGS.use_softmax_feedback else 2),
+ normalize_nearest_neighbor_distances=
+ FLAGS.normalize_nearest_neighbor_distances,
+ also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
+ use_local_previous_frame_attention=
+ FLAGS.use_local_previous_frame_attention,
+ previous_frame_attention_window_size=
+ FLAGS.previous_frame_attention_window_size,
+ use_first_frame_matching=FLAGS.use_first_frame_matching
+ )
+ predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.int32)
+
+ if FLAGS.save_embeddings:
+ names = [n.name for n in tf.get_default_graph().as_graph_def().node]
+ embedding_names = [x for x in names if 'embeddings' in x]
+ # This will crash when multi-scale inference is used.
+ assert len(embedding_names) == 1, len(embedding_names)
+ embedding_name = embedding_names[0] + ':0'
+ embeddings = tf.get_default_graph().get_tensor_by_name(embedding_name)
+ return predictions, softmax_probabilities, embeddings
+ else:
+ return predictions, softmax_probabilities
+
+ init_labels = tf.squeeze(reference_labels, axis=-1)
+ init_softmax = embedding_utils.create_initial_softmax_from_labels(
+ reference_labels, reference_labels, common.parse_decoder_output_stride(),
+ reduce_labels=False)
+ if FLAGS.save_embeddings:
+ decoder_height = tf.shape(init_softmax)[1]
+ decoder_width = tf.shape(init_softmax)[2]
+ n_frames = (3 if FLAGS.also_attend_to_previous_frame
+ or FLAGS.use_softmax_feedback else 2)
+ embeddings_init = tf.zeros((n_frames, decoder_height, decoder_width,
+ FLAGS.embedding_dimension))
+ init = (init_labels, init_softmax, embeddings_init)
+ else:
+ init = (init_labels, init_softmax)
+ # Do not eval the first frame again but concat the first frame ground
+ # truth instead.
+ if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
+ elems = (samples[common.IMAGE][:-1], samples[common.IMAGE][1:])
+ else:
+ elems = (samples[common.IMAGE][1:],)
+ res = tf.scan(predict, elems,
+ initializer=init,
+ parallel_iterations=1,
+ swap_memory=True)
+ if FLAGS.save_embeddings:
+ predicted_labels, _, all_embeddings = res
+ first_frame_embeddings = all_embeddings[0, 0, tf.newaxis]
+ other_frame_embeddings = all_embeddings[:, -1]
+ all_embeddings = tf.concat(
+ [first_frame_embeddings, other_frame_embeddings], axis=0)
+ else:
+ predicted_labels, _ = res
+ all_embeddings = None
+ predicted_labels = tf.concat([reference_labels[..., 0],
+ tf.squeeze(predicted_labels, axis=1)],
+ axis=0)
+ return predicted_labels, all_embeddings
+
+
+def create_predictions_fast(samples, reference_labels, first_frame_img,
+ model_options):
+ """Predicts segmentation labels for each frame of the video.
+
+ Faster version than create_predictions, but does not support all options.
+
+ Args:
+ samples: Dictionary of input samples.
+ reference_labels: Int tensor of shape [1, height, width, 1].
+ first_frame_img: Float32 tensor of shape [height, width, 3].
+ model_options: An InternalModelOptions instance to configure models.
+
+ Returns:
+ predicted_labels: Int tensor of shape [time, height, width] of
+ predicted labels for each frame.
+ all_embeddings: Float32 tensor of shape
+ [time, height, width, embedding_dim], or None.
+
+ Raises:
+ ValueError: If FLAGS.save_embeddings is True, FLAGS.use_softmax_feedback is
+ False, or FLAGS.also_attend_to_previous_frame is False.
+ """
+ if FLAGS.save_embeddings:
+ raise ValueError('save_embeddings does not work with '
+ 'create_predictions_fast. Use the slower '
+ 'create_predictions instead.')
+ if not FLAGS.use_softmax_feedback:
+ raise ValueError('use_softmax_feedback must be True for '
+ 'create_predictions_fast. Use the slower '
+ 'create_predictions instead.')
+ if not FLAGS.also_attend_to_previous_frame:
+ raise ValueError('also_attend_to_previous_frame must be True for '
+ 'create_predictions_fast. Use the slower '
+ 'create_predictions instead.')
+ # Extract embeddings for first frame and prepare initial predictions.
+ first_frame_embeddings = embedding_utils.get_embeddings(
+ first_frame_img[tf.newaxis], model_options, FLAGS.embedding_dimension)
+ init_labels = tf.squeeze(reference_labels, axis=-1)
+ init_softmax = embedding_utils.create_initial_softmax_from_labels(
+ reference_labels, reference_labels, common.parse_decoder_output_stride(),
+ reduce_labels=False)
+ init = (init_labels, init_softmax, first_frame_embeddings)
+
+ def predict(args, img):
+ """Predicts segmentation labels and softmax probabilities for each image.
+
+ Args:
+ args: tuple of
+ (predictions, softmax_probabilities, last_frame_embeddings), where
+ predictions is an int tensor of shape [1, h, w],
+ softmax_probabilities is a float32 tensor of shape
+ [1, h_decoder, w_decoder, n_objects],
+ and last_frame_embeddings is a float32 tensor of shape
+ [h_decoder, w_decoder, embedding_dimension].
+ img: Image to predict labels for of shape [h, w, 3].
+
+ Returns:
+ predictions: The predicted labels as int tensor of shape [1, h, w].
+ softmax_probabilities: The softmax probabilities of shape
+ [1, h_decoder, w_decoder, n_objects].
+ """
+ (last_frame_predictions, last_softmax_probabilities,
+ prev_frame_embeddings) = args
+ ref_labels_to_use = tf.concat(
+ [reference_labels, last_frame_predictions[..., tf.newaxis]],
+ axis=0)
+
+ predictions, softmax_probabilities, embeddings = model.predict_labels(
+ img[tf.newaxis],
+ model_options=model_options,
+ image_pyramid=FLAGS.image_pyramid,
+ embedding_dimension=FLAGS.embedding_dimension,
+ reference_labels=ref_labels_to_use,
+ k_nearest_neighbors=FLAGS.k_nearest_neighbors,
+ use_softmax_feedback=FLAGS.use_softmax_feedback,
+ initial_softmax_feedback=last_softmax_probabilities,
+ embedding_seg_feature_dimension=
+ FLAGS.embedding_seg_feature_dimension,
+ embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
+ embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
+ embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
+ also_return_softmax_probabilities=True,
+ num_frames_per_video=1,
+ normalize_nearest_neighbor_distances=
+ FLAGS.normalize_nearest_neighbor_distances,
+ also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
+ use_local_previous_frame_attention=
+ FLAGS.use_local_previous_frame_attention,
+ previous_frame_attention_window_size=
+ FLAGS.previous_frame_attention_window_size,
+ use_first_frame_matching=FLAGS.use_first_frame_matching,
+ also_return_embeddings=True,
+ ref_embeddings=(first_frame_embeddings, prev_frame_embeddings)
+ )
+ predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.int32)
+ return predictions, softmax_probabilities, embeddings
+
+ # Do not eval the first frame again but concat the first frame ground
+ # truth instead.
+ # If you have a lot of GPU memory, you can try to set swap_memory=False,
+ # and/or parallel_iterations=2.
+ elems = samples[common.IMAGE][1:]
+ res = tf.scan(predict, elems,
+ initializer=init,
+ parallel_iterations=1,
+ swap_memory=True)
+ predicted_labels, _, _ = res
+ predicted_labels = tf.concat([reference_labels[..., 0],
+ tf.squeeze(predicted_labels, axis=1)],
+ axis=0)
+ return predicted_labels
+
+
+def main(unused_argv):
+ if FLAGS.vis_batch_size != 1:
+ raise ValueError('Only batch size 1 is supported for now')
+
+ data_type = 'tf_sequence_example'
+ # Get dataset-dependent information.
+ dataset = video_dataset.get_dataset(
+ FLAGS.dataset,
+ FLAGS.vis_split,
+ dataset_dir=FLAGS.dataset_dir,
+ data_type=data_type)
+
+ # Prepare for visualization.
+ tf.gfile.MakeDirs(FLAGS.vis_logdir)
+ segmentation_dir = os.path.join(FLAGS.vis_logdir, _SEGMENTATION_SAVE_FOLDER)
+ tf.gfile.MakeDirs(segmentation_dir)
+ embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER)
+ tf.gfile.MakeDirs(embeddings_dir)
+ num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0)
+ else FLAGS.num_vis_examples)
+ if FLAGS.first_frame_finetuning:
+ num_vis_examples = 1
+
+ tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
+ g = tf.Graph()
+ with g.as_default():
+ # Without setting device to CPU we run out of memory.
+ with tf.device('cpu:0'):
+ samples = video_input_generator.get(
+ dataset,
+ None,
+ None,
+ FLAGS.vis_batch_size,
+ min_resize_value=FLAGS.min_resize_value,
+ max_resize_value=FLAGS.max_resize_value,
+ resize_factor=FLAGS.resize_factor,
+ dataset_split=FLAGS.vis_split,
+ is_training=False,
+ model_variant=FLAGS.model_variant,
+ preprocess_image_and_label=False,
+ remap_labels_to_reference_frame=False)
+ samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32)
+ samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32)
+ first_frame_img = samples[common.IMAGE][0]
+ reference_labels = samples[common.LABEL][0, tf.newaxis]
+ gt_labels = tf.squeeze(samples[common.LABEL], axis=-1)
+ seq_name = samples[common.VIDEO_ID][0]
+
+ model_options = common.VideoModelOptions(
+ outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
+ crop_size=None,
+ atrous_rates=FLAGS.atrous_rates,
+ output_stride=FLAGS.output_stride)
+
+ all_embeddings = None
+ predicted_labels = create_predictions_fast(
+ samples, reference_labels, first_frame_img, model_options)
+ # If you need more options like saving embeddings, replace the call to
+ # create_predictions_fast with create_predictions.
+
+ tf.train.get_or_create_global_step()
+ saver = tf.train.Saver(slim.get_variables_to_restore())
+ sv = tf.train.Supervisor(graph=g,
+ logdir=FLAGS.vis_logdir,
+ init_op=tf.global_variables_initializer(),
+ summary_op=None,
+ summary_writer=None,
+ global_step=None,
+ saver=saver)
+ num_batches = int(
+ math.ceil(num_vis_examples / float(FLAGS.vis_batch_size)))
+ last_checkpoint = None
+
+ # Infinite loop to visualize the results when new checkpoint is created.
+ while True:
+ last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
+ FLAGS.checkpoint_dir, last_checkpoint)
+ start = time.time()
+ tf.logging.info(
+ 'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
+ time.gmtime()))
+ tf.logging.info('Visualizing with model %s', last_checkpoint)
+
+ all_ious = []
+ with sv.managed_session(FLAGS.master,
+ start_standard_services=False) as sess:
+ sv.start_queue_runners(sess)
+ sv.saver.restore(sess, last_checkpoint)
+
+ for batch in range(num_batches):
+ ops = [predicted_labels, gt_labels, seq_name]
+ if FLAGS.save_embeddings:
+ ops.append(all_embeddings)
+ tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches)
+ res = sess.run(ops)
+ tf.logging.info('Forwarding done')
+ pred_labels_val, gt_labels_val, seq_name_val = res[:3]
+ if FLAGS.save_embeddings:
+ all_embeddings_val = res[3]
+ else:
+ all_embeddings_val = None
+ seq_ious = _process_seq_data(segmentation_dir, embeddings_dir,
+ seq_name_val, pred_labels_val,
+ gt_labels_val, all_embeddings_val)
+ all_ious.append(seq_ious)
+ all_ious = np.concatenate(all_ious, axis=0)
+ tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape, all_ious.mean())
+ tf.logging.info(
+ 'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
+ time.gmtime()))
+ result_dir = FLAGS.vis_logdir + '/results/'
+ tf.gfile.MakeDirs(result_dir)
+ with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f:
+ f.write(str(all_ious))
+ if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit:
+ break
+ time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
+ if time_to_next_eval > 0:
+ time.sleep(time_to_next_eval)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('checkpoint_dir')
+ flags.mark_flag_as_required('vis_logdir')
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run()
diff --git a/models/research/fivo/.gitattributes b/models/research/fivo/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..f706c0421d718f8af8e62d96d69101fe383d2b4f
--- /dev/null
+++ b/models/research/fivo/.gitattributes
@@ -0,0 +1,2 @@
+*.pkl binary
+*.tfrecord binary
diff --git a/models/research/fivo/.gitignore b/models/research/fivo/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..af2f537516daf33fdaf579436dfa33fdd9044f49
--- /dev/null
+++ b/models/research/fivo/.gitignore
@@ -0,0 +1,104 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+.static_storage/
+.media/
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
diff --git a/models/research/fivo/README.md b/models/research/fivo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..36d355b1b2961f2c8c8b721b5ce13c0c3eab1e8b
--- /dev/null
+++ b/models/research/fivo/README.md
@@ -0,0 +1,215 @@
+
+
+
+
+# Filtering Variational Objectives
+
+This folder contains a TensorFlow implementation of the algorithms from
+
+Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, and Yee Whye Teh. "Filtering Variational Objectives." NIPS 2017.
+
+[https://arxiv.org/abs/1705.09279](https://arxiv.org/abs/1705.09279)
+
+This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
+
+Additionally it contains several sequential latent variable model implementations:
+
+* Variational recurrent neural network (VRNN)
+* Stochastic recurrent neural network (SRNN)
+* Gaussian hidden Markov model with linear conditionals (GHMM)
+
+The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model.
+
+#### Directory Structure
+The important parts of the code are organized as follows.
+
+```
+run_fivo.py # main script, contains flag definitions
+fivo
+├─smc.py # a sequential Monte Carlo implementation
+├─bounds.py # code for computing each bound, uses smc.py
+├─runners.py # code for VRNN and SRNN training and evaluation
+├─ghmm_runners.py # code for GHMM training and evaluation
+├─data
+| ├─datasets.py # readers for pianoroll and speech datasets
+| ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
+| └─create_timit_dataset.py # preprocesses the TIMIT dataset
+└─models
+ ├─base.py # base classes used in other models
+ ├─vrnn.py # VRNN implementation
+ ├─srnn.py # SRNN implementation
+ └─ghmm.py # Gaussian hidden Markov model (GHMM) implementation
+bin
+├─run_train.sh # an example script that runs training
+├─run_eval.sh # an example script that runs evaluation
+├─run_sample.sh # an example script that runs sampling
+├─run_tests.sh # a script that runs all tests
+└─download_pianorolls.sh # a script that downloads pianoroll files
+```
+
+### Pianorolls
+
+Requirements before we start:
+
+* TensorFlow (see [tensorflow.org](http://tensorflow.org) for how to install)
+* [scipy](https://www.scipy.org/)
+* [sonnet](https://github.com/deepmind/sonnet)
+
+
+#### Download the Data
+
+The pianoroll datasets are encoded as pickled sparse arrays and are available at [http://www-etud.iro.umontreal.ca/~boulanni/icml2012](http://www-etud.iro.umontreal.ca/~boulanni/icml2012). You can use the script `bin/download_pianorolls.sh` to download the files into a directory of your choosing.
+```
+export PIANOROLL_DIR=~/pianorolls
+mkdir $PIANOROLL_DIR
+sh bin/download_pianorolls.sh $PIANOROLL_DIR
+```
+
+#### Preprocess the Data
+
+The script `calculate_pianoroll_mean.py` loads a pianoroll pickle file, calculates the mean, updates the pickle file to include the mean under the key `train_mean`, and writes the file back to disk in-place. You should do this for all pianoroll datasets you wish to train on.
+
+```
+python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/piano-midi.de.pkl
+python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/nottingham.de.pkl
+python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/musedata.pkl
+python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl
+```
+
+#### Training
+
+Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`:
+```
+python run_fivo.py \
+ --mode=train \
+ --logdir=/tmp/fivo \
+ --model=vrnn \
+ --bound=fivo \
+ --summarize_every=100 \
+ --batch_size=4 \
+ --num_samples=4 \
+ --learning_rate=0.0001 \
+ --dataset_path="$PIANOROLL_DIR/jsb.pkl" \
+ --dataset_type="pianoroll"
+```
+
+You should see output that looks something like this (with extra logging cruft):
+
+```
+Saving checkpoints for 0 into /tmp/fivo/model.ckpt.
+Step 1, fivo bound per timestep: -11.322491
+global_step/sec: 7.49971
+Step 101, fivo bound per timestep: -11.399275
+global_step/sec: 8.04498
+Step 201, fivo bound per timestep: -11.174991
+global_step/sec: 8.03989
+Step 301, fivo bound per timestep: -11.073008
+```
+#### Evaluation
+
+You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
+
+```
+python run_fivo.py \
+ --mode=eval \
+ --split=test \
+ --alsologtostderr \
+ --logdir=/tmp/fivo \
+ --model=vrnn \
+ --batch_size=4 \
+ --num_samples=4 \
+ --dataset_path="$PIANOROLL_DIR/jsb.pkl" \
+ --dataset_type="pianoroll"
+```
+
+You should see output like this:
+```
+Restoring parameters from /tmp/fivo/model.ckpt-0
+Model restored from step 0, evaluating.
+test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776
+test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141
+```
+The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
+
+#### Sampling
+
+You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`:
+```
+python run_fivo.py \
+ --mode=sample \
+ --alsologtostderr \
+ --logdir="/tmp/fivo" \
+ --model=vrnn \
+ --bound=fivo \
+ --batch_size=4 \
+ --num_samples=4 \
+ --split=test \
+ --dataset_path="$PIANOROLL_DIR/jsb.pkl" \
+ --dataset_type="pianoroll" \
+ --prefix_length=25 \
+ --sample_length=50
+```
+
+Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix.
+
+You should see very little output.
+```
+Restoring parameters from /tmp/fivo/model.ckpt-0
+Running local_init_op.
+Done running local_init_op.
+```
+
+Loading the samples with `np.load` confirms that we conditioned the model on 4
+prefixes of length 25 and sampled 4 sequences of length 50 for each prefix.
+```
+>>> import numpy as np
+>>> x = np.load("/tmp/fivo/samples.npz")
+>>> x[()]['prefixes'].shape
+(25, 4, 88)
+>>> x[()]['samples'].shape
+(50, 4, 4, 88)
+```
+
+### Training on TIMIT
+
+The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
+
+#### Preprocess TIMIT
+
+We preprocess TIMIT (as described in our paper) and write it out to a series of TFRecord files. To prepare the TIMIT dataset use the script `create_timit_dataset.py`
+```
+export $TIMIT_DIR=~/timit_dataset
+mkdir $TIMIT_DIR
+python data/create_timit_dataset.py \
+ --raw_timit_dir=$RAW_TIMIT_DIR \
+ --out_dir=$TIMIT_DIR
+```
+You should see this exact output:
+```
+4389 train / 231 valid / 1680 test
+train mean: 0.006060 train std: 548.136169
+```
+
+#### Training on TIMIT
+This is very similar to training on pianoroll datasets, with just a few flags switched.
+```
+python run_fivo.py \
+ --mode=train \
+ --logdir=/tmp/fivo \
+ --model=vrnn \
+ --bound=fivo \
+ --summarize_every=100 \
+ --batch_size=4 \
+ --num_samples=4 \
+ --learning_rate=0.0001 \
+ --dataset_path="$TIMIT_DIR/train" \
+ --dataset_type="speech"
+```
+Evaluation and sampling are similar.
+
+### Tests
+This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code.
+
+### Contact
+
+This codebase is maintained by Dieterich Lawson. For questions and issues please open an issue on the tensorflow/models issues tracker and assign it to @dieterichlawson.
diff --git a/models/research/fivo/bin/download_pianorolls.sh b/models/research/fivo/bin/download_pianorolls.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ef7050b4df5fb9815be04d133e659fa31d8d055e
--- /dev/null
+++ b/models/research/fivo/bin/download_pianorolls.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# A script to download the pianoroll datasets.
+# Accepts one argument, the directory to put the files in
+
+if [ -z "$1" ]
+ then
+ echo "Error, must provide a directory to download the files to."
+ exit
+fi
+
+echo "Downloading datasets into $1"
+curl -s "http://www-etud.iro.umontreal.ca/~boulanni/Piano-midi.de.pickle" > $1/piano-midi.de.pkl
+curl -s "http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.pickle" > $1/nottingham.pkl
+curl -s "http://www-etud.iro.umontreal.ca/~boulanni/MuseData.pickle" > $1/musedata.pkl
+curl -s "http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle" > $1/jsb.pkl
diff --git a/models/research/fivo/bin/run_eval.sh b/models/research/fivo/bin/run_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b30bcedc2d16e5bdd681386100ecca23612a139a
--- /dev/null
+++ b/models/research/fivo/bin/run_eval.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# An example of running evaluation.
+
+PIANOROLL_DIR=$HOME/pianorolls
+
+python run_fivo.py \
+ --mode=eval \
+ --logdir=/tmp/fivo \
+ --model=vrnn \
+ --batch_size=4 \
+ --num_samples=4 \
+ --split=test \
+ --dataset_path="$PIANOROLL_DIR/jsb.pkl" \
+ --dataset_type="pianoroll"
diff --git a/models/research/fivo/bin/run_sample.sh b/models/research/fivo/bin/run_sample.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e0c82a0cb137822e85035a23081ecf6408b7cca1
--- /dev/null
+++ b/models/research/fivo/bin/run_sample.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# An example of sampling from the model.
+
+PIANOROLL_DIR=$HOME/pianorolls
+
+python run_fivo.py \
+ --mode=sample \
+ --alsologtostderr \
+ --logdir="/tmp/fivo" \
+ --model=vrnn \
+ --bound=fivo \
+ --batch_size=4 \
+ --num_samples=4 \
+ --split=test \
+ --dataset_path="$PIANOROLL_DIR/jsb.pkl" \
+ --dataset_type="pianoroll" \
+ --prefix_length=25 \
+ --sample_length=50
diff --git a/models/research/fivo/bin/run_tests.sh b/models/research/fivo/bin/run_tests.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2ea58f016620db98e258494919c6d339b5fd996e
--- /dev/null
+++ b/models/research/fivo/bin/run_tests.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+python -m fivo.smc_test && \
+python -m fivo.bounds_test && \
+python -m fivo.nested_utils_test && \
+python -m fivo.data.datasets_test && \
+python -m fivo.models.ghmm_test && \
+python -m fivo.models.vrnn_test && \
+python -m fivo.models.srnn_test && \
+python -m fivo.ghmm_runners_test && \
+python -m fivo.runners_test
diff --git a/models/research/fivo/bin/run_train.sh b/models/research/fivo/bin/run_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a845959770c77cd99528005e1ee69e4593fcae0c
--- /dev/null
+++ b/models/research/fivo/bin/run_train.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# An example of running training.
+
+PIANOROLL_DIR=$HOME/pianorolls
+
+python run_fivo.py \
+ --mode=train \
+ --logdir=/tmp/fivo \
+ --model=vrnn \
+ --bound=fivo \
+ --summarize_every=100 \
+ --batch_size=4 \
+ --num_samples=4 \
+ --learning_rate=0.0001 \
+ --dataset_path="$PIANOROLL_DIR/jsb.pkl" \
+ --dataset_type="pianoroll"
diff --git a/models/research/fivo/experimental/README.md b/models/research/fivo/experimental/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..649de0ba95cdee2fa1b101a588dc48903b2ca13b
--- /dev/null
+++ b/models/research/fivo/experimental/README.md
@@ -0,0 +1 @@
+An experimental codebase for running simple examples.
diff --git a/models/research/fivo/experimental/bounds.py b/models/research/fivo/experimental/bounds.py
new file mode 100644
index 0000000000000000000000000000000000000000..afc970c59a1a86dbe8438b4e8bba791d3c95aa63
--- /dev/null
+++ b/models/research/fivo/experimental/bounds.py
@@ -0,0 +1,673 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+
+import tensorflow as tf
+import summary_utils as summ
+
+Loss = namedtuple("Loss", "name loss vars")
+Loss.__new__.__defaults__ = (tf.GraphKeys.TRAINABLE_VARIABLES,)
+
+
+def iwae(model, observation, num_timesteps, num_samples=1,
+ summarize=False):
+ """Compute the IWAE evidence lower bound.
+
+ Args:
+ model: A callable that computes one timestep of the model.
+ observation: A shape [batch_size*num_samples, state_size] Tensor
+ containing z_n, the observation for each sequence in the batch.
+ num_timesteps: The number of timesteps in each sequence, an integer.
+ num_samples: The number of samples to use to compute the IWAE bound.
+ Returns:
+ log_p_hat: The IWAE estimator of the lower bound on the log marginal.
+ loss: A tensor that you can perform gradient descent on to optimize the
+ bound.
+ maintain_ema_op: A no-op included for compatibility with FIVO.
+ states: The sequence of states sampled.
+ """
+ # Initialization
+ num_instances = tf.shape(observation)[0]
+ batch_size = tf.cast(num_instances / num_samples, tf.int32)
+ states = [model.zero_state(num_instances)]
+ log_weights = []
+ log_weight_acc = tf.zeros([num_samples, batch_size], dtype=observation.dtype)
+
+ for t in xrange(num_timesteps):
+ # run the model for one timestep
+ (zt, log_q_zt, log_p_zt, log_p_x_given_z, _) = model(
+ states[-1], observation, t)
+ # update accumulators
+ states.append(zt)
+ log_weight = log_p_zt + log_p_x_given_z - log_q_zt
+ log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
+ if summarize:
+ weight_dist = tf.contrib.distributions.Categorical(
+ logits=tf.transpose(log_weight_acc, perm=[1, 0]),
+ allow_nan_stats=False)
+ weight_entropy = weight_dist.entropy()
+ weight_entropy = tf.reduce_mean(weight_entropy)
+ tf.summary.scalar("weight_entropy/%d" % t, weight_entropy)
+ log_weights.append(log_weight_acc)
+ # Compute the lower bound on the log evidence.
+ log_p_hat = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
+ tf.log(tf.cast(num_samples, observation.dtype))) / num_timesteps
+ loss = -tf.reduce_mean(log_p_hat)
+ losses = [Loss("log_p_hat", loss)]
+
+ # we clip off the initial state before returning.
+ # there are no emas for iwae, so we return a noop for that
+ return log_p_hat, losses, tf.no_op(), states[1:], log_weights
+
+
+def multinomial_resampling(log_weights, states, n, b):
+ """Resample states with multinomial resampling.
+
+ Args:
+ log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
+ Categorical distribution.
+ states: A list of (b*n x d) Tensors that will be resample in from the groups
+ of every n-th row.
+
+ Returns:
+ resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
+ log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
+ resampling_parameters: The Tensor of parameters of the resampling distribution.
+ ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
+ resampling_dist: The distribution object for resampling.
+ """
+ log_weights = tf.convert_to_tensor(log_weights)
+ states = [tf.convert_to_tensor(state) for state in states]
+
+ resampling_parameters = tf.transpose(log_weights, perm=[1,0])
+ resampling_dist = tf.contrib.distributions.Categorical(logits=resampling_parameters)
+ ancestors = tf.stop_gradient(
+ resampling_dist.sample(sample_shape=n))
+ log_probs = resampling_dist.log_prob(ancestors)
+
+ offset = tf.expand_dims(tf.range(b), 0)
+ ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
+
+ resampled_states = []
+ for state in states:
+ resampled_states.append(tf.gather(state, ancestor_inds))
+ return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
+
+def stratified_resampling(log_weights, states, n, b):
+ """Resample states with straitified resampling.
+
+ Args:
+ log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
+ Categorical distribution.
+ states: A list of (b*n x d) Tensors that will be resample in from the groups
+ of every n-th row.
+
+ Returns:
+ resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
+ log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
+ resampling_parameters: The Tensor of parameters of the resampling distribution.
+ ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
+ resampling_dist: The distribution object for resampling.
+ """
+ log_weights = tf.convert_to_tensor(log_weights)
+ states = [tf.convert_to_tensor(state) for state in states]
+
+ log_weights = tf.transpose(log_weights, perm=[1,0])
+
+ probs = tf.nn.softmax(
+ tf.tile(tf.expand_dims(log_weights, axis=1),
+ [1, n, 1])
+ )
+
+ cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
+
+ bins = tf.range(n, dtype=probs.dtype) / n
+ bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
+
+ strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
+ resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
+
+ resampling_dist = tf.contrib.distributions.Categorical(
+ probs = resampling_parameters,
+ allow_nan_stats=False)
+
+ ancestors = tf.stop_gradient(
+ resampling_dist.sample())
+ log_probs = resampling_dist.log_prob(ancestors)
+
+ ancestors = tf.transpose(ancestors, perm=[1,0])
+ log_probs = tf.transpose(log_probs, perm=[1,0])
+
+ offset = tf.expand_dims(tf.range(b), 0)
+ ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
+
+ resampled_states = []
+ for state in states:
+ resampled_states.append(tf.gather(state, ancestor_inds))
+
+ return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
+
+def systematic_resampling(log_weights, states, n, b):
+ """Resample states with systematic resampling.
+
+ Args:
+ log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
+ Categorical distribution.
+ states: A list of (b*n x d) Tensors that will be resample in from the groups
+ of every n-th row.
+
+ Returns:
+ resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
+ log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
+ resampling_parameters: The Tensor of parameters of the resampling distribution.
+ ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
+ resampling_dist: The distribution object for resampling.
+ """
+
+ log_weights = tf.convert_to_tensor(log_weights)
+ states = [tf.convert_to_tensor(state) for state in states]
+
+ log_weights = tf.transpose(log_weights, perm=[1,0])
+
+ probs = tf.nn.softmax(
+ tf.tile(tf.expand_dims(log_weights, axis=1),
+ [1, n, 1])
+ )
+
+ cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
+
+ bins = tf.range(n, dtype=probs.dtype) / n
+ bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
+
+ strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
+ resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
+
+ resampling_dist = tf.contrib.distributions.Categorical(
+ probs=resampling_parameters,
+ allow_nan_stats=True)
+
+ U = tf.random_uniform((b, 1, 1), dtype=probs.dtype)
+
+ ancestors = tf.stop_gradient(tf.reduce_sum(tf.to_float(U > strat_cdfs[:,:,1:]), axis=-1))
+ log_probs = resampling_dist.log_prob(ancestors)
+
+ ancestors = tf.transpose(ancestors, perm=[1,0])
+ log_probs = tf.transpose(log_probs, perm=[1,0])
+
+ offset = tf.expand_dims(tf.range(b, dtype=probs.dtype), 0)
+ ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
+
+ resampled_states = []
+ for state in states:
+ resampled_states.append(tf.gather(state, ancestor_inds))
+
+ return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
+
+
+def log_blend(inputs, weights):
+ """Blends state in the log space.
+
+ Args:
+ inputs: A set of scalar states, one for each particle in each particle filter.
+ Should be [num_samples, batch_size].
+ weights: A set of weights used to blend the state. Each set of weights
+ should be of dimension [num_samples] (one weight for each previous particle).
+ There should be one set of weights for each new particle in each particle filter.
+ Thus the shape should be [num_samples, batch_size, num_samples] where
+ the first axis indexes new particle and the last axis indexes old particles.
+ Returns:
+ blended: The blended states, a tensor of shape [num_samples, batch_size].
+ """
+ raw_max = tf.reduce_max(inputs, axis=0, keepdims=True)
+ my_max = tf.stop_gradient(
+ tf.where(tf.is_finite(raw_max), raw_max, tf.zeros_like(raw_max))
+ )
+ # Don't ask.
+ blended = tf.log(tf.einsum("ijk,kj->ij", weights, tf.exp(inputs - raw_max))) + my_max
+ return blended
+
+
+def relaxed_resampling(log_weights, states, num_samples, batch_size,
+ log_r_x=None, blend_type="log", temperature=0.5,
+ straight_through=False):
+ """Resample states with relaxed resampling.
+
+ Args:
+ log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
+ Categorical distribution.
+ states: A list of (b*n x d) Tensors that will be resample in from the groups
+ of every n-th row.
+
+ Returns:
+ resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
+ log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
+ resampling_parameters: The Tensor of parameters of the resampling distribution.
+ ancestors: An (n x b x n) Tensor of relaxed one hot representations of the ancestry decisions.
+ resampling_dist: The distribution object for resampling.
+ """
+ assert blend_type in ["log", "linear"], "Blend type must be 'log' or 'linear'."
+ log_weights = tf.convert_to_tensor(log_weights)
+ states = [tf.convert_to_tensor(state) for state in states]
+ state_dim = states[0].get_shape().as_list()[-1]
+ # weights are num_samples by batch_size, so we transpose to get a
+ # set of batch_size distributions over [0,num_samples).
+ resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
+ resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
+ temperature,
+ logits=resampling_parameters)
+
+ # sample num_samples samples from the distribution, resulting in a
+ # [num_samples, batch_size, num_samples] Tensor that represents a set of
+ # [num_samples, batch_size] blending weights. The dimensions represent
+ # [sample index, batch index, blending weight index]
+ ancestors = resampling_dist.sample(sample_shape=num_samples)
+ if straight_through:
+ # Forward pass discrete choices, backwards pass soft choices
+ hard_ancestor_indices = tf.argmax(ancestors, axis=-1)
+ hard_ancestors = tf.one_hot(hard_ancestor_indices, num_samples,
+ dtype=ancestors.dtype)
+ ancestors = tf.stop_gradient(hard_ancestors - ancestors) + ancestors
+ log_probs = resampling_dist.log_prob(ancestors)
+ if log_r_x is not None and blend_type == "log":
+ log_r_x = tf.reshape(log_r_x, [num_samples, batch_size])
+ log_r_x = log_blend(log_r_x, ancestors)
+ log_r_x = tf.reshape(log_r_x, [num_samples*batch_size])
+ elif log_r_x is not None and blend_type == "linear":
+ # If blend type is linear just add log_r to the states that will be blended
+ # linearly.
+ states.append(log_r_x)
+
+ # transpose the 'indices' to be [batch_index, blending weight index, sample index]
+ ancestor_inds = tf.transpose(ancestors, perm=[1, 2, 0])
+ resampled_states = []
+ for state in states:
+ # state is currently [num_samples * batch_size, state_dim] so we reshape
+ # to [num_samples, batch_size, state_dim] and then transpose to
+ # [batch_size, state_size, num_samples]
+ state = tf.transpose(tf.reshape(state, [num_samples, batch_size, -1]), perm=[1, 2, 0])
+ # state is now (batch_size, state_size, num_samples)
+ # and ancestor is (batch index, blending weight index, sample index)
+ # multiplying these gives a matrix of size [batch_size, state_size, num_samples]
+ next_state = tf.matmul(state, ancestor_inds)
+ # transpose the state to be [num_samples, batch_size, state_size]
+ # and then reshape it to match the state format.
+ next_state = tf.reshape(tf.transpose(next_state, perm=[2,0,1]), [num_samples*batch_size, state_dim])
+ resampled_states.append(next_state)
+
+ new_dist = tf.contrib.distributions.Categorical(
+ logits=resampling_parameters)
+
+ if log_r_x is not None and blend_type == "linear":
+ # If blend type is linear pop off log_r that we added to the states.
+ log_r_x = tf.squeeze(resampled_states[-1])
+ resampled_states = resampled_states[:-1]
+ return resampled_states, log_probs, log_r_x, resampling_parameters, ancestors, new_dist
+
+
+def fivo(model,
+ observation,
+ num_timesteps,
+ resampling_schedule,
+ num_samples=1,
+ use_resampling_grads=True,
+ resampling_type="multinomial",
+ resampling_temperature=0.5,
+ aux=True,
+ summarize=False):
+ """Compute the FIVO evidence lower bound.
+
+ Args:
+ model: A callable that computes one timestep of the model.
+ observation: A shape [batch_size*num_samples, state_size] Tensor
+ containing z_n, the observation for each sequence in the batch.
+ num_timesteps: The number of timesteps in each sequence, an integer.
+ resampling_schedule: A list of booleans of length num_timesteps, contains
+ True if a resampling should occur on a specific timestep.
+ num_samples: The number of samples to use to compute the IWAE bound.
+ use_resampling_grads: Whether or not to include the resampling gradients
+ in loss.
+ resampling type: The type of resampling, one of "multinomial", "stratified",
+ "relaxed-logblend", "relaxed-linearblend", "relaxed-stateblend", or
+ "systematic".
+ resampling_temperature: A positive temperature only used for relaxed
+ resampling.
+ aux: If true, compute the FIVO-AUX bound.
+ Returns:
+ log_p_hat: The IWAE estimator of the lower bound on the log marginal.
+ loss: A tensor that you can perform gradient descent on to optimize the
+ bound.
+ maintain_ema_op: An op to update the baseline ema used for the resampling
+ gradients.
+ states: The sequence of states sampled.
+ """
+ # Initialization
+ num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
+ batch_size = tf.cast(num_instances / num_samples, tf.int32)
+ states = [model.zero_state(num_instances)]
+ prev_state = states[0]
+ log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
+ prev_log_r_zt = tf.zeros([num_instances], dtype=observation.dtype)
+ log_weights = []
+ log_weights_all = []
+ log_p_hats = []
+ resampling_log_probs = []
+ for t in xrange(num_timesteps):
+ # run the model for one timestep
+ (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_zt) = model(
+ prev_state, observation, t)
+ # update accumulators
+ states.append(zt)
+ log_weight = log_p_zt + log_p_x_given_z - log_q_zt
+ if aux:
+ if t == num_timesteps - 1:
+ log_weight -= prev_log_r_zt
+ else:
+ log_weight += log_r_zt - prev_log_r_zt
+ prev_log_r_zt = log_r_zt
+ log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
+ log_weights_all.append(log_weight_acc)
+ if resampling_schedule[t]:
+
+ # These objects will be resampled
+ to_resample = [states[-1]]
+ if aux and "relaxed" not in resampling_type:
+ to_resample.append(prev_log_r_zt)
+
+ # do the resampling
+ if resampling_type == "multinomial":
+ (resampled,
+ resampling_log_prob,
+ _, _, _) = multinomial_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size)
+ elif resampling_type == "stratified":
+ (resampled,
+ resampling_log_prob,
+ _, _, _) = stratified_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size)
+ elif resampling_type == "systematic":
+ (resampled,
+ resampling_log_prob,
+ _, _, _) = systematic_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size)
+ elif "relaxed" in resampling_type:
+ if aux:
+ if resampling_type == "relaxed-logblend":
+ (resampled,
+ resampling_log_prob,
+ prev_log_r_zt,
+ _, _, _) = relaxed_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size,
+ temperature=resampling_temperature,
+ log_r_x=prev_log_r_zt,
+ blend_type="log")
+ elif resampling_type == "relaxed-linearblend":
+ (resampled,
+ resampling_log_prob,
+ prev_log_r_zt,
+ _, _, _) = relaxed_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size,
+ temperature=resampling_temperature,
+ log_r_x=prev_log_r_zt,
+ blend_type="linear")
+ elif resampling_type == "relaxed-stateblend":
+ (resampled,
+ resampling_log_prob,
+ _, _, _, _) = relaxed_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size,
+ temperature=resampling_temperature)
+ # Calculate prev_log_r_zt from the post-resampling state
+ prev_r_zt = model.r.r_xn(resampled[0], t)
+ prev_log_r_zt = tf.reduce_sum(
+ prev_r_zt.log_prob(observation), axis=[1])
+ elif resampling_type == "relaxed-stateblend-st":
+ (resampled,
+ resampling_log_prob,
+ _, _, _, _) = relaxed_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size,
+ temperature=resampling_temperature,
+ straight_through=True)
+ # Calculate prev_log_r_zt from the post-resampling state
+ prev_r_zt = model.r.r_xn(resampled[0], t)
+ prev_log_r_zt = tf.reduce_sum(
+ prev_r_zt.log_prob(observation), axis=[1])
+ else:
+ (resampled,
+ resampling_log_prob,
+ _, _, _, _) = relaxed_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size,
+ temperature=resampling_temperature)
+ #if summarize:
+ # resampling_entropy = resampling_dist.entropy()
+ # resampling_entropy = tf.reduce_mean(resampling_entropy)
+ # tf.summary.scalar("weight_entropy/%d" % t, resampling_entropy)
+
+ resampling_log_probs.append(tf.reduce_sum(resampling_log_prob, axis=0))
+ prev_state = resampled[0]
+ if aux and "relaxed" not in resampling_type:
+ # Squeeze out the extra dim potentially added by resampling.
+ # prev_log_r_zt should always be [num_instances]
+ prev_log_r_zt = tf.squeeze(resampled[1])
+ # Update the log p hat estimate, taking a log sum exp over the sample
+ # dimension. The appended tensor is [batch_size].
+ log_p_hats.append(
+ tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
+ tf.cast(num_samples, dtype=observation.dtype)))
+ # reset the weights
+ log_weights.append(log_weight_acc)
+ log_weight_acc = tf.zeros_like(log_weight_acc)
+ else:
+ prev_state = states[-1]
+ # Compute the final weight update. If we just resampled this will be zero.
+ final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
+ tf.log(tf.cast(num_samples, dtype=observation.dtype)))
+ # If we ever resampled, then sum up the previous log p hat terms
+ if len(log_p_hats) > 0:
+ log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
+ else: # otherwise, log_p_hat only comes from the final update
+ log_p_hat = final_update
+
+ if use_resampling_grads and any(resampling_schedule):
+ # compute the rewards
+ # cumsum([a, b, c]) => [a, a+b, a+b+c]
+ # learning signal at timestep t is
+ # [sum from i=t+1 to T of log_p_hat_i for t=1:T]
+ # so we will compute (sum from i=1 to T of log_p_hat_i)
+ # and at timestep t will subtract off (sum from i=1 to t of log_p_hat_i)
+ # rewards is a [num_resampling_events, batch_size] Tensor
+ rewards = tf.stop_gradient(
+ tf.expand_dims(log_p_hat, 0) - tf.cumsum(log_p_hats, axis=0))
+ batch_avg_rewards = tf.reduce_mean(rewards, axis=1)
+ # compute ema baseline.
+ # centered_rewards is [num_resampling_events, batch_size]
+ baseline_ema = tf.train.ExponentialMovingAverage(decay=0.94)
+ maintain_baseline_op = baseline_ema.apply([batch_avg_rewards])
+ baseline = tf.expand_dims(baseline_ema.average(batch_avg_rewards), 1)
+ centered_rewards = rewards - baseline
+ if summarize:
+ summ.summarize_learning_signal(rewards, "rewards")
+ summ.summarize_learning_signal(centered_rewards, "centered_rewards")
+ # compute the loss tensor.
+ resampling_grads = tf.reduce_sum(
+ tf.stop_gradient(centered_rewards) * resampling_log_probs, axis=0)
+ losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps),
+ Loss("resampling_grads", -tf.reduce_mean(resampling_grads)/num_timesteps)]
+ else:
+ losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps)]
+ maintain_baseline_op = tf.no_op()
+
+ log_p_hat /= num_timesteps
+ # we clip off the initial state before returning.
+ return log_p_hat, losses, maintain_baseline_op, states[1:], log_weights_all
+
+
+def fivo_aux_td(
+ model,
+ observation,
+ num_timesteps,
+ resampling_schedule,
+ num_samples=1,
+ summarize=False):
+ """Compute the FIVO_AUX evidence lower bound."""
+ # Initialization
+ num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
+ batch_size = tf.cast(num_instances / num_samples, tf.int32)
+ states = [model.zero_state(num_instances)]
+ prev_state = states[0]
+ log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
+ prev_log_r = tf.zeros([num_instances], dtype=observation.dtype)
+ # must be pre-resampling
+ log_rs = []
+ # must be post-resampling
+ r_tilde_params = [model.r_tilde.r_zt(states[0], observation, 0)]
+ log_r_tildes = []
+ log_p_xs = []
+ # contains the weight at each timestep before resampling only on resampling timesteps
+ log_weights = []
+ # contains weight at each timestep before resampling
+ log_weights_all = []
+ log_p_hats = []
+ for t in xrange(num_timesteps):
+ # run the model for one timestep
+ # zt is state, [num_instances, state_dim]
+ # log_q_zt, log_p_x_given_z is [num_instances]
+ # r_tilde_mu, r_tilde_sigma is [num_instances, state_dim]
+ # p_ztplus1 is a normal distribution on [num_instances, state_dim]
+ (zt, log_q_zt, log_p_zt, log_p_x_given_z,
+ r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) = model(prev_state, observation, t)
+
+ # Compute the log weight without log r.
+ log_weight = log_p_zt + log_p_x_given_z - log_q_zt
+
+ # Compute log r.
+ if t == num_timesteps - 1:
+ log_r = tf.zeros_like(prev_log_r)
+ else:
+ p_mu = p_ztplus1.mean()
+ p_sigma_sq = p_ztplus1.variance()
+ log_r = (tf.log(r_tilde_sigma_sq) -
+ tf.log(r_tilde_sigma_sq + p_sigma_sq) -
+ tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
+ log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
+
+ #log_weight += tf.stop_gradient(log_r - prev_log_r)
+ log_weight += log_r - prev_log_r
+ log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
+
+ # Update accumulators
+ states.append(zt)
+ log_weights_all.append(log_weight_acc)
+ log_p_xs.append(log_p_x_given_z)
+ log_rs.append(log_r)
+
+ # Compute log_r_tilde as [num_instances] Tensor.
+ prev_r_tilde_mu, prev_r_tilde_sigma_sq = r_tilde_params[-1]
+ prev_log_r_tilde = -0.5*tf.reduce_sum(
+ tf.square(zt - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
+ #tf.square(tf.stop_gradient(zt) - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
+ #tf.square(zt - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
+ log_r_tildes.append(prev_log_r_tilde)
+
+ # optionally resample
+ if resampling_schedule[t]:
+ # These objects will be resampled
+ if t < num_timesteps - 1:
+ to_resample = [zt, log_r, r_tilde_mu, r_tilde_sigma_sq]
+ else:
+ to_resample = [zt, log_r]
+ (resampled,
+ _, _, _, _) = multinomial_resampling(log_weight_acc,
+ to_resample,
+ num_samples,
+ batch_size)
+ prev_state = resampled[0]
+ # Squeeze out the extra dim potentially added by resampling.
+ # prev_log_r_zt and log_r_tilde should always be [num_instances]
+ prev_log_r = tf.squeeze(resampled[1])
+ if t < num_timesteps -1:
+ r_tilde_params.append((resampled[2], resampled[3]))
+ # Update the log p hat estimate, taking a log sum exp over the sample
+ # dimension. The appended tensor is [batch_size].
+ log_p_hats.append(
+ tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
+ tf.cast(num_samples, dtype=observation.dtype)))
+ # reset the weights
+ log_weights.append(log_weight_acc)
+ log_weight_acc = tf.zeros_like(log_weight_acc)
+ else:
+ prev_state = zt
+ prev_log_r = log_r
+ if t < num_timesteps - 1:
+ r_tilde_params.append((r_tilde_mu, r_tilde_sigma_sq))
+
+ # Compute the final weight update. If we just resampled this will be zero.
+ final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
+ tf.log(tf.cast(num_samples, dtype=observation.dtype)))
+ # If we ever resampled, then sum up the previous log p hat terms
+ if len(log_p_hats) > 0:
+ log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
+ else: # otherwise, log_p_hat only comes from the final update
+ log_p_hat = final_update
+
+ # Compute the bellman loss.
+ # Will remove the first timestep as it is not used.
+ # log p(x_t|z_t) is in row t-1.
+ log_p_x = tf.reshape(tf.stack(log_p_xs),
+ [num_timesteps, num_samples, batch_size])
+ # log r_t is contained in row t-1.
+ # last column is zeros (because at timestep T (num_timesteps) r is 1.
+ log_r = tf.reshape(tf.stack(log_rs),
+ [num_timesteps, num_samples, batch_size])
+ # [num_timesteps, num_instances]. log r_tilde_t is in row t-1.
+ log_r_tilde = tf.reshape(tf.stack(log_r_tildes),
+ [num_timesteps, num_samples, batch_size])
+ log_lambda = tf.reduce_mean(log_r_tilde - log_p_x - log_r, axis=1,
+ keepdims=True)
+ bellman_sos = tf.reduce_mean(tf.square(
+ log_r_tilde - tf.stop_gradient(log_lambda + log_p_x + log_r)), axis=[0, 1])
+ bellman_loss = tf.reduce_mean(bellman_sos)/num_timesteps
+ tf.summary.scalar("bellman_loss", bellman_loss)
+
+ if len(tf.get_collection("LOG_P_HAT_VARS")) == 0:
+ log_p_hat_collection = list(set(tf.trainable_variables()) -
+ set(tf.get_collection("R_TILDE_VARS")))
+ for v in log_p_hat_collection:
+ tf.add_to_collection("LOG_P_HAT_VARS", v)
+
+ log_p_hat /= num_timesteps
+ losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat), "LOG_P_HAT_VARS"),
+ Loss("bellman_loss", bellman_loss, "R_TILDE_VARS")]
+
+ return log_p_hat, losses, tf.no_op(), states[1:], log_weights_all
diff --git a/models/research/fivo/experimental/data.py b/models/research/fivo/experimental/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..0842f212991e1651a12cca239c5b8380fea9d0f8
--- /dev/null
+++ b/models/research/fivo/experimental/data.py
@@ -0,0 +1,192 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+import models
+
+
+def make_long_chain_dataset(
+ state_size=1,
+ num_obs=5,
+ steps_per_obs=3,
+ variance=1.,
+ observation_variance=1.,
+ batch_size=4,
+ num_samples=1,
+ observation_type=models.STANDARD_OBSERVATION,
+ transition_type=models.STANDARD_TRANSITION,
+ fixed_observation=None,
+ dtype="float32"):
+ """Creates a long chain data generating process.
+
+ Creates a tf.data.Dataset that provides batches of data from a long
+ chain.
+
+ Args:
+ state_size: The dimension of the state space of the process.
+ num_obs: The number of observations in the chain.
+ steps_per_obs: The number of steps between each observation.
+ variance: The variance of the normal distributions used at each timestep.
+ batch_size: The number of trajectories to include in each batch.
+ num_samples: The number of replicas of each trajectory to include in each
+ batch.
+ dtype: The datatype of the states and observations.
+ Returns:
+ dataset: A tf.data.Dataset that can be iterated over.
+ """
+ num_timesteps = num_obs * steps_per_obs
+ def data_generator():
+ """An infinite generator of latents and observations from the model."""
+ while True:
+ states = []
+ observations = []
+ # z0 ~ Normal(0, sqrt(variance)).
+ states.append(
+ np.random.normal(size=[state_size],
+ scale=np.sqrt(variance)).astype(dtype))
+ # start at 1 because we've already generated z0
+ # go to num_timesteps+1 because we want to include the num_timesteps-th step
+ for t in xrange(1, num_timesteps+1):
+ if transition_type == models.ROUND_TRANSITION:
+ loc = np.round(states[-1])
+ elif transition_type == models.STANDARD_TRANSITION:
+ loc = states[-1]
+ new_state = np.random.normal(size=[state_size],
+ loc=loc,
+ scale=np.sqrt(variance))
+ states.append(new_state.astype(dtype))
+ if t % steps_per_obs == 0:
+ if fixed_observation is None:
+ if observation_type == models.SQUARED_OBSERVATION:
+ loc = np.square(states[-1])
+ elif observation_type == models.ABS_OBSERVATION:
+ loc = np.abs(states[-1])
+ elif observation_type == models.STANDARD_OBSERVATION:
+ loc = states[-1]
+ new_obs = np.random.normal(size=[state_size],
+ loc=loc,
+ scale=np.sqrt(observation_variance)).astype(dtype)
+ else:
+ new_obs = np.ones([state_size])* fixed_observation
+
+ observations.append(new_obs)
+ yield states, observations
+
+ dataset = tf.data.Dataset.from_generator(
+ data_generator,
+ output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
+ output_shapes=([num_timesteps+1, state_size], [num_obs, state_size]))
+ dataset = dataset.repeat().batch(batch_size)
+
+ def tile_batch(state, observation):
+ state = tf.tile(state, [num_samples, 1, 1])
+ observation = tf.tile(observation, [num_samples, 1, 1])
+ return state, observation
+
+ dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
+ return dataset
+
+
+def make_dataset(bs=None,
+ state_size=1,
+ num_timesteps=10,
+ variance=1.,
+ prior_type="unimodal",
+ bimodal_prior_weight=0.5,
+ bimodal_prior_mean=1,
+ transition_type=models.STANDARD_TRANSITION,
+ fixed_observation=None,
+ batch_size=4,
+ num_samples=1,
+ dtype='float32'):
+ """Creates a data generating process.
+
+ Creates a tf.data.Dataset that provides batches of data.
+
+ Args:
+ bs: The parameters of the data generating process. If None, new bs are
+ randomly generated.
+ state_size: The dimension of the state space of the process.
+ num_timesteps: The length of the state sequences in the process.
+ variance: The variance of the normal distributions used at each timestep.
+ batch_size: The number of trajectories to include in each batch.
+ num_samples: The number of replicas of each trajectory to include in each
+ batch.
+ Returns:
+ bs: The true bs used to generate the data
+ dataset: A tf.data.Dataset that can be iterated over.
+ """
+
+ if bs is None:
+ bs = [np.random.uniform(size=[state_size]).astype(dtype) for _ in xrange(num_timesteps)]
+ tf.logging.info("data generating processs bs: %s",
+ np.array(bs).reshape(num_timesteps))
+
+
+ def data_generator():
+ """An infinite generator of latents and observations from the model."""
+ while True:
+ states = []
+ if prior_type == "unimodal" or prior_type == "nonlinear":
+ # Prior is Normal(0, sqrt(variance)).
+ states.append(np.random.normal(size=[state_size], scale=np.sqrt(variance)).astype(dtype))
+ elif prior_type == "bimodal":
+ if np.random.uniform() > bimodal_prior_weight:
+ loc = bimodal_prior_mean
+ else:
+ loc = - bimodal_prior_mean
+ states.append(np.random.normal(size=[state_size],
+ loc=loc,
+ scale=np.sqrt(variance)
+ ).astype(dtype))
+
+ for t in xrange(num_timesteps):
+ if transition_type == models.ROUND_TRANSITION:
+ loc = np.round(states[-1])
+ elif transition_type == models.STANDARD_TRANSITION:
+ loc = states[-1]
+ loc += bs[t]
+ new_state = np.random.normal(size=[state_size],
+ loc=loc,
+ scale=np.sqrt(variance)).astype(dtype)
+ states.append(new_state)
+
+ if fixed_observation is None:
+ observation = states[-1]
+ else:
+ observation = np.ones_like(states[-1]) * fixed_observation
+ yield np.array(states[:-1]), observation
+
+ dataset = tf.data.Dataset.from_generator(
+ data_generator,
+ output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
+ output_shapes=([num_timesteps, state_size], [state_size]))
+ dataset = dataset.repeat().batch(batch_size)
+
+ def tile_batch(state, observation):
+ state = tf.tile(state, [num_samples, 1, 1])
+ observation = tf.tile(observation, [num_samples, 1])
+ return state, observation
+
+ dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
+ return np.array(bs), dataset
diff --git a/models/research/fivo/experimental/models.py b/models/research/fivo/experimental/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..62801ca1ee145e64c80b66e0c83dd7d834ac0847
--- /dev/null
+++ b/models/research/fivo/experimental/models.py
@@ -0,0 +1,1227 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import sonnet as snt
+import tensorflow as tf
+import numpy as np
+import math
+
+SQUARED_OBSERVATION = "squared"
+ABS_OBSERVATION = "abs"
+STANDARD_OBSERVATION = "standard"
+OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
+
+ROUND_TRANSITION = "round"
+STANDARD_TRANSITION = "standard"
+TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
+
+
+class Q(object):
+
+ def __init__(self,
+ state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ dtype=tf.float32,
+ random_seed=None,
+ init_mu0_to_zero=False,
+ graph_collection_name="Q_VARS"):
+ self.sigma_min = sigma_min
+ self.dtype = dtype
+ self.graph_collection_name = graph_collection_name
+ initializers = []
+ for t in xrange(num_timesteps):
+ if t == 0 and init_mu0_to_zero:
+ initializers.append(
+ {"w": tf.zeros_initializer, "b": tf.zeros_initializer})
+ else:
+ initializers.append(
+ {"w": tf.random_uniform_initializer(seed=random_seed),
+ "b": tf.zeros_initializer})
+
+ def custom_getter(getter, *args, **kwargs):
+ out = getter(*args, **kwargs)
+ ref = tf.get_collection_ref(self.graph_collection_name)
+ if out not in ref:
+ ref.append(out)
+ return out
+
+ self.mus = [
+ snt.Linear(output_size=state_size,
+ initializers=initializers[t],
+ name="q_mu_%d" % t,
+ custom_getter=custom_getter
+ )
+ for t in xrange(num_timesteps)
+ ]
+ self.sigmas = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="q_sigma_%d" % (t + 1),
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
+ initializer=tf.random_uniform_initializer(seed=random_seed))
+ for t in xrange(num_timesteps)
+ ]
+
+ def q_zt(self, observation, prev_state, t):
+ batch_size = tf.shape(prev_state)[0]
+ q_mu = self.mus[t](tf.concat([observation, prev_state], axis=1))
+ q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
+ q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
+ q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
+ return q_zt
+
+ def summarize_weights(self):
+ for t, sigma in enumerate(self.sigmas):
+ tf.summary.scalar("q_sigma/%d" % t, sigma[0])
+ for t, f in enumerate(self.mus):
+ tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
+ tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
+ if t != 0:
+ tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[1,0])
+
+
+class PreviousStateQ(Q):
+
+ def q_zt(self, unused_observation, prev_state, t):
+ batch_size = tf.shape(prev_state)[0]
+ q_mu = self.mus[t](prev_state)
+ q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
+ q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
+ q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
+ return q_zt
+
+ def summarize_weights(self):
+ for t, sigma in enumerate(self.sigmas):
+ tf.summary.scalar("q_sigma/%d" % t, sigma[0])
+ for t, f in enumerate(self.mus):
+ tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
+ tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[0,0])
+
+
+class ObservationQ(Q):
+
+ def q_zt(self, observation, prev_state, t):
+ batch_size = tf.shape(prev_state)[0]
+ q_mu = self.mus[t](observation)
+ q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
+ q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
+ q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
+ return q_zt
+
+ def summarize_weights(self):
+ for t, sigma in enumerate(self.sigmas):
+ tf.summary.scalar("q_sigma/%d" % t, sigma[0])
+ for t, f in enumerate(self.mus):
+ tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
+ tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
+
+
+class SimpleMeanQ(object):
+
+ def __init__(self,
+ state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ dtype=tf.float32,
+ random_seed=None,
+ init_mu0_to_zero=False,
+ graph_collection_name="Q_VARS"):
+ self.sigma_min = sigma_min
+ self.dtype = dtype
+ self.graph_collection_name = graph_collection_name
+ initializers = []
+ for t in xrange(num_timesteps):
+ if t == 0 and init_mu0_to_zero:
+ initializers.append(tf.zeros_initializer)
+ else:
+ initializers.append(tf.random_uniform_initializer(seed=random_seed))
+
+ self.mus = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="q_mu_%d" % (t + 1),
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
+ initializer=initializers[t])
+ for t in xrange(num_timesteps)
+ ]
+ self.sigmas = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="q_sigma_%d" % (t + 1),
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
+ initializer=tf.random_uniform_initializer(seed=random_seed))
+ for t in xrange(num_timesteps)
+ ]
+
+ def q_zt(self, unused_observation, prev_state, t):
+ batch_size = tf.shape(prev_state)[0]
+ q_mu = tf.tile(self.mus[t][tf.newaxis, :], [batch_size, 1])
+ q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
+ q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
+ q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
+ return q_zt
+
+ def summarize_weights(self):
+ for t, sigma in enumerate(self.sigmas):
+ tf.summary.scalar("q_sigma/%d" % t, sigma[0])
+ for t, f in enumerate(self.mus):
+ tf.summary.scalar("q_mu/%d" % t, f[0])
+
+
+class R(object):
+
+ def __init__(self,
+ state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ dtype=tf.float32,
+ sigma_init=1.,
+ random_seed=None,
+ graph_collection_name="R_VARS"):
+ self.dtype = dtype
+ self.sigma_min = sigma_min
+ initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
+ "b": tf.zeros_initializer}
+ self.graph_collection_name=graph_collection_name
+
+ def custom_getter(getter, *args, **kwargs):
+ out = getter(*args, **kwargs)
+ ref = tf.get_collection_ref(self.graph_collection_name)
+ if out not in ref:
+ ref.append(out)
+ return out
+
+ self.mus= [
+ snt.Linear(output_size=state_size,
+ initializers=initializers,
+ name="r_mu_%d" % t,
+ custom_getter=custom_getter)
+ for t in xrange(num_timesteps)
+ ]
+
+ self.sigmas = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="r_sigma_%d" % (t + 1),
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
+ #initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
+ initializer=tf.constant_initializer(sigma_init))
+ for t in xrange(num_timesteps)
+ ]
+
+ def r_xn(self, z_t, t):
+ batch_size = tf.shape(z_t)[0]
+ r_mu = self.mus[t](z_t)
+ r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
+ r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
+ return tf.contrib.distributions.Normal(
+ loc=r_mu, scale=tf.sqrt(r_sigma))
+
+ def summarize_weights(self):
+ for t in range(len(self.mus) - 1):
+ tf.summary.scalar("r_mu/%d" % t, self.mus[t][0])
+ tf.summary.scalar("r_sigma/%d" % t, self.sigmas[t][0])
+
+
+class P(object):
+
+ def __init__(self,
+ state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ variance=1.0,
+ dtype=tf.float32,
+ random_seed=None,
+ trainable=True,
+ init_bs_to_zero=False,
+ graph_collection_name="P_VARS"):
+ self.state_size = state_size
+ self.num_timesteps = num_timesteps
+ self.sigma_min = sigma_min
+ self.dtype = dtype
+ self.variance = variance
+ self.graph_collection_name = graph_collection_name
+ if init_bs_to_zero:
+ initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
+ else:
+ initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
+
+ self.bs = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="p_b_%d" % (t + 1),
+ initializer=initializers[t],
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
+ trainable=trainable) for t in xrange(num_timesteps)
+ ]
+ self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
+
+ def posterior(self, observation, prev_state, t):
+ """Computes the true posterior p(z_t|z_{t-1}, z_n)."""
+ # bs[0] is really b_1
+ # Bs[i] is sum from k=i+1^n b_k
+ mu = observation - self.Bs[t]
+ if t > 0:
+ mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
+ mu /= float(self.num_timesteps - t + 1)
+ sigma = tf.ones_like(mu) * self.variance * (
+ float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
+ return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
+
+ def lookahead(self, state, t):
+ """Computes the true lookahead distribution p(z_n|z_t)."""
+ mu = state + self.Bs[t]
+ sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
+ return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
+
+ def likelihood(self, observation):
+ batch_size = tf.shape(observation)[0]
+ mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
+ sigma = tf.ones_like(mu) * self.variance * (self.num_timesteps + 1)
+ dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
+ # Average over the batch and take the sum over the state size
+ return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
+
+ def p_zt(self, prev_state, t):
+ """Computes the model p(z_t| z_{t-1})."""
+ batch_size = tf.shape(prev_state)[0]
+ if t > 0:
+ z_mu_p = prev_state + self.bs[t - 1]
+ else: # p(z_0) is Normal(0,1)
+ z_mu_p = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
+ p_zt = tf.contrib.distributions.Normal(
+ loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
+ return p_zt
+
+ def generative(self, unused_observation, z_nm1):
+ """Computes the model's generative distribution p(z_n| z_{n-1})."""
+ generative_p_mu = z_nm1 + self.bs[-1]
+ return tf.contrib.distributions.Normal(
+ loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
+
+
+class ShortChainNonlinearP(object):
+
+ def __init__(self,
+ state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ variance=1.0,
+ observation_variance=1.0,
+ transition_type=STANDARD_TRANSITION,
+ transition_dist=tf.contrib.distributions.Normal,
+ dtype=tf.float32,
+ random_seed=None):
+ self.state_size = state_size
+ self.num_timesteps = num_timesteps
+ self.sigma_min = sigma_min
+ self.dtype = dtype
+ self.variance = variance
+ self.observation_variance = observation_variance
+ self.transition_type = transition_type
+ self.transition_dist = transition_dist
+
+ def p_zt(self, prev_state, t):
+ """Computes the model p(z_t| z_{t-1})."""
+ batch_size = tf.shape(prev_state)[0]
+ if t > 0:
+ if self.transition_type == ROUND_TRANSITION:
+ loc = tf.round(prev_state)
+ tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
+ elif self.transition_type == STANDARD_TRANSITION:
+ loc = prev_state
+ tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
+ else: # p(z_0) is Normal(0,1)
+ loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
+ tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
+
+ p_zt = self.transition_dist(
+ loc=loc,
+ scale=tf.sqrt(tf.ones_like(loc) * self.variance))
+ return p_zt
+
+ def generative(self, unused_obs, z_ni):
+ """Computes the model's generative distribution p(x_i| z_{ni})."""
+ if self.transition_type == ROUND_TRANSITION:
+ loc = tf.round(z_ni)
+ elif self.transition_type == STANDARD_TRANSITION:
+ loc = z_ni
+ generative_sigma_sq = tf.ones_like(loc) * self.observation_variance
+ return self.transition_dist(
+ loc=loc, scale=tf.sqrt(generative_sigma_sq))
+
+
+class BimodalPriorP(object):
+
+ def __init__(self,
+ state_size,
+ num_timesteps,
+ mixing_coeff=0.5,
+ prior_mode_mean=1,
+ sigma_min=1e-5,
+ variance=1.0,
+ dtype=tf.float32,
+ random_seed=None,
+ trainable=True,
+ init_bs_to_zero=False,
+ graph_collection_name="P_VARS"):
+ self.state_size = state_size
+ self.num_timesteps = num_timesteps
+ self.sigma_min = sigma_min
+ self.dtype = dtype
+ self.variance = variance
+ self.mixing_coeff = mixing_coeff
+ self.prior_mode_mean = prior_mode_mean
+
+ if init_bs_to_zero:
+ initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
+ else:
+ initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
+
+ self.bs = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="b_%d" % (t + 1),
+ initializer=initializers[t],
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
+ trainable=trainable) for t in xrange(num_timesteps)
+ ]
+ self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
+
+ def posterior(self, observation, prev_state, t):
+ # NOTE: This is currently wrong, but would require a refactoring of
+ # summarize_q to fix as kl is not defined for a mixture
+ """Computes the true posterior p(z_t|z_{t-1}, z_n)."""
+ # bs[0] is really b_1
+ # Bs[i] is sum from k=i+1^n b_k
+ mu = observation - self.Bs[t]
+ if t > 0:
+ mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
+ mu /= float(self.num_timesteps - t + 1)
+ sigma = tf.ones_like(mu) * self.variance * (
+ float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
+ return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
+
+ def lookahead(self, state, t):
+ """Computes the true lookahead distribution p(z_n|z_t)."""
+ mu = state + self.Bs[t]
+ sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
+ return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
+
+ def likelihood(self, observation):
+ batch_size = tf.shape(observation)[0]
+ sum_of_bs = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
+ sigma = tf.ones_like(sum_of_bs) * self.variance * (self.num_timesteps + 1)
+ mu_pos = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean) + sum_of_bs
+ mu_neg = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean) + sum_of_bs
+ zn_pos = tf.contrib.distributions.Normal(
+ loc=mu_pos,
+ scale=tf.sqrt(sigma))
+ zn_neg = tf.contrib.distributions.Normal(
+ loc=mu_neg,
+ scale=tf.sqrt(sigma))
+ mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
+ mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
+ mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
+ zn_dist = tf.contrib.distributions.Mixture(
+ cat=mode_selection_dist,
+ components=[zn_pos, zn_neg],
+ validate_args=True)
+ # Average over the batch and take the sum over the state size
+ return tf.reduce_mean(tf.reduce_sum(zn_dist.log_prob(observation), axis=1))
+
+ def p_zt(self, prev_state, t):
+ """Computes the model p(z_t| z_{t-1})."""
+ batch_size = tf.shape(prev_state)[0]
+ if t > 0:
+ z_mu_p = prev_state + self.bs[t - 1]
+ p_zt = tf.contrib.distributions.Normal(
+ loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
+ return p_zt
+ else: # p(z_0) is mixture of two Normals
+ mu_pos = tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean
+ mu_neg = tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean
+ z0_pos = tf.contrib.distributions.Normal(
+ loc=mu_pos,
+ scale=tf.sqrt(tf.ones_like(mu_pos) * self.variance))
+ z0_neg = tf.contrib.distributions.Normal(
+ loc=mu_neg,
+ scale=tf.sqrt(tf.ones_like(mu_neg) * self.variance))
+ mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
+ mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
+ mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
+ z0_dist = tf.contrib.distributions.Mixture(
+ cat=mode_selection_dist,
+ components=[z0_pos, z0_neg],
+ validate_args=False)
+ return z0_dist
+
+ def generative(self, unused_observation, z_nm1):
+ """Computes the model's generative distribution p(z_n| z_{n-1})."""
+ generative_p_mu = z_nm1 + self.bs[-1]
+ return tf.contrib.distributions.Normal(
+ loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
+
+class Model(object):
+
+ def __init__(self,
+ p,
+ q,
+ r,
+ state_size,
+ num_timesteps,
+ dtype=tf.float32):
+ self.p = p
+ self.q = q
+ self.r = r
+ self.state_size = state_size
+ self.num_timesteps = num_timesteps
+ self.dtype = dtype
+
+ def zero_state(self, batch_size):
+ return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
+
+ def __call__(self, prev_state, observation, t):
+ # Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
+ q_zt = self.q.q_zt(observation, prev_state, t)
+ # Compute the p distribution over z, p(z_t|z_{t-1}).
+ p_zt = self.p.p_zt(prev_state, t)
+ # sample from q
+ zt = q_zt.sample()
+ r_xn = self.r.r_xn(zt, t)
+ # Calculate the logprobs and sum over the state size.
+ log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
+ log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
+ log_r_xn = tf.reduce_sum(r_xn.log_prob(observation), axis=1)
+ # If we're at the last timestep, also calc the logprob of the observation.
+ if t == self.num_timesteps - 1:
+ generative_dist = self.p.generative(observation, zt)
+ log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
+ else:
+ log_p_x_given_z = tf.zeros_like(log_q_zt)
+ return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
+
+ @staticmethod
+ def create(state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ r_sigma_init=1,
+ variance=1.0,
+ mixing_coeff=0.5,
+ prior_mode_mean=1.0,
+ dtype=tf.float32,
+ random_seed=None,
+ train_p=True,
+ p_type="unimodal",
+ q_type="normal",
+ observation_variance=1.0,
+ transition_type=STANDARD_TRANSITION,
+ use_bs=True):
+ if p_type == "unimodal":
+ p = P(state_size,
+ num_timesteps,
+ sigma_min=sigma_min,
+ variance=variance,
+ dtype=dtype,
+ random_seed=random_seed,
+ trainable=train_p,
+ init_bs_to_zero=not use_bs)
+ elif p_type == "bimodal":
+ p = BimodalPriorP(
+ state_size,
+ num_timesteps,
+ mixing_coeff=mixing_coeff,
+ prior_mode_mean=prior_mode_mean,
+ sigma_min=sigma_min,
+ variance=variance,
+ dtype=dtype,
+ random_seed=random_seed,
+ trainable=train_p,
+ init_bs_to_zero=not use_bs)
+ elif "nonlinear" in p_type:
+ if "cauchy" in p_type:
+ trans_dist = tf.contrib.distributions.Cauchy
+ else:
+ trans_dist = tf.contrib.distributions.Normal
+ p = ShortChainNonlinearP(
+ state_size,
+ num_timesteps,
+ sigma_min=sigma_min,
+ variance=variance,
+ observation_variance=observation_variance,
+ transition_type=transition_type,
+ transition_dist=trans_dist,
+ dtype=dtype,
+ random_seed=random_seed
+ )
+
+ if q_type == "normal":
+ q_class = Q
+ elif q_type == "simple_mean":
+ q_class = SimpleMeanQ
+ elif q_type == "prev_state":
+ q_class = PreviousStateQ
+ elif q_type == "observation":
+ q_class = ObservationQ
+
+ q = q_class(state_size,
+ num_timesteps,
+ sigma_min=sigma_min,
+ dtype=dtype,
+ random_seed=random_seed,
+ init_mu0_to_zero=not use_bs)
+ r = R(state_size,
+ num_timesteps,
+ sigma_min=sigma_min,
+ sigma_init=r_sigma_init,
+ dtype=dtype,
+ random_seed=random_seed)
+ model = Model(p, q, r, state_size, num_timesteps, dtype=dtype)
+ return model
+
+
+class BackwardsModel(object):
+
+ def __init__(self,
+ state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ dtype=tf.float32):
+ self.state_size = state_size
+ self.num_timesteps = num_timesteps
+ self.sigma_min = sigma_min
+ self.dtype = dtype
+ self.bs = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="b_%d" % (t + 1),
+ initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
+ ]
+ self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
+ self.q_mus = [
+ snt.Linear(output_size=state_size) for _ in xrange(num_timesteps)
+ ]
+ self.q_sigmas = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="q_sigma_%d" % (t + 1),
+ initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
+ ]
+ self.r_mus = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="r_mu_%d" % (t + 1),
+ initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
+ ]
+ self.r_sigmas = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="r_sigma_%d" % (t + 1),
+ initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
+ ]
+
+ def zero_state(self, batch_size):
+ return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
+
+ def posterior(self, unused_observation, prev_state, unused_t):
+ # TODO(dieterichl): Correct this.
+ return tf.contrib.distributions.Normal(
+ loc=tf.zeros_like(prev_state), scale=tf.zeros_like(prev_state))
+
+ def lookahead(self, state, unused_t):
+ # TODO(dieterichl): Correct this.
+ return tf.contrib.distributions.Normal(
+ loc=tf.zeros_like(state), scale=tf.zeros_like(state))
+
+ def q_zt(self, observation, next_state, t):
+ """Computes the variational posterior q(z_{t}|z_{t+1}, z_n)."""
+ t_backwards = self.num_timesteps - t - 1
+ batch_size = tf.shape(next_state)[0]
+ q_mu = self.q_mus[t_backwards](tf.concat([observation, next_state], axis=1))
+ q_sigma = tf.maximum(
+ tf.nn.softplus(self.q_sigmas[t_backwards]), self.sigma_min)
+ q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
+ q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
+ return q_zt
+
+ def p_zt(self, prev_state, t):
+ """Computes the model p(z_{t+1}| z_{t})."""
+ t_backwards = self.num_timesteps - t - 1
+ z_mu_p = prev_state + self.bs[t_backwards]
+ p_zt = tf.contrib.distributions.Normal(
+ loc=z_mu_p, scale=tf.ones_like(z_mu_p))
+ return p_zt
+
+ def generative(self, unused_observation, z_nm1):
+ """Computes the model's generative distribution p(z_n| z_{n-1})."""
+ generative_p_mu = z_nm1 + self.bs[-1]
+ return tf.contrib.distributions.Normal(
+ loc=generative_p_mu, scale=tf.ones_like(generative_p_mu))
+
+ def r(self, z_t, t):
+ t_backwards = self.num_timesteps - t - 1
+ batch_size = tf.shape(z_t)[0]
+ r_mu = tf.tile(self.r_mus[t_backwards][tf.newaxis, :], [batch_size, 1])
+ r_sigma = tf.maximum(
+ tf.nn.softplus(self.r_sigmas[t_backwards]), self.sigma_min)
+ r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
+ return tf.contrib.distributions.Normal(loc=r_mu, scale=tf.sqrt(r_sigma))
+
+ def likelihood(self, observation):
+ batch_size = tf.shape(observation)[0]
+ mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
+ sigma = tf.ones_like(mu) * (self.num_timesteps + 1)
+ dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
+ # Average over the batch and take the sum over the state size
+ return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
+
+ def __call__(self, next_state, observation, t):
+ # next state = z_{t+1}
+ # Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}).
+ q_zt = self.q_zt(observation, next_state, t)
+ # sample from q
+ zt = q_zt.sample()
+ # Compute the p distribution over z, p(z_{t+1}|z_{t}).
+ p_zt = self.p_zt(zt, t)
+ # Compute log p(z_{t+1} | z_t)
+ if t == 0:
+ log_p_zt = p_zt.log_prob(observation)
+ else:
+ log_p_zt = p_zt.log_prob(next_state)
+
+ # Compute r prior over zt
+ r_zt = self.r(zt, t)
+ log_r_zt = r_zt.log_prob(zt)
+ # Compute proposal density at zt
+ log_q_zt = q_zt.log_prob(zt)
+ # If we're at the last timestep, also calc the logprob of the observation.
+
+ if t == self.num_timesteps - 1:
+ p_z0_dist = tf.contrib.distributions.Normal(
+ loc=tf.zeros_like(zt), scale=tf.ones_like(zt))
+ z0_log_prob = p_z0_dist.log_prob(zt)
+ else:
+ z0_log_prob = tf.zeros_like(log_q_zt)
+ return (zt, log_q_zt, log_p_zt, z0_log_prob, log_r_zt)
+
+
+class LongChainP(object):
+
+ def __init__(self,
+ state_size,
+ num_obs,
+ steps_per_obs,
+ sigma_min=1e-5,
+ variance=1.0,
+ observation_variance=1.0,
+ observation_type=STANDARD_OBSERVATION,
+ transition_type=STANDARD_TRANSITION,
+ dtype=tf.float32,
+ random_seed=None):
+ self.state_size = state_size
+ self.steps_per_obs = steps_per_obs
+ self.num_obs = num_obs
+ self.num_timesteps = steps_per_obs*num_obs + 1
+ self.sigma_min = sigma_min
+ self.dtype = dtype
+ self.variance = variance
+ self.observation_variance = observation_variance
+ self.observation_type = observation_type
+ self.transition_type = transition_type
+
+ def likelihood(self, observations):
+ """Computes the model's true likelihood of the observations.
+
+ Args:
+ observations: A [batch_size, m, state_size] Tensor representing each of
+ the m observations.
+ Returns:
+ logprob: The true likelihood of the observations given the model.
+ """
+ raise ValueError("Likelihood is not defined for long-chain models")
+ # batch_size = tf.shape(observations)[0]
+ # mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype)
+ # sigma = np.fromfunction(
+ # lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1),
+ # [self.num_obs, self.num_obs])
+ # sigma += np.eye(self.num_obs)
+ # sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype)
+ # sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...],
+ # [batch_size, self.state_size, 1, 1])
+ # dist = tf.contrib.distributions.MultivariateNormalFullCovariance(
+ # loc=mu,
+ # covariance_matrix=sigma)
+ # Average over the batch and take the sum over the state size
+ #return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1))
+
+ def p_zt(self, prev_state, t):
+ """Computes the model p(z_t| z_{t-1})."""
+ batch_size = tf.shape(prev_state)[0]
+ if t > 0:
+ if self.transition_type == ROUND_TRANSITION:
+ loc = tf.round(prev_state)
+ tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
+ elif self.transition_type == STANDARD_TRANSITION:
+ loc = prev_state
+ tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
+ else: # p(z_0) is Normal(0,1)
+ loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
+ tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
+
+ p_zt = tf.contrib.distributions.Normal(
+ loc=loc,
+ scale=tf.sqrt(tf.ones_like(loc) * self.variance))
+ return p_zt
+
+ def generative(self, z_ni, t):
+ """Computes the model's generative distribution p(x_i| z_{ni})."""
+ if self.observation_type == SQUARED_OBSERVATION:
+ generative_mu = tf.square(z_ni)
+ tf.logging.info("p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)" % (t, t, t, self.variance))
+ elif self.observation_type == ABS_OBSERVATION:
+ generative_mu = tf.abs(z_ni)
+ tf.logging.info("p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)" % (t, t, t, self.variance))
+ elif self.observation_type == STANDARD_OBSERVATION:
+ generative_mu = z_ni
+ tf.logging.info("p(x_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t, t, self.variance))
+ generative_sigma_sq = tf.ones_like(generative_mu) * self.observation_variance
+ return tf.contrib.distributions.Normal(
+ loc=generative_mu, scale=tf.sqrt(generative_sigma_sq))
+
+
+class LongChainQ(object):
+
+ def __init__(self,
+ state_size,
+ num_obs,
+ steps_per_obs,
+ sigma_min=1e-5,
+ dtype=tf.float32,
+ random_seed=None):
+ self.state_size = state_size
+ self.sigma_min = sigma_min
+ self.dtype = dtype
+ self.steps_per_obs = steps_per_obs
+ self.num_obs = num_obs
+ self.num_timesteps = num_obs*steps_per_obs +1
+
+ initializers = {
+ "w": tf.random_uniform_initializer(seed=random_seed),
+ "b": tf.zeros_initializer
+ }
+ self.mus = [
+ snt.Linear(output_size=state_size, initializers=initializers)
+ for t in xrange(self.num_timesteps)
+ ]
+ self.sigmas = [
+ tf.get_variable(
+ shape=[state_size],
+ dtype=self.dtype,
+ name="q_sigma_%d" % (t + 1),
+ initializer=tf.random_uniform_initializer(seed=random_seed))
+ for t in xrange(self.num_timesteps)
+ ]
+
+ def first_relevant_obs_index(self, t):
+ return int(max((t-1)/self.steps_per_obs, 0))
+
+ def q_zt(self, observations, prev_state, t):
+ """Computes a distribution over z_t.
+
+ Args:
+ observations: a [batch_size, num_observations, state_size] Tensor.
+ prev_state: a [batch_size, state_size] Tensor.
+ t: The current timestep, an int Tensor.
+ """
+ # filter out unneeded past obs
+ first_relevant_obs_index = int(math.floor(max(t-1, 0) / self.steps_per_obs))
+ num_relevant_observations = self.num_obs - first_relevant_obs_index
+ observations = observations[:,first_relevant_obs_index:,:]
+ batch_size = tf.shape(prev_state)[0]
+ # concatenate the prev state and observations along the second axis (that is
+ # not the batch or state size axis, and then flatten it to
+ # [batch_size, (num_relevant_observations + 1) * state_size] to feed it into
+ # the linear layer.
+ q_input = tf.concat([observations, prev_state[:,tf.newaxis, :]], axis=1)
+ q_input = tf.reshape(q_input,
+ [batch_size, (num_relevant_observations + 1) * self.state_size])
+ q_mu = self.mus[t](q_input)
+ q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
+ q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
+ q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
+ tf.logging.info(
+ "q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})".format(
+ **{"t": t,
+ "tm1": t-1,
+ "obsf": (first_relevant_obs_index+1)*self.steps_per_obs,
+ "obst":self.steps_per_obs*self.num_obs}))
+ return q_zt
+
+ def summarize_weights(self):
+ pass
+
+class LongChainR(object):
+
+ def __init__(self,
+ state_size,
+ num_obs,
+ steps_per_obs,
+ sigma_min=1e-5,
+ dtype=tf.float32,
+ random_seed=None):
+ self.state_size = state_size
+ self.dtype = dtype
+ self.sigma_min = sigma_min
+ self.steps_per_obs = steps_per_obs
+ self.num_obs = num_obs
+ self.num_timesteps = num_obs*steps_per_obs + 1
+ self.sigmas = [
+ tf.get_variable(
+ shape=[self.num_future_obs(t)],
+ dtype=self.dtype,
+ name="r_sigma_%d" % (t + 1),
+ #initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
+ initializer=tf.constant_initializer(1.0))
+ for t in range(self.num_timesteps)
+ ]
+
+ def first_future_obs_index(self, t):
+ return int(math.floor(t / self.steps_per_obs))
+
+ def num_future_obs(self, t):
+ return int(self.num_obs - self.first_future_obs_index(t))
+
+ def r_xn(self, z_t, t):
+ """Computes a distribution over the future observations given current latent
+ state.
+
+ The indexing in these messages is 1 indexed and inclusive. This is
+ consistent with the latex documents.
+
+ Args:
+ z_t: [batch_size, state_size] Tensor
+ t: Current timestep
+ """
+ tf.logging.info(
+ "r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})".format(
+ **{"t": t,
+ "start": (self.first_future_obs_index(t)+1)*self.steps_per_obs,
+ "end": self.num_timesteps-1}))
+ batch_size = tf.shape(z_t)[0]
+ # the mean for all future observations is the same.
+ # this tiling results in a [batch_size, num_future_obs, state_size] Tensor
+ r_mu = tf.tile(z_t[:,tf.newaxis,:], [1, self.num_future_obs(t), 1])
+ # compute the variance
+ r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
+ # the variance is the same across all state dimensions, so we only have to
+ # time sigma to be [batch_size, num_future_obs].
+ r_sigma = tf.tile(r_sigma[tf.newaxis,:, tf.newaxis], [batch_size, 1, self.state_size])
+ return tf.contrib.distributions.Normal(
+ loc=r_mu, scale=tf.sqrt(r_sigma))
+
+ def summarize_weights(self):
+ pass
+
+
+class LongChainModel(object):
+
+ def __init__(self,
+ p,
+ q,
+ r,
+ state_size,
+ num_obs,
+ steps_per_obs,
+ dtype=tf.float32,
+ disable_r=False):
+ self.p = p
+ self.q = q
+ self.r = r
+ self.disable_r = disable_r
+ self.state_size = state_size
+ self.num_obs = num_obs
+ self.steps_per_obs = steps_per_obs
+ self.num_timesteps = steps_per_obs*num_obs + 1
+ self.dtype = dtype
+
+ def zero_state(self, batch_size):
+ return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
+
+ def next_obs_ind(self, t):
+ return int(math.floor(max(t-1,0)/self.steps_per_obs))
+
+ def __call__(self, prev_state, observations, t):
+ """Computes the importance weight for the model system.
+
+ Args:
+ prev_state: [batch_size, state_size] Tensor
+ observations: [batch_size, num_observations, state_size] Tensor
+ """
+ # Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
+ q_zt = self.q.q_zt(observations, prev_state, t)
+ # Compute the p distribution over z, p(z_t|z_{t-1}).
+ p_zt = self.p.p_zt(prev_state, t)
+ # sample from q and evaluate the logprobs, summing over the state size
+ zt = q_zt.sample()
+ log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
+ log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
+ if not self.disable_r and t < self.num_timesteps-1:
+ # score the remaining observations using r
+ r_xn = self.r.r_xn(zt, t)
+ log_r_xn = r_xn.log_prob(observations[:, self.next_obs_ind(t+1):, :])
+ # sum over state size and observation, leaving the batch index
+ log_r_xn = tf.reduce_sum(log_r_xn, axis=[1,2])
+ else:
+ log_r_xn = tf.zeros_like(log_p_zt)
+ if t != 0 and t % self.steps_per_obs == 0:
+ generative_dist = self.p.generative(zt, t)
+ log_p_x_given_z = generative_dist.log_prob(observations[:,self.next_obs_ind(t),:])
+ log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=1)
+ else:
+ log_p_x_given_z = tf.zeros_like(log_q_zt)
+ return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
+
+ @staticmethod
+ def create(state_size,
+ num_obs,
+ steps_per_obs,
+ sigma_min=1e-5,
+ variance=1.0,
+ observation_variance=1.0,
+ observation_type=STANDARD_OBSERVATION,
+ transition_type=STANDARD_TRANSITION,
+ dtype=tf.float32,
+ random_seed=None,
+ disable_r=False):
+ p = LongChainP(
+ state_size,
+ num_obs,
+ steps_per_obs,
+ sigma_min=sigma_min,
+ variance=variance,
+ observation_variance=observation_variance,
+ observation_type=observation_type,
+ transition_type=transition_type,
+ dtype=dtype,
+ random_seed=random_seed)
+ q = LongChainQ(
+ state_size,
+ num_obs,
+ steps_per_obs,
+ sigma_min=sigma_min,
+ dtype=dtype,
+ random_seed=random_seed)
+ r = LongChainR(
+ state_size,
+ num_obs,
+ steps_per_obs,
+ sigma_min=sigma_min,
+ dtype=dtype,
+ random_seed=random_seed)
+ model = LongChainModel(
+ p, q, r, state_size, num_obs, steps_per_obs,
+ dtype=dtype,
+ disable_r=disable_r)
+ return model
+
+
+class RTilde(object):
+
+ def __init__(self,
+ state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ dtype=tf.float32,
+ random_seed=None,
+ graph_collection_name="R_TILDE_VARS"):
+ self.dtype = dtype
+ self.sigma_min = sigma_min
+ initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
+ "b": tf.zeros_initializer}
+ self.graph_collection_name=graph_collection_name
+
+ def custom_getter(getter, *args, **kwargs):
+ out = getter(*args, **kwargs)
+ ref = tf.get_collection_ref(self.graph_collection_name)
+ if out not in ref:
+ ref.append(out)
+ return out
+
+ self.fns = [
+ snt.Linear(output_size=2*state_size,
+ initializers=initializers,
+ name="r_tilde_%d" % t,
+ custom_getter=custom_getter)
+ for t in xrange(num_timesteps)
+ ]
+
+ def r_zt(self, z_t, observation, t):
+ #out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1)))
+ out = self.fns[t](tf.concat([z_t, observation], axis=1))
+ mu, raw_sigma_sq = tf.split(out, 2, axis=1)
+ sigma_sq = tf.maximum(tf.nn.softplus(raw_sigma_sq), self.sigma_min)
+ return mu, sigma_sq
+
+class TDModel(object):
+
+ def __init__(self,
+ p,
+ q,
+ r_tilde,
+ state_size,
+ num_timesteps,
+ dtype=tf.float32,
+ disable_r=False):
+ self.p = p
+ self.q = q
+ self.r_tilde = r_tilde
+ self.disable_r = disable_r
+ self.state_size = state_size
+ self.num_timesteps = num_timesteps
+ self.dtype = dtype
+
+ def zero_state(self, batch_size):
+ return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
+
+ def __call__(self, prev_state, observation, t):
+ """Computes the importance weight for the model system.
+
+ Args:
+ prev_state: [batch_size, state_size] Tensor
+ observations: [batch_size, num_observations, state_size] Tensor
+ """
+ # Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
+ q_zt = self.q.q_zt(observation, prev_state, t)
+ # Compute the p distribution over z, p(z_t|z_{t-1}).
+ p_zt = self.p.p_zt(prev_state, t)
+ # sample from q and evaluate the logprobs, summing over the state size
+ zt = q_zt.sample()
+ # If it isn't the last timestep, compute the distribution over the next z.
+ if t < self.num_timesteps - 1:
+ p_ztplus1 = self.p.p_zt(zt, t+1)
+ else:
+ p_ztplus1 = None
+ log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
+ log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
+
+ if not self.disable_r and t < self.num_timesteps-1:
+ # score the remaining observations using r
+ r_tilde_mu, r_tilde_sigma_sq = self.r_tilde.r_zt(zt, observation, t+1)
+ else:
+ r_tilde_mu = None
+ r_tilde_sigma_sq = None
+ if t == self.num_timesteps - 1:
+ generative_dist = self.p.generative(observation, zt)
+ log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
+ else:
+ log_p_x_given_z = tf.zeros_like(log_q_zt)
+ return (zt, log_q_zt, log_p_zt, log_p_x_given_z,
+ r_tilde_mu, r_tilde_sigma_sq, p_ztplus1)
+
+ @staticmethod
+ def create(state_size,
+ num_timesteps,
+ sigma_min=1e-5,
+ variance=1.0,
+ dtype=tf.float32,
+ random_seed=None,
+ train_p=True,
+ p_type="unimodal",
+ q_type="normal",
+ mixing_coeff=0.5,
+ prior_mode_mean=1.0,
+ observation_variance=1.0,
+ transition_type=STANDARD_TRANSITION,
+ use_bs=True):
+ if p_type == "unimodal":
+ p = P(state_size,
+ num_timesteps,
+ sigma_min=sigma_min,
+ variance=variance,
+ dtype=dtype,
+ random_seed=random_seed,
+ trainable=train_p,
+ init_bs_to_zero=not use_bs)
+ elif p_type == "bimodal":
+ p = BimodalPriorP(
+ state_size,
+ num_timesteps,
+ mixing_coeff=mixing_coeff,
+ prior_mode_mean=prior_mode_mean,
+ sigma_min=sigma_min,
+ variance=variance,
+ dtype=dtype,
+ random_seed=random_seed,
+ trainable=train_p,
+ init_bs_to_zero=not use_bs)
+ elif "nonlinear" in p_type:
+ if "cauchy" in p_type:
+ trans_dist = tf.contrib.distributions.Cauchy
+ else:
+ trans_dist = tf.contrib.distributions.Normal
+
+ p = ShortChainNonlinearP(
+ state_size,
+ num_timesteps,
+ sigma_min=sigma_min,
+ variance=variance,
+ observation_variance=observation_variance,
+ transition_type=transition_type,
+ transition_dist=trans_dist,
+ dtype=dtype,
+ random_seed=random_seed
+ )
+
+ if q_type == "normal":
+ q_class = Q
+ elif q_type == "simple_mean":
+ q_class = SimpleMeanQ
+ elif q_type == "prev_state":
+ q_class = PreviousStateQ
+ elif q_type == "observation":
+ q_class = ObservationQ
+
+ q = q_class(state_size,
+ num_timesteps,
+ sigma_min=sigma_min,
+ dtype=dtype,
+ random_seed=random_seed,
+ init_mu0_to_zero=not use_bs)
+ r_tilde = RTilde(
+ state_size,
+ num_timesteps,
+ sigma_min=sigma_min,
+ dtype=dtype,
+ random_seed=random_seed)
+ model = TDModel(p, q, r_tilde, state_size, num_timesteps, dtype=dtype)
+ return model
diff --git a/models/research/fivo/experimental/run.sh b/models/research/fivo/experimental/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c650f636d5313a196960a92b509202b47e7da518
--- /dev/null
+++ b/models/research/fivo/experimental/run.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+model="forward"
+T=5
+num_obs=1
+var=0.1
+n=4
+lr=0.0001
+bound="fivo-aux"
+q_type="normal"
+resampling_method="multinomial"
+rgrad="true"
+p_type="unimodal"
+use_bs=false
+
+LOGDIR=/tmp/fivo/model-$model-$bound-$resampling_method-resampling-rgrad-$rgrad-T-$T-var-$var-n-$n-lr-$lr-q-$q_type-p-$p_type
+
+python train.py \
+ --logdir=$LOGDIR \
+ --model=$model \
+ --bound=$bound \
+ --q_type=$q_type \
+ --p_type=$p_type \
+ --variance=$var \
+ --use_resampling_grads=$rgrad \
+ --resampling=always \
+ --resampling_method=$resampling_method \
+ --batch_size=4 \
+ --num_samples=$n \
+ --num_timesteps=$T \
+ --num_eval_samples=256 \
+ --summarize_every=100 \
+ --learning_rate=$lr \
+ --decay_steps=1000000 \
+ --max_steps=1000000000 \
+ --random_seed=1234 \
+ --train_p=false \
+ --use_bs=$use_bs \
+ --alsologtostderr
diff --git a/models/research/fivo/experimental/summary_utils.py b/models/research/fivo/experimental/summary_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..04e4aeea257577e60d3651656d0c62355d501ea8
--- /dev/null
+++ b/models/research/fivo/experimental/summary_utils.py
@@ -0,0 +1,332 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utils for plotting and summarizing.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import matplotlib.gridspec as gridspec
+import matplotlib.pyplot as plt
+import numpy as np
+import scipy
+
+import tensorflow as tf
+
+import models
+
+
+def summarize_ess(weights, only_last_timestep=False):
+ """Plots the effective sample size.
+
+ Args:
+ weights: List of length num_timesteps Tensors of shape
+ [num_samples, batch_size]
+ """
+ num_timesteps = len(weights)
+ batch_size = tf.cast(tf.shape(weights[0])[1], dtype=tf.float64)
+ for i in range(num_timesteps):
+ if only_last_timestep and i < num_timesteps-1: continue
+
+ w = tf.nn.softmax(weights[i], dim=0)
+ centered_weights = w - tf.reduce_mean(w, axis=0, keepdims=True)
+ variance = tf.reduce_sum(tf.square(centered_weights))/(batch_size-1)
+ ess = 1./tf.reduce_mean(tf.reduce_sum(tf.square(w), axis=0))
+ tf.summary.scalar("ess/%d" % i, ess)
+ tf.summary.scalar("ese/%d" % i, ess / batch_size)
+ tf.summary.scalar("weight_variance/%d" % i, variance)
+
+
+def summarize_particles(states, weights, observation, model):
+ """Plots particle locations and weights.
+
+ Args:
+ states: List of length num_timesteps Tensors of shape
+ [batch_size*num_particles, state_size].
+ weights: List of length num_timesteps Tensors of shape [num_samples,
+ batch_size]
+ observation: Tensor of shape [batch_size*num_samples, state_size]
+ """
+ num_timesteps = len(weights)
+ num_samples, batch_size = weights[0].get_shape().as_list()
+ # get q0 information for plotting
+ q0_dist = model.q.q_zt(observation, tf.zeros_like(states[0]), 0)
+ q0_loc = q0_dist.loc[0:batch_size, 0]
+ q0_scale = q0_dist.scale[0:batch_size, 0]
+ # get posterior information for plotting
+ post = (model.p.mixing_coeff, model.p.prior_mode_mean, model.p.variance,
+ tf.reduce_sum(model.p.bs), model.p.num_timesteps)
+
+ # Reshape states and weights to be [time, num_samples, batch_size]
+ states = tf.stack(states)
+ weights = tf.stack(weights)
+ # normalize the weights over the sample dimension
+ weights = tf.nn.softmax(weights, dim=1)
+ states = tf.reshape(states, tf.shape(weights))
+
+ ess = 1./tf.reduce_sum(tf.square(weights), axis=1)
+
+ def _plot_states(states_batch, weights_batch, observation_batch, ess_batch, q0, post):
+ """
+ states: [time, num_samples, batch_size]
+ weights [time, num_samples, batch_size]
+ observation: [batch_size, 1]
+ q0: ([batch_size], [batch_size])
+ post: ...
+ """
+ num_timesteps, _, batch_size = states_batch.shape
+ plots = []
+ for i in range(batch_size):
+ states = states_batch[:,:,i]
+ weights = weights_batch[:,:,i]
+ observation = observation_batch[i]
+ ess = ess_batch[:,i]
+ q0_loc = q0[0][i]
+ q0_scale = q0[1][i]
+
+ fig = plt.figure(figsize=(7, (num_timesteps + 1) * 2))
+ # Each timestep gets two plots -- a bar plot and a histogram of state locs.
+ # The bar plot will be bar_rows rows tall.
+ # The histogram will be 1 row tall.
+ # There is also 1 extra plot at the top showing the posterior and q.
+ bar_rows = 8
+ num_rows = (num_timesteps + 1) * (bar_rows + 1)
+ gs = gridspec.GridSpec(num_rows, 1)
+
+ # Figure out how wide to make the plot
+ prior_lims = (post[1] * -2, post[1] * 2)
+ q_lims = (scipy.stats.norm.ppf(0.01, loc=q0_loc, scale=q0_scale),
+ scipy.stats.norm.ppf(0.99, loc=q0_loc, scale=q0_scale))
+ state_width = states.max() - states.min()
+ state_lims = (states.min() - state_width * 0.15,
+ states.max() + state_width * 0.15)
+
+ lims = (min(prior_lims[0], q_lims[0], state_lims[0]),
+ max(prior_lims[1], q_lims[1], state_lims[1]))
+ # plot the posterior
+ z0 = np.arange(lims[0], lims[1], 0.1)
+ alpha, pos_mu, sigma_sq, B, T = post
+ neg_mu = -pos_mu
+ scale = np.sqrt((T + 1) * sigma_sq)
+ p_zn = (
+ alpha * scipy.stats.norm.pdf(
+ observation, loc=pos_mu + B, scale=scale) + (1 - alpha) *
+ scipy.stats.norm.pdf(observation, loc=neg_mu + B, scale=scale))
+ p_z0 = (
+ alpha * scipy.stats.norm.pdf(z0, loc=pos_mu, scale=np.sqrt(sigma_sq))
+ + (1 - alpha) * scipy.stats.norm.pdf(
+ z0, loc=neg_mu, scale=np.sqrt(sigma_sq)))
+ p_zn_given_z0 = scipy.stats.norm.pdf(
+ observation, loc=z0 + B, scale=np.sqrt(T * sigma_sq))
+ post_z0 = (p_z0 * p_zn_given_z0) / p_zn
+ # plot q
+ q_z0 = scipy.stats.norm.pdf(z0, loc=q0_loc, scale=q0_scale)
+ ax = plt.subplot(gs[0:bar_rows, :])
+ ax.plot(z0, q_z0, color="blue")
+ ax.plot(z0, post_z0, color="green")
+ ax.plot(z0, p_z0, color="red")
+ ax.legend(("q", "posterior", "prior"), loc="best", prop={"size": 10})
+
+ ax.set_xticks([])
+ ax.set_xlim(*lims)
+
+ # plot the states
+ for t in range(num_timesteps):
+ start = (t + 1) * (bar_rows + 1)
+ ax1 = plt.subplot(gs[start:start + bar_rows, :])
+ ax2 = plt.subplot(gs[start + bar_rows:start + bar_rows + 1, :])
+ # plot the states barplot
+ # ax1.hist(
+ # states[t, :],
+ # weights=weights[t, :],
+ # bins=50,
+ # edgecolor="none",
+ # alpha=0.2)
+ ax1.bar(states[t,:], weights[t,:], width=0.02, alpha=0.2, edgecolor = "none")
+ ax1.set_ylabel("t=%d" % t)
+ ax1.set_xticks([])
+ ax1.grid(True, which="both")
+ ax1.set_xlim(*lims)
+ # plot the observation
+ ax1.axvline(x=observation, color="red", linestyle="dashed")
+ # add the ESS
+ ax1.text(0.1, 0.9, "ESS: %0.2f" % ess[t],
+ ha='center', va='center', transform=ax1.transAxes)
+
+ # plot the state location histogram
+ ax2.hist2d(
+ states[t, :], np.zeros_like(states[t, :]), bins=[50, 1], cmap="Greys")
+ ax2.grid(False)
+ ax2.set_yticks([])
+ ax2.set_xlim(*lims)
+ if t != num_timesteps - 1:
+ ax2.set_xticks([])
+
+ fig.canvas.draw()
+ p = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
+ plots.append(p.reshape(fig.canvas.get_width_height()[::-1] + (3,)))
+ plt.close(fig)
+ return np.stack(plots)
+
+ plots = tf.py_func(_plot_states,
+ [states, weights, observation, ess, (q0_loc, q0_scale), post],
+ [tf.uint8])[0]
+ tf.summary.image("states", plots, 5, collections=["infrequent_summaries"])
+
+
+def plot_weights(weights, resampled=None):
+ """Plots the weights and effective sample size from an SMC rollout.
+
+ Args:
+ weights: [num_timesteps, num_samples, batch_size] importance weights
+ resampled: [num_timesteps] 0/1 indicating if resampling ocurred
+ """
+ weights = tf.convert_to_tensor(weights)
+
+ def _make_plots(weights, resampled):
+ num_timesteps, num_samples, batch_size = weights.shape
+ plots = []
+ for i in range(batch_size):
+ fig, axes = plt.subplots(nrows=1, sharex=True, figsize=(8, 4))
+ axes.stackplot(np.arange(num_timesteps), np.transpose(weights[:, :, i]))
+ axes.set_title("Weights")
+ axes.set_xlabel("Steps")
+ axes.set_ylim([0, 1])
+ axes.set_xlim([0, num_timesteps - 1])
+ for j in np.where(resampled > 0)[0]:
+ axes.axvline(x=j, color="red", linestyle="dashed", ymin=0.0, ymax=1.0)
+ fig.canvas.draw()
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ plots.append(data)
+ plt.close(fig)
+ return np.stack(plots, axis=0)
+
+ if resampled is None:
+ num_timesteps, _, batch_size = weights.get_shape().as_list()
+ resampled = tf.zeros([num_timesteps], dtype=tf.float32)
+ plots = tf.py_func(_make_plots,
+ [tf.nn.softmax(weights, dim=1),
+ tf.to_float(resampled)], [tf.uint8])[0]
+ batch_size = weights.get_shape().as_list()[-1]
+ tf.summary.image(
+ "weights", plots, batch_size, collections=["infrequent_summaries"])
+
+
+def summarize_weights(weights, num_timesteps, num_samples):
+ # weights is [num_timesteps, num_samples, batch_size]
+ weights = tf.convert_to_tensor(weights)
+ mean = tf.reduce_mean(weights, axis=1, keepdims=True)
+ squared_diff = tf.square(weights - mean)
+ variances = tf.reduce_sum(squared_diff, axis=1) / (num_samples - 1)
+ # average the variance over the batch
+ variances = tf.reduce_mean(variances, axis=1)
+ avg_magnitude = tf.reduce_mean(tf.abs(weights), axis=[1, 2])
+ for t in xrange(num_timesteps):
+ tf.summary.scalar("weights/variance_%d" % t, variances[t])
+ tf.summary.scalar("weights/magnitude_%d" % t, avg_magnitude[t])
+ tf.summary.histogram("weights/step_%d" % t, weights[t])
+
+
+def summarize_learning_signal(rewards, tag):
+ num_resampling_events, _ = rewards.get_shape().as_list()
+ mean = tf.reduce_mean(rewards, axis=1)
+ avg_magnitude = tf.reduce_mean(tf.abs(rewards), axis=1)
+ reward_square = tf.reduce_mean(tf.square(rewards), axis=1)
+ for t in xrange(num_resampling_events):
+ tf.summary.scalar("%s/mean_%d" % (tag, t), mean[t])
+ tf.summary.scalar("%s/magnitude_%d" % (tag, t), avg_magnitude[t])
+ tf.summary.scalar("%s/squared_%d" % (tag, t), reward_square[t])
+ tf.summary.histogram("%s/step_%d" % (tag, t), rewards[t])
+
+
+def summarize_qs(model, observation, states):
+ model.q.summarize_weights()
+ if hasattr(model.p, "posterior") and callable(getattr(model.p, "posterior")):
+ states = [tf.zeros_like(states[0])] + states[:-1]
+ for t, prev_state in enumerate(states):
+ p = model.p.posterior(observation, prev_state, t)
+ q = model.q.q_zt(observation, prev_state, t)
+ kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(p, q))
+ tf.summary.scalar("kl_q/%d" % t, tf.reduce_mean(kl))
+ mean_diff = q.loc - p.loc
+ mean_abs_err = tf.abs(mean_diff)
+ mean_rel_err = tf.abs(mean_diff / p.loc)
+ tf.summary.scalar("q_mean_convergence/absolute_error_%d" % t,
+ tf.reduce_mean(mean_abs_err))
+ tf.summary.scalar("q_mean_convergence/relative_error_%d" % t,
+ tf.reduce_mean(mean_rel_err))
+ sigma_diff = tf.square(q.scale) - tf.square(p.scale)
+ sigma_abs_err = tf.abs(sigma_diff)
+ sigma_rel_err = tf.abs(sigma_diff / tf.square(p.scale))
+ tf.summary.scalar("q_variance_convergence/absolute_error_%d" % t,
+ tf.reduce_mean(sigma_abs_err))
+ tf.summary.scalar("q_variance_convergence/relative_error_%d" % t,
+ tf.reduce_mean(sigma_rel_err))
+
+
+def summarize_rs(model, states):
+ model.r.summarize_weights()
+ for t, state in enumerate(states):
+ true_r = model.p.lookahead(state, t)
+ r = model.r.r_xn(state, t)
+ kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(true_r, r))
+ tf.summary.scalar("kl_r/%d" % t, tf.reduce_mean(kl))
+ mean_diff = true_r.loc - r.loc
+ mean_abs_err = tf.abs(mean_diff)
+ mean_rel_err = tf.abs(mean_diff / true_r.loc)
+ tf.summary.scalar("r_mean_convergence/absolute_error_%d" % t,
+ tf.reduce_mean(mean_abs_err))
+ tf.summary.scalar("r_mean_convergence/relative_error_%d" % t,
+ tf.reduce_mean(mean_rel_err))
+ sigma_diff = tf.square(r.scale) - tf.square(true_r.scale)
+ sigma_abs_err = tf.abs(sigma_diff)
+ sigma_rel_err = tf.abs(sigma_diff / tf.square(true_r.scale))
+ tf.summary.scalar("r_variance_convergence/absolute_error_%d" % t,
+ tf.reduce_mean(sigma_abs_err))
+ tf.summary.scalar("r_variance_convergence/relative_error_%d" % t,
+ tf.reduce_mean(sigma_rel_err))
+
+
+def summarize_model(model, true_bs, observation, states, bound, summarize_r=True):
+ if hasattr(model.p, "bs"):
+ model_b = tf.reduce_sum(model.p.bs, axis=0)
+ true_b = tf.reduce_sum(true_bs, axis=0)
+ abs_err = tf.abs(model_b - true_b)
+ rel_err = abs_err / true_b
+ tf.summary.scalar("sum_of_bs/data_generating_process", tf.reduce_mean(true_b))
+ tf.summary.scalar("sum_of_bs/model", tf.reduce_mean(model_b))
+ tf.summary.scalar("sum_of_bs/absolute_error", tf.reduce_mean(abs_err))
+ tf.summary.scalar("sum_of_bs/relative_error", tf.reduce_mean(rel_err))
+ #summarize_qs(model, observation, states)
+ #if bound == "fivo-aux" and summarize_r:
+ # summarize_rs(model, states)
+
+
+def summarize_grads(grads, loss_name):
+ grad_ema = tf.train.ExponentialMovingAverage(decay=0.99)
+ vectorized_grads = tf.concat(
+ [tf.reshape(g, [-1]) for g, _ in grads if g is not None], axis=0)
+ new_second_moments = tf.square(vectorized_grads)
+ new_first_moments = vectorized_grads
+ maintain_grad_ema_op = grad_ema.apply([new_first_moments, new_second_moments])
+ first_moments = grad_ema.average(new_first_moments)
+ second_moments = grad_ema.average(new_second_moments)
+ variances = second_moments - tf.square(first_moments)
+ tf.summary.scalar("grad_variance/%s" % loss_name, tf.reduce_mean(variances))
+ tf.summary.histogram("grad_variance/%s" % loss_name, variances)
+ tf.summary.histogram("grad_mean/%s" % loss_name, first_moments)
+ return maintain_grad_ema_op
diff --git a/models/research/fivo/experimental/train.py b/models/research/fivo/experimental/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..8abc9909b115298a30151a332d340f7b25e3cf90
--- /dev/null
+++ b/models/research/fivo/experimental/train.py
@@ -0,0 +1,637 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Main script for running fivo"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+
+import numpy as np
+import tensorflow as tf
+
+import bounds
+import data
+import models
+import summary_utils as summ
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+tf.app.flags.DEFINE_integer("random_seed", None,
+ "A random seed for the data generating process. Same seed "
+ "-> same data generating process and initialization.")
+tf.app.flags.DEFINE_enum("bound", "fivo", ["iwae", "fivo", "fivo-aux", "fivo-aux-td"],
+ "The bound to optimize.")
+tf.app.flags.DEFINE_enum("model", "forward", ["forward", "long_chain"],
+ "The model to use.")
+tf.app.flags.DEFINE_enum("q_type", "normal",
+ ["normal", "simple_mean", "prev_state", "observation"],
+ "The parameterization to use for q")
+tf.app.flags.DEFINE_enum("p_type", "unimodal", ["unimodal", "bimodal", "nonlinear"],
+ "The type of prior.")
+tf.app.flags.DEFINE_boolean("train_p", True,
+ "If false, do not train the model p.")
+
+tf.app.flags.DEFINE_integer("state_size", 1,
+ "The dimensionality of the state space.")
+tf.app.flags.DEFINE_float("variance", 1.0,
+ "The variance of the data generating process.")
+
+tf.app.flags.DEFINE_boolean("use_bs", True,
+ "If False, initialize all bs to 0.")
+tf.app.flags.DEFINE_float("bimodal_prior_weight", 0.5,
+ "The weight assigned to the positive mode of the prior in "
+ "both the data generating process and p.")
+tf.app.flags.DEFINE_float("bimodal_prior_mean", None,
+ "If supplied, sets the mean of the 2 modes of the prior to "
+ "be 1 and -1 times the supplied value. This is for both the "
+ "data generating process and p.")
+tf.app.flags.DEFINE_float("fixed_observation", None,
+ "If supplied, fix the observation to a constant value in the"
+ " data generating process only.")
+tf.app.flags.DEFINE_float("r_sigma_init", 1.,
+ "Value to initialize variance of r to.")
+tf.app.flags.DEFINE_enum("observation_type",
+ models.STANDARD_OBSERVATION, models.OBSERVATION_TYPES,
+ "The type of observation for the long chain model.")
+tf.app.flags.DEFINE_enum("transition_type",
+ models.STANDARD_TRANSITION, models.TRANSITION_TYPES,
+ "The type of transition for the long chain model.")
+tf.app.flags.DEFINE_float("observation_variance", None,
+ "The variance of the observation. Defaults to 'variance'")
+
+tf.app.flags.DEFINE_integer("num_timesteps", 5,
+ "Number of timesteps in the sequence.")
+tf.app.flags.DEFINE_integer("num_observations", 1,
+ "The number of observations.")
+tf.app.flags.DEFINE_integer("steps_per_observation", 5,
+ "The number of timesteps between each observation.")
+
+tf.app.flags.DEFINE_integer("batch_size", 4,
+ "The number of examples per batch.")
+tf.app.flags.DEFINE_integer("num_samples", 4,
+ "The number particles to use.")
+tf.app.flags.DEFINE_integer("num_eval_samples", 512,
+ "The batch size and # of particles to use for eval.")
+
+tf.app.flags.DEFINE_string("resampling", "always",
+ "How to resample. Accepts 'always','never', or a "
+ "comma-separated list of booleans like 'true,true,false'.")
+tf.app.flags.DEFINE_enum("resampling_method", "multinomial", ["multinomial",
+ "stratified",
+ "systematic",
+ "relaxed-logblend",
+ "relaxed-stateblend",
+ "relaxed-linearblend",
+ "relaxed-stateblend-st",],
+ "Type of resampling method to use.")
+tf.app.flags.DEFINE_boolean("use_resampling_grads", True,
+ "Whether or not to use resampling grads to optimize FIVO."
+ "Disabled automatically if resampling_method=relaxed.")
+tf.app.flags.DEFINE_boolean("disable_r", False,
+ "If false, r is not used for fivo-aux and is set to zeros.")
+
+tf.app.flags.DEFINE_float("learning_rate", 1e-4,
+ "The learning rate to use for ADAM or SGD.")
+tf.app.flags.DEFINE_integer("decay_steps", 25000,
+ "The number of steps before the learning rate is halved.")
+tf.app.flags.DEFINE_integer("max_steps", int(1e6),
+ "The number of steps to run training for.")
+
+tf.app.flags.DEFINE_string("logdir", "/tmp/fivo-aux",
+ "Directory for summaries and checkpoints.")
+
+tf.app.flags.DEFINE_integer("summarize_every", int(1e3),
+ "The number of steps between each evaluation.")
+FLAGS = tf.app.flags.FLAGS
+
+
+def combine_grad_lists(grad_lists):
+ # grads is num_losses by num_variables.
+ # each list could have different variables.
+ # for each variable, sum the grads across all losses.
+ grads_dict = defaultdict(list)
+ var_dict = {}
+ for grad_list in grad_lists:
+ for grad, var in grad_list:
+ if grad is not None:
+ grads_dict[var.name].append(grad)
+ var_dict[var.name] = var
+
+ final_grads = []
+ for var_name, var in var_dict.iteritems():
+ grads = grads_dict[var_name]
+ if len(grads) > 0:
+ tf.logging.info("Var %s has combined grads from %s." %
+ (var_name, [g.name for g in grads]))
+ grad = tf.reduce_sum(grads, axis=0)
+ else:
+ tf.logging.info("Var %s has no grads" % var_name)
+ grad = None
+ final_grads.append((grad, var))
+ return final_grads
+
+
+def make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps):
+ for l in losses:
+ assert isinstance(l, bounds.Loss)
+
+ lr = tf.train.exponential_decay(
+ learning_rate, global_step, lr_decay_steps, 0.5, staircase=False)
+ tf.summary.scalar("learning_rate", lr)
+ opt = tf.train.AdamOptimizer(lr)
+
+ ema_ops = []
+ grads = []
+ for loss_name, loss, loss_var_collection in losses:
+ tf.logging.info("Computing grads of %s w.r.t. vars in collection %s" %
+ (loss_name, loss_var_collection))
+ g = opt.compute_gradients(loss,
+ var_list=tf.get_collection(loss_var_collection))
+ ema_ops.append(summ.summarize_grads(g, loss_name))
+ grads.append(g)
+
+ all_grads = combine_grad_lists(grads)
+ apply_grads_op = opt.apply_gradients(all_grads, global_step=global_step)
+
+ # Update the emas after applying the grads.
+ with tf.control_dependencies([apply_grads_op]):
+ train_op = tf.group(*ema_ops)
+ return train_op
+
+
+def add_check_numerics_ops():
+ check_op = []
+ for op in tf.get_default_graph().get_operations():
+ bad = ["logits/Log", "sample/Reshape", "log_prob/mul",
+ "log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape",
+ "entropy/Reshape", "entropy/LogSoftmax", "Categorical", "Mean"]
+ if all([x not in op.name for x in bad]):
+ for output in op.outputs:
+ if output.dtype in [tf.float16, tf.float32, tf.float64]:
+ if op._get_control_flow_context() is not None: # pylint: disable=protected-access
+ raise ValueError("`tf.add_check_numerics_ops() is not compatible "
+ "with TensorFlow control flow operations such as "
+ "`tf.cond()` or `tf.while_loop()`.")
+
+ message = op.name + ":" + str(output.value_index)
+ with tf.control_dependencies(check_op):
+ check_op = [tf.check_numerics(output, message=message)]
+ return tf.group(*check_op)
+
+
+def create_long_chain_graph(bound, state_size, num_obs, steps_per_obs,
+ batch_size, num_samples, num_eval_samples,
+ resampling_schedule, use_resampling_grads,
+ learning_rate, lr_decay_steps, dtype="float64"):
+ num_timesteps = num_obs * steps_per_obs + 1
+ # Make the dataset.
+ dataset = data.make_long_chain_dataset(
+ state_size=state_size,
+ num_obs=num_obs,
+ steps_per_obs=steps_per_obs,
+ batch_size=batch_size,
+ num_samples=num_samples,
+ variance=FLAGS.variance,
+ observation_variance=FLAGS.observation_variance,
+ dtype=dtype,
+ observation_type=FLAGS.observation_type,
+ transition_type=FLAGS.transition_type,
+ fixed_observation=FLAGS.fixed_observation)
+ itr = dataset.make_one_shot_iterator()
+ _, observations = itr.get_next()
+ # Make the dataset for eval
+ eval_dataset = data.make_long_chain_dataset(
+ state_size=state_size,
+ num_obs=num_obs,
+ steps_per_obs=steps_per_obs,
+ batch_size=batch_size,
+ num_samples=num_eval_samples,
+ variance=FLAGS.variance,
+ observation_variance=FLAGS.observation_variance,
+ dtype=dtype,
+ observation_type=FLAGS.observation_type,
+ transition_type=FLAGS.transition_type,
+ fixed_observation=FLAGS.fixed_observation)
+ eval_itr = eval_dataset.make_one_shot_iterator()
+ _, eval_observations = eval_itr.get_next()
+
+ # Make the model.
+ model = models.LongChainModel.create(
+ state_size,
+ num_obs,
+ steps_per_obs,
+ observation_type=FLAGS.observation_type,
+ transition_type=FLAGS.transition_type,
+ variance=FLAGS.variance,
+ observation_variance=FLAGS.observation_variance,
+ dtype=tf.as_dtype(dtype),
+ disable_r=FLAGS.disable_r)
+
+ # Compute the bound and loss
+ if bound == "iwae":
+ (_, losses, ema_op, _, _) = bounds.iwae(
+ model,
+ observations,
+ num_timesteps,
+ num_samples=num_samples)
+ (eval_log_p_hat, _, _, _, eval_log_weights) = bounds.iwae(
+ model,
+ eval_observations,
+ num_timesteps,
+ num_samples=num_eval_samples,
+ summarize=False)
+ eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
+ elif bound == "fivo" or "fivo-aux":
+ (_, losses, ema_op, _, _) = bounds.fivo(
+ model,
+ observations,
+ num_timesteps,
+ resampling_schedule=resampling_schedule,
+ use_resampling_grads=use_resampling_grads,
+ resampling_type=FLAGS.resampling_method,
+ aux=("aux" in bound),
+ num_samples=num_samples)
+ (eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo(
+ model,
+ eval_observations,
+ num_timesteps,
+ resampling_schedule=resampling_schedule,
+ use_resampling_grads=False,
+ resampling_type="multinomial",
+ aux=("aux" in bound),
+ num_samples=num_eval_samples,
+ summarize=False)
+ eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
+
+ summ.summarize_ess(eval_log_weights, only_last_timestep=True)
+
+ tf.summary.scalar("log_p_hat", eval_log_p_hat)
+
+ # Compute and apply grads.
+ global_step = tf.train.get_or_create_global_step()
+
+ apply_grads = make_apply_grads_op(losses,
+ global_step,
+ learning_rate,
+ lr_decay_steps)
+
+ # Update the emas after applying the grads.
+ with tf.control_dependencies([apply_grads]):
+ train_op = tf.group(ema_op)
+
+ # We can't calculate the likelihood for most of these models
+ # so we just return zeros.
+ eval_likelihood = tf.zeros([], dtype=dtype)
+ return global_step, train_op, eval_log_p_hat, eval_likelihood
+
+
+def create_graph(bound, state_size, num_timesteps, batch_size,
+ num_samples, num_eval_samples, resampling_schedule,
+ use_resampling_grads, learning_rate, lr_decay_steps,
+ train_p, dtype='float64'):
+ if FLAGS.use_bs:
+ true_bs = None
+ else:
+ true_bs = [np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)]
+
+ # Make the dataset.
+ true_bs, dataset = data.make_dataset(
+ bs=true_bs,
+ state_size=state_size,
+ num_timesteps=num_timesteps,
+ batch_size=batch_size,
+ num_samples=num_samples,
+ variance=FLAGS.variance,
+ prior_type=FLAGS.p_type,
+ bimodal_prior_weight=FLAGS.bimodal_prior_weight,
+ bimodal_prior_mean=FLAGS.bimodal_prior_mean,
+ transition_type=FLAGS.transition_type,
+ fixed_observation=FLAGS.fixed_observation,
+ dtype=dtype)
+ itr = dataset.make_one_shot_iterator()
+ _, observations = itr.get_next()
+ # Make the dataset for eval
+ _, eval_dataset = data.make_dataset(
+ bs=true_bs,
+ state_size=state_size,
+ num_timesteps=num_timesteps,
+ batch_size=num_eval_samples,
+ num_samples=num_eval_samples,
+ variance=FLAGS.variance,
+ prior_type=FLAGS.p_type,
+ bimodal_prior_weight=FLAGS.bimodal_prior_weight,
+ bimodal_prior_mean=FLAGS.bimodal_prior_mean,
+ transition_type=FLAGS.transition_type,
+ fixed_observation=FLAGS.fixed_observation,
+ dtype=dtype)
+ eval_itr = eval_dataset.make_one_shot_iterator()
+ _, eval_observations = eval_itr.get_next()
+
+ # Make the model.
+ if bound == "fivo-aux-td":
+ model = models.TDModel.create(
+ state_size,
+ num_timesteps,
+ variance=FLAGS.variance,
+ train_p=train_p,
+ p_type=FLAGS.p_type,
+ q_type=FLAGS.q_type,
+ mixing_coeff=FLAGS.bimodal_prior_weight,
+ prior_mode_mean=FLAGS.bimodal_prior_mean,
+ observation_variance=FLAGS.observation_variance,
+ transition_type=FLAGS.transition_type,
+ use_bs=FLAGS.use_bs,
+ dtype=tf.as_dtype(dtype),
+ random_seed=FLAGS.random_seed)
+ else:
+ model = models.Model.create(
+ state_size,
+ num_timesteps,
+ variance=FLAGS.variance,
+ train_p=train_p,
+ p_type=FLAGS.p_type,
+ q_type=FLAGS.q_type,
+ mixing_coeff=FLAGS.bimodal_prior_weight,
+ prior_mode_mean=FLAGS.bimodal_prior_mean,
+ observation_variance=FLAGS.observation_variance,
+ transition_type=FLAGS.transition_type,
+ use_bs=FLAGS.use_bs,
+ r_sigma_init=FLAGS.r_sigma_init,
+ dtype=tf.as_dtype(dtype),
+ random_seed=FLAGS.random_seed)
+
+ # Compute the bound and loss
+ if bound == "iwae":
+ (_, losses, ema_op, _, _) = bounds.iwae(
+ model,
+ observations,
+ num_timesteps,
+ num_samples=num_samples)
+ (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.iwae(
+ model,
+ eval_observations,
+ num_timesteps,
+ num_samples=num_eval_samples,
+ summarize=True)
+
+ eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
+
+ elif "fivo" in bound:
+ if bound == "fivo-aux-td":
+ (_, losses, ema_op, _, _) = bounds.fivo_aux_td(
+ model,
+ observations,
+ num_timesteps,
+ resampling_schedule=resampling_schedule,
+ num_samples=num_samples)
+ (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo_aux_td(
+ model,
+ eval_observations,
+ num_timesteps,
+ resampling_schedule=resampling_schedule,
+ num_samples=num_eval_samples,
+ summarize=True)
+ else:
+ (_, losses, ema_op, _, _) = bounds.fivo(
+ model,
+ observations,
+ num_timesteps,
+ resampling_schedule=resampling_schedule,
+ use_resampling_grads=use_resampling_grads,
+ resampling_type=FLAGS.resampling_method,
+ aux=("aux" in bound),
+ num_samples=num_samples)
+ (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo(
+ model,
+ eval_observations,
+ num_timesteps,
+ resampling_schedule=resampling_schedule,
+ use_resampling_grads=False,
+ resampling_type="multinomial",
+ aux=("aux" in bound),
+ num_samples=num_eval_samples,
+ summarize=True)
+ eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
+
+ summ.summarize_ess(eval_log_weights, only_last_timestep=True)
+
+ # if FLAGS.p_type == "bimodal":
+ # # create the observations that showcase the model.
+ # mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
+ # dtype=tf.float64)
+ # mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
+ # k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
+ # explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
+ # explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
+ # # run the model on the explainable observations
+ # if bound == "iwae":
+ # (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
+ # model,
+ # explain_obs,
+ # num_timesteps,
+ # num_samples=num_eval_samples)
+ # elif bound == "fivo" or "fivo-aux":
+ # (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
+ # model,
+ # explain_obs,
+ # num_timesteps,
+ # resampling_schedule=resampling_schedule,
+ # use_resampling_grads=False,
+ # resampling_type="multinomial",
+ # aux=("aux" in bound),
+ # num_samples=num_eval_samples)
+ # summ.summarize_particles(explain_states,
+ # explain_log_weights,
+ # explain_obs,
+ # model)
+
+ # Calculate the true likelihood.
+ if hasattr(model.p, 'likelihood') and callable(getattr(model.p, 'likelihood')):
+ eval_likelihood = model.p.likelihood(eval_observations)/ FLAGS.num_timesteps
+ else:
+ eval_likelihood = tf.zeros_like(eval_log_p_hat)
+
+ tf.summary.scalar("log_p_hat", eval_log_p_hat)
+ tf.summary.scalar("likelihood", eval_likelihood)
+ tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat)
+ summ.summarize_model(model, true_bs, eval_observations, eval_states, bound,
+ summarize_r=not bound == "fivo-aux-td")
+
+ # Compute and apply grads.
+ global_step = tf.train.get_or_create_global_step()
+
+ apply_grads = make_apply_grads_op(losses,
+ global_step,
+ learning_rate,
+ lr_decay_steps)
+
+ # Update the emas after applying the grads.
+ with tf.control_dependencies([apply_grads]):
+ train_op = tf.group(ema_op)
+ #train_op = tf.group(ema_op, add_check_numerics_ops())
+
+ return global_step, train_op, eval_log_p_hat, eval_likelihood
+
+
+def parse_resampling_schedule(schedule, num_timesteps):
+ schedule = schedule.strip().lower()
+ if schedule == "always":
+ return [True] * (num_timesteps - 1) + [False]
+ elif schedule == "never":
+ return [False] * num_timesteps
+ elif "every" in schedule:
+ n = int(schedule.split("_")[1])
+ return [(i+1) % n == 0 for i in xrange(num_timesteps)]
+ else:
+ sched = [x.strip() == "true" for x in schedule.split(",")]
+ assert len(
+ sched
+ ) == num_timesteps, "Wrong number of timesteps in resampling schedule."
+ return sched
+
+
+def create_log_hook(step, eval_log_p_hat, eval_likelihood):
+ def summ_formatter(d):
+ return ("Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}".format(**d))
+ hook = tf.train.LoggingTensorHook(
+ {
+ "step": step,
+ "log_p_hat": eval_log_p_hat,
+ "likelihood": eval_likelihood,
+ },
+ every_n_iter=FLAGS.summarize_every,
+ formatter=summ_formatter)
+ return hook
+
+
+def create_infrequent_summary_hook():
+ infrequent_summary_hook = tf.train.SummarySaverHook(
+ save_steps=10000,
+ output_dir=FLAGS.logdir,
+ summary_op=tf.summary.merge_all(key="infrequent_summaries")
+ )
+ return infrequent_summary_hook
+
+
+def main(unused_argv):
+ if FLAGS.model == "long_chain":
+ resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
+ FLAGS.num_timesteps + 1)
+ else:
+ resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
+ FLAGS.num_timesteps)
+ if FLAGS.random_seed is None:
+ seed = np.random.randint(0, high=10000)
+ else:
+ seed = FLAGS.random_seed
+ tf.logging.info("Using random seed %d", seed)
+
+ if FLAGS.model == "long_chain":
+ assert FLAGS.q_type == "normal", "Q type %s not supported for long chain models" % FLAGS.q_type
+ assert FLAGS.p_type == "unimodal", "Bimodal priors are not supported for long chain models"
+ assert not FLAGS.use_bs, "Bs are not supported with long chain models"
+ assert FLAGS.num_timesteps == FLAGS.num_observations * FLAGS.steps_per_observation, "Num timesteps does not match."
+ assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with long chain models."
+
+ if FLAGS.model == "forward":
+ if "nonlinear" not in FLAGS.p_type:
+ assert FLAGS.transition_type == models.STANDARD_TRANSITION, "Non-standard transitions not supported by the forward model."
+ assert FLAGS.observation_type == models.STANDARD_OBSERVATION, "Non-standard observations not supported by the forward model."
+ assert FLAGS.observation_variance is None, "Forward model does not support observation variance."
+ assert FLAGS.num_observations == 1, "Forward model only supports 1 observation."
+
+ if "relaxed" in FLAGS.resampling_method:
+ FLAGS.use_resampling_grads = False
+ assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with relaxed resampling."
+
+ if FLAGS.observation_variance is None:
+ FLAGS.observation_variance = FLAGS.variance
+
+ if FLAGS.p_type == "bimodal":
+ assert FLAGS.bimodal_prior_mean is not None, "Must specify prior mean if using bimodal p."
+
+ if FLAGS.p_type == "nonlinear" or FLAGS.p_type == "nonlinear-cauchy":
+ assert not FLAGS.use_bs, "Using bs is not compatible with the nonlinear model."
+
+ g = tf.Graph()
+ with g.as_default():
+ # Set the seeds.
+ tf.set_random_seed(seed)
+ np.random.seed(seed)
+ if FLAGS.model == "long_chain":
+ (global_step, train_op, eval_log_p_hat,
+ eval_likelihood) = create_long_chain_graph(
+ FLAGS.bound,
+ FLAGS.state_size,
+ FLAGS.num_observations,
+ FLAGS.steps_per_observation,
+ FLAGS.batch_size,
+ FLAGS.num_samples,
+ FLAGS.num_eval_samples,
+ resampling_schedule,
+ FLAGS.use_resampling_grads,
+ FLAGS.learning_rate,
+ FLAGS.decay_steps)
+ else:
+ (global_step, train_op,
+ eval_log_p_hat, eval_likelihood) = create_graph(
+ FLAGS.bound,
+ FLAGS.state_size,
+ FLAGS.num_timesteps,
+ FLAGS.batch_size,
+ FLAGS.num_samples,
+ FLAGS.num_eval_samples,
+ resampling_schedule,
+ FLAGS.use_resampling_grads,
+ FLAGS.learning_rate,
+ FLAGS.decay_steps,
+ FLAGS.train_p)
+
+ log_hooks = [create_log_hook(global_step, eval_log_p_hat, eval_likelihood)]
+ if len(tf.get_collection("infrequent_summaries")) > 0:
+ log_hooks.append(create_infrequent_summary_hook())
+
+ tf.logging.info("trainable variables:")
+ tf.logging.info([v.name for v in tf.trainable_variables()])
+ tf.logging.info("p vars:")
+ tf.logging.info([v.name for v in tf.get_collection("P_VARS")])
+ tf.logging.info("q vars:")
+ tf.logging.info([v.name for v in tf.get_collection("Q_VARS")])
+ tf.logging.info("r vars:")
+ tf.logging.info([v.name for v in tf.get_collection("R_VARS")])
+ tf.logging.info("r tilde vars:")
+ tf.logging.info([v.name for v in tf.get_collection("R_TILDE_VARS")])
+
+ with tf.train.MonitoredTrainingSession(
+ master="",
+ is_chief=True,
+ hooks=log_hooks,
+ checkpoint_dir=FLAGS.logdir,
+ save_checkpoint_secs=120,
+ save_summaries_steps=FLAGS.summarize_every,
+ log_step_count_steps=FLAGS.summarize_every) as sess:
+ cur_step = -1
+ while True:
+ if sess.should_stop() or cur_step > FLAGS.max_steps:
+ break
+ # run a step
+ _, cur_step = sess.run([train_op, global_step])
+
+
+if __name__ == "__main__":
+ tf.app.run(main)
diff --git a/models/research/fivo/fivo/__init__.py b/models/research/fivo/fivo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/fivo/fivo/bounds.py b/models/research/fivo/fivo/bounds.py
new file mode 100644
index 0000000000000000000000000000000000000000..088519033dd80669e99015b8e465888bd94a4cb1
--- /dev/null
+++ b/models/research/fivo/fivo/bounds.py
@@ -0,0 +1,317 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Implementation of objectives for training stochastic latent variable models.
+
+Contains implementations of the Importance Weighted Autoencoder objective (IWAE)
+and the Filtering Variational objective (FIVO).
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import tensorflow as tf
+
+from fivo import nested_utils as nested
+from fivo import smc
+
+
+def iwae(model,
+ observations,
+ seq_lengths,
+ num_samples=1,
+ parallel_iterations=30,
+ swap_memory=True):
+ """Computes the IWAE lower bound on the log marginal probability.
+
+ This method accepts a stochastic latent variable model and some observations
+ and computes a stochastic lower bound on the log marginal probability of the
+ observations. The IWAE estimator is defined by averaging multiple importance
+ weights. For more details see "Importance Weighted Autoencoders" by Burda
+ et al. https://arxiv.org/abs/1509.00519.
+
+ When num_samples = 1, this bound becomes the evidence lower bound (ELBO).
+
+ Args:
+ model: A subclass of ELBOTrainableSequenceModel that implements one
+ timestep of the model. See models/vrnn.py for an example.
+ observations: The inputs to the model. A potentially nested list or tuple of
+ Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
+ have a rank at least two and have matching shapes in the first two
+ dimensions, which represent time and the batch respectively. The model
+ will be provided with the observations before computing the bound.
+ seq_lengths: A [batch_size] Tensor of ints encoding the length of each
+ sequence in the batch (sequences can be padded to a common length).
+ num_samples: The number of samples to use.
+ parallel_iterations: The number of parallel iterations to use for the
+ internal while loop.
+ swap_memory: Whether GPU-CPU memory swapping should be enabled for the
+ internal while loop.
+
+ Returns:
+ log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
+ log marginal probability of the observations.
+ log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
+ containing the log weights at each timestep. Will not be valid for
+ timesteps past the end of a sequence.
+ """
+ log_p_hat, log_weights, _, final_state = fivo(
+ model,
+ observations,
+ seq_lengths,
+ num_samples=num_samples,
+ resampling_criterion=smc.never_resample_criterion,
+ parallel_iterations=parallel_iterations,
+ swap_memory=swap_memory)
+ return log_p_hat, log_weights, final_state
+
+
+def fivo(model,
+ observations,
+ seq_lengths,
+ num_samples=1,
+ resampling_criterion=smc.ess_criterion,
+ resampling_type='multinomial',
+ relaxed_resampling_temperature=0.5,
+ parallel_iterations=30,
+ swap_memory=True,
+ random_seed=None):
+ """Computes the FIVO lower bound on the log marginal probability.
+
+ This method accepts a stochastic latent variable model and some observations
+ and computes a stochastic lower bound on the log marginal probability of the
+ observations. The lower bound is defined by a particle filter's unbiased
+ estimate of the marginal probability of the observations. For more details see
+ "Filtering Variational Objectives" by Maddison et al.
+ https://arxiv.org/abs/1705.09279.
+
+ When the resampling criterion is "never resample", this bound becomes IWAE.
+
+ Args:
+ model: A subclass of ELBOTrainableSequenceModel that implements one
+ timestep of the model. See models/vrnn.py for an example.
+ observations: The inputs to the model. A potentially nested list or tuple of
+ Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
+ have a rank at least two and have matching shapes in the first two
+ dimensions, which represent time and the batch respectively. The model
+ will be provided with the observations before computing the bound.
+ seq_lengths: A [batch_size] Tensor of ints encoding the length of each
+ sequence in the batch (sequences can be padded to a common length).
+ num_samples: The number of particles to use in each particle filter.
+ resampling_criterion: The resampling criterion to use for this particle
+ filter. Must accept the number of samples, the current log weights,
+ and the current timestep and return a boolean Tensor of shape [batch_size]
+ indicating whether each particle filter should resample. See
+ ess_criterion and related functions for examples. When
+ resampling_criterion is never_resample_criterion, resampling_fn is ignored
+ and never called.
+ resampling_type: The type of resampling, one of "multinomial" or "relaxed".
+ relaxed_resampling_temperature: A positive temperature only used for relaxed
+ resampling.
+ parallel_iterations: The number of parallel iterations to use for the
+ internal while loop. Note that values greater than 1 can introduce
+ non-determinism even when random_seed is provided.
+ swap_memory: Whether GPU-CPU memory swapping should be enabled for the
+ internal while loop.
+ random_seed: The random seed to pass to the resampling operations in
+ the particle filter. Mainly useful for testing.
+
+ Returns:
+ log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
+ log marginal probability of the observations.
+ log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
+ containing the log weights at each timestep of the particle filter. Note
+ that on timesteps when a resampling operation is performed the log weights
+ are reset to 0. Will not be valid for timesteps past the end of a
+ sequence.
+ resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
+ particle filters resampled. Will be 1.0 on timesteps when resampling
+ occurred and 0.0 on timesteps when it did not.
+ """
+ # batch_size is the number of particle filters running in parallel.
+ batch_size = tf.shape(seq_lengths)[0]
+
+ # Each sequence in the batch will be the input data for a different
+ # particle filter. The batch will be laid out as:
+ # particle 1 of particle filter 1
+ # particle 1 of particle filter 2
+ # ...
+ # particle 1 of particle filter batch_size
+ # particle 2 of particle filter 1
+ # ...
+ # particle num_samples of particle filter batch_size
+ observations = nested.tile_tensors(observations, [1, num_samples])
+ tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
+ model.set_observations(observations, tiled_seq_lengths)
+
+ if resampling_type == 'multinomial':
+ resampling_fn = smc.multinomial_resampling
+ elif resampling_type == 'relaxed':
+ resampling_fn = functools.partial(
+ smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
+ resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)
+
+ def transition_fn(prev_state, t):
+ if prev_state is None:
+ return model.zero_state(batch_size * num_samples, tf.float32)
+ return model.propose_and_weight(prev_state, t)
+
+ log_p_hat, log_weights, resampled, final_state, _ = smc.smc(
+ transition_fn,
+ seq_lengths,
+ num_particles=num_samples,
+ resampling_criterion=resampling_criterion,
+ resampling_fn=resampling_fn,
+ parallel_iterations=parallel_iterations,
+ swap_memory=swap_memory)
+
+ return log_p_hat, log_weights, resampled, final_state
+
+def fivo_aux_td(
+ model,
+ observations,
+ seq_lengths,
+ num_samples=1,
+ resampling_criterion=smc.ess_criterion,
+ resampling_type='multinomial',
+ relaxed_resampling_temperature=0.5,
+ parallel_iterations=30,
+ swap_memory=True,
+ random_seed=None):
+ """Experimental."""
+ # batch_size is the number of particle filters running in parallel.
+ batch_size = tf.shape(seq_lengths)[0]
+ max_seq_len = tf.reduce_max(seq_lengths)
+
+ # Each sequence in the batch will be the input data for a different
+ # particle filter. The batch will be laid out as:
+ # particle 1 of particle filter 1
+ # particle 1 of particle filter 2
+ # ...
+ # particle 1 of particle filter batch_size
+ # particle 2 of particle filter 1
+ # ...
+ # particle num_samples of particle filter batch_size
+ observations = nested.tile_tensors(observations, [1, num_samples])
+ tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
+ model.set_observations(observations, tiled_seq_lengths)
+
+ if resampling_type == 'multinomial':
+ resampling_fn = smc.multinomial_resampling
+ elif resampling_type == 'relaxed':
+ resampling_fn = functools.partial(
+ smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
+ resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)
+
+ def transition_fn(prev_state, t):
+ if prev_state is None:
+ model_init_state = model.zero_state(batch_size * num_samples, tf.float32)
+ return (tf.zeros([num_samples*batch_size], dtype=tf.float32),
+ (tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32),
+ tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32)),
+ model_init_state)
+
+ prev_log_r, prev_log_r_tilde, prev_model_state = prev_state
+ (new_model_state, zt, log_q_zt, log_p_zt,
+ log_p_x_given_z, log_r_tilde, p_ztplus1) = model(prev_model_state, t)
+ r_tilde_mu, r_tilde_sigma_sq = log_r_tilde
+ # Compute the weight without r.
+ log_weight = log_p_zt + log_p_x_given_z - log_q_zt
+ # Compute log_r and log_r_tilde.
+ p_mu = tf.stop_gradient(p_ztplus1.mean())
+ p_sigma_sq = tf.stop_gradient(p_ztplus1.variance())
+ log_r = (tf.log(r_tilde_sigma_sq) -
+ tf.log(r_tilde_sigma_sq + p_sigma_sq) -
+ tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
+ # log_r is [num_samples*batch_size, latent_size]. We sum it along the last
+ # dimension to compute log r.
+ log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
+ # Compute prev log r tilde
+ prev_r_tilde_mu, prev_r_tilde_sigma_sq = prev_log_r_tilde
+ prev_log_r_tilde = -0.5*tf.reduce_sum(
+ tf.square(tf.stop_gradient(zt) - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
+ # If the sequence is on the last timestep, log_r and log_r_tilde are just zeros.
+ last_timestep = t >= (tiled_seq_lengths - 1)
+ log_r = tf.where(last_timestep,
+ tf.zeros_like(log_r),
+ log_r)
+ prev_log_r_tilde = tf.where(last_timestep,
+ tf.zeros_like(prev_log_r_tilde),
+ prev_log_r_tilde)
+ log_weight += tf.stop_gradient(log_r - prev_log_r)
+ new_state = (log_r, log_r_tilde, new_model_state)
+ loop_fn_args = (log_r, prev_log_r_tilde, log_p_x_given_z, log_r - prev_log_r)
+ return log_weight, new_state, loop_fn_args
+
+ def loop_fn(loop_state, loop_args, unused_model_state, log_weights, resampled, mask, t):
+ if loop_state is None:
+ return (tf.zeros([batch_size], dtype=tf.float32),
+ tf.zeros([batch_size], dtype=tf.float32),
+ tf.zeros([num_samples, batch_size], dtype=tf.float32))
+ log_p_hat_acc, bellman_loss_acc, log_r_diff_acc = loop_state
+ log_r, prev_log_r_tilde, log_p_x_given_z, log_r_diff = loop_args
+ # Compute the log_p_hat update
+ log_p_hat_update = tf.reduce_logsumexp(
+ log_weights, axis=0) - tf.log(tf.to_float(num_samples))
+ # If it is the last timestep, we always add the update.
+ log_p_hat_acc += tf.cond(t >= max_seq_len-1,
+ lambda: log_p_hat_update,
+ lambda: log_p_hat_update * resampled)
+ # Compute the Bellman update.
+ log_r = tf.reshape(log_r, [num_samples, batch_size])
+ prev_log_r_tilde = tf.reshape(prev_log_r_tilde, [num_samples, batch_size])
+ log_p_x_given_z = tf.reshape(log_p_x_given_z, [num_samples, batch_size])
+ mask = tf.reshape(mask, [num_samples, batch_size])
+ # On the first timestep there is no bellman error because there is no
+ # prev_log_r_tilde.
+ mask = tf.cond(tf.equal(t, 0),
+ lambda: tf.zeros_like(mask),
+ lambda: mask)
+ # On the first timestep also fix up prev_log_r_tilde, which will be -inf.
+ prev_log_r_tilde = tf.where(
+ tf.is_inf(prev_log_r_tilde),
+ tf.zeros_like(prev_log_r_tilde),
+ prev_log_r_tilde)
+ # log_lambda is [num_samples, batch_size]
+ log_lambda = tf.reduce_mean(prev_log_r_tilde - log_p_x_given_z - log_r,
+ axis=0, keepdims=True)
+ bellman_error = mask * tf.square(
+ prev_log_r_tilde -
+ tf.stop_gradient(log_lambda + log_p_x_given_z + log_r)
+ )
+ bellman_loss_acc += tf.reduce_mean(bellman_error, axis=0)
+ # Compute the log_r_diff update
+ log_r_diff_acc += mask * tf.reshape(log_r_diff, [num_samples, batch_size])
+ return (log_p_hat_acc, bellman_loss_acc, log_r_diff_acc)
+
+ log_weights, resampled, accs = smc.smc(
+ transition_fn,
+ seq_lengths,
+ num_particles=num_samples,
+ resampling_criterion=resampling_criterion,
+ resampling_fn=resampling_fn,
+ loop_fn=loop_fn,
+ parallel_iterations=parallel_iterations,
+ swap_memory=swap_memory)
+
+ log_p_hat, bellman_loss, log_r_diff = accs
+ loss_per_seq = [- log_p_hat, bellman_loss]
+ tf.summary.scalar("bellman_loss",
+ tf.reduce_mean(bellman_loss / tf.to_float(seq_lengths)))
+ tf.summary.scalar("log_r_diff",
+ tf.reduce_mean(tf.reduce_mean(log_r_diff, axis=0) / tf.to_float(seq_lengths)))
+ return loss_per_seq, log_p_hat, log_weights, resampled
diff --git a/models/research/fivo/fivo/bounds_test.py b/models/research/fivo/fivo/bounds_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c970f74f4cec36a855c54bbe6cdf8d76c3f86599
--- /dev/null
+++ b/models/research/fivo/fivo/bounds_test.py
@@ -0,0 +1,183 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.bounds"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from fivo.test_utils import create_vrnn
+from fivo import bounds
+
+
+class BoundsTest(tf.test.TestCase):
+
+ def test_elbo(self):
+ """A golden-value test for the ELBO (the IWAE bound with num_samples=1)."""
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ model, inputs, targets, lengths = create_vrnn(random_seed=1234)
+ outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=1,
+ parallel_iterations=1)
+ sess.run(tf.global_variables_initializer())
+ log_p_hat, _, _ = sess.run(outs)
+ self.assertAllClose([-21.615765, -13.614225], log_p_hat)
+
+ def test_iwae(self):
+ """A golden-value test for the IWAE bound."""
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ model, inputs, targets, lengths = create_vrnn(random_seed=1234)
+ outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=4,
+ parallel_iterations=1)
+ sess.run(tf.global_variables_initializer())
+ log_p_hat, weights, _ = sess.run(outs)
+ self.assertAllClose([-23.301426, -13.64028], log_p_hat)
+ weights_gt = np.array(
+ [[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
+ [-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
+ [[-6.2539978, -4.37615728, -7.43738699, -7.85044909],
+ [-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
+ [[-9.19093227, -8.01637268, -11.64603615, -10.51128292],
+ [-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
+ [[-12.20609856, -10.47217369, -13.66270638, -13.46115875],
+ [-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
+ [[-16.14766312, -15.57472229, -17.47755432, -17.98189926],
+ [-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
+ [[-20.07182884, -18.43191147, -20.1606636, -21.45263863],
+ [-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
+ [[-24.10270691, -22.20865822, -24.14675522, -25.27248383],
+ [-17.17656708, -16.25190353, -15.28658581, -12.33067703]]])
+ self.assertAllClose(weights_gt, weights)
+
+ def test_fivo(self):
+ """A golden-value test for the FIVO bound."""
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ model, inputs, targets, lengths = create_vrnn(random_seed=1234)
+ outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
+ random_seed=1234, parallel_iterations=1)
+ sess.run(tf.global_variables_initializer())
+ log_p_hat, weights, resampled, _ = sess.run(outs)
+ self.assertAllClose([-22.98902512, -14.21689224], log_p_hat)
+ weights_gt = np.array(
+ [[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
+ [-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
+ [[-2.67100811, -2.30541706, -2.34178066, -2.81751347],
+ [-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
+ [[-5.65190411, -5.94563246, -6.55041981, -5.4783473],
+ [-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
+ [[-8.71947861, -8.40143299, -8.54593086, -8.42822266],
+ [-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
+ [[-12.7003831, -13.5039815, -12.3569726, -12.9489622],
+ [-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
+ [[-16.4520301, -16.3611698, -15.0314846, -16.4197006],
+ [-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
+ [[-20.7010765, -20.1379165, -19.0020351, -20.2395458],
+ [-4.28782988, -4.50591278, -3.40847206, -2.63650274]]])
+ self.assertAllClose(weights_gt, weights)
+ resampled_gt = np.array(
+ [[1., 0.],
+ [0., 0.],
+ [0., 1.],
+ [0., 0.],
+ [0., 0.],
+ [0., 0.],
+ [0., 0.]])
+ self.assertAllClose(resampled_gt, resampled)
+
+ def test_fivo_relaxed(self):
+ """A golden-value test for the FIVO bound with relaxed sampling."""
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ model, inputs, targets, lengths = create_vrnn(random_seed=1234)
+ outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
+ random_seed=1234, parallel_iterations=1,
+ resampling_type="relaxed")
+ sess.run(tf.global_variables_initializer())
+ log_p_hat, weights, resampled, _ = sess.run(outs)
+ self.assertAllClose([-22.942394, -14.273882], log_p_hat)
+ weights_gt = np.array(
+ [[[-3.66708851, -2.07074118, -4.91751575, -5.03293514],
+ [-2.99690628, -3.17782831, -4.50084877, -3.48536515]],
+ [[-2.84939098, -2.30087185, -2.35649204, -2.48417377],
+ [-8.27518654, -6.71545172, -8.96199131, -7.05567837]],
+ [[-5.92327023, -5.9433074, -6.5826683, -5.04259014],
+ [-12.34527206, -11.54284668, -11.86675072, -9.69417477]],
+ [[-8.95323944, -8.40061855, -8.52760506, -7.99130583],
+ [-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
+ [[-12.87836456, -13.49628639, -12.31680107, -12.74228859],
+ [-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
+ [[-16.78347397, -16.35150909, -14.98797417, -16.35162735],
+ [-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
+ [[-20.81165886, -20.1307621, -18.92229652, -20.17458153],
+ [-4.58102798, -4.56017351, -3.46283388, -2.65550804]]])
+ self.assertAllClose(weights_gt, weights)
+ resampled_gt = np.array(
+ [[1., 0.],
+ [0., 0.],
+ [0., 1.],
+ [0., 0.],
+ [0., 0.],
+ [0., 0.],
+ [0., 0.]])
+ self.assertAllClose(resampled_gt, resampled)
+
+ def test_fivo_aux_relaxed(self):
+ """A golden-value test for the FIVO-AUX bound with relaxed sampling."""
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ model, inputs, targets, lengths = create_vrnn(random_seed=1234,
+ use_tilt=True)
+ outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
+ random_seed=1234, parallel_iterations=1,
+ resampling_type="relaxed")
+ sess.run(tf.global_variables_initializer())
+ log_p_hat, weights, resampled, _ = sess.run(outs)
+ self.assertAllClose([-23.1395, -14.271059], log_p_hat)
+ weights_gt = np.array(
+ [[[-5.19826221, -3.55476403, -5.98663855, -6.08058834],
+ [-6.31685925, -5.70243931, -7.07638931, -6.18138981]],
+ [[-3.97986865, -3.58831525, -3.85753584, -3.5010016],
+ [-11.38203049, -8.66213989, -11.23646641, -10.02024746]],
+ [[-6.62269831, -6.36680222, -6.78096485, -5.80072498],
+ [-3.55419445, -8.11326408, -3.48766923, -3.08593249]],
+ [[-10.56472301, -10.16084099, -9.96741676, -8.5270071],
+ [-6.04880285, -7.80853653, -4.72652149, -3.49711013]],
+ [[-13.36585426, -16.08720398, -13.33416367, -13.1017189],
+ [-0., -0., -0., -0.]],
+ [[-17.54233551, -17.35167503, -16.79163361, -16.51471138],
+ [0., -0., -0., -0.]],
+ [[-19.74024963, -18.69452858, -17.76246452, -18.76182365],
+ [0., -0., -0., -0.]]])
+ self.assertAllClose(weights_gt, weights)
+ resampled_gt = np.array([[1., 0.],
+ [0., 1.],
+ [0., 0.],
+ [0., 1.],
+ [0., 0.],
+ [0., 0.],
+ [0., 0.]])
+ self.assertAllClose(resampled_gt, resampled)
+
+
+if __name__ == "__main__":
+ np.set_printoptions(threshold=np.nan) # Used to easily see the gold values.
+ # Use print(repr(numpy_array)) to print the values.
+ tf.test.main()
diff --git a/models/research/fivo/fivo/data/__init__.py b/models/research/fivo/fivo/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/fivo/fivo/data/calculate_pianoroll_mean.py b/models/research/fivo/fivo/data/calculate_pianoroll_mean.py
new file mode 100644
index 0000000000000000000000000000000000000000..93f712bd328f61a83faffc55ad2cf6ca33b47fb7
--- /dev/null
+++ b/models/research/fivo/fivo/data/calculate_pianoroll_mean.py
@@ -0,0 +1,65 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Script to calculate the mean of a pianoroll dataset.
+
+Given a pianoroll pickle file, this script loads the dataset and
+calculates the mean of the training set. Then it updates the pickle file
+so that the key "train_mean" points to the mean vector.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import pickle
+import numpy as np
+
+import tensorflow as tf
+
+
+from datasets import sparse_pianoroll_to_dense
+
+tf.app.flags.DEFINE_string('in_file', None,
+ 'Filename of the pickled pianoroll dataset to load.')
+tf.app.flags.DEFINE_string('out_file', None,
+ 'Name of the output pickle file. Defaults to in_file, '
+ 'updating the input pickle file.')
+tf.app.flags.mark_flag_as_required('in_file')
+
+FLAGS = tf.app.flags.FLAGS
+
+MIN_NOTE = 21
+MAX_NOTE = 108
+NUM_NOTES = MAX_NOTE - MIN_NOTE + 1
+
+
+def main(unused_argv):
+ if FLAGS.out_file is None:
+ FLAGS.out_file = FLAGS.in_file
+ with tf.gfile.Open(FLAGS.in_file, 'r') as f:
+ pianorolls = pickle.load(f)
+ dense_pianorolls = [sparse_pianoroll_to_dense(p, MIN_NOTE, NUM_NOTES)[0]
+ for p in pianorolls['train']]
+ # Concatenate all elements along the time axis.
+ concatenated = np.concatenate(dense_pianorolls, axis=0)
+ mean = np.mean(concatenated, axis=0)
+ pianorolls['train_mean'] = mean
+ # Write out the whole pickle file, including the train mean.
+ pickle.dump(pianorolls, open(FLAGS.out_file, 'wb'))
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/fivo/fivo/data/create_timit_dataset.py b/models/research/fivo/fivo/data/create_timit_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1cd3b10cb0812c2d6aad51491924ecfe8eec37
--- /dev/null
+++ b/models/research/fivo/fivo/data/create_timit_dataset.py
@@ -0,0 +1,180 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import os
+import random
+import re
+
+import numpy as np
+import tensorflow as tf
+
+tf.app.flags.DEFINE_string("raw_timit_dir", None,
+ "Directory containing TIMIT files.")
+tf.app.flags.DEFINE_string("out_dir", None,
+ "Output directory for TFRecord files.")
+tf.app.flags.DEFINE_float("valid_frac", 0.05,
+ "Fraction of train set to use as valid set. "
+ "Must be between 0.0 and 1.0.")
+
+tf.app.flags.mark_flag_as_required("raw_timit_dir")
+tf.app.flags.mark_flag_as_required("out_dir")
+
+FLAGS = tf.app.flags.FLAGS
+
+NUM_TRAIN_FILES = 4620
+NUM_TEST_FILES = 1680
+SAMPLES_PER_TIMESTEP = 200
+
+# Regexes for reading SPHERE header files.
+SAMPLE_COUNT_REGEX = re.compile(r"sample_count -i (\d+)")
+SAMPLE_MIN_REGEX = re.compile(r"sample_min -i (-?\d+)")
+SAMPLE_MAX_REGEX = re.compile(r"sample_max -i (-?\d+)")
+
+
+def get_filenames(split):
+ """Get all wav filenames from the TIMIT archive."""
+ path = os.path.join(FLAGS.raw_timit_dir, "TIMIT", split, "*", "*", "*.WAV")
+ # Sort the output by name so the order is deterministic.
+ files = sorted(glob.glob(path))
+ return files
+
+
+def load_timit_wav(filename):
+ """Loads a TIMIT wavfile into a numpy array.
+
+ TIMIT wavfiles include a SPHERE header, detailed in the TIMIT docs. The first
+ line is the header type and the second is the length of the header in bytes.
+ After the header, the remaining bytes are actual WAV data.
+
+ The header includes information about the WAV data such as the number of
+ samples and minimum and maximum amplitude. This function asserts that the
+ loaded wav data matches the header.
+
+ Args:
+ filename: The name of the TIMIT wavfile to load.
+ Returns:
+ wav: A numpy array containing the loaded wav data.
+ """
+ wav_file = open(filename, "rb")
+ header_type = wav_file.readline()
+ header_length_str = wav_file.readline()
+ # The header length includes the length of the first two lines.
+ header_remaining_bytes = (int(header_length_str) - len(header_type) -
+ len(header_length_str))
+ header = wav_file.read(header_remaining_bytes)
+ # Read the relevant header fields.
+ sample_count = int(SAMPLE_COUNT_REGEX.search(header).group(1))
+ sample_min = int(SAMPLE_MIN_REGEX.search(header).group(1))
+ sample_max = int(SAMPLE_MAX_REGEX.search(header).group(1))
+ wav = np.fromstring(wav_file.read(), dtype="int16").astype("float32")
+ # Check that the loaded data conforms to the header description.
+ assert len(wav) == sample_count
+ assert wav.min() == sample_min
+ assert wav.max() == sample_max
+ return wav
+
+
+def preprocess(wavs, block_size, mean, std):
+ """Normalize the wav data and reshape it into chunks."""
+ processed_wavs = []
+ for wav in wavs:
+ wav = (wav - mean) / std
+ wav_length = wav.shape[0]
+ if wav_length % block_size != 0:
+ pad_width = block_size - (wav_length % block_size)
+ wav = np.pad(wav, (0, pad_width), "constant")
+ assert wav.shape[0] % block_size == 0
+ wav = wav.reshape((-1, block_size))
+ processed_wavs.append(wav)
+ return processed_wavs
+
+
+def create_tfrecord_from_wavs(wavs, output_file):
+ """Writes processed wav files to disk as sharded TFRecord files."""
+ with tf.python_io.TFRecordWriter(output_file) as builder:
+ for wav in wavs:
+ builder.write(wav.astype(np.float32).tobytes())
+
+
+def main(unused_argv):
+ train_filenames = get_filenames("TRAIN")
+ test_filenames = get_filenames("TEST")
+
+ num_train_files = len(train_filenames)
+ num_test_files = len(test_filenames)
+ num_valid_files = int(num_train_files * FLAGS.valid_frac)
+ num_train_files -= num_valid_files
+
+ print("%d train / %d valid / %d test" % (
+ num_train_files, num_valid_files, num_test_files))
+
+ random.seed(1234)
+ random.shuffle(train_filenames)
+
+ valid_filenames = train_filenames[:num_valid_files]
+ train_filenames = train_filenames[num_valid_files:]
+
+ # Make sure there is no overlap in the train, test, and valid sets.
+ train_s = set(train_filenames)
+ test_s = set(test_filenames)
+ valid_s = set(valid_filenames)
+ # Disable explicit length testing to make the assertions more readable.
+ # pylint: disable=g-explicit-length-test
+ assert len(train_s & test_s) == 0
+ assert len(train_s & valid_s) == 0
+ assert len(valid_s & test_s) == 0
+ # pylint: enable=g-explicit-length-test
+
+ train_wavs = [load_timit_wav(f) for f in train_filenames]
+ valid_wavs = [load_timit_wav(f) for f in valid_filenames]
+ test_wavs = [load_timit_wav(f) for f in test_filenames]
+ assert len(train_wavs) + len(valid_wavs) == NUM_TRAIN_FILES
+ assert len(test_wavs) == NUM_TEST_FILES
+
+ # Calculate the mean and standard deviation of the train set.
+ train_stacked = np.hstack(train_wavs)
+ train_mean = np.mean(train_stacked)
+ train_std = np.std(train_stacked)
+ print("train mean: %f train std: %f" % (train_mean, train_std))
+
+ # Process all data, normalizing with the train set statistics.
+ processed_train_wavs = preprocess(train_wavs, SAMPLES_PER_TIMESTEP,
+ train_mean, train_std)
+ processed_valid_wavs = preprocess(valid_wavs, SAMPLES_PER_TIMESTEP,
+ train_mean, train_std)
+ processed_test_wavs = preprocess(test_wavs, SAMPLES_PER_TIMESTEP, train_mean,
+ train_std)
+
+ # Write the datasets to disk.
+ create_tfrecord_from_wavs(
+ processed_train_wavs,
+ os.path.join(FLAGS.out_dir, "train"))
+ create_tfrecord_from_wavs(
+ processed_valid_wavs,
+ os.path.join(FLAGS.out_dir, "valid"))
+ create_tfrecord_from_wavs(
+ processed_test_wavs,
+ os.path.join(FLAGS.out_dir, "test"))
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/fivo/fivo/data/datasets.py b/models/research/fivo/fivo/data/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d5324623250e31d65b23c97e7e684de59da1ba6
--- /dev/null
+++ b/models/research/fivo/fivo/data/datasets.py
@@ -0,0 +1,453 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Code for creating sequence datasets.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import pickle
+
+import numpy as np
+from scipy.sparse import coo_matrix
+import tensorflow as tf
+
+# The default number of threads used to process data in parallel.
+DEFAULT_PARALLELISM = 12
+
+
+def sparse_pianoroll_to_dense(pianoroll, min_note, num_notes):
+ """Converts a sparse pianoroll to a dense numpy array.
+
+ Given a sparse pianoroll, converts it to a dense numpy array of shape
+ [num_timesteps, num_notes] where entry i,j is 1.0 if note j is active on
+ timestep i and 0.0 otherwise.
+
+ Args:
+ pianoroll: A sparse pianoroll object, a list of tuples where the i'th tuple
+ contains the indices of the notes active at timestep i.
+ min_note: The minimum note in the pianoroll, subtracted from all notes so
+ that the minimum note becomes 0.
+ num_notes: The number of possible different note indices, determines the
+ second dimension of the resulting dense array.
+ Returns:
+ dense_pianoroll: A [num_timesteps, num_notes] numpy array of floats.
+ num_timesteps: A python int, the number of timesteps in the pianoroll.
+ """
+ num_timesteps = len(pianoroll)
+ inds = []
+ for time, chord in enumerate(pianoroll):
+ # Re-index the notes to start from min_note.
+ inds.extend((time, note-min_note) for note in chord)
+ shape = [num_timesteps, num_notes]
+ values = [1.] * len(inds)
+ sparse_pianoroll = coo_matrix(
+ (values, ([x[0] for x in inds], [x[1] for x in inds])),
+ shape=shape)
+ return sparse_pianoroll.toarray(), num_timesteps
+
+
+def create_pianoroll_dataset(path,
+ split,
+ batch_size,
+ num_parallel_calls=DEFAULT_PARALLELISM,
+ shuffle=False,
+ repeat=False,
+ min_note=21,
+ max_note=108):
+ """Creates a pianoroll dataset.
+
+ Args:
+ path: The path of a pickle file containing the dataset to load.
+ split: The split to use, can be train, test, or valid.
+ batch_size: The batch size. If repeat is False then it is not guaranteed
+ that the true batch size will match for all batches since batch_size
+ may not necessarily evenly divide the number of elements.
+ num_parallel_calls: The number of threads to use for parallel processing of
+ the data.
+ shuffle: If true, shuffles the order of the dataset.
+ repeat: If true, repeats the dataset endlessly.
+ min_note: The minimum note number of the dataset. For all pianoroll datasets
+ the minimum note is number 21, and changing this affects the dimension of
+ the data. This is useful mostly for testing.
+ max_note: The maximum note number of the dataset. For all pianoroll datasets
+ the maximum note is number 108, and changing this affects the dimension of
+ the data. This is useful mostly for testing.
+ Returns:
+ inputs: A batch of input sequences represented as a dense Tensor of shape
+ [time, batch_size, data_dimension]. The sequences in inputs are the
+ sequences in targets shifted one timestep into the future, padded with
+ zeros. This tensor is mean-centered, with the mean taken from the pickle
+ file key 'train_mean'.
+ targets: A batch of target sequences represented as a dense Tensor of
+ shape [time, batch_size, data_dimension].
+ lens: An int Tensor of shape [batch_size] representing the lengths of each
+ sequence in the batch.
+ mean: A float Tensor of shape [data_dimension] containing the mean loaded
+ from the pickle file.
+ """
+ # Load the data from disk.
+ num_notes = max_note - min_note + 1
+ with tf.gfile.Open(path, "r") as f:
+ raw_data = pickle.load(f)
+ pianorolls = raw_data[split]
+ mean = raw_data["train_mean"]
+ num_examples = len(pianorolls)
+
+ def pianoroll_generator():
+ for sparse_pianoroll in pianorolls:
+ yield sparse_pianoroll_to_dense(sparse_pianoroll, min_note, num_notes)
+
+ dataset = tf.data.Dataset.from_generator(
+ pianoroll_generator,
+ output_types=(tf.float64, tf.int64),
+ output_shapes=([None, num_notes], []))
+
+ if repeat: dataset = dataset.repeat()
+ if shuffle: dataset = dataset.shuffle(num_examples)
+
+ # Batch sequences togther, padding them to a common length in time.
+ dataset = dataset.padded_batch(batch_size,
+ padded_shapes=([None, num_notes], []))
+
+ def process_pianoroll_batch(data, lengths):
+ """Create mean-centered and time-major next-step prediction Tensors."""
+ data = tf.to_float(tf.transpose(data, perm=[1, 0, 2]))
+ lengths = tf.to_int32(lengths)
+ targets = data
+ # Mean center the inputs.
+ inputs = data - tf.constant(mean, dtype=tf.float32,
+ shape=[1, 1, mean.shape[0]])
+ # Shift the inputs one step forward in time. Also remove the last timestep
+ # so that targets and inputs are the same length.
+ inputs = tf.pad(inputs, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
+ # Mask out unused timesteps.
+ inputs *= tf.expand_dims(tf.transpose(
+ tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
+ return inputs, targets, lengths
+
+ dataset = dataset.map(process_pianoroll_batch,
+ num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(num_examples)
+
+ itr = dataset.make_one_shot_iterator()
+ inputs, targets, lengths = itr.get_next()
+ return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
+
+
+def create_human_pose_dataset(
+ path,
+ split,
+ batch_size,
+ num_parallel_calls=DEFAULT_PARALLELISM,
+ shuffle=False,
+ repeat=False,):
+ """Creates a human pose dataset.
+
+ Args:
+ path: The path of a pickle file containing the dataset to load.
+ split: The split to use, can be train, test, or valid.
+ batch_size: The batch size. If repeat is False then it is not guaranteed
+ that the true batch size will match for all batches since batch_size
+ may not necessarily evenly divide the number of elements.
+ num_parallel_calls: The number of threads to use for parallel processing of
+ the data.
+ shuffle: If true, shuffles the order of the dataset.
+ repeat: If true, repeats the dataset endlessly.
+ Returns:
+ inputs: A batch of input sequences represented as a dense Tensor of shape
+ [time, batch_size, data_dimension]. The sequences in inputs are the
+ sequences in targets shifted one timestep into the future, padded with
+ zeros. This tensor is mean-centered, with the mean taken from the pickle
+ file key 'train_mean'.
+ targets: A batch of target sequences represented as a dense Tensor of
+ shape [time, batch_size, data_dimension].
+ lens: An int Tensor of shape [batch_size] representing the lengths of each
+ sequence in the batch.
+ mean: A float Tensor of shape [data_dimension] containing the mean loaded
+ from the pickle file.
+ """
+ # Load the data from disk.
+ with tf.gfile.Open(path, "r") as f:
+ raw_data = pickle.load(f)
+
+ mean = raw_data["train_mean"]
+ pose_sequences = raw_data[split]
+ num_examples = len(pose_sequences)
+ num_features = pose_sequences[0].shape[1]
+
+ def pose_generator():
+ """A generator that yields pose data sequences."""
+ # Each timestep has 32 x values followed by 32 y values so is 64
+ # dimensional.
+ for pose_sequence in pose_sequences:
+ yield pose_sequence, pose_sequence.shape[0]
+
+ dataset = tf.data.Dataset.from_generator(
+ pose_generator,
+ output_types=(tf.float64, tf.int64),
+ output_shapes=([None, num_features], []))
+
+ if repeat:
+ dataset = dataset.repeat()
+ if shuffle:
+ dataset = dataset.shuffle(num_examples)
+
+ # Batch sequences togther, padding them to a common length in time.
+ dataset = dataset.padded_batch(
+ batch_size, padded_shapes=([None, num_features], []))
+
+ # Post-process each batch, ensuring that it is mean-centered and time-major.
+ def process_pose_data(data, lengths):
+ """Creates Tensors for next step prediction and mean-centers the input."""
+ data = tf.to_float(tf.transpose(data, perm=[1, 0, 2]))
+ lengths = tf.to_int32(lengths)
+ targets = data
+ # Mean center the inputs.
+ inputs = data - tf.constant(
+ mean, dtype=tf.float32, shape=[1, 1, mean.shape[0]])
+ # Shift the inputs one step forward in time. Also remove the last timestep
+ # so that targets and inputs are the same length.
+ inputs = tf.pad(inputs, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
+ # Mask out unused timesteps.
+ inputs *= tf.expand_dims(
+ tf.transpose(tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
+ return inputs, targets, lengths
+
+ dataset = dataset.map(
+ process_pose_data,
+ num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(num_examples)
+
+ itr = dataset.make_one_shot_iterator()
+ inputs, targets, lengths = itr.get_next()
+ return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
+
+
+def create_speech_dataset(path,
+ batch_size,
+ samples_per_timestep=200,
+ num_parallel_calls=DEFAULT_PARALLELISM,
+ prefetch_buffer_size=2048,
+ shuffle=False,
+ repeat=False):
+ """Creates a speech dataset.
+
+ Args:
+ path: The path of a possibly sharded TFRecord file containing the data.
+ batch_size: The batch size. If repeat is False then it is not guaranteed
+ that the true batch size will match for all batches since batch_size
+ may not necessarily evenly divide the number of elements.
+ samples_per_timestep: The number of audio samples per timestep. Used to
+ reshape the data into sequences of shape [time, samples_per_timestep].
+ Should not change except for testing -- in all speech datasets 200 is the
+ number of samples per timestep.
+ num_parallel_calls: The number of threads to use for parallel processing of
+ the data.
+ prefetch_buffer_size: The size of the prefetch queues to use after reading
+ and processing the raw data.
+ shuffle: If true, shuffles the order of the dataset.
+ repeat: If true, repeats the dataset endlessly.
+ Returns:
+ inputs: A batch of input sequences represented as a dense Tensor of shape
+ [time, batch_size, samples_per_timestep]. The sequences in inputs are the
+ sequences in targets shifted one timestep into the future, padded with
+ zeros.
+ targets: A batch of target sequences represented as a dense Tensor of
+ shape [time, batch_size, samples_per_timestep].
+ lens: An int Tensor of shape [batch_size] representing the lengths of each
+ sequence in the batch.
+ """
+ filenames = [path]
+
+ def read_speech_example(value):
+ """Parses a single tf.Example from the TFRecord file."""
+ decoded = tf.decode_raw(value, out_type=tf.float32)
+ example = tf.reshape(decoded, [-1, samples_per_timestep])
+ length = tf.shape(example)[0]
+ return example, length
+
+ # Create the dataset from the TFRecord files
+ dataset = tf.data.TFRecordDataset(filenames).map(
+ read_speech_example, num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(prefetch_buffer_size)
+
+ if repeat: dataset = dataset.repeat()
+ if shuffle: dataset = dataset.shuffle(prefetch_buffer_size)
+
+ dataset = dataset.padded_batch(
+ batch_size, padded_shapes=([None, samples_per_timestep], []))
+
+ def process_speech_batch(data, lengths):
+ """Creates Tensors for next step prediction."""
+ data = tf.transpose(data, perm=[1, 0, 2])
+ lengths = tf.to_int32(lengths)
+ targets = data
+ # Shift the inputs one step forward in time. Also remove the last timestep
+ # so that targets and inputs are the same length.
+ inputs = tf.pad(data, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
+ # Mask out unused timesteps.
+ inputs *= tf.expand_dims(
+ tf.transpose(tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
+ return inputs, targets, lengths
+
+ dataset = dataset.map(process_speech_batch,
+ num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(prefetch_buffer_size)
+
+ itr = dataset.make_one_shot_iterator()
+ inputs, targets, lengths = itr.get_next()
+ return inputs, targets, lengths
+
+
+SQUARED_OBSERVATION = "squared"
+ABS_OBSERVATION = "abs"
+STANDARD_OBSERVATION = "standard"
+OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
+
+ROUND_TRANSITION = "round"
+STANDARD_TRANSITION = "standard"
+TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
+
+
+def create_chain_graph_dataset(
+ batch_size,
+ num_timesteps,
+ steps_per_observation=None,
+ state_size=1,
+ transition_variance=1.,
+ observation_variance=1.,
+ transition_type=STANDARD_TRANSITION,
+ observation_type=STANDARD_OBSERVATION,
+ fixed_observation=None,
+ prefetch_buffer_size=2048,
+ dtype="float32"):
+ """Creates a toy chain graph dataset.
+
+ Creates a dataset where the data are sampled from a diffusion process. The
+ 'latent' states of the process are sampled as a chain of Normals:
+
+ z0 ~ N(0, transition_variance)
+ z1 ~ N(transition_fn(z0), transition_variance)
+ ...
+
+ where transition_fn could be round z0 or pass it through unchanged.
+
+ The observations are produced every steps_per_observation timesteps as a
+ function of the latent zs. For example if steps_per_observation is 3 then the
+ first observation will be produced as a function of z3:
+
+ x1 ~ N(observation_fn(z3), observation_variance)
+
+ where observation_fn could square z3, take the absolute value, or pass
+ it through unchanged.
+
+ Only the observations are returned.
+
+ Args:
+ batch_size: The batch size. The number of trajectories to run in parallel.
+ num_timesteps: The length of the chain of latent states (i.e. the
+ number of z's excluding z0.
+ steps_per_observation: The number of latent states between each observation,
+ must evenly divide num_timesteps.
+ state_size: The size of the latent state and observation, must be a
+ python int.
+ transition_variance: The variance of the transition density.
+ observation_variance: The variance of the observation density.
+ transition_type: Must be one of "round" or "standard". "round" means that
+ the transition density is centered at the rounded previous latent state.
+ "standard" centers the transition density at the previous latent state,
+ unchanged.
+ observation_type: Must be one of "squared", "abs" or "standard". "squared"
+ centers the observation density at the squared latent state. "abs"
+ centers the observaiton density at the absolute value of the current
+ latent state. "standard" centers the observation density at the current
+ latent state.
+ fixed_observation: If not None, fixes all observations to be a constant.
+ Must be a scalar.
+ prefetch_buffer_size: The size of the prefetch queues to use after reading
+ and processing the raw data.
+ dtype: A string convertible to a tensorflow datatype. The datatype used
+ to represent the states and observations.
+ Returns:
+ observations: A batch of observations represented as a dense Tensor of
+ shape [num_observations, batch_size, state_size]. num_observations is
+ num_timesteps/steps_per_observation.
+ lens: An int Tensor of shape [batch_size] representing the lengths of each
+ sequence in the batch. Will contain num_observations as each entry.
+ Raises:
+ ValueError: Raised if steps_per_observation does not evenly divide
+ num_timesteps.
+ """
+ if steps_per_observation is None:
+ steps_per_observation = num_timesteps
+ if num_timesteps % steps_per_observation != 0:
+ raise ValueError("steps_per_observation must evenly divide num_timesteps.")
+ num_observations = int(num_timesteps / steps_per_observation)
+ def data_generator():
+ """An infinite generator of latents and observations from the model."""
+ transition_std = np.sqrt(transition_variance)
+ observation_std = np.sqrt(observation_variance)
+ while True:
+ states = []
+ observations = []
+ # Sample z0 ~ Normal(0, sqrt(variance)).
+ states.append(
+ np.random.normal(size=[state_size],
+ scale=observation_std).astype(dtype))
+ # Start the range at 1 because we've already generated z0.
+ # The range ends at num_timesteps+1 because we want to include the
+ # num_timesteps-th step.
+ for t in xrange(1, num_timesteps+1):
+ if transition_type == ROUND_TRANSITION:
+ loc = np.round(states[-1])
+ elif transition_type == STANDARD_TRANSITION:
+ loc = states[-1]
+ z_t = np.random.normal(size=[state_size], loc=loc, scale=transition_std)
+ states.append(z_t.astype(dtype))
+ if t % steps_per_observation == 0:
+ if fixed_observation is None:
+ if observation_type == SQUARED_OBSERVATION:
+ loc = np.square(states[-1])
+ elif observation_type == ABS_OBSERVATION:
+ loc = np.abs(states[-1])
+ elif observation_type == STANDARD_OBSERVATION:
+ loc = states[-1]
+ x_t = np.random.normal(size=[state_size],
+ loc=loc,
+ scale=observation_std).astype(dtype)
+ else:
+ x_t = np.ones([state_size]) * fixed_observation
+
+ observations.append(x_t)
+ yield states, observations
+
+ dataset = tf.data.Dataset.from_generator(
+ data_generator,
+ output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
+ output_shapes=([num_timesteps+1, state_size],
+ [num_observations, state_size])
+ )
+ dataset = dataset.repeat().batch(batch_size)
+ dataset = dataset.prefetch(prefetch_buffer_size)
+ itr = dataset.make_one_shot_iterator()
+ _, observations = itr.get_next()
+ # Transpose observations from [batch, time, state_size] to
+ # [time, batch, state_size].
+ observations = tf.transpose(observations, perm=[1, 0, 2])
+ lengths = tf.ones([batch_size], dtype=tf.int32) * num_observations
+ return observations, lengths
diff --git a/models/research/fivo/fivo/data/datasets_test.py b/models/research/fivo/fivo/data/datasets_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6bbfda67aa44efc0bc4b1a34eb0cb9f09d53de5
--- /dev/null
+++ b/models/research/fivo/fivo/data/datasets_test.py
@@ -0,0 +1,303 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.data.datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import pickle
+import os
+
+import numpy as np
+import tensorflow as tf
+
+from fivo.data import datasets
+
+FLAGS = tf.app.flags.FLAGS
+
+
+class DatasetsTest(tf.test.TestCase):
+
+ def test_sparse_pianoroll_to_dense_empty_at_end(self):
+ sparse_pianoroll = [(0, 1), (1, 0), (), (1,), (), ()]
+ dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
+ sparse_pianoroll, min_note=0, num_notes=2)
+ self.assertEqual(num_timesteps, 6)
+ self.assertAllEqual([[1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [0, 0],
+ [0, 0]], dense_pianoroll)
+
+ def test_sparse_pianoroll_to_dense_with_chord(self):
+ sparse_pianoroll = [(0, 1), (1, 0), (), (1,)]
+ dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
+ sparse_pianoroll, min_note=0, num_notes=2)
+ self.assertEqual(num_timesteps, 4)
+ self.assertAllEqual([[1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1]], dense_pianoroll)
+
+ def test_sparse_pianoroll_to_dense_simple(self):
+ sparse_pianoroll = [(0,), (), (1,)]
+ dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
+ sparse_pianoroll, min_note=0, num_notes=2)
+ self.assertEqual(num_timesteps, 3)
+ self.assertAllEqual([[1, 0],
+ [0, 0],
+ [0, 1]], dense_pianoroll)
+
+ def test_sparse_pianoroll_to_dense_subtracts_min_note(self):
+ sparse_pianoroll = [(4, 5), (5, 4), (), (5,), (), ()]
+ dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
+ sparse_pianoroll, min_note=4, num_notes=2)
+ self.assertEqual(num_timesteps, 6)
+ self.assertAllEqual([[1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [0, 0],
+ [0, 0]], dense_pianoroll)
+
+ def test_sparse_pianoroll_to_dense_uses_num_notes(self):
+ sparse_pianoroll = [(4, 5), (5, 4), (), (5,), (), ()]
+ dense_pianoroll, num_timesteps = datasets.sparse_pianoroll_to_dense(
+ sparse_pianoroll, min_note=4, num_notes=3)
+ self.assertEqual(num_timesteps, 6)
+ self.assertAllEqual([[1, 1, 0],
+ [1, 1, 0],
+ [0, 0, 0],
+ [0, 1, 0],
+ [0, 0, 0],
+ [0, 0, 0]], dense_pianoroll)
+
+ def test_pianoroll_dataset(self):
+ pianoroll_data = [[(0,), (), (1,)],
+ [(0, 1), (1,)],
+ [(1,), (0,), (), (0, 1), (), ()]]
+ pianoroll_mean = np.zeros([3])
+ pianoroll_mean[-1] = 1
+ data = {"train": pianoroll_data, "train_mean": pianoroll_mean}
+ path = os.path.join(tf.test.get_temp_dir(), "test.pkl")
+ pickle.dump(data, open(path, "wb"))
+ with self.test_session() as sess:
+ inputs, targets, lens, mean = datasets.create_pianoroll_dataset(
+ path, "train", 2, num_parallel_calls=1,
+ shuffle=False, repeat=False,
+ min_note=0, max_note=2)
+ i1, t1, l1 = sess.run([inputs, targets, lens])
+ i2, t2, l2 = sess.run([inputs, targets, lens])
+ m = sess.run(mean)
+ # Check the lengths.
+ self.assertAllEqual([3, 2], l1)
+ self.assertAllEqual([6], l2)
+ # Check the mean.
+ self.assertAllEqual(pianoroll_mean, m)
+ # Check the targets. The targets should not be mean-centered and should
+ # be padded with zeros to a common length within a batch.
+ self.assertAllEqual([[1, 0, 0],
+ [0, 0, 0],
+ [0, 1, 0]], t1[:, 0, :])
+ self.assertAllEqual([[1, 1, 0],
+ [0, 1, 0],
+ [0, 0, 0]], t1[:, 1, :])
+ self.assertAllEqual([[0, 1, 0],
+ [1, 0, 0],
+ [0, 0, 0],
+ [1, 1, 0],
+ [0, 0, 0],
+ [0, 0, 0]], t2[:, 0, :])
+ # Check the inputs. Each sequence should start with zeros on the first
+ # timestep. Each sequence should be padded with zeros to a common length
+ # within a batch. The mean should be subtracted from all timesteps except
+ # the first and the padding.
+ self.assertAllEqual([[0, 0, 0],
+ [1, 0, -1],
+ [0, 0, -1]], i1[:, 0, :])
+ self.assertAllEqual([[0, 0, 0],
+ [1, 1, -1],
+ [0, 0, 0]], i1[:, 1, :])
+ self.assertAllEqual([[0, 0, 0],
+ [0, 1, -1],
+ [1, 0, -1],
+ [0, 0, -1],
+ [1, 1, -1],
+ [0, 0, -1]], i2[:, 0, :])
+
+ def test_human_pose_dataset(self):
+ pose_data = [
+ [[0, 0], [2, 2]],
+ [[2, 2]],
+ [[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]],
+ ]
+ pose_data = [np.array(x, dtype=np.float64) for x in pose_data]
+ pose_data_mean = np.array([1, 1], dtype=np.float64)
+ data = {
+ "train": pose_data,
+ "train_mean": pose_data_mean,
+ }
+ path = os.path.join(tf.test.get_temp_dir(), "test_human_pose_dataset.pkl")
+ with open(path, "wb") as out:
+ pickle.dump(data, out)
+ with self.test_session() as sess:
+ inputs, targets, lens, mean = datasets.create_human_pose_dataset(
+ path, "train", 2, num_parallel_calls=1, shuffle=False, repeat=False)
+ i1, t1, l1 = sess.run([inputs, targets, lens])
+ i2, t2, l2 = sess.run([inputs, targets, lens])
+ m = sess.run(mean)
+ # Check the lengths.
+ self.assertAllEqual([2, 1], l1)
+ self.assertAllEqual([5], l2)
+ # Check the mean.
+ self.assertAllEqual(pose_data_mean, m)
+ # Check the targets. The targets should not be mean-centered and should
+ # be padded with zeros to a common length within a batch.
+ self.assertAllEqual([[0, 0], [2, 2]], t1[:, 0, :])
+ self.assertAllEqual([[2, 2], [0, 0]], t1[:, 1, :])
+ self.assertAllEqual([[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]], t2[:, 0, :])
+ # Check the inputs. Each sequence should start with zeros on the first
+ # timestep. Each sequence should be padded with zeros to a common length
+ # within a batch. The mean should be subtracted from all timesteps except
+ # the first and the padding.
+ self.assertAllEqual([[0, 0], [-1, -1]], i1[:, 0, :])
+ self.assertAllEqual([[0, 0], [0, 0]], i1[:, 1, :])
+ self.assertAllEqual([[0, 0], [-1, -1], [-1, -1], [1, 1], [1, 1]],
+ i2[:, 0, :])
+
+ def test_speech_dataset(self):
+ with self.test_session() as sess:
+ path = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
+ "test_data",
+ "tiny_speech_dataset.tfrecord")
+ inputs, targets, lens = datasets.create_speech_dataset(
+ path, 3, samples_per_timestep=2, num_parallel_calls=1,
+ prefetch_buffer_size=3, shuffle=False, repeat=False)
+ inputs1, targets1, lengths1 = sess.run([inputs, targets, lens])
+ inputs2, targets2, lengths2 = sess.run([inputs, targets, lens])
+ # Check the lengths.
+ self.assertAllEqual([1, 2, 3], lengths1)
+ self.assertAllEqual([4], lengths2)
+ # Check the targets. The targets should be padded with zeros to a common
+ # length within a batch.
+ self.assertAllEqual([[[0., 1.], [0., 1.], [0., 1.]],
+ [[0., 0.], [2., 3.], [2., 3.]],
+ [[0., 0.], [0., 0.], [4., 5.]]],
+ targets1)
+ self.assertAllEqual([[[0., 1.]],
+ [[2., 3.]],
+ [[4., 5.]],
+ [[6., 7.]]],
+ targets2)
+ # Check the inputs. Each sequence should start with zeros on the first
+ # timestep. Each sequence should be padded with zeros to a common length
+ # within a batch.
+ self.assertAllEqual([[[0., 0.], [0., 0.], [0., 0.]],
+ [[0., 0.], [0., 1.], [0., 1.]],
+ [[0., 0.], [0., 0.], [2., 3.]]],
+ inputs1)
+ self.assertAllEqual([[[0., 0.]],
+ [[0., 1.]],
+ [[2., 3.]],
+ [[4., 5.]]],
+ inputs2)
+
+ def test_chain_graph_raises_error_on_wrong_steps_per_observation(self):
+ with self.assertRaises(ValueError):
+ datasets.create_chain_graph_dataset(
+ batch_size=4,
+ num_timesteps=10,
+ steps_per_observation=9)
+
+ def test_chain_graph_single_obs(self):
+ with self.test_session() as sess:
+ np.random.seed(1234)
+ num_observations = 1
+ num_timesteps = 5
+ batch_size = 2
+ state_size = 1
+ observations, lengths = datasets.create_chain_graph_dataset(
+ batch_size=batch_size,
+ num_timesteps=num_timesteps,
+ state_size=state_size)
+ out_observations, out_lengths = sess.run([observations, lengths])
+ self.assertAllEqual([num_observations, num_observations], out_lengths)
+ self.assertAllClose(
+ [[[1.426677], [-1.789461]]],
+ out_observations)
+
+ def test_chain_graph_multiple_obs(self):
+ with self.test_session() as sess:
+ np.random.seed(1234)
+ num_observations = 3
+ num_timesteps = 6
+ batch_size = 2
+ state_size = 1
+ observations, lengths = datasets.create_chain_graph_dataset(
+ batch_size=batch_size,
+ num_timesteps=num_timesteps,
+ steps_per_observation=num_timesteps/num_observations,
+ state_size=state_size)
+ out_observations, out_lengths = sess.run([observations, lengths])
+ self.assertAllEqual([num_observations, num_observations], out_lengths)
+ self.assertAllClose(
+ [[[0.40051451], [1.07405114]],
+ [[1.73932898], [3.16880035]],
+ [[-1.98377144], [2.82669163]]],
+ out_observations)
+
+ def test_chain_graph_state_dims(self):
+ with self.test_session() as sess:
+ np.random.seed(1234)
+ num_observations = 1
+ num_timesteps = 5
+ batch_size = 2
+ state_size = 3
+ observations, lengths = datasets.create_chain_graph_dataset(
+ batch_size=batch_size,
+ num_timesteps=num_timesteps,
+ state_size=state_size)
+ out_observations, out_lengths = sess.run([observations, lengths])
+ self.assertAllEqual([num_observations, num_observations], out_lengths)
+ self.assertAllClose(
+ [[[1.052287, -4.560759, 3.07988],
+ [2.008926, 0.495567, 3.488678]]],
+ out_observations)
+
+ def test_chain_graph_fixed_obs(self):
+ with self.test_session() as sess:
+ np.random.seed(1234)
+ num_observations = 3
+ num_timesteps = 6
+ batch_size = 2
+ state_size = 1
+ observations, lengths = datasets.create_chain_graph_dataset(
+ batch_size=batch_size,
+ num_timesteps=num_timesteps,
+ steps_per_observation=num_timesteps/num_observations,
+ state_size=state_size,
+ fixed_observation=4.)
+ out_observations, out_lengths = sess.run([observations, lengths])
+ self.assertAllEqual([num_observations, num_observations], out_lengths)
+ self.assertAllClose(
+ np.ones([num_observations, batch_size, state_size]) * 4.,
+ out_observations)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/research/fivo/fivo/ghmm_runners.py b/models/research/fivo/fivo/ghmm_runners.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f1ba6d4f9ea9ed9dee7d95449ba73285c77f24d
--- /dev/null
+++ b/models/research/fivo/fivo/ghmm_runners.py
@@ -0,0 +1,235 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Creates and runs Gaussian HMM-related graphs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+import tensorflow as tf
+
+from fivo import smc
+from fivo import bounds
+from fivo.data import datasets
+from fivo.models import ghmm
+
+
+def run_train(config):
+ """Runs training for a Gaussian HMM setup."""
+
+ def create_logging_hook(step, bound_value, likelihood, bound_gap):
+ """Creates a logging hook that prints the bound value periodically."""
+ bound_label = config.bound + "/t"
+ def summary_formatter(log_dict):
+ string = ("Step {step}, %s: {value:.3f}, "
+ "likelihood: {ll:.3f}, gap: {gap:.3e}") % bound_label
+ return string.format(**log_dict)
+ logging_hook = tf.train.LoggingTensorHook(
+ {"step": step, "value": bound_value,
+ "ll": likelihood, "gap": bound_gap},
+ every_n_iter=config.summarize_every,
+ formatter=summary_formatter)
+ return logging_hook
+
+ def create_losses(model, observations, lengths):
+ """Creates the loss to be optimized.
+
+ Args:
+ model: A Trainable GHMM model.
+ observations: A set of observations.
+ lengths: The lengths of each sequence in the observations.
+ Returns:
+ loss: A float Tensor that when differentiated yields the gradients
+ to apply to the model. Should be optimized via gradient descent.
+ bound: A float Tensor containing the value of the bound that is
+ being optimized.
+ true_ll: The true log-likelihood of the data under the model.
+ bound_gap: The gap between the bound and the true log-likelihood.
+ """
+ # Compute lower bounds on the log likelihood.
+ if config.bound == "elbo":
+ ll_per_seq, _, _ = bounds.iwae(
+ model, observations, lengths, num_samples=1,
+ parallel_iterations=config.parallel_iterations
+ )
+ elif config.bound == "iwae":
+ ll_per_seq, _, _ = bounds.iwae(
+ model, observations, lengths, num_samples=config.num_samples,
+ parallel_iterations=config.parallel_iterations
+ )
+ elif config.bound == "fivo":
+ if config.resampling_type == "relaxed":
+ ll_per_seq, _, _, _ = bounds.fivo(
+ model,
+ observations,
+ lengths,
+ num_samples=config.num_samples,
+ resampling_criterion=smc.ess_criterion,
+ resampling_type=config.resampling_type,
+ relaxed_resampling_temperature=config.
+ relaxed_resampling_temperature,
+ random_seed=config.random_seed,
+ parallel_iterations=config.parallel_iterations)
+ else:
+ ll_per_seq, _, _, _ = bounds.fivo(
+ model, observations, lengths,
+ num_samples=config.num_samples,
+ resampling_criterion=smc.ess_criterion,
+ resampling_type=config.resampling_type,
+ random_seed=config.random_seed,
+ parallel_iterations=config.parallel_iterations
+ )
+ ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
+ # Compute the data's true likelihood under the model and the bound gap.
+ true_ll_per_seq = model.likelihood(tf.squeeze(observations))
+ true_ll_per_t = tf.reduce_mean(true_ll_per_seq / tf.to_float(lengths))
+ bound_gap = true_ll_per_seq - ll_per_seq
+ bound_gap = tf.reduce_mean(bound_gap/ tf.to_float(lengths))
+ tf.summary.scalar("train_ll_bound", ll_per_t)
+ tf.summary.scalar("train_true_ll", true_ll_per_t)
+ tf.summary.scalar("bound_gap", bound_gap)
+ return -ll_per_t, ll_per_t, true_ll_per_t, bound_gap
+
+ def create_graph():
+ """Creates the training graph."""
+ global_step = tf.train.get_or_create_global_step()
+ xs, lengths = datasets.create_chain_graph_dataset(
+ config.batch_size,
+ config.num_timesteps,
+ steps_per_observation=1,
+ state_size=1,
+ transition_variance=config.variance,
+ observation_variance=config.variance)
+ model = ghmm.TrainableGaussianHMM(
+ config.num_timesteps,
+ config.proposal_type,
+ transition_variances=config.variance,
+ emission_variances=config.variance,
+ random_seed=config.random_seed)
+ loss, bound, true_ll, gap = create_losses(model, xs, lengths)
+ opt = tf.train.AdamOptimizer(config.learning_rate)
+ grads = opt.compute_gradients(loss, var_list=tf.trainable_variables())
+ train_op = opt.apply_gradients(grads, global_step=global_step)
+ return bound, true_ll, gap, train_op, global_step
+
+ with tf.Graph().as_default():
+ if config.random_seed:
+ tf.set_random_seed(config.random_seed)
+ np.random.seed(config.random_seed)
+ bound, true_ll, gap, train_op, global_step = create_graph()
+ log_hook = create_logging_hook(global_step, bound, true_ll, gap)
+ with tf.train.MonitoredTrainingSession(
+ master="",
+ hooks=[log_hook],
+ checkpoint_dir=config.logdir,
+ save_checkpoint_secs=120,
+ save_summaries_steps=config.summarize_every,
+ log_step_count_steps=config.summarize_every*20) as sess:
+ cur_step = -1
+ while cur_step <= config.max_steps and not sess.should_stop():
+ cur_step = sess.run(global_step)
+ _, cur_step = sess.run([train_op, global_step])
+
+
+def run_eval(config):
+ """Evaluates a Gaussian HMM using the given config."""
+
+ def create_bound(model, xs, lengths):
+ """Creates the bound to be evaluated."""
+ if config.bound == "elbo":
+ ll_per_seq, log_weights, _ = bounds.iwae(
+ model, xs, lengths, num_samples=1,
+ parallel_iterations=config.parallel_iterations
+ )
+ elif config.bound == "iwae":
+ ll_per_seq, log_weights, _ = bounds.iwae(
+ model, xs, lengths, num_samples=config.num_samples,
+ parallel_iterations=config.parallel_iterations
+ )
+ elif config.bound == "fivo":
+ ll_per_seq, log_weights, resampled, _ = bounds.fivo(
+ model, xs, lengths,
+ num_samples=config.num_samples,
+ resampling_criterion=smc.ess_criterion,
+ resampling_type=config.resampling_type,
+ random_seed=config.random_seed,
+ parallel_iterations=config.parallel_iterations
+ )
+ # Compute bound scaled by number of timesteps.
+ bound_per_t = ll_per_seq / tf.to_float(lengths)
+ if config.bound == "fivo":
+ return bound_per_t, log_weights, resampled
+ else:
+ return bound_per_t, log_weights
+
+ def create_graph():
+ """Creates the dataset, model, and bound."""
+ xs, lengths = datasets.create_chain_graph_dataset(
+ config.batch_size,
+ config.num_timesteps,
+ steps_per_observation=1,
+ state_size=1,
+ transition_variance=config.variance,
+ observation_variance=config.variance)
+ model = ghmm.TrainableGaussianHMM(
+ config.num_timesteps,
+ config.proposal_type,
+ transition_variances=config.variance,
+ emission_variances=config.variance,
+ random_seed=config.random_seed)
+ true_likelihood = tf.reduce_mean(
+ model.likelihood(tf.squeeze(xs)) / tf.to_float(lengths))
+ outs = [true_likelihood]
+ outs.extend(list(create_bound(model, xs, lengths)))
+ return outs
+
+ with tf.Graph().as_default():
+ if config.random_seed:
+ tf.set_random_seed(config.random_seed)
+ np.random.seed(config.random_seed)
+ graph_outs = create_graph()
+ with tf.train.SingularMonitoredSession(
+ checkpoint_dir=config.logdir) as sess:
+ outs = sess.run(graph_outs)
+ likelihood = outs[0]
+ avg_bound = np.mean(outs[1])
+ std = np.std(outs[1])
+ log_weights = outs[2]
+ log_weight_variances = np.var(log_weights, axis=2)
+ avg_log_weight_variance = np.var(log_weight_variances, axis=1)
+ avg_log_weight = np.mean(log_weights, axis=(1, 2))
+ data = {"mean": avg_bound, "std": std, "log_weights": log_weights,
+ "log_weight_means": avg_log_weight,
+ "log_weight_variances": avg_log_weight_variance}
+ if len(outs) == 4:
+ data["resampled"] = outs[3]
+ data["avg_resampled"] = np.mean(outs[3], axis=1)
+ # Log some useful statistics.
+ tf.logging.info("Evaled bound %s with batch_size: %d, num_samples: %d."
+ % (config.bound, config.batch_size, config.num_samples))
+ tf.logging.info("mean: %f, std: %f" % (avg_bound, std))
+ tf.logging.info("true likelihood: %s" % likelihood)
+ tf.logging.info("avg log weight: %s" % avg_log_weight)
+ tf.logging.info("log weight variance: %s" % avg_log_weight_variance)
+ if len(outs) == 4:
+ tf.logging.info("avg resamples per t: %s" % data["avg_resampled"])
+ if not tf.gfile.Exists(config.logdir):
+ tf.gfile.MakeDirs(config.logdir)
+ with tf.gfile.Open(os.path.join(config.logdir, "out.npz"), "w") as fout:
+ np.save(fout, data)
diff --git a/models/research/fivo/fivo/ghmm_runners_test.py b/models/research/fivo/fivo/ghmm_runners_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..50044ad475b3458858b580a6ff7664267485757b
--- /dev/null
+++ b/models/research/fivo/fivo/ghmm_runners_test.py
@@ -0,0 +1,106 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.ghmm_runners."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import tensorflow as tf
+
+from fivo import ghmm_runners
+
+
+class GHMMRunnersTest(tf.test.TestCase):
+
+ def default_config(self):
+ class Config(object):
+ pass
+ config = Config()
+ config.model = "ghmm"
+ config.bound = "fivo"
+ config.proposal_type = "prior"
+ config.batch_size = 4
+ config.num_samples = 4
+ config.num_timesteps = 10
+ config.variance = 0.1
+ config.resampling_type = "multinomial"
+ config.random_seed = 1234
+ config.parallel_iterations = 1
+ config.learning_rate = 1e-4
+ config.summarize_every = 1
+ config.max_steps = 1
+ return config
+
+ def test_eval_ghmm_notraining_fivo_prior(self):
+ self.eval_ghmm_notraining("fivo", "prior", -3.063864)
+
+ def test_eval_ghmm_notraining_fivo_true_filtering(self):
+ self.eval_ghmm_notraining("fivo", "true-filtering", -1.1409812)
+
+ def test_eval_ghmm_notraining_fivo_true_smoothing(self):
+ self.eval_ghmm_notraining("fivo", "true-smoothing", -0.85592091)
+
+ def test_eval_ghmm_notraining_iwae_prior(self):
+ self.eval_ghmm_notraining("iwae", "prior", -5.9730167)
+
+ def test_eval_ghmm_notraining_iwae_true_filtering(self):
+ self.eval_ghmm_notraining("iwae", "true-filtering", -1.1485999)
+
+ def test_eval_ghmm_notraining_iwae_true_smoothing(self):
+ self.eval_ghmm_notraining("iwae", "true-smoothing", -0.85592091)
+
+ def eval_ghmm_notraining(self, bound, proposal_type, expected_bound_avg):
+ config = self.default_config()
+ config.proposal_type = proposal_type
+ config.bound = bound
+ config.logdir = os.path.join(
+ tf.test.get_temp_dir(), "test-ghmm-%s-%s" % (proposal_type, bound))
+
+ ghmm_runners.run_eval(config)
+
+ data = np.load(os.path.join(config.logdir, "out.npz")).item()
+ self.assertAlmostEqual(expected_bound_avg, data["mean"], places=3)
+
+ def test_train_ghmm_for_one_step_and_eval_fivo_filtering(self):
+ self.train_ghmm_for_one_step_and_eval("fivo", "filtering", -16.727108)
+
+ def test_train_ghmm_for_one_step_and_eval_fivo_smoothing(self):
+ self.train_ghmm_for_one_step_and_eval("fivo", "smoothing", -19.381277)
+
+ def test_train_ghmm_for_one_step_and_eval_iwae_filtering(self):
+ self.train_ghmm_for_one_step_and_eval("iwae", "filtering", -33.31966)
+
+ def test_train_ghmm_for_one_step_and_eval_iwae_smoothing(self):
+ self.train_ghmm_for_one_step_and_eval("iwae", "smoothing", -46.388447)
+
+ def train_ghmm_for_one_step_and_eval(self, bound, proposal_type, expected_bound_avg):
+ config = self.default_config()
+ config.proposal_type = proposal_type
+ config.bound = bound
+ config.max_steps = 1
+ config.logdir = os.path.join(
+ tf.test.get_temp_dir(), "test-ghmm-training-%s-%s" % (proposal_type, bound))
+ ghmm_runners.run_train(config)
+ ghmm_runners.run_eval(config)
+ data = np.load(os.path.join(config.logdir, "out.npz")).item()
+ self.assertAlmostEqual(expected_bound_avg, data["mean"], places=2)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/research/fivo/fivo/models/__init__.py b/models/research/fivo/fivo/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/fivo/fivo/models/base.py b/models/research/fivo/fivo/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ffcb7af216f5659e71d7425eeb4e2c3158b3d47
--- /dev/null
+++ b/models/research/fivo/fivo/models/base.py
@@ -0,0 +1,342 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Reusable model classes for FIVO."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sonnet as snt
+import tensorflow as tf
+
+from fivo import nested_utils as nested
+
+tfd = tf.contrib.distributions
+
+
+class ELBOTrainableSequenceModel(object):
+ """An abstract class for ELBO-trainable sequence models to extend.
+
+ Because the ELBO, IWAE, and FIVO bounds all accept the same arguments,
+ any model that is ELBO-trainable is also IWAE- and FIVO-trainable.
+ """
+
+ def zero_state(self, batch_size, dtype):
+ """Returns the initial state of the model as a Tensor or tuple of Tensors.
+
+ Args:
+ batch_size: The batch size.
+ dtype: The datatype to use for the state.
+ """
+ raise NotImplementedError("zero_state not yet implemented.")
+
+ def set_observations(self, observations, seq_lengths):
+ """Sets the observations for the model.
+
+ This method provides the model with all observed variables including both
+ inputs and targets. It will be called before running any computations with
+ the model that require the observations, e.g. training the model or
+ computing bounds, and should be used to run any necessary preprocessing
+ steps.
+
+ Args:
+ observations: A potentially nested set of Tensors containing
+ all observations for the model, both inputs and targets. Typically
+ a set of Tensors with shape [max_seq_len, batch_size, data_size].
+ seq_lengths: A [batch_size] Tensor of ints encoding the length of each
+ sequence in the batch (sequences can be padded to a common length).
+ """
+ self.observations = observations
+ self.max_seq_len = tf.reduce_max(seq_lengths)
+ self.observations_ta = nested.tas_for_tensors(
+ observations, self.max_seq_len, clear_after_read=False)
+ self.seq_lengths = seq_lengths
+
+ def propose_and_weight(self, state, t):
+ """Propogates model state one timestep and computes log weights.
+
+ This method accepts the current state of the model and computes the state
+ for the next timestep as well as the incremental log weight of each
+ element in the batch.
+
+ Args:
+ state: The current state of the model.
+ t: A scalar integer Tensor representing the current timestep.
+ Returns:
+ next_state: The state of the model after one timestep.
+ log_weights: A [batch_size] Tensor containing the incremental log weights.
+ """
+ raise NotImplementedError("propose_and_weight not yet implemented.")
+
+DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
+ "b": tf.zeros_initializer()}
+
+
+class ConditionalNormalDistribution(object):
+ """A Normal distribution conditioned on Tensor inputs via a fc network."""
+
+ def __init__(self, size, hidden_layer_sizes, sigma_min=0.0,
+ raw_sigma_bias=0.25, hidden_activation_fn=tf.nn.relu,
+ initializers=None, name="conditional_normal_distribution"):
+ """Creates a conditional Normal distribution.
+
+ Args:
+ size: The dimension of the random variable.
+ hidden_layer_sizes: The sizes of the hidden layers of the fully connected
+ network used to condition the distribution on the inputs.
+ sigma_min: The minimum standard deviation allowed, a scalar.
+ raw_sigma_bias: A scalar that is added to the raw standard deviation
+ output from the fully connected network. Set to 0.25 by default to
+ prevent standard deviations close to 0.
+ hidden_activation_fn: The activation function to use on the hidden layers
+ of the fully connected network.
+ initializers: The variable intitializers to use for the fully connected
+ network. The network is implemented using snt.nets.MLP so it must
+ be a dictionary mapping the keys 'w' and 'b' to the initializers for
+ the weights and biases. Defaults to xavier for the weights and zeros
+ for the biases when initializers is None.
+ name: The name of this distribution, used for sonnet scoping.
+ """
+ self.sigma_min = sigma_min
+ self.raw_sigma_bias = raw_sigma_bias
+ self.name = name
+ self.size = size
+ if initializers is None:
+ initializers = DEFAULT_INITIALIZERS
+ self.fcnet = snt.nets.MLP(
+ output_sizes=hidden_layer_sizes + [2*size],
+ activation=hidden_activation_fn,
+ initializers=initializers,
+ activate_final=False,
+ use_bias=True,
+ name=name + "_fcnet")
+
+ def condition(self, tensor_list, **unused_kwargs):
+ """Computes the parameters of a normal distribution based on the inputs."""
+ inputs = tf.concat(tensor_list, axis=1)
+ outs = self.fcnet(inputs)
+ mu, sigma = tf.split(outs, 2, axis=1)
+ sigma = tf.maximum(tf.nn.softplus(sigma + self.raw_sigma_bias),
+ self.sigma_min)
+ return mu, sigma
+
+ def __call__(self, *args, **kwargs):
+ """Creates a normal distribution conditioned on the inputs."""
+ mu, sigma = self.condition(args, **kwargs)
+ return tf.contrib.distributions.Normal(loc=mu, scale=sigma)
+
+
+class ConditionalBernoulliDistribution(object):
+ """A Bernoulli distribution conditioned on Tensor inputs via a fc net."""
+
+ def __init__(self, size, hidden_layer_sizes, hidden_activation_fn=tf.nn.relu,
+ initializers=None, bias_init=0.0,
+ name="conditional_bernoulli_distribution"):
+ """Creates a conditional Bernoulli distribution.
+
+ Args:
+ size: The dimension of the random variable.
+ hidden_layer_sizes: The sizes of the hidden layers of the fully connected
+ network used to condition the distribution on the inputs.
+ hidden_activation_fn: The activation function to use on the hidden layers
+ of the fully connected network.
+ initializers: The variable intiializers to use for the fully connected
+ network. The network is implemented using snt.nets.MLP so it must
+ be a dictionary mapping the keys 'w' and 'b' to the initializers for
+ the weights and biases. Defaults to xavier for the weights and zeros
+ for the biases when initializers is None.
+ bias_init: A scalar or vector Tensor that is added to the output of the
+ fully-connected network that parameterizes the mean of this
+ distribution.
+ name: The name of this distribution, used for sonnet scoping.
+ """
+ self.bias_init = bias_init
+ self.size = size
+ if initializers is None:
+ initializers = DEFAULT_INITIALIZERS
+ self.fcnet = snt.nets.MLP(
+ output_sizes=hidden_layer_sizes + [size],
+ activation=hidden_activation_fn,
+ initializers=initializers,
+ activate_final=False,
+ use_bias=True,
+ name=name + "_fcnet")
+
+ def condition(self, tensor_list):
+ """Computes the p parameter of the Bernoulli distribution."""
+ inputs = tf.concat(tensor_list, axis=1)
+ return self.fcnet(inputs) + self.bias_init
+
+ def __call__(self, *args):
+ p = self.condition(args)
+ return tf.contrib.distributions.Bernoulli(logits=p)
+
+
+class NormalApproximatePosterior(ConditionalNormalDistribution):
+ """A Normally-distributed approx. posterior with res_q parameterization."""
+
+ def __init__(self, size, hidden_layer_sizes, sigma_min=0.0,
+ raw_sigma_bias=0.25, hidden_activation_fn=tf.nn.relu,
+ initializers=None, smoothing=False,
+ name="conditional_normal_distribution"):
+ super(NormalApproximatePosterior, self).__init__(
+ size, hidden_layer_sizes, sigma_min=sigma_min,
+ raw_sigma_bias=raw_sigma_bias,
+ hidden_activation_fn=hidden_activation_fn, initializers=initializers,
+ name=name)
+ self.smoothing = smoothing
+
+ def condition(self, tensor_list, prior_mu, smoothing_tensors=None):
+ """Generates the mean and variance of the normal distribution.
+
+ Args:
+ tensor_list: The list of Tensors to condition on. Will be concatenated and
+ fed through a fully connected network.
+ prior_mu: The mean of the prior distribution associated with this
+ approximate posterior. Will be added to the mean produced by
+ this approximate posterior, in res_q fashion.
+ smoothing_tensors: A list of Tensors. If smoothing is True, these Tensors
+ will be concatenated with the tensors in tensor_list.
+ Returns:
+ mu: The mean of the approximate posterior.
+ sigma: The standard deviation of the approximate posterior.
+ """
+ if self.smoothing:
+ tensor_list.extend(smoothing_tensors)
+ mu, sigma = super(NormalApproximatePosterior, self).condition(tensor_list)
+ return mu + prior_mu, sigma
+
+
+class NonstationaryLinearDistribution(object):
+ """A set of loc-scale distributions that are linear functions of inputs.
+
+ This class defines a series of location-scale distributions such that
+ the means are learnable linear functions of the inputs and the log variances
+ are learnable constants. The functions and log variances are different across
+ timesteps, allowing the distributions to be nonstationary.
+ """
+
+ def __init__(self,
+ num_timesteps,
+ inputs_per_timestep=None,
+ outputs_per_timestep=None,
+ initializers=None,
+ variance_min=0.0,
+ output_distribution=tfd.Normal,
+ dtype=tf.float32):
+ """Creates a NonstationaryLinearDistribution.
+
+ Args:
+ num_timesteps: The number of timesteps, i.e. the number of distributions.
+ inputs_per_timestep: A list of python ints, the dimension of inputs to the
+ linear function at each timestep. If not provided, the dimension at each
+ timestep is assumed to be 1.
+ outputs_per_timestep: A list of python ints, the dimension of the output
+ distribution at each timestep. If not provided, the dimension at each
+ timestep is assumed to be 1.
+ initializers: A dictionary containing intializers for the variables. The
+ initializer under the key 'w' is used for the weights in the linear
+ function and the initializer under the key 'b' is used for the biases.
+ Defaults to xavier initialization for the weights and zeros for the
+ biases.
+ variance_min: Python float, the minimum variance of each distribution.
+ output_distribution: A locatin-scale subclass of tfd.Distribution that
+ defines the output distribution, e.g. Normal.
+ dtype: The dtype of the weights and biases.
+ """
+ if not initializers:
+ initializers = DEFAULT_INITIALIZERS
+ if not inputs_per_timestep:
+ inputs_per_timestep = [1] * num_timesteps
+ if not outputs_per_timestep:
+ outputs_per_timestep = [1] * num_timesteps
+ self.num_timesteps = num_timesteps
+ self.variance_min = variance_min
+ self.initializers = initializers
+ self.dtype = dtype
+ self.output_distribution = output_distribution
+
+ def _get_variables_ta(shapes, name, initializer, trainable=True):
+ """Creates a sequence of variables and stores them in a TensorArray."""
+ # Infer shape if all shapes are equal.
+ first_shape = shapes[0]
+ infer_shape = all(shape == first_shape for shape in shapes)
+ ta = tf.TensorArray(
+ dtype=dtype, size=len(shapes), dynamic_size=False,
+ clear_after_read=False, infer_shape=infer_shape)
+ for t, shape in enumerate(shapes):
+ var = tf.get_variable(
+ name % t, shape=shape, initializer=initializer, trainable=trainable)
+ ta = ta.write(t, var)
+ return ta
+
+ bias_shapes = [[num_outputs] for num_outputs in outputs_per_timestep]
+ self.log_variances = _get_variables_ta(
+ bias_shapes, "proposal_log_variance_%d", initializers["b"])
+ self.mean_biases = _get_variables_ta(
+ bias_shapes, "proposal_b_%d", initializers["b"])
+ weight_shapes = zip(inputs_per_timestep, outputs_per_timestep)
+ self.mean_weights = _get_variables_ta(
+ weight_shapes, "proposal_w_%d", initializers["w"])
+ self.shapes = tf.TensorArray(
+ dtype=tf.int32, size=num_timesteps,
+ dynamic_size=False, clear_after_read=False).unstack(weight_shapes)
+
+ def __call__(self, t, inputs):
+ """Computes the distribution at timestep t.
+
+ Args:
+ t: Scalar integer Tensor, the current timestep. Must be in
+ [0, num_timesteps).
+ inputs: The inputs to the linear function parameterizing the mean of
+ the current distribution. A Tensor of shape [batch_size, num_inputs_t].
+ Returns:
+ A tfd.Distribution subclass representing the distribution at timestep t.
+ """
+ b = self.mean_biases.read(t)
+ w = self.mean_weights.read(t)
+ shape = self.shapes.read(t)
+ w = tf.reshape(w, shape)
+ b = tf.reshape(b, [shape[1], 1])
+ log_variance = self.log_variances.read(t)
+ scale = tf.sqrt(tf.maximum(tf.exp(log_variance), self.variance_min))
+ loc = tf.matmul(w, inputs, transpose_a=True) + b
+ return self.output_distribution(loc=loc, scale=scale)
+
+
+def encode_all(inputs, encoder):
+ """Encodes a timeseries of inputs with a time independent encoder.
+
+ Args:
+ inputs: A [time, batch, feature_dimensions] tensor.
+ encoder: A network that takes a [batch, features_dimensions] input and
+ encodes the input.
+ Returns:
+ A [time, batch, encoded_feature_dimensions] output tensor.
+ """
+ input_shape = tf.shape(inputs)
+ num_timesteps, batch_size = input_shape[0], input_shape[1]
+ reshaped_inputs = tf.reshape(inputs, [-1, inputs.shape[-1]])
+ inputs_encoded = encoder(reshaped_inputs)
+ inputs_encoded = tf.reshape(inputs_encoded,
+ [num_timesteps, batch_size, encoder.output_size])
+ return inputs_encoded
+
+
+def ta_for_tensor(x, **kwargs):
+ """Creates a TensorArray for the input tensor."""
+ return tf.TensorArray(
+ x.dtype, tf.shape(x)[0], dynamic_size=False, **kwargs).unstack(x)
diff --git a/models/research/fivo/fivo/models/ghmm.py b/models/research/fivo/fivo/models/ghmm.py
new file mode 100644
index 0000000000000000000000000000000000000000..07cf6c50e803383ef5690e8d24010e4706286eb7
--- /dev/null
+++ b/models/research/fivo/fivo/models/ghmm.py
@@ -0,0 +1,483 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A Gaussian hidden markov model.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from fivo.models import base
+
+tfd = tf.contrib.distributions
+
+
+class GaussianHMM(object):
+ """A hidden markov model with 1-D Gaussian latent space and observations.
+
+ This is a hidden markov model where the state and observations are
+ one-dimensional Gaussians. The mean of each latent state is a linear
+ function of the previous latent state, and the mean of each observation
+ is a linear function of the current latent state.
+
+ The description that follows is 0-indexed instead of 1-indexed to make
+ it easier to reason about the parameters passed to the model.
+
+ The parameters of the model are:
+ T: The number timesteps, latent states, and observations.
+ vz_t, t=0 to T-1: The variance of the latent state at timestep t.
+ vx_t, t=0 to T-1: The variance of the observation at timestep t.
+ wz_t, t=1 to T-1: The weight that defines the latent transition at t.
+ wx_t, t=0 to T-1: The weight that defines the observation function at t.
+
+ There are T vz_t, vx_t, and wx_t but only T-1 wz_t because there are only
+ T-1 transitions in the model.
+
+ Given these parameters, sampling from the model is defined as
+
+ z_0 ~ N(0, vz_0)
+ x_0 | z_0 ~ N(wx_0 * z_0, vx_0)
+ z_1 | z_0 ~ N(wz_1 * z_0, vz_1)
+ x_1 | z_1 ~ N(wx_1 * z_1, vx_1)
+ ...
+ z_{T-1} | z_{T-2} ~ N(wz_{T-1} * z_{T-2}, vz_{T-1})
+ x_{T-1} | z_{T-1} ~ N(wx_{T-1} * z_{T-1}, vx_{T-1}).
+ """
+
+ def __init__(self,
+ num_timesteps,
+ transition_variances=1.,
+ emission_variances=1.,
+ transition_weights=1.,
+ emission_weights=1.,
+ dtype=tf.float32):
+ """Creates a gaussian hidden markov model.
+
+ Args:
+ num_timesteps: A python int, the number of timesteps in the model.
+ transition_variances: The variance of p(z_t | z_t-1). Can be a scalar,
+ setting all variances to be the same, or a Tensor of shape
+ [num_timesteps].
+ emission_variances: The variance of p(x_t | z_t). Can be a scalar,
+ setting all variances to be the same, or a Tensor of shape
+ [num_timesteps].
+ transition_weights: The weight that defines the linear function that
+ produces the mean of z_t given z_{t-1}. Can be a scalar, setting
+ all weights to be the same, or a Tensor of shape [num_timesteps-1].
+ emission_weights: The weight that defines the linear function that
+ produces the mean of x_t given z_t. Can be a scalar, setting
+ all weights to be the same, or a Tensor of shape [num_timesteps].
+ dtype: The datatype of the state.
+ """
+ self.num_timesteps = num_timesteps
+ self.dtype = dtype
+
+ def _expand_param(param, size):
+ param = tf.convert_to_tensor(param, dtype=self.dtype)
+ if not param.get_shape().as_list():
+ param = tf.tile(param[tf.newaxis], [size])
+
+ return param
+
+ def _ta_for_param(param):
+ size = tf.shape(param)[0]
+ ta = tf.TensorArray(dtype=param.dtype,
+ size=size,
+ dynamic_size=False,
+ clear_after_read=False).unstack(param)
+ return ta
+
+ self.transition_variances = _ta_for_param(
+ _expand_param(transition_variances, num_timesteps))
+ self.transition_weights = _ta_for_param(
+ _expand_param(transition_weights, num_timesteps-1))
+ em_var = _expand_param(emission_variances, num_timesteps)
+ self.emission_variances = _ta_for_param(em_var)
+ em_w = _expand_param(emission_weights, num_timesteps)
+ self.emission_weights = _ta_for_param(em_w)
+ self._compute_covariances(em_w, em_var)
+
+ def _compute_covariances(self, emission_weights, emission_variances):
+ """Compute all covariance matrices.
+
+ Computes the covaraince matrix for the latent variables, the observations,
+ and the covariance between the latents and observations.
+
+ Args:
+ emission_weights: A Tensor of shape [num_timesteps] containing
+ the emission distribution weights at each timestep.
+ emission_variances: A Tensor of shape [num_timesteps] containing
+ the emiision distribution variances at each timestep.
+ """
+ # Compute the marginal variance of each latent.
+ z_variances = [self.transition_variances.read(0)]
+ for i in range(1, self.num_timesteps):
+ z_variances.append(
+ z_variances[i-1] * tf.square(self.transition_weights.read(i-1)) +
+ self.transition_variances.read(i))
+ # Compute the latent covariance matrix.
+ sigma_z = []
+ for i in range(self.num_timesteps):
+ sigma_z_row = []
+ for j in range(self.num_timesteps):
+ if i == j:
+ sigma_z_row.append(z_variances[i])
+ continue
+ min_ind = min(i, j)
+ max_ind = max(i, j)
+ weight = tf.reduce_prod(
+ self.transition_weights.gather(tf.range(min_ind, max_ind)))
+ sigma_z_row.append(z_variances[min_ind] * weight)
+ sigma_z.append(tf.stack(sigma_z_row))
+ self.sigma_z = tf.stack(sigma_z)
+ # Compute the observation covariance matrix.
+ x_weights_outer = tf.einsum("i,j->ij", emission_weights, emission_weights)
+ self.sigma_x = x_weights_outer * self.sigma_z + tf.diag(emission_variances)
+ # Compute the latent - observation covariance matrix.
+ # The first axis will index latents, the second axis will index observtions.
+ self.sigma_zx = emission_weights[tf.newaxis, :] * self.sigma_z
+ self.obs_dist = tfd.MultivariateNormalFullCovariance(
+ loc=tf.zeros([self.num_timesteps], dtype=tf.float32),
+ covariance_matrix=self.sigma_x)
+
+ def transition(self, t, z_prev):
+ """Compute the transition distribution p(z_t | z_t-1).
+
+ Args:
+ t: The current timestep, a scalar integer Tensor. When t=0 z_prev is
+ mostly ignored and the distribution p(z_0) is returned. z_prev is
+ 'mostly' ignored because it is still used to derive batch_size.
+ z_prev: A [batch_size] set of states.
+ Returns:
+ p(z_t | z_t-1) as a univariate normal distribution.
+ """
+ batch_size = tf.shape(z_prev)[0]
+ scale = tf.sqrt(self.transition_variances.read(t))
+ scale = tf.tile(scale[tf.newaxis], [batch_size])
+ loc = tf.cond(tf.greater(t, 0),
+ lambda: self.transition_weights.read(t-1)*z_prev,
+ lambda: tf.zeros_like(scale))
+ return tfd.Normal(loc=loc, scale=scale)
+
+ def emission(self, t, z):
+ """Compute the emission distribution p(x_t | z_t).
+
+ Args:
+ t: The current timestep, a scalar integer Tensor.
+ z: A [batch_size] set of the current states.
+ Returns:
+ p(x_t | z_t) as a univariate normal distribution.
+ """
+ batch_size = tf.shape(z)[0]
+ scale = tf.sqrt(self.emission_variances.read(t))
+ scale = tf.tile(scale[tf.newaxis], [batch_size])
+ loc = self.emission_weights.read(t)*z
+ return tfd.Normal(loc=loc, scale=scale)
+
+ def filtering(self, t, z_prev, x_cur):
+ """Computes the filtering distribution p(z_t | z_{t-1}, x_t).
+
+ Args:
+ t: A python int, the index for z_t. When t is 0, z_prev is ignored,
+ giving p(z_0 | x_0).
+ z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
+ [batch_size].
+ x_cur: x_t, the current x to condition on. A Tensor of shape [batch_size].
+ Returns:
+ p(z_t | z_{t-1}, x_t) as a univariate normal distribution.
+ """
+ z_prev = tf.convert_to_tensor(z_prev)
+ x_cur = tf.convert_to_tensor(x_cur)
+ batch_size = tf.shape(z_prev)[0]
+ z_var = self.transition_variances.read(t)
+ x_var = self.emission_variances.read(t)
+ x_weight = self.emission_weights.read(t)
+ prev_state_weight = x_var/(tf.square(x_weight)*z_var + x_var)
+ prev_state_weight *= tf.cond(tf.greater(t, 0),
+ lambda: self.transition_weights.read(t-1),
+ lambda: tf.zeros_like(prev_state_weight))
+ cur_obs_weight = (x_weight*z_var)/(tf.square(x_weight)*z_var + x_var)
+ loc = prev_state_weight*z_prev + cur_obs_weight*x_cur
+ scale = tf.sqrt((z_var*x_var)/(tf.square(x_weight)*z_var + x_var))
+ scale = tf.tile(scale[tf.newaxis], [batch_size])
+ return tfd.Normal(loc=loc, scale=scale)
+
+ def smoothing(self, t, z_prev, xs):
+ """Computes the smoothing distribution p(z_t | z_{t-1}, x_{t:num_timesteps).
+
+ Args:
+ t: A python int, the index for z_t. When t is 0, z_prev is ignored,
+ giving p(z_0 | x_{0:num_timesteps-1}).
+ z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
+ [batch_size].
+ xs: x_{t:num_timesteps}, the future xs to condition on. A Tensor of shape
+ [num_timesteps - t, batch_size].
+ Returns:
+ p(z_t | z_{t-1}, x_{t:num_timesteps}) as a univariate normal distribution.
+ """
+ xs = tf.convert_to_tensor(xs)
+ z_prev = tf.convert_to_tensor(z_prev)
+ batch_size = tf.shape(xs)[1]
+ mess_mean, mess_prec = tf.cond(
+ tf.less(t, self.num_timesteps-1),
+ lambda: tf.unstack(self._compute_backwards_messages(xs[1:]).read(0)),
+ lambda: [tf.zeros([batch_size]), tf.zeros([batch_size])])
+ return self._smoothing_from_message(t, z_prev, xs[0], mess_mean, mess_prec)
+
+ def _smoothing_from_message(self, t, z_prev, x_t, mess_mean, mess_prec):
+ """Computes the smoothing distribution given message incoming to z_t.
+
+ Computes p(z_t | z_{t-1}, x_{t:num_timesteps}) given the message incoming
+ to the node for z_t.
+
+ Args:
+ t: A python int, the index for z_t. When t is 0, z_prev is ignored.
+ z_prev: z_{t-1}, the previous z to condition on. A Tensor of shape
+ [batch_size].
+ x_t: The observation x at timestep t.
+ mess_mean: The mean of the message incoming to z_t, in information form.
+ mess_prec: The precision of the message incoming to z_t.
+ Returns:
+ p(z_t | z_{t-1}, x_{t:num_timesteps}) as a univariate normal distribution.
+ """
+
+ batch_size = tf.shape(x_t)[0]
+ z_var = self.transition_variances.read(t)
+ x_var = self.emission_variances.read(t)
+ w_x = self.emission_weights.read(t)
+
+ def transition_term():
+ return (tf.square(self.transition_weights.read(t))/
+ self.transition_variances.read(t+1))
+
+ prec = 1./z_var + tf.square(w_x)/x_var + mess_prec
+ prec += tf.cond(tf.less(t, self.num_timesteps-1),
+ transition_term, lambda: 0.)
+ mean = x_t*(w_x/x_var) + mess_mean
+ mean += tf.cond(tf.greater(t, 0),
+ lambda: z_prev*(self.transition_weights.read(t-1)/z_var),
+ lambda: 0.)
+ mean = tf.reshape(mean / prec, [batch_size])
+ scale = tf.reshape(tf.sqrt(1./prec), [batch_size])
+ return tfd.Normal(loc=mean, scale=scale)
+
+ def _compute_backwards_messages(self, xs):
+ """Computes the backwards messages used in smoothing."""
+ batch_size = tf.shape(xs)[1]
+ num_xs = tf.shape(xs)[0]
+ until_t = self.num_timesteps - num_xs
+ xs = tf.TensorArray(dtype=xs.dtype,
+ size=num_xs,
+ dynamic_size=False,
+ clear_after_read=True).unstack(xs)
+ messages_ta = tf.TensorArray(dtype=xs.dtype,
+ size=num_xs,
+ dynamic_size=False,
+ clear_after_read=False)
+
+ def compute_message(t, prev_mean, prev_prec, messages_ta):
+ """Computes one step of the backwards messages."""
+ z_var = self.transition_variances.read(t)
+ w_z = self.transition_weights.read(t-1)
+ x_var = self.emission_variances.read(t)
+ w_x = self.emission_weights.read(t)
+ cur_x = xs.read(t - until_t)
+
+ # If it isn't the first message, add the terms from the transition.
+ def transition_term():
+ return (tf.square(self.transition_weights.read(t))/
+ self.transition_variances.read(t+1))
+
+ unary_prec = 1/z_var + tf.square(w_x)/x_var
+ unary_prec += tf.cond(tf.less(t, self.num_timesteps-1),
+ transition_term, lambda: 0.)
+
+ unary_mean = (w_x / x_var) * cur_x
+ pairwise_prec = w_z / z_var
+
+ next_prec = -tf.square(pairwise_prec)/(unary_prec + prev_prec)
+ next_mean = (pairwise_prec * (unary_mean + prev_mean) /
+ (unary_prec + prev_prec))
+ next_prec = tf.reshape(next_prec, [batch_size])
+ next_mean = tf.reshape(next_mean, [batch_size])
+ messages_ta = messages_ta.write(t - until_t,
+ tf.stack([next_mean, next_prec]))
+ return t-1, next_mean, next_prec, messages_ta
+
+ def pred(t, *unused_args):
+ return tf.greater_equal(t, until_t)
+
+ init_prec = tf.zeros([batch_size], dtype=xs.dtype)
+ init_mean = tf.zeros([batch_size], dtype=xs.dtype)
+ t0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
+
+ outs = tf.while_loop(pred, compute_message,
+ (t0, init_mean, init_prec, messages_ta))
+ messages = outs[-1]
+ return messages
+
+ def lookahead(self, t, z_prev):
+ """Compute the 'lookahead' distribution, p(x_{t:T} | z_{t-1}).
+
+ Args:
+ t: A scalar Tensor int, the current timestep. Must be at least 1.
+ z_prev: The latent state at time t-1. A Tensor of shape [batch_size].
+ Returns:
+ p(x_{t:T} | z_{t-1}) as a multivariate normal distribution.
+ """
+ z_prev = tf.convert_to_tensor(z_prev)
+ sigma_zx = self.sigma_zx[t-1, t:]
+ z_var = self.sigma_z[t-1, t-1]
+ mean = tf.einsum("i,j->ij", z_prev, sigma_zx) / z_var
+ variance = (self.sigma_x[t:, t:] -
+ tf.einsum("i,j->ij", sigma_zx, sigma_zx) / z_var)
+ return tfd.MultivariateNormalFullCovariance(
+ loc=mean, covariance_matrix=variance)
+
+ def likelihood(self, xs):
+ """Compute the true marginal likelihood of the data.
+
+ Args:
+ xs: The observations, a [num_timesteps, batch_size] float Tensor.
+ Returns:
+ likelihoods: A [batch_size] float Tensor representing the likelihood of
+ each sequence of observations in the batch.
+ """
+ return self.obs_dist.log_prob(tf.transpose(xs))
+
+
+class TrainableGaussianHMM(GaussianHMM, base.ELBOTrainableSequenceModel):
+ """An interface between importance-sampling training methods and the GHMM."""
+
+ def __init__(self,
+ num_timesteps,
+ proposal_type,
+ transition_variances=1.,
+ emission_variances=1.,
+ transition_weights=1.,
+ emission_weights=1.,
+ random_seed=None,
+ dtype=tf.float32):
+ """Constructs a trainable Gaussian HMM.
+
+ Args:
+ num_timesteps: A python int, the number of timesteps in the model.
+ proposal_type: The type of proposal to use in the importance sampling
+ setup. Could be "filtering", "smoothing", "prior", "true-filtering",
+ or "true-smoothing". If "true-filtering" or "true-smoothing" are
+ selected, then the true filtering or smoothing distributions are used to
+ propose new states. If "learned-filtering" is selected then a
+ distribution with learnable parameters is used. Specifically at each
+ timestep the proposal is Gaussian with mean that is a learnable linear
+ function of the previous state and current observation. The log variance
+ is a per-timestep learnable constant. "learned-smoothing" is similar,
+ but the mean is a learnable linear function of the previous state and
+ all future observations. Note that this proposal class includes the true
+ posterior. If "prior" is selected then states are proposed from the
+ model's prior.
+ transition_variances: The variance of p(z_t | z_t-1). Can be a scalar,
+ setting all variances to be the same, or a Tensor of shape
+ [num_timesteps].
+ emission_variances: The variance of p(x_t | z_t). Can be a scalar,
+ setting all variances to be the same, or a Tensor of shape
+ [num_timesteps].
+ transition_weights: The weight that defines the linear function that
+ produces the mean of z_t given z_{t-1}. Can be a scalar, setting
+ all weights to be the same, or a Tensor of shape [num_timesteps-1].
+ emission_weights: The weight that defines the linear function that
+ produces the mean of x_t given z_t. Can be a scalar, setting
+ all weights to be the same, or a Tensor of shape [num_timesteps].
+ random_seed: A seed for the proposal sampling, mainly useful for testing.
+ dtype: The datatype of the state.
+ """
+ super(TrainableGaussianHMM, self).__init__(
+ num_timesteps, transition_variances, emission_variances,
+ transition_weights, emission_weights, dtype=dtype)
+ self.random_seed = random_seed
+ assert proposal_type in ["filtering", "smoothing", "prior",
+ "true-filtering", "true-smoothing"]
+ if proposal_type == "true-filtering":
+ self.proposal = self._filtering_proposal
+ elif proposal_type == "true-smoothing":
+ self.proposal = self._smoothing_proposal
+ elif proposal_type == "prior":
+ self.proposal = self.transition
+ elif proposal_type == "filtering":
+ self._learned_proposal_fn = base.NonstationaryLinearDistribution(
+ num_timesteps, inputs_per_timestep=[1] + [2] * (num_timesteps-1))
+ self.proposal = self._learned_filtering_proposal
+ elif proposal_type == "smoothing":
+ inputs_per_timestep = [num_timesteps] + [num_timesteps - t
+ for t in range(num_timesteps-1)]
+ self._learned_proposal_fn = base.NonstationaryLinearDistribution(
+ num_timesteps, inputs_per_timestep=inputs_per_timestep)
+ self.proposal = self._learned_smoothing_proposal
+
+ def set_observations(self, xs, seq_lengths):
+ """Sets the observations and stores the backwards messages."""
+ # Squeeze out data dimension since everything is 1-d.
+ xs = tf.squeeze(xs)
+ self.batch_size = tf.shape(xs)[1]
+ super(TrainableGaussianHMM, self).set_observations(xs, seq_lengths)
+ self.messages = self._compute_backwards_messages(xs[1:])
+
+ def zero_state(self, batch_size, dtype):
+ return tf.zeros([batch_size], dtype=dtype)
+
+ def propose_and_weight(self, state, t):
+ """Computes the next state and log weights for the GHMM."""
+ state_shape = tf.shape(state)
+ xt = self.observations[t]
+ p_zt = self.transition(t, state)
+ q_zt = self.proposal(t, state)
+ zt = q_zt.sample(seed=self.random_seed)
+ zt = tf.reshape(zt, state_shape)
+ p_xt_given_zt = self.emission(t, zt)
+ log_p_zt = p_zt.log_prob(zt)
+ log_q_zt = q_zt.log_prob(zt)
+ log_p_xt_given_zt = p_xt_given_zt.log_prob(xt)
+ weight = log_p_zt + log_p_xt_given_zt - log_q_zt
+ return weight, zt
+
+ def _filtering_proposal(self, t, state):
+ """Uses the stored observations to compute the filtering distribution."""
+ cur_x = self.observations[t]
+ return self.filtering(t, state, cur_x)
+
+ def _smoothing_proposal(self, t, state):
+ """Uses the stored messages to compute the smoothing distribution."""
+ mess_mean, mess_prec = tf.cond(
+ tf.less(t, self.num_timesteps-1),
+ lambda: tf.unstack(self.messages.read(t)),
+ lambda: [tf.zeros([self.batch_size]), tf.zeros([self.batch_size])])
+ return self._smoothing_from_message(t, state, self.observations[t],
+ mess_mean, mess_prec)
+
+ def _learned_filtering_proposal(self, t, state):
+ cur_x = self.observations[t]
+ inputs = tf.cond(tf.greater(t, 0),
+ lambda: tf.stack([state, cur_x], axis=0),
+ lambda: cur_x[tf.newaxis, :])
+ return self._learned_proposal_fn(t, inputs)
+
+ def _learned_smoothing_proposal(self, t, state):
+ xs = self.observations_ta.gather(tf.range(t, self.num_timesteps))
+ inputs = tf.cond(tf.greater(t, 0),
+ lambda: tf.concat([state[tf.newaxis, :], xs], axis=0),
+ lambda: xs)
+ return self._learned_proposal_fn(t, inputs)
diff --git a/models/research/fivo/fivo/models/ghmm_test.py b/models/research/fivo/fivo/models/ghmm_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..15a03c0c7abeae09bd1cfc87f917ef53ecac205f
--- /dev/null
+++ b/models/research/fivo/fivo/models/ghmm_test.py
@@ -0,0 +1,313 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.models.ghmm"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from fivo.models.ghmm import GaussianHMM
+from fivo.models.ghmm import TrainableGaussianHMM
+
+
+class GHMMTest(tf.test.TestCase):
+
+ def test_transition_no_weights(self):
+ with self.test_session() as sess:
+ ghmm = GaussianHMM(3,
+ transition_variances=[1., 2., 3.])
+ prev_z = tf.constant([1., 2.], dtype=tf.float32)
+ z0 = ghmm.transition(0, prev_z)
+ z1 = ghmm.transition(1, prev_z)
+ z2 = ghmm.transition(2, prev_z)
+ outs = sess.run([z0.mean(), z0.variance(),
+ z1.mean(), z1.variance(),
+ z2.mean(), z2.variance()])
+ self.assertAllClose(outs, [[0., 0.], [1., 1.],
+ [1., 2.], [2., 2.],
+ [1., 2.], [3., 3.]])
+
+ def test_transition_with_weights(self):
+ with self.test_session() as sess:
+ ghmm = GaussianHMM(3,
+ transition_variances=[1., 2., 3.],
+ transition_weights=[2., 3.])
+ prev_z = tf.constant([1., 2.], dtype=tf.float32)
+ z0 = ghmm.transition(0, prev_z)
+ z1 = ghmm.transition(1, prev_z)
+ z2 = ghmm.transition(2, prev_z)
+ outs = sess.run([z0.mean(), z0.variance(),
+ z1.mean(), z1.variance(),
+ z2.mean(), z2.variance()])
+ self.assertAllClose(outs, [[0., 0.], [1., 1.],
+ [2., 4.], [2., 2.],
+ [3., 6.], [3., 3.]])
+
+ def test_emission_no_weights(self):
+ with self.test_session() as sess:
+ ghmm = GaussianHMM(3, emission_variances=[1., 2., 3.])
+ z = tf.constant([1., 2.], dtype=tf.float32)
+ x0 = ghmm.emission(0, z)
+ x1 = ghmm.emission(1, z)
+ x2 = ghmm.emission(2, z)
+ outs = sess.run([x0.mean(), x0.variance(),
+ x1.mean(), x1.variance(),
+ x2.mean(), x2.variance()])
+ self.assertAllClose(outs, [[1., 2.], [1., 1.],
+ [1., 2.], [2., 2.],
+ [1., 2.], [3., 3.]])
+
+ def test_emission_with_weights(self):
+ with self.test_session() as sess:
+ ghmm = GaussianHMM(3,
+ emission_variances=[1., 2., 3.],
+ emission_weights=[1., 2., 3.])
+ z = tf.constant([1., 2.], dtype=tf.float32)
+ x0 = ghmm.emission(0, z)
+ x1 = ghmm.emission(1, z)
+ x2 = ghmm.emission(2, z)
+ outs = sess.run([x0.mean(), x0.variance(),
+ x1.mean(), x1.variance(),
+ x2.mean(), x2.variance()])
+ self.assertAllClose(outs, [[1., 2.], [1., 1.],
+ [2., 4.], [2., 2.],
+ [3., 6.], [3., 3.]])
+
+ def test_filtering_no_weights(self):
+ with self.test_session() as sess:
+ ghmm = GaussianHMM(3,
+ transition_variances=[1., 2., 3.],
+ emission_variances=[4., 5., 6.])
+ z_prev = tf.constant([1., 2.], dtype=tf.float32)
+ x_cur = tf.constant([3., 4.], dtype=tf.float32)
+ expected_outs = [[[3./5., 4./5.], [4./5., 4./5.]],
+ [[11./7., 18./7.], [10./7., 10./7.]],
+ [[5./3., 8./3.], [2., 2.]]]
+ f_post_0 = ghmm.filtering(0, z_prev, x_cur)
+ f_post_1 = ghmm.filtering(1, z_prev, x_cur)
+ f_post_2 = ghmm.filtering(2, z_prev, x_cur)
+ outs = sess.run([[f_post_0.mean(), f_post_0.variance()],
+ [f_post_1.mean(), f_post_1.variance()],
+ [f_post_2.mean(), f_post_2.variance()]])
+ self.assertAllClose(expected_outs, outs)
+
+ def test_filtering_with_weights(self):
+ with self.test_session() as sess:
+ ghmm = GaussianHMM(3,
+ transition_variances=[1., 2., 3.],
+ emission_variances=[4., 5., 6.],
+ transition_weights=[7., 8.],
+ emission_weights=[9., 10., 11])
+ z_prev = tf.constant([1., 2.], dtype=tf.float32)
+ x_cur = tf.constant([3., 4.], dtype=tf.float32)
+ expected_outs = [[[27./85., 36./85.], [4./85., 4./85.]],
+ [[95./205., 150./205.], [10./205., 10./205.]],
+ [[147./369., 228./369.], [18./369., 18./369.]]]
+ f_post_0 = ghmm.filtering(0, z_prev, x_cur)
+ f_post_1 = ghmm.filtering(1, z_prev, x_cur)
+ f_post_2 = ghmm.filtering(2, z_prev, x_cur)
+ outs = sess.run([[f_post_0.mean(), f_post_0.variance()],
+ [f_post_1.mean(), f_post_1.variance()],
+ [f_post_2.mean(), f_post_2.variance()]])
+ self.assertAllClose(expected_outs, outs)
+
+ def test_smoothing(self):
+ with self.test_session() as sess:
+ ghmm = GaussianHMM(3,
+ transition_variances=[1., 2., 3.],
+ emission_variances=[4., 5., 6.])
+ z_prev = tf.constant([1., 2.], dtype=tf.float32)
+ xs = tf.constant([[1., 2.],
+ [3., 4.],
+ [5., 6.]], dtype=tf.float32)
+ s_post1 = ghmm.smoothing(0, z_prev, xs)
+ outs = sess.run([s_post1.mean(), s_post1.variance()])
+ expected_outs = [[281./421., 410./421.], [292./421., 292./421.]]
+ self.assertAllClose(expected_outs, outs)
+
+ expected_outs = [[149./73., 222./73.], [90./73., 90./73.]]
+ s_post2 = ghmm.smoothing(1, z_prev, xs[1:])
+ outs = sess.run([s_post2.mean(), s_post2.variance()])
+ self.assertAllClose(expected_outs, outs)
+
+ s_post3 = ghmm.smoothing(2, z_prev, xs[2:])
+ outs = sess.run([s_post3.mean(), s_post3.variance()])
+ expected_outs = [[7./3., 10./3.], [2., 2.]]
+ self.assertAllClose(expected_outs, outs)
+
+ def test_smoothing_with_weights(self):
+ with self.test_session() as sess:
+ x_weight = np.array([4, 5, 6, 7], dtype=np.float32)
+ sigma_x = np.array([5, 6, 7, 8], dtype=np.float32)
+ z_weight = np.array([1, 2, 3], dtype=np.float32)
+ sigma_z = np.array([1, 2, 3, 4], dtype=np.float32)
+ z_prev = np.array([1, 2], dtype=np.float32)
+ batch_size = 2
+ xs = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32)
+
+ z_cov, x_cov, z_x_cov = self._compute_covariance_matrices(
+ x_weight, z_weight, sigma_x, sigma_z)
+
+ expected_outs = []
+ # Compute mean and variance for z_0 when we don't condition
+ # on previous zs.
+ sigma_12 = z_x_cov[0, :]
+ sigma_12_22 = np.dot(sigma_12, np.linalg.inv(x_cov))
+ mean = np.dot(sigma_12_22, xs)
+ variance = np.squeeze(z_cov[0, 0] - np.dot(sigma_12_22, sigma_12))
+ expected_outs.append([mean, np.tile(variance, [batch_size])])
+
+ # Compute mean and variance for remaining z_ts.
+ for t in xrange(1, 4):
+ sigma_12 = np.concatenate([[z_cov[t, t - 1]], z_x_cov[t, t:]])
+ sigma_22 = np.vstack((
+ np.hstack((z_cov[t-1, t-1], z_x_cov[t-1, t:])),
+ np.hstack((np.transpose([z_x_cov[t-1, t:]]), x_cov[t:, t:]))
+ ))
+ sigma_12_22 = np.dot(sigma_12, np.linalg.inv(sigma_22))
+ mean = np.dot(sigma_12_22, np.vstack((z_prev, xs[t:])))
+ variance = np.squeeze(z_cov[t, t] - np.dot(sigma_12_22, sigma_12))
+ expected_outs.append([mean, np.tile(variance, [batch_size])])
+
+ ghmm = GaussianHMM(4,
+ transition_variances=sigma_z,
+ emission_variances=sigma_x,
+ transition_weights=z_weight,
+ emission_weights=x_weight)
+ out_dists = [ghmm.smoothing(t, z_prev, xs[t:]) for t in range(0, 4)]
+ outs = [[d.mean(), d.variance()] for d in out_dists]
+ run_outs = sess.run(outs)
+ self.assertAllClose(expected_outs, run_outs)
+
+ def test_covariance_matrices(self):
+ with self.test_session() as sess:
+ x_weight = np.array([4, 5, 6, 7], dtype=np.float32)
+ sigma_x = np.array([5, 6, 7, 8], dtype=np.float32)
+ z_weight = np.array([1, 2, 3], dtype=np.float32)
+ sigma_z = np.array([1, 2, 3, 4], dtype=np.float32)
+
+ z_cov, x_cov, z_x_cov = self._compute_covariance_matrices(
+ x_weight, z_weight, sigma_x, sigma_z)
+
+ ghmm = GaussianHMM(4,
+ transition_variances=sigma_z,
+ emission_variances=sigma_x,
+ transition_weights=z_weight,
+ emission_weights=x_weight)
+ self.assertAllClose(z_cov, sess.run(ghmm.sigma_z))
+ self.assertAllClose(x_cov, sess.run(ghmm.sigma_x))
+ self.assertAllClose(z_x_cov, sess.run(ghmm.sigma_zx))
+
+ def _compute_covariance_matrices(self, x_weight, z_weight, sigma_x, sigma_z):
+ # Create z covariance matrix from the definitions.
+ z_cov = np.zeros([4, 4])
+ z_cov[0, 0] = sigma_z[0]
+ for i in range(1, 4):
+ z_cov[i, i] = (z_cov[i - 1, i - 1] * np.square(z_weight[i - 1]) +
+ sigma_z[i])
+ for i in range(4):
+ for j in range(4):
+ if i == j: continue
+ min_ind = min(i, j)
+ max_ind = max(i, j)
+ weights = np.prod(z_weight[min_ind:max_ind])
+ z_cov[i, j] = z_cov[min_ind, min_ind] * weights
+ # Compute the x covariance matrix and the z-x covariance matrix.
+ x_weights_outer = np.outer(x_weight, x_weight)
+ x_cov = x_weights_outer * z_cov + np.diag(sigma_x)
+ z_x_cov = x_weight * z_cov
+ return z_cov, x_cov, z_x_cov
+
+ def test_lookahead(self):
+ x_weight = np.array([4, 5, 6, 7], dtype=np.float32)
+ sigma_x = np.array([5, 6, 7, 8], dtype=np.float32)
+ z_weight = np.array([1, 2, 3], dtype=np.float32)
+ sigma_z = np.array([1, 2, 3, 4], dtype=np.float32)
+ z_prev = np.array([1, 2], dtype=np.float32)
+
+ with self.test_session() as sess:
+ z_cov, x_cov, z_x_cov = self._compute_covariance_matrices(
+ x_weight, z_weight, sigma_x, sigma_z)
+
+ expected_outs = []
+ for t in range(1, 4):
+ sigma_12 = z_x_cov[t-1, t:]
+ z_var = z_cov[t-1, t-1]
+ mean = np.outer(z_prev, sigma_12/z_var)
+ variance = x_cov[t:, t:] - np.outer(sigma_12, sigma_12)/ z_var
+ expected_outs.append([mean, variance])
+
+ ghmm = GaussianHMM(4,
+ transition_variances=sigma_z,
+ emission_variances=sigma_x,
+ transition_weights=z_weight,
+ emission_weights=x_weight)
+ out_dists = [ghmm.lookahead(t, z_prev) for t in range(1, 4)]
+ outs = [[d.mean(), d.covariance()] for d in out_dists]
+ run_outs = sess.run(outs)
+ self.assertAllClose(expected_outs, run_outs)
+
+
+class TrainableGHMMTest(tf.test.TestCase):
+
+ def test_filtering_proposal(self):
+ """Check that stashing the xs doesn't change the filtering distributions."""
+ with self.test_session() as sess:
+ ghmm = TrainableGaussianHMM(
+ 3, "filtering",
+ transition_variances=[1., 2., 3.],
+ emission_variances=[4., 5., 6.],
+ transition_weights=[7., 8.],
+ emission_weights=[9., 10., 11])
+ observations = tf.constant([[3., 4.],
+ [3., 4.],
+ [3., 4.]], dtype=tf.float32)
+ ghmm.set_observations(observations, [3, 3])
+ z_prev = tf.constant([1., 2.], dtype=tf.float32)
+
+ proposals = [ghmm._filtering_proposal(t, z_prev) for t in range(3)]
+ dist_params = [[p.mean(), p.variance()] for p in proposals]
+
+ expected_outs = [[[27./85., 36./85.], [4./85., 4./85.]],
+ [[95./205., 150./205.], [10./205., 10./205.]],
+ [[147./369., 228./369.], [18./369., 18./369.]]]
+ self.assertAllClose(expected_outs, sess.run(dist_params))
+
+ def test_smoothing_proposal(self):
+ with self.test_session() as sess:
+ ghmm = TrainableGaussianHMM(
+ 3, "smoothing",
+ transition_variances=[1., 2., 3.],
+ emission_variances=[4., 5., 6.])
+ xs = tf.constant([[1., 2.],
+ [3., 4.],
+ [5., 6.]], dtype=tf.float32)
+ ghmm.set_observations(xs, [3, 3])
+ z_prev = tf.constant([1., 2.], dtype=tf.float32)
+
+ proposals = [ghmm._smoothing_proposal(t, z_prev) for t in range(3)]
+ dist_params = [[p.mean(), p.variance()] for p in proposals]
+
+ expected_outs = [[[281./421., 410./421.], [292./421., 292./421.]],
+ [[149./73., 222./73.], [90./73., 90./73.]],
+ [[7./3., 10./3.], [2., 2.]]]
+ self.assertAllClose(expected_outs, sess.run(dist_params))
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/research/fivo/fivo/models/srnn.py b/models/research/fivo/fivo/models/srnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdfb560eedffccf8edf41dbab4e85bbd8bbfab46
--- /dev/null
+++ b/models/research/fivo/fivo/models/srnn.py
@@ -0,0 +1,587 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""SRNN classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+import functools
+
+import sonnet as snt
+import tensorflow as tf
+
+from fivo.models import base
+
+
+SRNNState = namedtuple("SRNNState", "rnn_state latent_encoded")
+
+
+class SRNN(object):
+ """Implementation of a Stochastic Recurrent Neural Network (SRNN).
+
+ Introduced in "Sequential Neural Models with Stochastic Layers"
+ by Fraccaro et al. https://arxiv.org/pdf/1605.07571.pdf.
+
+ The SRNN is a sequence model similar to an RNN that uses stochastic latent
+ variables to improve its representational power. It can be thought of as a
+ sequential analogue to the variational auto-encoder (VAE).
+
+ The SRNN has a deterministic RNN as its backbone, represented by the
+ sequence of RNN hidden states h_t. The latent state is conditioned on
+ the deterministic RNN states and previous latent state. Unlike the VRNN, the
+ the RNN state is not conditioned on the previous latent state. The latent
+ states have a Markov structure and it is assumed that
+ p(z_t | z_{1:t-1}) = p(z_t | z_{t-1}).
+
+ In this implementation of the SRNN the latent state z_t is Gaussian. The
+ model's prior over z_t (also called the transition distribution) is
+ distributed as Normal(mu_t, diag(sigma_t^2)) where mu_t and sigma_t are the
+ mean and standard deviation output from a fully connected network that accepts
+ the rnn hidden state h_t and previous latent state z_{t-1} as input.
+
+ The emission distribution p(x_t|z_t, h_t) is conditioned on the latent state
+ z_t as well as the current RNN hidden state h_t via a fully connected network.
+
+ To increase the modeling power of the SRNN, two additional networks are
+ used to extract features from the data and the latent state. Those networks
+ are called data_encoder and latent_encoder respectively.
+
+ For an example of how to call the SRNN's methods see sample_step.
+
+ There are a few differences between this exposition and the paper. The main
+ goal was to be consistent with the VRNN code. A few components are renamed.
+ The backward RNN for approximating the posterior, g_phi_a in the paper, is the
+ rev_rnn_cell. The forward RNN that conditions the latent distribution, d in
+ the paper, is the rnn_cell. The paper doesn't name the NN's that serve as
+ feature extractors, and we name them here as the data_encoder and
+ latent_encoder.
+ """
+
+ def __init__(self,
+ rnn_cell,
+ data_encoder,
+ latent_encoder,
+ transition,
+ emission,
+ random_seed=None):
+ """Create a SRNN.
+
+ Args:
+ rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
+ deterministic backbone of the SRNN. The inputs to the RNN will be the
+ the encoded input of the current timestep, a Tensor of shape
+ [batch_size, encoded_data_size].
+ data_encoder: A callable that accepts a batch of data x_t and
+ 'encodes' it, e.g. runs it through a fully connected network. Must
+ accept as argument the inputs x_t, a Tensor of the shape
+ [batch_size, data_size] and return a Tensor of shape
+ [batch_size, encoded_data_size]. This callable will be called multiple
+ times in the SRNN cell so if scoping is not handled correctly then
+ multiple copies of the variables in this network could be made. It is
+ recommended to use a snt.nets.MLP module, which takes care of this for
+ you.
+ latent_encoder: A callable that accepts a latent state z_t and
+ 'encodes' it, e.g. runs it through a fully connected network. Must
+ accept as argument a Tensor of shape [batch_size, latent_size] and
+ return a Tensor of shape [batch_size, encoded_latent_size].
+ This callable must also have the property 'output_size' defined,
+ returning encoded_latent_size.
+ transition: A callable that implements the transition distribution
+ p(z_t|h_t, z_t-1). Must accept as argument the previous RNN hidden state
+ and previous encoded latent state then return a tf.distributions.Normal
+ distribution conditioned on the input.
+ emission: A callable that implements the emission distribution
+ p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
+ and the RNN hidden state and return a subclass of
+ tf.distributions.Distribution that can be used to evaluate the logprob
+ of the targets.
+ random_seed: The seed for the random ops. Sets the seed for sample_step.
+ """
+ self.random_seed = random_seed
+ self.rnn_cell = rnn_cell
+ self.data_encoder = data_encoder
+ self.latent_encoder = latent_encoder
+ self.encoded_z_size = latent_encoder.output_size
+ self.state_size = (self.rnn_cell.state_size)
+ self._transition = transition
+ self._emission = emission
+
+ def zero_state(self, batch_size, dtype):
+ """The initial state of the SRNN.
+
+ Contains the initial state of the RNN and the inital encoded latent.
+
+ Args:
+ batch_size: The batch size.
+ dtype: The data type of the SRNN.
+ Returns:
+ zero_state: The initial state of the SRNN.
+ """
+ return SRNNState(
+ rnn_state=self.rnn_cell.zero_state(batch_size, dtype),
+ latent_encoded=tf.zeros(
+ [batch_size, self.latent_encoder.output_size], dtype=dtype))
+
+ def run_rnn(self, prev_rnn_state, inputs):
+ """Runs the deterministic RNN for one step.
+
+ Args:
+ prev_rnn_state: The state of the RNN from the previous timestep.
+ inputs: A Tensor of shape [batch_size, data_size], the current inputs to
+ the model. Most often this is x_{t-1}, the previous token in the
+ observation sequence.
+ Returns:
+ rnn_out: The output of the RNN.
+ rnn_state: The new state of the RNN.
+ """
+ rnn_inputs = self.data_encoder(tf.to_float(inputs))
+ rnn_out, rnn_state = self.rnn_cell(rnn_inputs, prev_rnn_state)
+ return rnn_out, rnn_state
+
+ def transition(self, rnn_out, prev_latent_encoded):
+ """Computes the transition distribution p(z_t|h_t, z_{t-1}).
+
+ Note that p(z_t | h_t, z_{t-1}) = p(z_t| z_{t-1}, x_{1:t-1})
+
+ Args:
+ rnn_out: The output of the rnn for the current timestep.
+ prev_latent_encoded: Float Tensor of shape
+ [batch_size, encoded_latent_size], the previous latent state z_{t-1}
+ run through latent_encoder.
+ Returns:
+ p(z_t | h_t): A normal distribution with event shape
+ [batch_size, latent_size].
+ """
+ return self._transition(rnn_out, prev_latent_encoded)
+
+ def emission(self, latent, rnn_out):
+ """Computes the emission distribution p(x_t | z_t, h_t).
+
+ Note that p(x_t | z_t, h_t) = p(x_t | z_t, x_{1:t-1})
+
+ Args:
+ latent: The stochastic latent state z_t.
+ rnn_out: The output of the rnn for the current timestep.
+ Returns:
+ p(x_t | z_t, h_t): A distribution with event shape
+ [batch_size, data_size].
+ latent_encoded: The latent state encoded with latent_encoder. Should be
+ passed to transition() on the next timestep.
+ """
+ latent_encoded = self.latent_encoder(latent)
+ return self._emission(latent_encoded, rnn_out), latent_encoded
+
+ def sample_step(self, prev_state, inputs, unused_t):
+ """Samples one output from the model.
+
+ Args:
+ prev_state: The previous state of the model, a SRNNState containing the
+ previous rnn state and the previous encoded latent.
+ inputs: A Tensor of shape [batch_size, data_size], the current inputs to
+ the model. Most often this is x_{t-1}, the previous token in the
+ observation sequence.
+ unused_t: The current timestep. Not used currently.
+ Returns:
+ new_state: The next state of the model, a SRNNState.
+ xt: A float Tensor of shape [batch_size, data_size], an output sampled
+ from the emission distribution.
+ """
+ rnn_out, rnn_state = self.run_rnn(prev_state.rnn_state,
+ inputs)
+ p_zt = self.transition(rnn_out, prev_state.latent_encoded)
+ zt = p_zt.sample(seed=self.random_seed)
+ p_xt_given_zt, latent_encoded = self.emission(zt, rnn_out)
+ xt = p_xt_given_zt.sample(seed=self.random_seed)
+ new_state = SRNNState(rnn_state=rnn_state, latent_encoded=latent_encoded)
+ return new_state, tf.to_float(xt)
+
+# pylint: disable=invalid-name
+# pylint thinks this is a top-level constant.
+TrainableSRNNState = namedtuple("TrainableSRNNState",
+ SRNNState._fields + ("rnn_out",))
+# pylint: enable=g-invalid-name
+
+
+class TrainableSRNN(SRNN, base.ELBOTrainableSequenceModel):
+ """A SRNN subclass with proposals and methods for training and evaluation.
+
+ This class adds proposals used for training with importance-sampling based
+ methods such as the ELBO. The model can be configured to propose from one
+ of three proposals: a learned filtering proposal, a learned smoothing
+ proposal, or the prior (i.e. the transition distribution).
+
+ As described in the SRNN paper, the learned filtering proposal is
+ parameterized by a fully connected neural network that accepts as input the
+ current target x_t and the current rnn output h_t. The learned smoothing
+ proposal is also given the hidden state of an RNN run in reverse over the
+ inputs, so as to incorporate information about future observations.
+
+ All learned proposals use the 'res_q' parameterization, meaning that instead
+ of directly producing the mean of z_t, the proposal network predicts the
+ 'residual' from the prior's mean. This is explored more in section 3.3 of
+ https://arxiv.org/pdf/1605.07571.pdf.
+
+ During training, the latent state z_t is sampled from the proposal and the
+ reparameterization trick is used to provide low-variance gradients.
+
+ Note that the SRNN paper refers to the proposals as the approximate posterior,
+ but we match the VRNN convention of referring to it as the encoder.
+ """
+
+ def __init__(self,
+ rnn_cell,
+ data_encoder,
+ latent_encoder,
+ transition,
+ emission,
+ proposal_type,
+ proposal=None,
+ rev_rnn_cell=None,
+ tilt=None,
+ random_seed=None):
+ """Create a trainable RNN.
+
+ Args:
+ rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
+ deterministic backbone of the SRNN. The inputs to the RNN will be the
+ the encoded input of the current timestep, a Tensor of shape
+ [batch_size, encoded_data_size].
+ data_encoder: A callable that accepts a batch of data x_t and
+ 'encodes' it, e.g. runs it through a fully connected network. Must
+ accept as argument the inputs x_t, a Tensor of the shape
+ [batch_size, data_size] and return a Tensor of shape
+ [batch_size, encoded_data_size]. This callable will be called multiple
+ times in the SRNN cell so if scoping is not handled correctly then
+ multiple copies of the variables in this network could be made. It is
+ recommended to use a snt.nets.MLP module, which takes care of this for
+ you.
+ latent_encoder: A callable that accepts a latent state z_t and
+ 'encodes' it, e.g. runs it through a fully connected network. Must
+ accept as argument a Tensor of shape [batch_size, latent_size] and
+ return a Tensor of shape [batch_size, encoded_latent_size].
+ This callable must also have the property 'output_size' defined,
+ returning encoded_latent_size.
+ transition: A callable that implements the transition distribution
+ p(z_t|h_t, z_t-1). Must accept as argument the previous RNN hidden state
+ and previous encoded latent state then return a tf.distributions.Normal
+ distribution conditioned on the input.
+ emission: A callable that implements the emission distribution
+ p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
+ and the RNN hidden state and return a subclass of
+ tf.distributions.Distribution that can be used to evaluate the logprob
+ of the targets.
+ proposal_type: A string indicating the type of proposal to use. Can
+ be either "filtering", "smoothing", or "prior". When proposal_type is
+ "filtering" or "smoothing", proposal must be provided. When
+ proposal_type is "smoothing", rev_rnn_cell must also be provided.
+ proposal: A callable that implements the proposal q(z_t| h_t, x_{1:T}).
+ If proposal_type is "filtering" then proposal must accept as arguments
+ the current rnn output, the encoded target of the current timestep,
+ and the mean of the prior. If proposal_type is "smoothing" then
+ in addition to the current rnn output and the mean of the prior
+ proposal must accept as arguments the output of the reverse rnn.
+ proposal should return a tf.distributions.Normal distribution
+ conditioned on its inputs. If proposal_type is "prior" this argument is
+ ignored.
+ rev_rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will aggregate
+ forward rnn outputs in the reverse direction. The inputs to the RNN
+ will be the encoded reverse input of the current timestep, a Tensor of
+ shape [batch_size, encoded_data_size].
+ tilt: A callable that implements the log of a positive tilting function
+ (ideally approximating log p(x_{t+1}|z_t, h_t). Must accept as arguments
+ the encoded latent state and the RNN hidden state and return a subclass
+ of tf.distributions.Distribution that can be used to evaluate the
+ logprob of x_{t+1}. Optionally, None and then no tilt is used.
+ random_seed: The seed for the random ops. Sets the seed for sample_step
+ and __call__.
+ """
+ super(TrainableSRNN, self).__init__(
+ rnn_cell, data_encoder, latent_encoder,
+ transition, emission, random_seed=random_seed)
+ self.rev_rnn_cell = rev_rnn_cell
+ self._tilt = tilt
+ assert proposal_type in ["filtering", "smoothing", "prior"]
+ self._proposal = proposal
+ self.proposal_type = proposal_type
+ if proposal_type != "prior":
+ assert proposal, "If not proposing from the prior, must provide proposal."
+ if proposal_type == "smoothing":
+ assert rev_rnn_cell, "Must provide rev_rnn_cell for smoothing proposal."
+
+ def zero_state(self, batch_size, dtype):
+ super_state = super(TrainableSRNN, self).zero_state(batch_size, dtype)
+ return TrainableSRNNState(
+ rnn_out=tf.zeros([batch_size, self.rnn_cell.output_size], dtype=dtype),
+ **super_state._asdict())
+
+ def set_observations(self, observations, seq_lengths):
+ """Stores the model's observations.
+
+ Stores the observations (inputs and targets) in TensorArrays and precomputes
+ things for later like the reverse RNN output and encoded targets.
+
+ Args:
+ observations: The observations of the model, a tuple containing two
+ Tensors of shape [max_seq_len, batch_size, data_size]. The Tensors
+ should be the inputs and targets, respectively.
+ seq_lengths: An int Tensor of shape [batch_size] containing the length
+ of each sequence in observations.
+ """
+ inputs, targets = observations
+ self.seq_lengths = seq_lengths
+ self.max_seq_len = tf.reduce_max(seq_lengths)
+ self.targets_ta = base.ta_for_tensor(targets, clear_after_read=False)
+ targets_encoded = base.encode_all(targets, self.data_encoder)
+ self.targets_encoded_ta = base.ta_for_tensor(targets_encoded,
+ clear_after_read=False)
+ inputs_encoded = base.encode_all(inputs, self.data_encoder)
+ rnn_out, _ = tf.nn.dynamic_rnn(self.rnn_cell,
+ inputs_encoded,
+ time_major=True,
+ dtype=tf.float32,
+ scope="forward_rnn")
+ self.rnn_ta = base.ta_for_tensor(rnn_out,
+ clear_after_read=False)
+ if self.rev_rnn_cell:
+ targets_and_rnn_out = tf.concat([rnn_out, targets_encoded], 2)
+ reversed_targets_and_rnn_out = tf.reverse_sequence(
+ targets_and_rnn_out, seq_lengths, seq_axis=0, batch_axis=1)
+ # Compute the reverse rnn over the targets.
+ reverse_rnn_out, _ = tf.nn.dynamic_rnn(self.rev_rnn_cell,
+ reversed_targets_and_rnn_out,
+ time_major=True,
+ dtype=tf.float32,
+ scope="reverse_rnn")
+ reverse_rnn_out = tf.reverse_sequence(reverse_rnn_out, seq_lengths,
+ seq_axis=0, batch_axis=1)
+ self.reverse_rnn_ta = base.ta_for_tensor(reverse_rnn_out,
+ clear_after_read=False)
+
+ def _filtering_proposal(self, rnn_out, prev_latent_encoded, prior, t):
+ """Computes the filtering proposal distribution."""
+ return self._proposal(rnn_out,
+ prev_latent_encoded,
+ self.targets_encoded_ta.read(t),
+ prior_mu=prior.mean())
+
+ def _smoothing_proposal(self, rnn_out, prev_latent_encoded, prior, t):
+ """Computes the smoothing proposal distribution."""
+ return self._proposal(rnn_out,
+ prev_latent_encoded,
+ smoothing_tensors=[self.reverse_rnn_ta.read(t)],
+ prior_mu=prior.mean())
+
+ def proposal(self, rnn_out, prev_latent_encoded, prior, t):
+ """Computes the proposal distribution specified by proposal_type.
+
+ Args:
+ rnn_out: The output of the rnn for the current timestep.
+ prev_latent_encoded: Float Tensor of shape
+ [batch_size, encoded_latent_size], the previous latent state z_{t-1}
+ run through latent_encoder.
+ prior: A tf.distributions.Normal distribution representing the prior
+ over z_t, p(z_t | z_{1:t-1}, x_{1:t-1}). Used for 'res_q'.
+ t: A scalar int Tensor, the current timestep.
+ """
+ if self.proposal_type == "filtering":
+ return self._filtering_proposal(rnn_out, prev_latent_encoded, prior, t)
+ elif self.proposal_type == "smoothing":
+ return self._smoothing_proposal(rnn_out, prev_latent_encoded, prior, t)
+ elif self.proposal_type == "prior":
+ return self.transition(rnn_out, prev_latent_encoded)
+
+ def tilt(self, rnn_out, latent_encoded, targets):
+ r_func = self._tilt(rnn_out, latent_encoded)
+ return tf.reduce_sum(r_func.log_prob(targets), axis=-1)
+
+ def propose_and_weight(self, state, t):
+ """Runs the model and computes importance weights for one timestep.
+
+ Runs the model and computes importance weights, sampling from the proposal
+ instead of the transition/prior.
+
+ Args:
+ state: The previous state of the model, a TrainableSRNNState containing
+ the previous rnn state, the previous rnn outs, and the previous encoded
+ latent.
+ t: A scalar integer Tensor, the current timestep.
+ Returns:
+ weights: A float Tensor of shape [batch_size].
+ new_state: The new state of the model.
+ """
+ targets = self.targets_ta.read(t)
+ rnn_out = self.rnn_ta.read(t)
+ p_zt = self.transition(rnn_out, state.latent_encoded)
+ q_zt = self.proposal(rnn_out, state.latent_encoded, p_zt, t)
+ zt = q_zt.sample(seed=self.random_seed)
+ p_xt_given_zt, latent_encoded = self.emission(zt, rnn_out)
+ log_p_xt_given_zt = tf.reduce_sum(p_xt_given_zt.log_prob(targets), axis=-1)
+ log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=-1)
+ log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=-1)
+ weights = log_p_zt + log_p_xt_given_zt - log_q_zt
+ if self._tilt:
+ prev_log_r = tf.cond(
+ tf.greater(t, 0),
+ lambda: self.tilt(state.rnn_out, state.latent_encoded, targets),
+ lambda: 0.) # On the first step, prev_log_r = 0.
+ log_r = tf.cond(
+ tf.less(t + 1, self.max_seq_len),
+ lambda: self.tilt(rnn_out, latent_encoded, self.targets_ta.read(t+1)),
+ lambda: 0.)
+ # On the last step, log_r = 0.
+ log_r *= tf.to_float(t < self.seq_lengths - 1)
+ weights += log_r - prev_log_r
+
+ # This reshape is required because the TensorArray reports different shapes
+ # than the initial state provides (where the first dimension is unknown).
+ # The difference breaks the while_loop. Reshape prevents the error.
+ rnn_out = tf.reshape(rnn_out, tf.shape(state.rnn_out))
+
+ new_state = TrainableSRNNState(rnn_out=rnn_out,
+ rnn_state=state.rnn_state, # unmodified
+ latent_encoded=latent_encoded)
+ return weights, new_state
+
+
+_DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
+ "b": tf.zeros_initializer()}
+
+
+def create_srnn(
+ data_size,
+ latent_size,
+ emission_class,
+ rnn_hidden_size=None,
+ fcnet_hidden_sizes=None,
+ encoded_data_size=None,
+ encoded_latent_size=None,
+ sigma_min=0.0,
+ raw_sigma_bias=0.25,
+ emission_bias_init=0.0,
+ use_tilt=False,
+ proposal_type="filtering",
+ initializers=None,
+ random_seed=None):
+ """A factory method for creating SRNN cells.
+
+ Args:
+ data_size: The dimension of the vectors that make up the data sequences.
+ latent_size: The size of the stochastic latent state of the SRNN.
+ emission_class: The class of the emission distribution. Can be either
+ ConditionalNormalDistribution or ConditionalBernoulliDistribution.
+ rnn_hidden_size: The hidden state dimension of the RNN that forms the
+ deterministic part of this SRNN. If None, then it defaults
+ to latent_size.
+ fcnet_hidden_sizes: A list of python integers, the size of the hidden
+ layers of the fully connected networks that parameterize the conditional
+ distributions of the SRNN. If None, then it defaults to one hidden
+ layer of size latent_size.
+ encoded_data_size: The size of the output of the data encoding network. If
+ None, defaults to latent_size.
+ encoded_latent_size: The size of the output of the latent state encoding
+ network. If None, defaults to latent_size.
+ sigma_min: The minimum value that the standard deviation of the
+ distribution over the latent state can take.
+ raw_sigma_bias: A scalar that is added to the raw standard deviation
+ output from the neural networks that parameterize the prior and
+ approximate posterior. Useful for preventing standard deviations close
+ to zero.
+ emission_bias_init: A bias to added to the raw output of the fully
+ connected network that parameterizes the emission distribution. Useful
+ for initalizing the mean of the distribution to a sensible starting point
+ such as the mean of the training data. Only used with Bernoulli generative
+ distributions.
+ use_tilt: If true, create a SRNN with a tilting function.
+ proposal_type: The type of proposal to use. Can be "filtering", "smoothing",
+ or "prior".
+ initializers: The variable intitializers to use for the fully connected
+ networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
+ to the initializers for the weights and biases. Defaults to xavier for
+ the weights and zeros for the biases when initializers is None.
+ random_seed: A random seed for the SRNN resampling operations.
+ Returns:
+ model: A TrainableSRNN object.
+ """
+ if rnn_hidden_size is None:
+ rnn_hidden_size = latent_size
+ if fcnet_hidden_sizes is None:
+ fcnet_hidden_sizes = [latent_size]
+ if encoded_data_size is None:
+ encoded_data_size = latent_size
+ if encoded_latent_size is None:
+ encoded_latent_size = latent_size
+ if initializers is None:
+ initializers = _DEFAULT_INITIALIZERS
+ data_encoder = snt.nets.MLP(
+ output_sizes=fcnet_hidden_sizes + [encoded_data_size],
+ initializers=initializers,
+ name="data_encoder")
+ latent_encoder = snt.nets.MLP(
+ output_sizes=fcnet_hidden_sizes + [encoded_latent_size],
+ initializers=initializers,
+ name="latent_encoder")
+ transition = base.ConditionalNormalDistribution(
+ size=latent_size,
+ hidden_layer_sizes=fcnet_hidden_sizes,
+ sigma_min=sigma_min,
+ raw_sigma_bias=raw_sigma_bias,
+ initializers=initializers,
+ name="prior")
+ # Construct the emission distribution.
+ if emission_class == base.ConditionalBernoulliDistribution:
+ # For Bernoulli distributed outputs, we initialize the bias so that the
+ # network generates on average the mean from the training set.
+ emission_dist = functools.partial(base.ConditionalBernoulliDistribution,
+ bias_init=emission_bias_init)
+ else:
+ emission_dist = base.ConditionalNormalDistribution
+ emission = emission_dist(
+ size=data_size,
+ hidden_layer_sizes=fcnet_hidden_sizes,
+ initializers=initializers,
+ name="generative")
+ # Construct the proposal distribution.
+ if proposal_type in ["filtering", "smoothing"]:
+ proposal = base.NormalApproximatePosterior(
+ size=latent_size,
+ hidden_layer_sizes=fcnet_hidden_sizes,
+ sigma_min=sigma_min,
+ raw_sigma_bias=raw_sigma_bias,
+ initializers=initializers,
+ smoothing=(proposal_type == "smoothing"),
+ name="approximate_posterior")
+ else:
+ proposal = None
+
+ if use_tilt:
+ tilt = emission_dist(
+ size=data_size,
+ hidden_layer_sizes=fcnet_hidden_sizes,
+ initializers=initializers,
+ name="tilt")
+ else:
+ tilt = None
+
+ rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
+ initializer=initializers["w"])
+ rev_rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
+ initializer=initializers["w"])
+ return TrainableSRNN(
+ rnn_cell, data_encoder, latent_encoder, transition,
+ emission, proposal_type, proposal=proposal, rev_rnn_cell=rev_rnn_cell,
+ tilt=tilt, random_seed=random_seed)
diff --git a/models/research/fivo/fivo/models/srnn_test.py b/models/research/fivo/fivo/models/srnn_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..39e10da134d3834babcf2eef1bb3e97fce12a07a
--- /dev/null
+++ b/models/research/fivo/fivo/models/srnn_test.py
@@ -0,0 +1,105 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.models.srnn."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from fivo.models import base
+from fivo.test_utils import create_srnn
+
+
+class SrnnTest(tf.test.TestCase):
+
+ def test_srnn_normal_emission(self):
+ self.run_srnn(base.ConditionalNormalDistribution, [-5.947752, -1.182961])
+
+ def test_srnn_bernoulli_emission(self):
+ self.run_srnn(base.ConditionalBernoulliDistribution, [-2.566631, -2.479234])
+
+ def run_srnn(self, generative_class, gt_log_alpha):
+ """Tests the SRNN.
+
+ All test values are 'golden values' derived by running the code and copying
+ the output.
+
+ Args:
+ generative_class: The class of the generative distribution to use.
+ gt_log_alpha: The ground-truth value of log alpha.
+ """
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ batch_size = 2
+ model, inputs, targets, _ = create_srnn(generative_class=generative_class,
+ batch_size=batch_size,
+ data_lengths=(1, 1),
+ random_seed=1234)
+ zero_state = model.zero_state(batch_size=batch_size, dtype=tf.float32)
+ model.set_observations([inputs, targets], tf.convert_to_tensor([1, 1]))
+ model_out = model.propose_and_weight(zero_state, 0)
+ sess.run(tf.global_variables_initializer())
+ log_alpha, state = sess.run(model_out)
+ self.assertAllClose(
+ state.latent_encoded,
+ [[0.591787, 1.310583], [-1.523136, 0.953918]])
+ self.assertAllClose(state.rnn_out,
+ [[0.041675, -0.056038, -0.001823, 0.005224],
+ [0.042925, -0.044619, 0.021401, 0.016998]])
+ self.assertAllClose(log_alpha, gt_log_alpha)
+
+ def test_srnn_with_tilt_normal_emission(self):
+ self.run_srnn_with_tilt(base.ConditionalNormalDistribution, [-9.13577, -4.56725])
+
+
+ def test_srnn_with_tilt_bernoulli_emission(self):
+ self.run_srnn_with_tilt(base.ConditionalBernoulliDistribution, [-4.617461, -5.079248])
+
+ def run_srnn_with_tilt(self, generative_class, gt_log_alpha):
+ """Tests the SRNN with a tilting function.
+
+ All test values are 'golden values' derived by running the code and copying
+ the output.
+
+ Args:
+ generative_class: The class of the generative distribution to use.
+ gt_log_alpha: The ground-truth value of log alpha.
+ """
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ batch_size = 2
+ model, inputs, targets, _ = create_srnn(generative_class=generative_class,
+ batch_size=batch_size,
+ data_lengths=(3, 2),
+ random_seed=1234,
+ use_tilt=True)
+ zero_state = model.zero_state(batch_size=batch_size, dtype=tf.float32)
+ model.set_observations([inputs, targets], tf.convert_to_tensor([3, 2]))
+ model_out = model.propose_and_weight(zero_state, 0)
+ sess.run(tf.global_variables_initializer())
+ log_alpha, state = sess.run(model_out)
+ self.assertAllClose(
+ state.latent_encoded,
+ [[0.591787, 1.310583], [-1.523136, 0.953918]])
+ self.assertAllClose(state.rnn_out,
+ [[0.041675, -0.056038, -0.001823, 0.005224],
+ [0.042925, -0.044619, 0.021401, 0.016998]])
+ self.assertAllClose(log_alpha, gt_log_alpha)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/research/fivo/fivo/models/vrnn.py b/models/research/fivo/fivo/models/vrnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e2552088c19f141a75d791d2be0d0a5238ed87c
--- /dev/null
+++ b/models/research/fivo/fivo/models/vrnn.py
@@ -0,0 +1,572 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""VRNN classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+import functools
+
+import sonnet as snt
+import tensorflow as tf
+
+from fivo.models import base
+
+
+VRNNState = namedtuple("VRNNState", "rnn_state latent_encoded")
+
+
+class VRNN(object):
+ """Implementation of a Variational Recurrent Neural Network (VRNN).
+
+ Introduced in "A Recurrent Latent Variable Model for Sequential data"
+ by Chung et al. https://arxiv.org/pdf/1506.02216.pdf.
+
+ The VRNN is a sequence model similar to an RNN that uses stochastic latent
+ variables to improve its representational power. It can be thought of as a
+ sequential analogue to the variational auto-encoder (VAE).
+
+ The VRNN has a deterministic RNN as its backbone, represented by the
+ sequence of RNN hidden states h_t. At each timestep, the RNN hidden state h_t
+ is conditioned on the previous sequence element, x_{t-1}, as well as the
+ latent state from the previous timestep, z_{t-1}.
+
+ In this implementation of the VRNN the latent state z_t is Gaussian. The
+ model's prior over z_t (also called the transition distribution) is
+ distributed as Normal(mu_t, diag(sigma_t^2)) where mu_t and sigma_t are the
+ mean and standard deviation output from a fully connected network that accepts
+ the rnn hidden state h_t as input.
+
+ The emission distribution p(x_t|z_t, h_t) is conditioned on the latent state
+ z_t as well as the current RNN hidden state h_t via a fully connected network.
+
+ To increase the modeling power of the VRNN, two additional networks are
+ used to extract features from the data and the latent state. Those networks
+ are called data_encoder and latent_encoder respectively.
+
+ For an example of how to call the VRNN's methods see sample_step.
+
+ There are a few differences between this exposition and the paper.
+ First, the indexing scheme for h_t is different than the paper's -- what the
+ paper calls h_t we call h_{t+1}. This is the same notation used by Fraccaro
+ et al. to describe the VRNN in the paper linked above. Also, the VRNN paper
+ uses VAE terminology to refer to the different internal networks, so it
+ refers to the emission distribution as the decoder. This implementation also
+ renames the functions phi_x and phi_z in the paper to data_encoder and
+ latent_encoder.
+ """
+
+ def __init__(self,
+ rnn_cell,
+ data_encoder,
+ latent_encoder,
+ transition,
+ emission,
+ random_seed=None):
+ """Create a VRNN.
+
+ Args:
+ rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
+ deterministic backbone of the VRNN. The inputs to the RNN will be the
+ encoded latent state of the previous timestep with shape
+ [batch_size, encoded_latent_size] as well as the encoded input of the
+ current timestep, a Tensor of shape [batch_size, encoded_data_size].
+ data_encoder: A callable that accepts a batch of data x_t and
+ 'encodes' it, e.g. runs it through a fully connected network. Must
+ accept as argument the inputs x_t, a Tensor of the shape
+ [batch_size, data_size] and return a Tensor of shape
+ [batch_size, encoded_data_size]. This callable will be called multiple
+ times in the VRNN cell so if scoping is not handled correctly then
+ multiple copies of the variables in this network could be made. It is
+ recommended to use a snt.nets.MLP module, which takes care of this for
+ you.
+ latent_encoder: A callable that accepts a latent state z_t and
+ 'encodes' it, e.g. runs it through a fully connected network. Must
+ accept as argument a Tensor of shape [batch_size, latent_size] and
+ return a Tensor of shape [batch_size, encoded_latent_size].
+ This callable must also have the property 'output_size' defined,
+ returning encoded_latent_size.
+ transition: A callable that implements the transition distribution
+ p(z_t|h_t). Must accept as argument the previous RNN hidden state and
+ return a tf.distributions.Normal distribution conditioned on the input.
+ emission: A callable that implements the emission distribution
+ p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
+ and the RNN hidden state and return a subclass of
+ tf.distributions.Distribution that can be used to evaluate the logprob
+ of the targets.
+ random_seed: The seed for the random ops. Sets the seed for sample_step.
+ """
+ self.random_seed = random_seed
+ self.rnn_cell = rnn_cell
+ self.data_encoder = data_encoder
+ self.latent_encoder = latent_encoder
+ self.encoded_z_size = latent_encoder.output_size
+ self.state_size = (self.rnn_cell.state_size)
+ self._transition = transition
+ self._emission = emission
+
+ def zero_state(self, batch_size, dtype):
+ """The initial state of the VRNN.
+
+ Contains the initial state of the RNN and the inital encoded latent.
+
+ Args:
+ batch_size: The batch size.
+ dtype: The data type of the VRNN.
+ Returns:
+ zero_state: The initial state of the VRNN.
+ """
+ return VRNNState(
+ rnn_state=self.rnn_cell.zero_state(batch_size, dtype),
+ latent_encoded=tf.zeros(
+ [batch_size, self.latent_encoder.output_size], dtype=dtype))
+
+ def run_rnn(self, prev_rnn_state, prev_latent_encoded, inputs):
+ """Runs the deterministic RNN for one step.
+
+ Args:
+ prev_rnn_state: The state of the RNN from the previous timestep.
+ prev_latent_encoded: Float Tensor of shape
+ [batch_size, encoded_latent_size], the previous latent state z_{t-1}
+ run through latent_encoder.
+ inputs: A Tensor of shape [batch_size, data_size], the current inputs to
+ the model. Most often this is x_{t-1}, the previous token in the
+ observation sequence.
+ Returns:
+ rnn_out: The output of the RNN.
+ rnn_state: The new state of the RNN.
+ """
+ inputs_encoded = self.data_encoder(tf.to_float(inputs))
+ rnn_inputs = tf.concat([inputs_encoded, prev_latent_encoded], axis=1)
+ rnn_out, rnn_state = self.rnn_cell(rnn_inputs, prev_rnn_state)
+ return rnn_out, rnn_state
+
+ def transition(self, rnn_out):
+ """Computes the transition distribution p(z_t|h_t).
+
+ Note that p(z_t | h_t) = p(z_t| z_{1:t-1}, x_{1:t-1})
+
+ Args:
+ rnn_out: The output of the rnn for the current timestep.
+ Returns:
+ p(z_t | h_t): A normal distribution with event shape
+ [batch_size, latent_size].
+ """
+ return self._transition(rnn_out)
+
+ def emission(self, latent, rnn_out):
+ """Computes the emission distribution p(x_t | z_t, h_t).
+
+ Note that p(x_t | z_t, h_t) = p(x_t | z_{1:t}, x_{1:t-1}).
+
+ Args:
+ latent: The stochastic latent state z_t.
+ rnn_out: The output of the rnn for the current timestep.
+ Returns:
+ p(x_t | z_t, h_t): A distribution with event shape
+ [batch_size, data_size].
+ latent_encoded: The latent state encoded with latent_encoder. Should be
+ passed to run_rnn on the next timestep.
+ """
+ latent_encoded = self.latent_encoder(latent)
+ return self._emission(latent_encoded, rnn_out), latent_encoded
+
+ def sample_step(self, prev_state, inputs, unused_t):
+ """Samples one output from the model.
+
+ Args:
+ prev_state: The previous state of the model, a VRNNState containing the
+ previous rnn state and the previous encoded latent.
+ inputs: A Tensor of shape [batch_size, data_size], the current inputs to
+ the model. Most often this is x_{t-1}, the previous token in the
+ observation sequence.
+ unused_t: The current timestep. Not used currently.
+ Returns:
+ new_state: The next state of the model, a VRNNState.
+ xt: A float Tensor of shape [batch_size, data_size], an output sampled
+ from the emission distribution.
+ """
+ rnn_out, rnn_state = self.run_rnn(prev_state.rnn_state,
+ prev_state.latent_encoded,
+ inputs)
+ p_zt = self.transition(rnn_out)
+ zt = p_zt.sample(seed=self.random_seed)
+ p_xt_given_zt, latent_encoded = self.emission(zt, rnn_out)
+ xt = p_xt_given_zt.sample(seed=self.random_seed)
+ new_state = VRNNState(rnn_state=rnn_state, latent_encoded=latent_encoded)
+ return new_state, tf.to_float(xt)
+
+# pylint: disable=invalid-name
+# pylint thinks this is a top-level constant.
+TrainableVRNNState = namedtuple("TrainableVRNNState",
+ VRNNState._fields + ("rnn_out",))
+# pylint: enable=g-invalid-name
+
+
+class TrainableVRNN(VRNN, base.ELBOTrainableSequenceModel):
+ """A VRNN subclass with proposals and methods for training and evaluation.
+
+ This class adds proposals used for training with importance-sampling based
+ methods such as the ELBO. The model can be configured to propose from one
+ of three proposals: a learned filtering proposal, a learned smoothing
+ proposal, or the prior (i.e. the transition distribution).
+
+ As described in the VRNN paper, the learned filtering proposal is
+ parameterized by a fully connected neural network that accepts as input the
+ current target x_t and the current rnn output h_t. The learned smoothing
+ proposal is also given the hidden state of an RNN run in reverse over the
+ inputs, so as to incorporate information about future observations. This
+ smoothing proposal is not described in the VRNN paper.
+
+ All learned proposals use the 'res_q' parameterization, meaning that instead
+ of directly producing the mean of z_t, the proposal network predicts the
+ 'residual' from the prior's mean. This is explored more in section 3.3 of
+ https://arxiv.org/pdf/1605.07571.pdf.
+
+ During training, the latent state z_t is sampled from the proposal and the
+ reparameterization trick is used to provide low-variance gradients.
+
+ Note that the VRNN paper uses VAE terminology to refer to the different
+ internal networks, so the proposal is referred to as the encoder.
+ """
+
+ def __init__(self,
+ rnn_cell,
+ data_encoder,
+ latent_encoder,
+ transition,
+ emission,
+ proposal_type,
+ proposal=None,
+ rev_rnn_cell=None,
+ tilt=None,
+ random_seed=None):
+ """Create a trainable RNN.
+
+ Args:
+ rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
+ deterministic backbone of the VRNN. The inputs to the RNN will be the
+ encoded latent state of the previous timestep with shape
+ [batch_size, encoded_latent_size] as well as the encoded input of the
+ current timestep, a Tensor of shape [batch_size, encoded_data_size].
+ data_encoder: A callable that accepts a batch of data x_t and
+ 'encodes' it, e.g. runs it through a fully connected network. Must
+ accept as argument the inputs x_t, a Tensor of the shape
+ [batch_size, data_size] and return a Tensor of shape
+ [batch_size, encoded_data_size]. This callable will be called multiple
+ times in the VRNN cell so if scoping is not handled correctly then
+ multiple copies of the variables in this network could be made. It is
+ recommended to use a snt.nets.MLP module, which takes care of this for
+ you.
+ latent_encoder: A callable that accepts a latent state z_t and
+ 'encodes' it, e.g. runs it through a fully connected network. Must
+ accept as argument a Tensor of shape [batch_size, latent_size] and
+ return a Tensor of shape [batch_size, encoded_latent_size].
+ This callable must also have the property 'output_size' defined,
+ returning encoded_latent_size.
+ transition: A callable that implements the transition distribution
+ p(z_t|h_t). Must accept as argument the previous RNN hidden state and
+ return a tf.distributions.Normal distribution conditioned on the input.
+ emission: A callable that implements the emission distribution
+ p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
+ and the RNN hidden state and return a subclass of
+ tf.distributions.Distribution that can be used to evaluate the logprob
+ of the targets.
+ proposal_type: A string indicating the type of proposal to use. Can
+ be either "filtering", "smoothing", or "prior". When proposal_type is
+ "filtering" or "smoothing", proposal must be provided. When
+ proposal_type is "smoothing", rev_rnn_cell must also be provided.
+ proposal: A callable that implements the proposal q(z_t| h_t, x_{1:T}).
+ If proposal_type is "filtering" then proposal must accept as arguments
+ the current rnn output, the encoded target of the current timestep,
+ and the mean of the prior. If proposal_type is "smoothing" then
+ in addition to the current rnn output and the mean of the prior
+ proposal must accept as arguments the output of the reverse rnn.
+ proposal should return a tf.distributions.Normal distribution
+ conditioned on its inputs. If proposal_type is "prior" this argument is
+ ignored.
+ rev_rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will aggregate
+ observation statistics in the reverse direction. The inputs to the RNN
+ will be the encoded reverse input of the current timestep, a Tensor of
+ shape [batch_size, encoded_data_size].
+ tilt: A callable that implements the log of a positive tilting function
+ (ideally approximating log p(x_{t+1}|z_t, h_t). Must accept as arguments
+ the encoded latent state and the RNN hidden state and return a subclass
+ of tf.distributions.Distribution that can be used to evaluate the
+ logprob of x_{t+1}. Optionally, None and then no tilt is used.
+ random_seed: The seed for the random ops. Sets the seed for sample_step
+ and __call__.
+ """
+ super(TrainableVRNN, self).__init__(
+ rnn_cell, data_encoder, latent_encoder,
+ transition, emission, random_seed=random_seed)
+ self.rev_rnn_cell = rev_rnn_cell
+ self._tilt = tilt
+ assert proposal_type in ["filtering", "smoothing", "prior"]
+ self._proposal = proposal
+ self.proposal_type = proposal_type
+ if proposal_type != "prior":
+ assert proposal, "If not proposing from the prior, must provide proposal."
+ if proposal_type == "smoothing":
+ assert rev_rnn_cell, "Must provide rev_rnn_cell for smoothing proposal."
+
+ def zero_state(self, batch_size, dtype):
+ super_state = super(TrainableVRNN, self).zero_state(batch_size, dtype)
+ return TrainableVRNNState(
+ rnn_out=tf.zeros([batch_size, self.rnn_cell.output_size], dtype=dtype),
+ **super_state._asdict())
+
+ def set_observations(self, observations, seq_lengths):
+ """Stores the model's observations.
+
+ Stores the observations (inputs and targets) in TensorArrays and precomputes
+ things for later like the reverse RNN output and encoded targets.
+
+ Args:
+ observations: The observations of the model, a tuple containing two
+ Tensors of shape [max_seq_len, batch_size, data_size]. The Tensors
+ should be the inputs and targets, respectively.
+ seq_lengths: An int Tensor of shape [batch_size] containing the length
+ of each sequence in observations.
+ """
+ inputs, targets = observations
+ self.seq_lengths = seq_lengths
+ self.max_seq_len = tf.reduce_max(seq_lengths)
+ self.inputs_ta = base.ta_for_tensor(inputs, clear_after_read=False)
+ self.targets_ta = base.ta_for_tensor(targets, clear_after_read=False)
+ targets_encoded = base.encode_all(targets, self.data_encoder)
+ self.targets_encoded_ta = base.ta_for_tensor(targets_encoded,
+ clear_after_read=False)
+ if self.rev_rnn_cell:
+ reverse_targets_encoded = tf.reverse_sequence(
+ targets_encoded, seq_lengths, seq_axis=0, batch_axis=1)
+ # Compute the reverse rnn over the targets.
+ reverse_rnn_out, _ = tf.nn.dynamic_rnn(self.rev_rnn_cell,
+ reverse_targets_encoded,
+ time_major=True,
+ dtype=tf.float32)
+ reverse_rnn_out = tf.reverse_sequence(reverse_rnn_out, seq_lengths,
+ seq_axis=0, batch_axis=1)
+ self.reverse_rnn_ta = base.ta_for_tensor(reverse_rnn_out,
+ clear_after_read=False)
+
+ def _filtering_proposal(self, rnn_out, prior, t):
+ """Computes the filtering proposal distribution."""
+ return self._proposal(rnn_out,
+ self.targets_encoded_ta.read(t),
+ prior_mu=prior.mean())
+
+ def _smoothing_proposal(self, rnn_out, prior, t):
+ """Computes the smoothing proposal distribution."""
+ return self._proposal(rnn_out,
+ smoothing_tensors=[self.reverse_rnn_ta.read(t)],
+ prior_mu=prior.mean())
+
+ def proposal(self, rnn_out, prior, t):
+ """Computes the proposal distribution specified by proposal_type.
+
+ Args:
+ rnn_out: The output of the rnn for the current timestep.
+ prior: A tf.distributions.Normal distribution representing the prior
+ over z_t, p(z_t | z_{1:t-1}, x_{1:t-1}). Used for 'res_q'.
+ t: A scalar int Tensor, the current timestep.
+ """
+ if self.proposal_type == "filtering":
+ return self._filtering_proposal(rnn_out, prior, t)
+ elif self.proposal_type == "smoothing":
+ return self._smoothing_proposal(rnn_out, prior, t)
+ elif self.proposal_type == "prior":
+ return self.transition(rnn_out)
+
+ def tilt(self, rnn_out, latent_encoded, targets):
+ r_func = self._tilt(rnn_out, latent_encoded)
+ return tf.reduce_sum(r_func.log_prob(targets), axis=-1)
+
+ def propose_and_weight(self, state, t):
+ """Runs the model and computes importance weights for one timestep.
+
+ Runs the model and computes importance weights, sampling from the proposal
+ instead of the transition/prior.
+
+ Args:
+ state: The previous state of the model, a TrainableVRNNState containing
+ the previous rnn state, the previous rnn outs, and the previous encoded
+ latent.
+ t: A scalar integer Tensor, the current timestep.
+ Returns:
+ weights: A float Tensor of shape [batch_size].
+ new_state: The new state of the model.
+ """
+ inputs = self.inputs_ta.read(t)
+ targets = self.targets_ta.read(t)
+ rnn_out, next_rnn_state = self.run_rnn(state.rnn_state,
+ state.latent_encoded,
+ inputs)
+ p_zt = self.transition(rnn_out)
+ q_zt = self.proposal(rnn_out, p_zt, t)
+ zt = q_zt.sample(seed=self.random_seed)
+ p_xt_given_zt, latent_encoded = self.emission(zt, rnn_out)
+ log_p_xt_given_zt = tf.reduce_sum(p_xt_given_zt.log_prob(targets), axis=-1)
+ log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=-1)
+ log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=-1)
+ weights = log_p_zt + log_p_xt_given_zt - log_q_zt
+ if self._tilt:
+ prev_log_r = tf.cond(
+ tf.greater(t, 0),
+ lambda: self.tilt(state.rnn_out, state.latent_encoded, targets),
+ lambda: 0.) # On the first step, prev_log_r = 0.
+ log_r = tf.cond(
+ tf.less(t + 1, self.max_seq_len),
+ lambda: self.tilt(rnn_out, latent_encoded, self.targets_ta.read(t+1)),
+ lambda: 0.)
+ # On the last step, log_r = 0.
+ log_r *= tf.to_float(t < self.seq_lengths - 1)
+ weights += log_r - prev_log_r
+ new_state = TrainableVRNNState(rnn_state=next_rnn_state,
+ rnn_out=rnn_out,
+ latent_encoded=latent_encoded)
+ return weights, new_state
+
+
+_DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
+ "b": tf.zeros_initializer()}
+
+
+def create_vrnn(
+ data_size,
+ latent_size,
+ emission_class,
+ rnn_hidden_size=None,
+ fcnet_hidden_sizes=None,
+ encoded_data_size=None,
+ encoded_latent_size=None,
+ sigma_min=0.0,
+ raw_sigma_bias=0.25,
+ emission_bias_init=0.0,
+ use_tilt=False,
+ proposal_type="filtering",
+ initializers=None,
+ random_seed=None):
+ """A factory method for creating VRNN cells.
+
+ Args:
+ data_size: The dimension of the vectors that make up the data sequences.
+ latent_size: The size of the stochastic latent state of the VRNN.
+ emission_class: The class of the emission distribution. Can be either
+ ConditionalNormalDistribution or ConditionalBernoulliDistribution.
+ rnn_hidden_size: The hidden state dimension of the RNN that forms the
+ deterministic part of this VRNN. If None, then it defaults
+ to latent_size.
+ fcnet_hidden_sizes: A list of python integers, the size of the hidden
+ layers of the fully connected networks that parameterize the conditional
+ distributions of the VRNN. If None, then it defaults to one hidden
+ layer of size latent_size.
+ encoded_data_size: The size of the output of the data encoding network. If
+ None, defaults to latent_size.
+ encoded_latent_size: The size of the output of the latent state encoding
+ network. If None, defaults to latent_size.
+ sigma_min: The minimum value that the standard deviation of the
+ distribution over the latent state can take.
+ raw_sigma_bias: A scalar that is added to the raw standard deviation
+ output from the neural networks that parameterize the prior and
+ approximate posterior. Useful for preventing standard deviations close
+ to zero.
+ emission_bias_init: A bias to added to the raw output of the fully
+ connected network that parameterizes the emission distribution. Useful
+ for initalizing the mean of the distribution to a sensible starting point
+ such as the mean of the training data. Only used with Bernoulli generative
+ distributions.
+ use_tilt: If true, create a VRNN with a tilting function.
+ proposal_type: The type of proposal to use. Can be "filtering", "smoothing",
+ or "prior".
+ initializers: The variable intitializers to use for the fully connected
+ networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
+ to the initializers for the weights and biases. Defaults to xavier for
+ the weights and zeros for the biases when initializers is None.
+ random_seed: A random seed for the VRNN resampling operations.
+ Returns:
+ model: A TrainableVRNN object.
+ """
+ if rnn_hidden_size is None:
+ rnn_hidden_size = latent_size
+ if fcnet_hidden_sizes is None:
+ fcnet_hidden_sizes = [latent_size]
+ if encoded_data_size is None:
+ encoded_data_size = latent_size
+ if encoded_latent_size is None:
+ encoded_latent_size = latent_size
+ if initializers is None:
+ initializers = _DEFAULT_INITIALIZERS
+ data_encoder = snt.nets.MLP(
+ output_sizes=fcnet_hidden_sizes + [encoded_data_size],
+ initializers=initializers,
+ name="data_encoder")
+ latent_encoder = snt.nets.MLP(
+ output_sizes=fcnet_hidden_sizes + [encoded_latent_size],
+ initializers=initializers,
+ name="latent_encoder")
+ transition = base.ConditionalNormalDistribution(
+ size=latent_size,
+ hidden_layer_sizes=fcnet_hidden_sizes,
+ sigma_min=sigma_min,
+ raw_sigma_bias=raw_sigma_bias,
+ initializers=initializers,
+ name="prior")
+ # Construct the emission distribution.
+ if emission_class == base.ConditionalBernoulliDistribution:
+ # For Bernoulli distributed outputs, we initialize the bias so that the
+ # network generates on average the mean from the training set.
+ emission_dist = functools.partial(base.ConditionalBernoulliDistribution,
+ bias_init=emission_bias_init)
+ else:
+ emission_dist = base.ConditionalNormalDistribution
+ emission = emission_dist(
+ size=data_size,
+ hidden_layer_sizes=fcnet_hidden_sizes,
+ initializers=initializers,
+ name="generative")
+ # Construct the proposal distribution.
+ if proposal_type in ["filtering", "smoothing"]:
+ proposal = base.NormalApproximatePosterior(
+ size=latent_size,
+ hidden_layer_sizes=fcnet_hidden_sizes,
+ sigma_min=sigma_min,
+ raw_sigma_bias=raw_sigma_bias,
+ initializers=initializers,
+ smoothing=(proposal_type == "smoothing"),
+ name="approximate_posterior")
+ else:
+ proposal = None
+
+ if use_tilt:
+ tilt = emission_dist(
+ size=data_size,
+ hidden_layer_sizes=fcnet_hidden_sizes,
+ initializers=initializers,
+ name="tilt")
+ else:
+ tilt = None
+
+ rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
+ initializer=initializers["w"])
+ rev_rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
+ initializer=initializers["w"])
+ return TrainableVRNN(
+ rnn_cell, data_encoder, latent_encoder, transition,
+ emission, proposal_type, proposal=proposal, rev_rnn_cell=rev_rnn_cell,
+ tilt=tilt, random_seed=random_seed)
diff --git a/models/research/fivo/fivo/models/vrnn_test.py b/models/research/fivo/fivo/models/vrnn_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d9bde3d5b6c6f66a82bd331cf50a87737864239
--- /dev/null
+++ b/models/research/fivo/fivo/models/vrnn_test.py
@@ -0,0 +1,137 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.models.vrnn."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import tensorflow as tf
+
+from fivo.models import base
+from fivo.test_utils import create_vrnn
+
+
+class VrnnTest(tf.test.TestCase):
+
+ def test_vrnn_normal_emission(self):
+ self.run_vrnn(base.ConditionalNormalDistribution, [-4.509767, -3.242221])
+
+ def test_vrnn_bernoulli_emission(self):
+ self.run_vrnn(base.ConditionalBernoulliDistribution, [-2.63812733, -2.02216434]),
+
+ def run_vrnn(self, generative_class, gt_log_p_x_given_z):
+ """Tests the VRNN.
+
+ All test values are 'golden values' derived by running the code and copying
+ the output.
+
+ Args:
+ generative_class: The class of the generative distribution to use.
+ gt_log_p_x_given_z: The ground-truth value of log p(x|z).
+ """
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ batch_size = 2
+ model, inputs, targets, _ = create_vrnn(generative_class=generative_class,
+ batch_size=batch_size,
+ data_lengths=(1, 1),
+ random_seed=1234)
+ zero_state = model.zero_state(batch_size=batch_size, dtype=tf.float32)
+ model.set_observations([inputs, targets], tf.convert_to_tensor([1, 1]))
+ model_out = model.propose_and_weight(zero_state, 0)
+ sess.run(tf.global_variables_initializer())
+ log_alpha, state = sess.run(model_out)
+ rnn_state, latent_state, rnn_out = state
+ self.assertAllClose(
+ rnn_state.c,
+ [[-0.15014534, 0.0143046, 0.00160489, -0.12899463],
+ [-0.25015137, 0.09377634, -0.05000039, -0.17123522]])
+ self.assertAllClose(
+ rnn_state.h,
+ [[-0.06842659, 0.00760155, 0.00096106, -0.05434214],
+ [-0.1109542, 0.0441804, -0.03121299, -0.07882939]]
+ )
+ self.assertAllClose(
+ latent_state,
+ [[0.025241, 0.122011, 1.066661, 0.316209, -0.25369, 0.108215,
+ -1.501128, -0.440111, -0.40447, -0.156649, 1.206028],
+ [0.066824, 0.519937, 0.610973, 0.977739, -0.121889, -0.223429,
+ -0.32687, -0.578763, -0.56965, 0.751886, 0.681606]]
+ )
+ self.assertAllClose(rnn_out, [[-0.068427, 0.007602, 0.000961, -0.054342],
+ [-0.110954, 0.04418, -0.031213, -0.078829]])
+ gt_log_q_z = [-8.0895052, -6.75819111]
+ gt_log_p_z = [-7.246827, -6.512877]
+ gt_log_alpha = (np.array(gt_log_p_z) +
+ np.array(gt_log_p_x_given_z) -
+ np.array(gt_log_q_z))
+ self.assertAllClose(log_alpha, gt_log_alpha)
+
+ def test_vrnn_with_tilt_normal_emission(self):
+ self.run_vrnn_with_tilt(base.ConditionalNormalDistribution, [-5.198263, -6.31686])
+
+ def test_vrnn_with_tilt_bernoulli_emission(self):
+ self.run_vrnn_with_tilt(base.ConditionalBernoulliDistribution, [-4.66985, -3.802245])
+
+ def run_vrnn_with_tilt(self, generative_class, gt_log_alpha):
+ """Tests the VRNN with a tilting function.
+
+ All test values are 'golden values' derived by running the code and copying
+ the output.
+
+ Args:
+ generative_class: The class of the generative distribution to use.
+ gt_log_alpha: The ground-truth value of log alpha.
+ """
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ batch_size = 2
+ model, inputs, targets, _ = create_vrnn(generative_class=generative_class,
+ batch_size=batch_size,
+ data_lengths=(3, 2),
+ random_seed=1234,
+ use_tilt=True)
+ zero_state = model.zero_state(batch_size=batch_size, dtype=tf.float32)
+ model.set_observations([inputs, targets], tf.convert_to_tensor([3, 2]))
+ model_out = model.propose_and_weight(zero_state, 0)
+ sess.run(tf.global_variables_initializer())
+ log_alpha, state = sess.run(model_out)
+ rnn_state, latent_state, rnn_out = state
+ self.assertAllClose(
+ rnn_state.c,
+ [[-0.15014534, 0.0143046, 0.00160489, -0.12899463],
+ [-0.25015137, 0.09377634, -0.05000039, -0.17123522]])
+ self.assertAllClose(
+ rnn_state.h,
+ [[-0.06842659, 0.00760155, 0.00096106, -0.05434214],
+ [-0.1109542, 0.0441804, -0.03121299, -0.07882939]]
+ )
+ self.assertAllClose(
+ latent_state,
+ [[0.025241, 0.122011, 1.066661, 0.316209, -0.25369, 0.108215,
+ -1.501128, -0.440111, -0.40447, -0.156649, 1.206028],
+ [0.066824, 0.519937, 0.610973, 0.977739, -0.121889, -0.223429,
+ -0.32687, -0.578763, -0.56965, 0.751886, 0.681606]]
+ )
+ self.assertAllClose(rnn_out, [[-0.068427, 0.007602, 0.000961, -0.054342],
+ [-0.110954, 0.04418, -0.031213, -0.078829]])
+ self.assertAllClose(log_alpha, gt_log_alpha)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/research/fivo/fivo/nested_utils.py b/models/research/fivo/fivo/nested_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef956a80c40d55331a3acbfe78111e099559ddea
--- /dev/null
+++ b/models/research/fivo/fivo/nested_utils.py
@@ -0,0 +1,139 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A set of utils for dealing with nested lists and tuples of Tensors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import tensorflow as tf
+
+from tensorflow.python.util import nest
+
+
+def map_nested(map_fn, nested):
+ """Executes map_fn on every element in a (potentially) nested structure.
+
+ Args:
+ map_fn: A callable to execute on each element in 'nested'.
+ nested: A potentially nested combination of sequence objects. Sequence
+ objects include tuples, lists, namedtuples, and all subclasses of
+ collections.Sequence except strings. See nest.is_sequence for details.
+ For example [1, ('hello', 4.3)] is a nested structure containing elements
+ 1, 'hello', and 4.3.
+ Returns:
+ out_structure: A potentially nested combination of sequence objects with the
+ same structure as the 'nested' input argument. out_structure
+ contains the result of applying map_fn to each element in 'nested'. For
+ example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
+ """
+ out = map(map_fn, nest.flatten(nested))
+ return nest.pack_sequence_as(nested, out)
+
+
+def tile_tensors(tensors, multiples):
+ """Tiles a set of Tensors.
+
+ Args:
+ tensors: A potentially nested tuple or list of Tensors with rank
+ greater than or equal to the length of 'multiples'. The Tensors do not
+ need to have the same rank, but their rank must not be dynamic.
+ multiples: A python list of ints indicating how to tile each Tensor
+ in 'tensors'. Similar to the 'multiples' argument to tf.tile.
+ Returns:
+ tiled_tensors: A potentially nested tuple or list of Tensors with the same
+ structure as the 'tensors' input argument. Contains the result of
+ applying tf.tile to each Tensor in 'tensors'. When the rank of a Tensor
+ in 'tensors' is greater than the length of multiples, multiples is padded
+ at the end with 1s. For example when tiling a 4-dimensional Tensor with
+ multiples [3, 4], multiples would be padded to [3, 4, 1, 1] before tiling.
+ """
+ def tile_fn(x):
+ return tf.tile(x, multiples + [1] * (x.shape.ndims - len(multiples)))
+
+ return map_nested(tile_fn, tensors)
+
+
+def where_tensors(condition, x_tensors, y_tensors):
+ """Performs a tf.where operation on a two sets of Tensors.
+
+ Args:
+ condition: The condition tensor to use for the where operation.
+ x_tensors: A potentially nested tuple or list of Tensors.
+ y_tensors: A potentially nested tuple or list of Tensors. Must have the
+ same structure as x_tensors.
+ Returns:
+ whered_tensors: A potentially nested tuple or list of Tensors with the
+ same structure as the 'tensors' input argument. Contains the result of
+ applying tf.where(condition, x, y) on each pair of elements in x_tensors
+ and y_tensors.
+ """
+ flat_x = nest.flatten(x_tensors)
+ flat_y = nest.flatten(y_tensors)
+ result = [tf.where(condition, x, y) for x, y in
+ itertools.izip(flat_x, flat_y)]
+
+ return nest.pack_sequence_as(x_tensors, result)
+
+
+def gather_tensors(tensors, indices):
+ """Performs a tf.gather operation on a set of Tensors.
+
+ Args:
+ tensors: A potentially nested tuple or list of Tensors.
+ indices: The indices to use for the gather operation.
+ Returns:
+ gathered_tensors: A potentially nested tuple or list of Tensors with the
+ same structure as the 'tensors' input argument. Contains the result of
+ applying tf.gather(x, indices) on each element x in 'tensors'.
+ """
+ return map_nested(lambda x: tf.gather(x, indices), tensors)
+
+
+def tas_for_tensors(tensors, length, **kwargs):
+ """Unstacks a set of Tensors into TensorArrays.
+
+ Args:
+ tensors: A potentially nested tuple or list of Tensors with length in the
+ first dimension greater than or equal to the 'length' input argument.
+ length: The desired length of the TensorArrays.
+ **kwargs: Keyword args for TensorArray constructor.
+ Returns:
+ tensorarrays: A potentially nested tuple or list of TensorArrays with the
+ same structure as 'tensors'. Contains the result of unstacking each Tensor
+ in 'tensors'.
+ """
+ def map_fn(x):
+ ta = tf.TensorArray(x.dtype, length,
+ name=x.name.split(':')[0] + '_ta', **kwargs)
+ return ta.unstack(x[:length, :])
+ return map_nested(map_fn, tensors)
+
+
+def read_tas(tas, index):
+ """Performs a read operation on a set of TensorArrays.
+
+ Args:
+ tas: A potentially nested tuple or list of TensorArrays with length greater
+ than 'index'.
+ index: The location to read from.
+ Returns:
+ read_tensors: A potentially nested tuple or list of Tensors with the same
+ structure as the 'tas' input argument. Contains the result of
+ performing a read operation at 'index' on each TensorArray in 'tas'.
+ """
+ return map_nested(lambda ta: ta.read(index), tas)
diff --git a/models/research/fivo/fivo/nested_utils_test.py b/models/research/fivo/fivo/nested_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..87991dd79cdb29d12944f9afa3fd0c5178dc4eb5
--- /dev/null
+++ b/models/research/fivo/fivo/nested_utils_test.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.nested_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import tensorflow as tf
+nest = tf.contrib.framework.nest
+
+from fivo import nested_utils
+
+# An example namedtuple for use in the following tests.
+ExampleTuple = collections.namedtuple('ExampleTuple', ['a', 'b'])
+
+
+class NestedUtilsTest(tf.test.TestCase):
+
+ def test_map_nested_works_on_nested_structures(self):
+ """Check that map_nested works with nested structures."""
+ original = [1, (2, 3.2, (4., ExampleTuple(5, 6)))]
+ expected = [2, (3, 4.2, (5., ExampleTuple(6, 7)))]
+ out = nested_utils.map_nested(lambda x: x+1, original)
+ self.assertEqual(expected, out)
+
+ def test_map_nested_works_on_single_objects(self):
+ """Check that map_nested works with raw objects."""
+ original = 1
+ expected = 2
+ out = nested_utils.map_nested(lambda x: x+1, original)
+ self.assertEqual(expected, out)
+
+ def test_map_nested_works_on_flat_lists(self):
+ """Check that map_nested works with a flat list."""
+ original = [1, 2, 3]
+ expected = [2, 3, 4]
+ out = nested_utils.map_nested(lambda x: x+1, original)
+ self.assertEqual(expected, out)
+
+ def test_tile_tensors(self):
+ """Checks that tile_tensors correctly tiles tensors of different ranks."""
+ a = tf.range(20)
+ b = tf.reshape(a, [2, 10])
+ c = tf.reshape(a, [2, 2, 5])
+ a_tiled = tf.tile(a, [3])
+ b_tiled = tf.tile(b, [3, 1])
+ c_tiled = tf.tile(c, [3, 1, 1])
+ tensors = [a, (b, ExampleTuple(c, c))]
+ expected_tensors = [a_tiled, (b_tiled, ExampleTuple(c_tiled, c_tiled))]
+ tiled = nested_utils.tile_tensors(tensors, [3])
+ nest.assert_same_structure(expected_tensors, tiled)
+ with self.test_session() as sess:
+ expected, out = sess.run([expected_tensors, tiled])
+ expected = nest.flatten(expected)
+ out = nest.flatten(out)
+ # Check that the tiling is correct.
+ for x, y in zip(expected, out):
+ self.assertAllClose(x, y)
+
+ def test_gather_tensors(self):
+ a = tf.reshape(tf.range(20), [5, 4])
+ inds = [0, 0, 1, 4]
+ a_gathered = tf.gather(a, inds)
+ tensors = [a, (a, ExampleTuple(a, a))]
+ gt_gathered = [a_gathered, (a_gathered,
+ ExampleTuple(a_gathered, a_gathered))]
+ gathered = nested_utils.gather_tensors(tensors, inds)
+ nest.assert_same_structure(gt_gathered, gathered)
+ with self.test_session() as sess:
+ gt, out = sess.run([gt_gathered, gathered])
+ gt = nest.flatten(gt)
+ out = nest.flatten(out)
+ # Check that the gathering is correct.
+ for x, y in zip(gt, out):
+ self.assertAllClose(x, y)
+
+ def test_tas_for_tensors(self):
+ a = tf.reshape(tf.range(20), [5, 4])
+ tensors = [a, (a, ExampleTuple(a, a))]
+ tas = nested_utils.tas_for_tensors(tensors, 5)
+ nest.assert_same_structure(tensors, tas)
+ # We can't pass TensorArrays to sess.run so instead we turn then back into
+ # tensors to check that they were created correctly.
+ stacked = nested_utils.map_nested(lambda x: x.stack(), tas)
+ with self.test_session() as sess:
+ gt, out = sess.run([tensors, stacked])
+ gt = nest.flatten(gt)
+ out = nest.flatten(out)
+ # Check that the tas were created correctly.
+ for x, y in zip(gt, out):
+ self.assertAllClose(x, y)
+
+ def test_read_tas(self):
+ a = tf.reshape(tf.range(20), [5, 4])
+ a_read = a[3, :]
+ tensors = [a, (a, ExampleTuple(a, a))]
+ gt_read = [a_read, (a_read, ExampleTuple(a_read, a_read))]
+ tas = nested_utils.tas_for_tensors(tensors, 5)
+ tas_read = nested_utils.read_tas(tas, 3)
+ nest.assert_same_structure(tas, tas_read)
+ with self.test_session() as sess:
+ gt, out = sess.run([gt_read, tas_read])
+ gt = nest.flatten(gt)
+ out = nest.flatten(out)
+ # Check that the tas were read correctly.
+ for x, y in zip(gt, out):
+ self.assertAllClose(x, y)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/fivo/fivo/runners.py b/models/research/fivo/fivo/runners.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec6fb91bf51fa2c7c44d7402e635d257f80c3f7a
--- /dev/null
+++ b/models/research/fivo/fivo/runners.py
@@ -0,0 +1,489 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""High-level code for creating and running FIVO-related Tensorflow graphs.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import os
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from fivo import bounds
+from fivo import smc
+
+from fivo.data import datasets
+from fivo.models import base
+from fivo.models import srnn
+from fivo.models import vrnn
+
+
+def create_dataset_and_model(config, split, shuffle, repeat):
+ """Creates the dataset and model for a given config.
+
+ Args:
+ config: A configuration object with config values accessible as properties.
+ Most likely a FLAGS object. This function expects the properties
+ batch_size, dataset_path, dataset_type, and latent_size to be defined.
+ split: The dataset split to load.
+ shuffle: If true, shuffle the dataset randomly.
+ repeat: If true, repeat the dataset endlessly.
+ Returns:
+ inputs: A batch of input sequences represented as a dense Tensor of shape
+ [time, batch_size, data_dimension].
+ targets: A batch of target sequences represented as a dense Tensor of
+ shape [time, batch_size, data_dimension].
+ lens: An int Tensor of shape [batch_size] representing the lengths of each
+ sequence in the batch.
+ model: A vrnn.VRNNCell model object.
+ Raises:
+ ValueError: if the config is invalid.
+ """
+ sigma_min = 0.0
+ if config.dataset_type == "pianoroll":
+ inputs, targets, lengths, mean = datasets.create_pianoroll_dataset(
+ config.dataset_path, split, config.batch_size, shuffle=shuffle,
+ repeat=repeat)
+ # Convert the mean of the training set to logit space so it can be used to
+ # initialize the bias of the generative distribution.
+ emission_bias_init = -tf.log(
+ 1. / tf.clip_by_value(mean, 0.0001, 0.9999) - 1)
+ emission_distribution_class = base.ConditionalBernoulliDistribution
+ elif config.dataset_type == "speech":
+ inputs, targets, lengths = datasets.create_speech_dataset(
+ config.dataset_path, config.batch_size,
+ samples_per_timestep=config.data_dimension, prefetch_buffer_size=1,
+ shuffle=False, repeat=False)
+ # There is no bias for the generative distribution because the test set
+ # is assumed to be already standardized with the training set statistics.
+ mean = None
+ emission_bias_init = None
+ emission_distribution_class = base.ConditionalNormalDistribution
+ if config.model == "vrnn":
+ model = vrnn.create_vrnn(inputs.get_shape().as_list()[2],
+ config.latent_size,
+ emission_distribution_class,
+ emission_bias_init=emission_bias_init,
+ proposal_type=config.proposal_type,
+ sigma_min=sigma_min,
+ raw_sigma_bias=0.5,
+ use_tilt=(config.bound == "fivo-aux"))
+ elif config.model == "srnn":
+ model = srnn.create_srnn(inputs.get_shape().as_list()[2],
+ config.latent_size,
+ emission_distribution_class,
+ emission_bias_init=emission_bias_init,
+ proposal_type=config.proposal_type,
+ sigma_min=sigma_min,
+ raw_sigma_bias=0.5,
+ use_tilt=(config.bound == "fivo-aux"))
+ else:
+ raise ValueError("model flag: %s is unrecognized" % config.model)
+ return inputs, targets, lengths, model, mean
+
+
+def restore_checkpoint_if_exists(saver, sess, logdir):
+ """Looks for a checkpoint and restores the session from it if found.
+
+ Args:
+ saver: A tf.train.Saver for restoring the session.
+ sess: A TensorFlow session.
+ logdir: The directory to look for checkpoints in.
+ Returns:
+ True if a checkpoint was found and restored, False otherwise.
+ """
+ checkpoint = tf.train.get_checkpoint_state(logdir)
+ if checkpoint:
+ checkpoint_name = os.path.basename(checkpoint.model_checkpoint_path)
+ full_checkpoint_path = os.path.join(logdir, checkpoint_name)
+ saver.restore(sess, full_checkpoint_path)
+ return True
+ return False
+
+
+def wait_for_checkpoint(saver, sess, logdir):
+ """Loops until the session is restored from a checkpoint in logdir.
+
+ Args:
+ saver: A tf.train.Saver for restoring the session.
+ sess: A TensorFlow session.
+ logdir: The directory to look for checkpoints in.
+ """
+ while not restore_checkpoint_if_exists(saver, sess, logdir):
+ tf.logging.info("Checkpoint not found in %s, sleeping for 60 seconds."
+ % logdir)
+ time.sleep(60)
+
+
+def run_train(config, create_dataset_and_model_fn=create_dataset_and_model):
+ """Runs training for a sequential latent variable model.
+
+ Args:
+ config: A configuration object with config values accessible as properties.
+ Most likely a FLAGS object. For a list of expected properties and their
+ meaning see the flags defined in fivo.py.
+ create_dataset_and_model_fn: If present, calls this function to create a
+ dataset and model instead of create_dataset_and_model() above. The
+ signature must be the same.
+ """
+
+ def create_logging_hook(step, bound_value):
+ """Creates a logging hook that prints the bound value periodically."""
+ bound_label = config.bound + " bound"
+ if config.normalize_by_seq_len:
+ bound_label += " per timestep"
+ else:
+ bound_label += " per sequence"
+ def summary_formatter(log_dict):
+ return "Step %d, %s: %f" % (
+ log_dict["step"], bound_label, log_dict["bound_value"])
+ logging_hook = tf.train.LoggingTensorHook(
+ {"step": step, "bound_value": bound_value},
+ every_n_iter=config.summarize_every,
+ formatter=summary_formatter)
+ return logging_hook
+
+ def create_loss():
+ """Creates the loss to be optimized.
+
+ Returns:
+ bound: A float Tensor containing the value of the bound that is
+ being optimized.
+ loss: A float Tensor that when differentiated yields the gradients
+ to apply to the model. Should be optimized via gradient descent.
+ """
+ inputs, targets, lengths, model, _ = create_dataset_and_model_fn(
+ config, split="train", shuffle=True, repeat=True)
+ # Compute lower bounds on the log likelihood.
+ if config.bound == "elbo":
+ ll_per_seq, _, _ = bounds.iwae(
+ model, (inputs, targets), lengths, num_samples=1,
+ parallel_iterations=config.parallel_iterations
+ )
+ elif config.bound == "iwae":
+ ll_per_seq, _, _ = bounds.iwae(
+ model, (inputs, targets), lengths, num_samples=config.num_samples,
+ parallel_iterations=config.parallel_iterations
+ )
+ elif config.bound in ("fivo", "fivo-aux"):
+ if config.resampling_type == "relaxed":
+ ll_per_seq, _, _, _ = bounds.fivo(
+ model, (inputs, targets),
+ lengths,
+ num_samples=config.num_samples,
+ resampling_criterion=smc.ess_criterion,
+ resampling_type=config.resampling_type,
+ random_seed=config.random_seed,
+ relaxed_resampling_temperature=config.
+ relaxed_resampling_temperature,
+ parallel_iterations=config.parallel_iterations
+ )
+ else:
+ ll_per_seq, _, _, _ = bounds.fivo(
+ model, (inputs, targets), lengths, num_samples=config.num_samples,
+ resampling_criterion=smc.ess_criterion,
+ resampling_type=config.resampling_type,
+ random_seed=config.random_seed,
+ parallel_iterations=config.parallel_iterations
+ )
+ # Compute loss scaled by number of timesteps.
+ ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
+ ll_per_seq = tf.reduce_mean(ll_per_seq)
+
+ tf.summary.scalar("train_ll_per_seq", ll_per_seq)
+ tf.summary.scalar("train_ll_per_t", ll_per_t)
+
+ if config.normalize_by_seq_len:
+ return ll_per_t, -ll_per_t
+ else:
+ return ll_per_seq, -ll_per_seq
+
+ def create_graph():
+ """Creates the training graph."""
+ global_step = tf.train.get_or_create_global_step()
+ bound, loss = create_loss()
+ opt = tf.train.AdamOptimizer(config.learning_rate)
+ grads = opt.compute_gradients(loss, var_list=tf.trainable_variables())
+ train_op = opt.apply_gradients(grads, global_step=global_step)
+ return bound, train_op, global_step
+
+ device = tf.train.replica_device_setter(ps_tasks=config.ps_tasks)
+ with tf.Graph().as_default():
+ if config.random_seed: tf.set_random_seed(config.random_seed)
+ with tf.device(device):
+ bound, train_op, global_step = create_graph()
+ log_hook = create_logging_hook(global_step, bound)
+ start_training = not config.stagger_workers
+ with tf.train.MonitoredTrainingSession(
+ master=config.master,
+ is_chief=config.task == 0,
+ hooks=[log_hook],
+ checkpoint_dir=config.logdir,
+ save_checkpoint_secs=120,
+ save_summaries_steps=config.summarize_every,
+ log_step_count_steps=config.summarize_every) as sess:
+ cur_step = -1
+ while not sess.should_stop() and cur_step <= config.max_steps:
+ if config.task > 0 and not start_training:
+ cur_step = sess.run(global_step)
+ tf.logging.info("task %d not active yet, sleeping at step %d" %
+ (config.task, cur_step))
+ time.sleep(30)
+ if cur_step >= config.task * 1000:
+ start_training = True
+ else:
+ _, cur_step = sess.run([train_op, global_step])
+
+
+def run_eval(config, create_dataset_and_model_fn=create_dataset_and_model):
+ """Runs evaluation for a sequential latent variable model.
+
+ This method runs only one evaluation over the dataset, writes summaries to
+ disk, and then terminates. It does not loop indefinitely.
+
+ Args:
+ config: A configuration object with config values accessible as properties.
+ Most likely a FLAGS object. For a list of expected properties and their
+ meaning see the flags defined in fivo.py.
+ create_dataset_and_model_fn: If present, calls this function to create a
+ dataset and model instead of create_dataset_and_model() above. The
+ signature must be the same.
+ """
+
+ def create_graph():
+ """Creates the evaluation graph.
+
+ Returns:
+ lower_bounds: A tuple of float Tensors containing the values of the 3
+ evidence lower bounds, summed across the batch.
+ total_batch_length: The total number of timesteps in the batch, summed
+ across batch examples.
+ batch_size: The batch size.
+ global_step: The global step the checkpoint was loaded from.
+ """
+ global_step = tf.train.get_or_create_global_step()
+ inputs, targets, lengths, model, _ = create_dataset_and_model_fn(
+ config, split=config.split, shuffle=False, repeat=False)
+ # Compute lower bounds on the log likelihood.
+ elbo_ll_per_seq, _, _ = bounds.iwae(
+ model, (inputs, targets), lengths, num_samples=1,
+ parallel_iterations=config.parallel_iterations
+ )
+ iwae_ll_per_seq, _, _ = bounds.iwae(
+ model, (inputs, targets), lengths, num_samples=config.num_samples,
+ parallel_iterations=config.parallel_iterations
+ )
+ # The resampling type should only be used for training, so we ignore it.
+ fivo_ll_per_seq, _, _, _ = bounds.fivo(
+ model, (inputs, targets), lengths, num_samples=config.num_samples,
+ resampling_criterion=smc.ess_criterion, random_seed=config.random_seed,
+ parallel_iterations=config.parallel_iterations
+ )
+ elbo_ll = tf.reduce_sum(elbo_ll_per_seq)
+ iwae_ll = tf.reduce_sum(iwae_ll_per_seq)
+ fivo_ll = tf.reduce_sum(fivo_ll_per_seq)
+ batch_size = tf.shape(lengths)[0]
+ total_batch_length = tf.reduce_sum(lengths)
+ return ((elbo_ll, iwae_ll, fivo_ll), total_batch_length, batch_size,
+ global_step)
+
+ def average_bounds_over_dataset(lower_bounds, total_batch_length, batch_size,
+ sess):
+ """Computes the values of the bounds, averaged over the datset.
+
+ Args:
+ lower_bounds: Tuple of float Tensors containing the values of the bounds
+ evaluated on a single batch.
+ total_batch_length: Integer Tensor that represents the total number of
+ timesteps in the current batch.
+ batch_size: Integer Tensor containing the batch size. This can vary if the
+ requested batch_size does not evenly divide the size of the dataset.
+ sess: A TensorFlow Session object.
+ Returns:
+ ll_per_t: A length 3 numpy array of floats containing each bound's average
+ value, normalized by the total number of timesteps in the datset. Can
+ be interpreted as a lower bound on the average log likelihood per
+ timestep in the dataset.
+ ll_per_seq: A length 3 numpy array of floats containing each bound's
+ average value, normalized by the number of sequences in the dataset.
+ Can be interpreted as a lower bound on the average log likelihood per
+ sequence in the datset.
+ """
+ total_ll = np.zeros(3, dtype=np.float64)
+ total_n_elems = 0.0
+ total_length = 0.0
+ while True:
+ try:
+ outs = sess.run([lower_bounds, batch_size, total_batch_length])
+ except tf.errors.OutOfRangeError:
+ break
+ total_ll += outs[0]
+ total_n_elems += outs[1]
+ total_length += outs[2]
+ ll_per_t = total_ll / total_length
+ ll_per_seq = total_ll / total_n_elems
+ return ll_per_t, ll_per_seq
+
+ def summarize_lls(lls_per_t, lls_per_seq, summary_writer, step):
+ """Creates log-likelihood lower bound summaries and writes them to disk.
+
+ Args:
+ lls_per_t: An array of 3 python floats, contains the values of the
+ evaluated bounds normalized by the number of timesteps.
+ lls_per_seq: An array of 3 python floats, contains the values of the
+ evaluated bounds normalized by the number of sequences.
+ summary_writer: A tf.SummaryWriter.
+ step: The current global step.
+ """
+ def scalar_summary(name, value):
+ value = tf.Summary.Value(tag=name, simple_value=value)
+ return tf.Summary(value=[value])
+
+ for i, bound in enumerate(["elbo", "iwae", "fivo"]):
+ per_t_summary = scalar_summary("%s/%s_ll_per_t" % (config.split, bound),
+ lls_per_t[i])
+ per_seq_summary = scalar_summary("%s/%s_ll_per_seq" %
+ (config.split, bound),
+ lls_per_seq[i])
+ summary_writer.add_summary(per_t_summary, global_step=step)
+ summary_writer.add_summary(per_seq_summary, global_step=step)
+ summary_writer.flush()
+
+ with tf.Graph().as_default():
+ if config.random_seed: tf.set_random_seed(config.random_seed)
+ lower_bounds, total_batch_length, batch_size, global_step = create_graph()
+ summary_dir = config.logdir + "/" + config.split
+ summary_writer = tf.summary.FileWriter(
+ summary_dir, flush_secs=15, max_queue=100)
+ saver = tf.train.Saver()
+ with tf.train.SingularMonitoredSession() as sess:
+ wait_for_checkpoint(saver, sess, config.logdir)
+ step = sess.run(global_step)
+ tf.logging.info("Model restored from step %d, evaluating." % step)
+ ll_per_t, ll_per_seq = average_bounds_over_dataset(
+ lower_bounds, total_batch_length, batch_size, sess)
+ summarize_lls(ll_per_t, ll_per_seq, summary_writer, step)
+ tf.logging.info("%s elbo ll/t: %f, iwae ll/t: %f fivo ll/t: %f",
+ config.split, ll_per_t[0], ll_per_t[1], ll_per_t[2])
+ tf.logging.info("%s elbo ll/seq: %f, iwae ll/seq: %f fivo ll/seq: %f",
+ config.split, ll_per_seq[0], ll_per_seq[1], ll_per_seq[2])
+
+
+def run_sample(config, create_dataset_and_model_fn=create_dataset_and_model):
+ """Sample from the model. Only pianorolls and pose datasets are supported."""
+
+ def sample_from_model(model, initial_state, initial_inputs, mean):
+ """Samples a sequence of outputs from the model.
+
+ The mean must be supplied -- if it isn't the results will be incorrect.
+
+ Args:
+ model: A model with sample_step implemented. See models/vrnn.py for an
+ example.
+ initial_state: The initial state of the model.
+ initial_inputs: The initial inputs to feed into the model.
+ mean: The mean of the training set, a Tensor of shape [data_dimension].
+ Returns:
+ samples: A Tensor of shape [sample_length, batch_size, num_timesteps,
+ data_dimension] containing the samples from the model.
+ """
+ initial_state, initial_output = model.sample_step(initial_state,
+ initial_inputs, 0)
+ output_ta = tf.TensorArray(size=config.sample_length,
+ dtype=tf.float32,
+ dynamic_size=False,
+ clear_after_read=True)
+ output_ta = output_ta.write(0, initial_output)
+ t0 = tf.constant(1, dtype=tf.int32)
+
+ def sample_step(t, state, prev_outputs, output_ta):
+ state, output = model.sample_step(state, prev_outputs, t)
+ output_ta = output_ta.write(t, output)
+ centered_output = output - mean[tf.newaxis, :]
+ return t+1, state, centered_output, output_ta
+
+ def sample_predicate(t, *unused_args):
+ return t < config.sample_length
+
+ _, _, _, output_ta = tf.while_loop(
+ sample_predicate,
+ sample_step,
+ loop_vars=(t0, initial_state, initial_output, output_ta),
+ parallel_iterations=config.parallel_iterations
+ )
+ samples = output_ta.stack()
+ samples = tf.reshape(samples, [config.sample_length, config.batch_size,
+ config.num_samples, config.data_dimension])
+ return samples
+
+ def create_graph():
+ """Creates the graph to sample from the model.
+
+ First, the model is conditioned on a prefix by sampling a batch of data
+ and trimming it to prefix_length. The configured bound is used to do the
+ conditioning. Then the final state from the conditioning is used to sample
+ from the model.
+
+ Returns:
+ samples: A Tensor of shape [sample_length, batch_size,
+ num_samples, data_dimension] representing samples from the model.
+ prefixes: A Tensor of shape [prefix_length, batch_size, data_dimension]
+ representing the prefixes the model was conditioned on.
+ """
+ inputs, targets, lengths, model, mean = create_dataset_and_model_fn(
+ config, split=config.split, shuffle=True, repeat=True)
+ input_prefixes = inputs[:config.prefix_length]
+ target_prefixes = targets[:config.prefix_length]
+ prefix_lengths = tf.ones_like(lengths) * config.prefix_length
+ if config.bound == "elbo":
+ _, _, state = bounds.iwae(
+ model, (input_prefixes, target_prefixes),
+ prefix_lengths, num_samples=1)
+ elif config.bound == "iwae":
+ _, _, state = bounds.iwae(
+ model, (input_prefixes, target_prefixes),
+ prefix_lengths, num_samples=config.num_samples)
+ elif config.bound == "fivo":
+ _, _, _, state = bounds.fivo(
+ model, (input_prefixes, target_prefixes), prefix_lengths,
+ num_samples=config.num_samples,
+ resampling_criterion=smc.ess_criterion,
+ random_seed=config.random_seed)
+ sample_inputs = tf.tile(inputs[config.prefix_length],
+ [config.num_samples, 1])
+ samples = sample_from_model(model, state, sample_inputs, mean)
+ return samples, target_prefixes
+
+ with tf.Graph().as_default():
+ if config.random_seed:
+ tf.set_random_seed(config.random_seed)
+ samples, prefixes = create_graph()
+ if config.sample_out_dir:
+ out_dir = config.sample_our_dir
+ else:
+ out_dir = config.logdir
+ if not tf.gfile.Exists(out_dir):
+ tf.gfile.MakeDirs(out_dir)
+ with tf.train.SingularMonitoredSession(
+ checkpoint_dir=config.logdir) as sess:
+ samples_out, prefixes_out = sess.run([samples, prefixes])
+ with tf.gfile.Open(os.path.join(out_dir, "samples.npz"), "w") as fout:
+ np.save(fout, {"prefixes": prefixes_out, "samples": samples_out})
diff --git a/models/research/fivo/fivo/runners_test.py b/models/research/fivo/fivo/runners_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb050c0a0b38b2511f3d2fb9ec846e63ead3b5ac
--- /dev/null
+++ b/models/research/fivo/fivo/runners_test.py
@@ -0,0 +1,242 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.runners"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import tensorflow as tf
+
+from fivo import runners
+from fivo.models import base
+from fivo.models import vrnn
+
+FLAGS = tf.app.flags.FLAGS
+
+
+class RunnersTest(tf.test.TestCase):
+
+ def default_config(self):
+ class Config(object):
+ pass
+ config = Config()
+ config.model = "vrnn"
+ config.latent_size = 64
+ config.batch_size = 4
+ config.num_samples = 4
+ config.resampling_type = "multinomial"
+ config.normalize_by_seq_len = True
+ config.learning_rate = 0.0001
+ config.max_steps = int(1e6)
+ config.summarize_every = 50
+ # Master must be "" to prevent state from persisting between sessions.
+ config.master = ""
+ config.task = 0
+ config.ps_tasks = 0
+ config.stagger_workers = True
+ config.random_seed = 1234
+ config.parallel_iterations = 1
+ config.dataset_type = "pianoroll"
+ config.data_dimension = None
+ config.dataset_path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "test_data", "tiny_pianoroll.pkl")
+ config.proposal_type = "filtering"
+ return config
+
+ def run_training_one_step(self, bound, dataset_type, data_dimension,
+ dataset_filename, dir_prefix, resampling_type,
+ model, batch_size=2, num_samples=3,
+ create_dataset_and_model_fn=(runners.create_dataset_and_model)):
+ config = self.default_config()
+ config.model = model
+ config.resampling_type = resampling_type
+ config.relaxed_resampling_temperature = 0.5
+ config.bound = bound
+ config.split = "train"
+ config.dataset_type = dataset_type
+ config.dataset_path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "test_data",
+ dataset_filename)
+ config.max_steps = 1
+ config.batch_size = batch_size
+ config.num_samples = num_samples
+ config.latent_size = 4
+ config.data_dimension = data_dimension
+ config.logdir = os.path.join(tf.test.get_temp_dir(), "%s-%s-%s-%s" %
+ (dir_prefix, bound, dataset_type, model))
+ runners.run_train(config,
+ create_dataset_and_model_fn=create_dataset_and_model_fn)
+ return config
+
+ def dummmy_dataset_and_model_fn(self, *unused_args, **unused_kwargs):
+ # We ignore the arguments in the dummy but need to preserve prototype.
+ batch_elements = 5
+ sequence_length = 4
+ data_dimensions = 3
+ dataset = tf.data.Dataset.from_tensors(
+ tf.zeros((sequence_length, batch_elements, data_dimensions),
+ dtype=tf.float32))
+ inputs = dataset.make_one_shot_iterator().get_next()
+ targets = tf.zeros_like(inputs)
+ lengths = tf.constant([sequence_length] * batch_elements)
+ mean = tf.constant((0.0, 0.0, 0.0))
+ model = vrnn.create_vrnn(data_dimensions, 1,
+ base.ConditionalNormalDistribution)
+ return inputs, targets, lengths, model, mean
+
+ def test_training_one_step_fivo_pianoroll_vrnn(self):
+ self.run_training_one_step("fivo", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "multinomial", "vrnn")
+
+ def test_training_one_step_iwae_pianoroll_vrnn(self):
+ self.run_training_one_step("iwae", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "multinomial", "vrnn")
+
+ def test_training_one_step_elbo_pianoroll_vrnn(self):
+ self.run_training_one_step("elbo", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "multinomial", "vrnn")
+
+ def test_training_one_step_fivo_speech_vrnn(self):
+ self.run_training_one_step("fivo", "speech", 2, "tiny_speech_dataset.tfrecord",
+ "test-training", "multinomial", "vrnn")
+
+ def test_training_one_step_iwae_speech_vrnn(self):
+ self.run_training_one_step("iwae", "speech", 2, "tiny_speech_dataset.tfrecord",
+ "test-training", "multinomial", "vrnn")
+
+ def test_training_one_step_elbo_speech_vrnn(self):
+ self.run_training_one_step("elbo", "speech", 2, "tiny_speech_dataset.tfrecord",
+ "test-training", "multinomial", "vrnn")
+
+ def test_training_one_step_fivo_pianoroll_srnn(self):
+ self.run_training_one_step("fivo", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "multinomial", "srnn")
+
+ def test_training_one_step_iwae_pianoroll_srnn(self):
+ self.run_training_one_step("iwae", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "multinomial", "srnn")
+
+ def test_training_one_step_elbo_pianoroll_srnn(self):
+ self.run_training_one_step("elbo", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "multinomial", "srnn")
+
+ def test_training_one_step_fivo_speech_srnn(self):
+ self.run_training_one_step("fivo", "speech", 2, "tiny_speech_dataset.tfrecord",
+ "test-training", "multinomial", "srnn")
+
+ def test_training_one_step_iwae_speech_srnn(self):
+ self.run_training_one_step("iwae", "speech", 2, "tiny_speech_dataset.tfrecord",
+ "test-training", "multinomial", "srnn")
+
+ def test_training_one_step_elbo_speech_srnn(self):
+ self.run_training_one_step("elbo", "speech", 2, "tiny_speech_dataset.tfrecord",
+ "test-training", "multinomial", "srnn")
+
+ def test_training_one_step_fivo_pianoroll_vrnn_relaxed(self):
+ self.run_training_one_step("fivo", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "relaxed", "vrnn")
+
+ def test_training_one_step_iwae_pianoroll_vrnn_relaxed(self):
+ self.run_training_one_step("iwae", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "relaxed", "vrnn")
+
+ def test_training_one_step_elbo_pianoroll_vrnn_relaxed(self):
+ self.run_training_one_step("elbo", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "relaxed", "vrnn")
+
+ def test_training_one_step_fivo_pianoroll_srnn_relaxed(self):
+ self.run_training_one_step("fivo", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "relaxed", "srnn")
+
+ def test_training_one_step_iwae_pianoroll_srnn_relaxed(self):
+ self.run_training_one_step("iwae", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "relaxed", "srnn")
+
+ def test_training_one_step_elbo_pianoroll_srnn_relaxed(self):
+ self.run_training_one_step("elbo", "pianoroll", 88, "tiny_pianoroll.pkl",
+ "test-training", "relaxed", "srnn")
+
+ def test_eval_vrnn(self):
+ self.run_eval("vrnn")
+
+ def test_eval_srnn(self):
+ self.run_eval("srnn")
+
+ def run_eval(self, model):
+ config = self.run_training_one_step(
+ "fivo", "pianoroll", 88, "tiny_pianoroll.pkl", "test-eval-" + model,
+ "multinomial", model)
+ config.split = "train"
+ runners.run_eval(config)
+
+ def test_sampling_vrnn(self):
+ self.run_sampling("vrnn")
+
+ def test_sampling_srnn(self):
+ self.run_sampling("srnn")
+
+ def run_sampling(self, model):
+ """Test sampling from the model."""
+ config = self.run_training_one_step(
+ "fivo", "pianoroll", 88, "tiny_pianoroll.pkl", "test-sampling", "multinomial",
+ model)
+ config.prefix_length = 3
+ config.sample_length = 6
+ config.split = "train"
+ config.sample_out_dir = None
+
+ runners.run_sample(config)
+ unused_samples = np.load(os.path.join(config.logdir, "samples.npz"))
+
+ def test_training_with_custom_fn(self):
+ self.run_training_one_step(
+ "fivo", "pianoroll", 3, "tiny_pianoroll.pkl",
+ "test-training-custom-fn", "multinomial", "vrnn", batch_size=5,
+ create_dataset_and_model_fn=self.dummmy_dataset_and_model_fn)
+
+ def test_eval_with_custom_fn(self):
+ config = self.run_training_one_step(
+ "fivo", "pianoroll", 1, "tiny_pianoroll.pkl",
+ "test-eval-custom-fn", "multinomial", "vrnn", batch_size=1,
+ create_dataset_and_model_fn=self.dummmy_dataset_and_model_fn)
+ config.split = "train"
+ runners.run_eval(
+ config,
+ create_dataset_and_model_fn=self.dummmy_dataset_and_model_fn)
+
+ def test_sampling_with_custom_fn(self):
+ config = self.run_training_one_step(
+ "fivo", "pianoroll", 3, "tiny_pianoroll.pkl",
+ "test-sample-custom-fn", "multinomial", "vrnn", batch_size=5,
+ create_dataset_and_model_fn=self.dummmy_dataset_and_model_fn)
+ config.prefix_length = 2
+ config.sample_length = 3
+ config.split = "train"
+ config.sample_out_dir = None
+
+ runners.run_sample(
+ config,
+ create_dataset_and_model_fn=self.dummmy_dataset_and_model_fn)
+ unused_samples = np.load(os.path.join(config.logdir, "samples.npz"))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/research/fivo/fivo/smc.py b/models/research/fivo/fivo/smc.py
new file mode 100644
index 0000000000000000000000000000000000000000..25d4969043e2cb8bc2c2c7a3770d3d2dfcca0bef
--- /dev/null
+++ b/models/research/fivo/fivo/smc.py
@@ -0,0 +1,338 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Implementation of sequential Monte Carlo algorithms.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+import fivo.nested_utils as nested
+
+
+def ess_criterion(log_weights, unused_t):
+ """A criterion that resamples based on effective sample size."""
+ num_particles = tf.shape(log_weights)[0]
+ # Calculate the effective sample size.
+ ess_num = 2 * tf.reduce_logsumexp(log_weights, axis=0)
+ ess_denom = tf.reduce_logsumexp(2 * log_weights, axis=0)
+ log_ess = ess_num - ess_denom
+ return log_ess <= tf.log(tf.to_float(num_particles) / 2.0)
+
+
+def never_resample_criterion(log_weights, unused_t):
+ """A criterion that never resamples."""
+ batch_size = tf.shape(log_weights)[1]
+ return tf.cast(tf.zeros([batch_size]), tf.bool)
+
+
+def always_resample_criterion(log_weights, unused_t):
+ """A criterion resamples at every timestep."""
+ batch_size = tf.shape(log_weights)[1]
+ return tf.cast(tf.ones([batch_size]), tf.bool)
+
+
+def multinomial_resampling(log_weights, states, num_particles, batch_size,
+ random_seed=None):
+ """Resample states with multinomial resampling.
+
+ Args:
+ log_weights: A [num_particles, batch_size] Tensor representing a batch
+ of batch_size logits for num_particles-ary Categorical distribution.
+ states: A nested list of [batch_size*num_particles, data_size] Tensors that
+ will be resampled from the groups of every num_particles-th row.
+ num_particles: The number of particles/samples.
+ batch_size: The batch size.
+ random_seed: The random seed to pass to the resampling operations in
+ the particle filter. Mainly useful for testing.
+
+ Returns:
+ resampled_states: A nested list of [batch_size*num_particles, data_size]
+ Tensors resampled via multinomial sampling.
+ """
+ # Calculate the ancestor indices via resampling. Because we maintain the
+ # log unnormalized weights, we pass the weights in as logits, allowing
+ # the distribution object to apply a softmax and normalize them.
+ resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
+ resampling_dist = tf.contrib.distributions.Categorical(
+ logits=resampling_parameters)
+ ancestors = tf.stop_gradient(
+ resampling_dist.sample(sample_shape=num_particles, seed=random_seed))
+
+ # Because the batch is flattened, we must modify ancestor_inds to index the
+ # proper samples. The particles in the ith filter are distributed every
+ # batch_size rows in the batch, and offset i rows from the top. So, to
+ # correct the indices we multiply by the batch_size and add the proper offset.
+ # Crucially, when ancestor_inds is flattened the layout of the batch is
+ # maintained.
+ offset = tf.expand_dims(tf.range(batch_size), 0)
+ ancestor_inds = tf.reshape(ancestors * batch_size + offset, [-1])
+
+ resampled_states = nested.gather_tensors(states, ancestor_inds)
+ return resampled_states
+
+
+def _blend_tensor(blending_weights, tensor, num_particles, batch_size):
+ """Blend tensor according to the weights.
+
+ The first dimension of tensor is actually a 2d index compacted to a 1d
+ index and similarly for blended_tensor. So if we index these Tensors
+ by [(i, j), k], then
+
+ blended_tensor[(i, j), k] =
+ sum_l tensor[(l, j), :] * blending_weights[i, j, l].
+
+ Args:
+ blending_weights: [num_particles, batch_size, num_particles] weights where
+ the indices represent [sample index, batch index, blending weight index].
+ tensor: [num_particles * batch_size, state_dim] Tensor to be blended.
+ num_particles: The number of particles/samples.
+ batch_size: The batch size.
+
+ Returns:
+ blended_tensor: [num_particles*batch_size, state_dim] blended Tensor.
+ """
+ # tensor is currently [num_particles * batch_size, state_dim], so we reshape
+ # it to [num_particles, batch_size, state_dim]. Then, transpose it to
+ # [batch_size, state_size, num_particles].
+ tensor = tf.transpose(
+ tf.reshape(tensor, [num_particles, batch_size, -1]), perm=[1, 2, 0])
+ blending_weights = tf.transpose(blending_weights, perm=[1, 2, 0])
+ # blendeding_weights is [batch index, blending weight index, sample index].
+ # Multiplying these gives a matrix of size [batch_size, state_size,
+ # num_particles].
+ tensor = tf.matmul(tensor, blending_weights)
+ # transpose the tensor to be [num_particles, batch_size, state_size]
+ # and then reshape it to match the original format.
+ tensor = tf.reshape(tf.transpose(tensor, perm=[2, 0, 1]),
+ [num_particles*batch_size, -1])
+ return tensor
+
+
+def relaxed_resampling(log_weights, states, num_particles, batch_size,
+ temperature=0.5, random_seed=None):
+ """Resample states with relaxed resampling.
+
+ Draw soft "ancestors" using the Gumbel-Softmax distribution.
+
+ Args:
+ log_weights: A [num_particles, batch_size] Tensor representing a batch
+ of batch_size logits for num_particles-ary Categorical distribution.
+ states: A nested list of [batch_size * num_particles, d] Tensors that will
+ be resampled from the groups of every num_particles-th row.
+ num_particles: The number of particles/samples.
+ batch_size: The batch size.
+ temperature: The temperature used for the relaxed one hot distribution.
+ random_seed: The random seed to pass to the resampling operations in
+ the particle filter. Mainly useful for testing.
+
+ Returns:
+ resampled_states: A nested list of [batch_size * num_particles, d]
+ Tensors resampled via multinomial sampling.
+ """
+ # log_weights are [num_particles, batch_size], so we transpose to get a
+ # set of batch_size distributions over [0, num_particles).
+ resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
+ resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
+ temperature,
+ logits=resampling_parameters)
+
+ # Sample num_particles samples from the distribution, resulting in a
+ # [num_particles, batch_size, num_particles] Tensor that represents a set of
+ # [num_particles, batch_size] blending weights. The dimensions represent
+ # [particle index, batch index, blending weight index].
+ ancestors = resampling_dist.sample(sample_shape=num_particles,
+ seed=random_seed)
+ def map_fn(tensor):
+ return _blend_tensor(ancestors, tensor, num_particles, batch_size)
+
+ resampled_states = nested.map_nested(map_fn, states)
+ return resampled_states
+
+
+def smc(
+ transition_fn,
+ num_steps,
+ num_particles=1,
+ resampling_criterion=ess_criterion,
+ resampling_fn=multinomial_resampling,
+ loop_fn=None,
+ parallel_iterations=30,
+ swap_memory=True):
+ """Run a sequential Monte Carlo (SMC) algorithm.
+
+ This method runs an SMC algorithm that evolves systems of particles
+ using the supplied transition function for the specified number of steps. The
+ particles are optionally resampled using resampling_fn when indicated by
+ resampling_criterion.
+
+ Args:
+ transition_fn: A callable that propogates a batch of particles one step.
+ Must accept as arguments a batch of particle states and the current
+ timestep. Must return the particle states one timestep in the future, the
+ incremental weights of each particle as a [num_samples*batch_size] float
+ Tensor, and optionally a set of arguments to pass to the loop_fn. If
+ the loop args are not provided, they will be set to None. Before the
+ first timestep transition_fn will be called with the arguments None, -1
+ and should return the initial particle states.
+ num_steps: A [batch_size] Tensor of ints representing the number of steps
+ to run each filter for.
+ num_particles: A scalar int, the number of particles to use in each filter.
+ resampling_criterion: The resampling criterion to use for this particle
+ filter. Must accept the current log weights and timestep and
+ return a boolean Tensor of shape [batch_size] indicating whether each
+ particle filter should resample. See ess_criterion and related functions
+ for examples. When resampling_criterion is never_resample_criterion,
+ resampling_fn is ignored and never called.
+ resampling_fn: A callable that performs the resampling operation. Must
+ accept as arguments the log weights, particle states, num_particles,
+ and batch_size and return the resampled particle states. See
+ multinomial_resampling and relaxed_resampling for examples.
+ loop_fn: A callable that performs operations on the weights and
+ particle states, useful for accumulating and processing state that
+ shouldn't be resampled. At each timestep after (possibly) resampling
+ loop_fn will be called with the previous loop_state, a set of arguments
+ produced by transition_fn called loop_args, the resampled particle states,
+ the current log weights as [num_particles, batch_size] float Tensor, a
+ [batch_size] float Tensor representing whether or not each filter
+ resampled, the current mask indicating which filters are active, and the
+ current timestep. It must return the next loop state. Before the first
+ timestep loop_fn will be called with the arguments None, None, None, None,
+ -1 and must return the initial loop state. The loop state can be a
+ possibly nested structure of Tensors and TensorArrays.
+ parallel_iterations: The number of parallel iterations to use for the
+ internal while loop. Note that values greater than 1 can introduce
+ non-determinism even when resampling is deterministic.
+ swap_memory: Whether GPU-CPU memory swapping should be enabled for the
+ internal while loop.
+
+ Returns:
+ log_z_hat: A Tensor of shape [batch_size] containing an estimate of the log
+ normalizing constant that converts between the unormalized target
+ distribution (as defined by the weights) and the true target distribution.
+ log_weights: A Tensor of shape [max_num_steps, batch_size, num_particles]
+ containing the log weights at each timestep of the particle filter.
+ Will not be valid for timesteps past the supplied num_steps.
+ resampled: A float Tensor of shape [max_num_steps, batch_size] indicating
+ when the particle filters resampled. Will be 1.0 on timesteps when
+ resampling occurred and 0.0 on timesteps when it did not.
+ final_loop_state: The final state returned by loop_fn. If loop_fn is None
+ then 0 will be returned.
+ """
+ # batch_size represents the number of particle filters running in parallel.
+ batch_size = tf.shape(num_steps)[0]
+ # Create a TensorArray where element t is the [num_particles*batch_size]
+ # sequence mask for timestep t.
+ max_num_steps = tf.reduce_max(num_steps)
+ seq_mask = tf.transpose(
+ tf.sequence_mask(num_steps, maxlen=max_num_steps, dtype=tf.float32),
+ perm=[1, 0])
+ seq_mask = tf.tile(seq_mask, [1, num_particles])
+ mask_ta = tf.TensorArray(seq_mask.dtype,
+ max_num_steps,
+ name='mask_ta')
+ mask_ta = mask_ta.unstack(seq_mask)
+ # Initialize the state.
+ t0 = tf.constant(0, tf.int32)
+ init_particle_state = transition_fn(None, -1)
+
+ def transition(*args):
+ transition_outs = transition_fn(*args)
+ if len(transition_outs) == 2:
+ return transition_outs + (None,)
+ else:
+ return transition_outs
+
+ if loop_fn is None:
+ loop_fn = lambda *args: 0
+
+ init_loop_state = loop_fn(None, None, None, None, None, None, -1)
+ init_states = (init_particle_state, init_loop_state)
+ ta_names = ['log_weights', 'resampled']
+ tas = [tf.TensorArray(tf.float32, max_num_steps, name='%s_ta' % n)
+ for n in ta_names]
+ log_weights_acc = tf.zeros([num_particles, batch_size], dtype=tf.float32)
+ log_z_hat_acc = tf.zeros([batch_size], dtype=tf.float32)
+
+ def while_predicate(t, *unused_args):
+ return t < max_num_steps
+
+ def while_step(t, state, tas, log_weights_acc, log_z_hat_acc):
+ """Implements one timestep of the particle filter."""
+ particle_state, loop_state = state
+ cur_mask = nested.read_tas(mask_ta, t)
+ # Propagate the particles one step.
+ log_alpha, new_particle_state, loop_args = transition(particle_state, t)
+ # Update the current weights with the incremental weights.
+ log_alpha *= cur_mask
+ log_alpha = tf.reshape(log_alpha, [num_particles, batch_size])
+ log_weights_acc += log_alpha
+
+ should_resample = resampling_criterion(log_weights_acc, t)
+
+ if resampling_criterion == never_resample_criterion:
+ resampled = tf.to_float(should_resample)
+ else:
+ # Compute the states as if we did resample.
+ resampled_states = resampling_fn(
+ log_weights_acc,
+ new_particle_state,
+ num_particles,
+ batch_size)
+ # Decide whether or not we should resample; don't resample if we are past
+ # the end of a sequence.
+ should_resample = tf.logical_and(should_resample,
+ cur_mask[:batch_size] > 0.)
+ float_should_resample = tf.to_float(should_resample)
+ new_particle_state = nested.where_tensors(
+ tf.tile(should_resample, [num_particles]),
+ resampled_states,
+ new_particle_state)
+ resampled = float_should_resample
+
+ new_loop_state = loop_fn(loop_state, loop_args, new_particle_state,
+ log_weights_acc, resampled, cur_mask, t)
+ # Update log Z hat.
+ log_z_hat_update = tf.reduce_logsumexp(
+ log_weights_acc, axis=0) - tf.log(tf.to_float(num_particles))
+ # If it is the last timestep, always add the update.
+ log_z_hat_acc += tf.cond(t < max_num_steps - 1,
+ lambda: log_z_hat_update * resampled,
+ lambda: log_z_hat_update)
+ # Update the TensorArrays before we reset the weights so that we capture
+ # the incremental weights and not zeros.
+ ta_updates = [log_weights_acc, resampled]
+ new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
+ # For the particle filters that resampled, reset weights to zero.
+ log_weights_acc *= (1. - tf.tile(resampled[tf.newaxis, :],
+ [num_particles, 1]))
+ new_state = (new_particle_state, new_loop_state)
+ return t + 1, new_state, new_tas, log_weights_acc, log_z_hat_acc
+
+ _, final_state, tas, _, log_z_hat = tf.while_loop(
+ while_predicate,
+ while_step,
+ loop_vars=(t0, init_states, tas, log_weights_acc, log_z_hat_acc),
+ parallel_iterations=parallel_iterations,
+ swap_memory=swap_memory)
+
+ log_weights, resampled = [x.stack() for x in tas]
+ log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
+ final_particle_state, final_loop_state = final_state
+ return (log_z_hat, log_weights, resampled,
+ final_particle_state, final_loop_state)
diff --git a/models/research/fivo/fivo/smc_test.py b/models/research/fivo/fivo/smc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae32a62f21e037252bda44e3e1f47e007c9b7b9b
--- /dev/null
+++ b/models/research/fivo/fivo/smc_test.py
@@ -0,0 +1,241 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fivo.smc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import scipy
+import tensorflow as tf
+
+from fivo import smc
+
+lse = scipy.special.logsumexp
+
+
+def _simple_transition_fn(state, unused_t):
+ if state is None:
+ return tf.zeros([4], dtype=tf.float32)
+ return tf.constant([5., 4., 1., 0.5]), tf.zeros([4], dtype=tf.float32)
+
+
+def _resample_at_step_criterion(step):
+ """A criterion that resamples once at a specific timestep."""
+ def criterion(log_weights, t):
+ batch_size = tf.shape(log_weights)[1]
+ return tf.fill([batch_size], tf.equal(t, step))
+ return criterion
+
+
+class SMCTest(tf.test.TestCase):
+
+ def test_never_resampling(self):
+ """Test that never_resample_criterion makes smc not resample.
+
+ Also test that the weights and log_z_hat are computed correctly when never
+ resampling.
+ """
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ outs = smc.smc(
+ _simple_transition_fn,
+ num_steps=tf.convert_to_tensor([5, 3]),
+ num_particles=2,
+ resampling_criterion=smc.never_resample_criterion)
+ log_z_hat, weights, resampled = sess.run(outs[0:3])
+ gt_weights = np.array(
+ [[[5, 1], [4, .5]],
+ [[10, 2], [8, 1]],
+ [[15, 3], [12, 1.5]],
+ [[20, 4], [12, 1.5]],
+ [[25, 5], [12, 1.5]]],
+ dtype=np.float32)
+ gt_log_z_hat = np.array(
+ [lse([25, 5]) - np.log(2),
+ lse([12, 1.5]) - np.log(2)],
+ dtype=np.float32)
+ self.assertAllClose(gt_log_z_hat, log_z_hat)
+ self.assertAllClose(gt_weights, weights)
+ self.assertAllEqual(np.zeros_like(resampled), resampled)
+
+ def test_always_resampling(self):
+ """Test always_resample_criterion makes smc always resample.
+
+ Past a sequence end the filter should not resample, however.
+ Also check that weights and log_z_hat estimate are correct.
+ """
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ outs = smc.smc(
+ _simple_transition_fn,
+ num_steps=tf.convert_to_tensor([5, 3]),
+ num_particles=2,
+ resampling_criterion=smc.always_resample_criterion)
+ log_z_hat, weights, resampled = sess.run(outs[0:3])
+ gt_weights = np.array(
+ [[[5, 1], [4, .5]],
+ [[5, 1], [4, .5]],
+ [[5, 1], [4, .5]],
+ [[5, 1], [0., 0.]],
+ [[5, 1], [0., 0.]]],
+ dtype=np.float32)
+ gt_log_z_hat = np.array(
+ [5*lse([5, 1]) - 5*np.log(2),
+ 3*lse([4, .5]) - 3*np.log(2)],
+ dtype=np.float32)
+ gt_resampled = np.array(
+ [[1, 1], [1, 1], [1, 1], [1, 0], [1, 0]],
+ dtype=np.float32)
+ self.assertAllClose(gt_log_z_hat, log_z_hat)
+ self.assertAllClose(gt_weights, weights)
+ self.assertAllEqual(gt_resampled, resampled)
+
+ def test_weights_reset_when_resampling_at_sequence_end(self):
+ """Test that the weights are reset when resampling at the sequence end.
+
+ When resampling happens on the last timestep of a sequence the weights
+ should be set to zero on the next timestep and remain zero afterwards.
+ """
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ outs = smc.smc(
+ _simple_transition_fn,
+ num_steps=tf.convert_to_tensor([5, 3]),
+ num_particles=2,
+ resampling_criterion=_resample_at_step_criterion(2))
+ log_z_hat, weights, resampled = sess.run(outs[0:3])
+ gt_log_z = np.array(
+ [lse([15, 3]) + lse([10, 2]) - 2*np.log(2),
+ lse([12, 1.5]) - np.log(2)],
+ dtype=np.float32)
+ gt_resampled = np.array(
+ [[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]],
+ dtype=np.float32)
+ gt_weights = np.array(
+ [[[5, 1], [4, .5]],
+ [[10, 2], [8, 1]],
+ [[15, 3], [12, 1.5]],
+ [[5, 1], [0, 0]],
+ [[10, 2], [0, 0]]],
+ dtype=np.float32)
+ self.assertAllClose(gt_log_z, log_z_hat)
+ self.assertAllEqual(gt_resampled, resampled)
+ self.assertAllEqual(gt_weights, weights)
+
+ def test_weights_not_updated_past_sequence_end(self):
+ """Test that non-zero weights are not updated past the end of a sequence."""
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ outs = smc.smc(
+ _simple_transition_fn,
+ num_steps=tf.convert_to_tensor([6, 4]),
+ num_particles=2,
+ resampling_criterion=_resample_at_step_criterion(1))
+ log_z_hat, weights, resampled = sess.run(outs[0:3])
+ gt_log_z_hat = np.array(
+ [lse([10, 2]) + lse([20, 4]) - 2*np.log(2),
+ lse([8, 1]) + lse([8, 1]) - 2*np.log(2)],
+ dtype=np.float32)
+ # Ensure that we only resample on the 2nd timestep.
+ gt_resampled = np.array(
+ [[0, 0], [1, 1], [0, 0], [0, 0], [0, 0], [0, 0]],
+ dtype=np.float32)
+ # Ensure that the weights after the end of the sequence don't change.
+ # Ensure that the weights after resampling before the end of the sequence
+ # do change.
+ gt_weights = np.array(
+ [[[5, 1], [4, .5]],
+ [[10, 2], [8, 1]],
+ [[5, 1], [4, .5]],
+ [[10, 2], [8, 1]],
+ [[15, 3], [8, 1]],
+ [[20, 4], [8, 1]]],
+ dtype=np.float32)
+ self.assertAllClose(gt_log_z_hat, log_z_hat)
+ self.assertAllEqual(gt_resampled, resampled)
+ self.assertAllEqual(gt_weights, weights)
+
+ def test_resampling_on_max_num_steps(self):
+ """Test that everything is correct when resampling on step max_num_steps.
+
+ When resampling on step max_num_steps (i.e. the last step of the longest
+ sequence), ensure that there are no off-by-one errors preventing resampling
+ and also that the weights are not updated.
+ """
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ outs = smc.smc(
+ _simple_transition_fn,
+ num_steps=tf.convert_to_tensor([4, 2]),
+ num_particles=2,
+ resampling_criterion=_resample_at_step_criterion(3))
+ log_z_hat, weights, resampled = sess.run(outs[0:3])
+ gt_log_z_hat = np.array(
+ [lse([20, 4]) - np.log(2),
+ lse([8, 1]) - np.log(2)],
+ dtype=np.float32)
+ # Ensure that we only resample on the 3rd timestep and that the second
+ # filter doesn't resample at all because it is only run for 2 steps.
+ gt_resampled = np.array(
+ [[0, 0], [0, 0], [0, 0], [1, 0]],
+ dtype=np.float32)
+ gt_weights = np.array(
+ [[[5, 1], [4, .5]],
+ [[10, 2], [8, 1]],
+ [[15, 3], [8, 1]],
+ [[20, 4], [8, 1]]],
+ dtype=np.float32)
+ self.assertAllClose(gt_log_z_hat, log_z_hat)
+ self.assertAllEqual(gt_resampled, resampled)
+ self.assertAllEqual(gt_weights, weights)
+
+ def test_multinomial_resampling(self):
+ """Test that mulitnomial resampling selects the correct states."""
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ # Setup input.
+ inf = 1000.0 # Very large value in log space.
+ num_samples = 2
+ batch_size = 2
+ log_weights = tf.convert_to_tensor([[inf, 0], [0, inf]])
+ states = tf.convert_to_tensor([1, 2, 3, 4])
+ # Run test.
+ resampled_states = smc.multinomial_resampling(
+ log_weights, states, num_samples, batch_size, random_seed=0)
+ resampled_states_values = sess.run(resampled_states)
+ self.assertAllEqual(resampled_states_values, [1, 4, 1, 4])
+
+ def test_blend_tensor(self):
+ """Test that relaxed resampling blends the correct states."""
+ tf.set_random_seed(1234)
+ with self.test_session() as sess:
+ # Setup input.
+ num_samples = 2
+ batch_size = 2
+ blending_weights = tf.convert_to_tensor(
+ [[[0.5, 0.5], [0.25, 0.75]], [[0.75, 0.25], [0.5, 0.5]]])
+ states = tf.convert_to_tensor([4., 8., 12., 16.])
+ # Run test.
+ blended_states = smc._blend_tensor(blending_weights, states,
+ num_samples, batch_size)
+ blended_states_values = sess.run(blended_states)
+ self.assertAllClose(blended_states_values[:, 0], [8., 14., 6., 12.])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/fivo/fivo/test_data/tiny_pianoroll.pkl b/models/research/fivo/fivo/test_data/tiny_pianoroll.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..c2e4639da96ff5735576cd45dcccb5e0cd1cabec
Binary files /dev/null and b/models/research/fivo/fivo/test_data/tiny_pianoroll.pkl differ
diff --git a/models/research/fivo/fivo/test_data/tiny_speech_dataset.tfrecord b/models/research/fivo/fivo/test_data/tiny_speech_dataset.tfrecord
new file mode 100644
index 0000000000000000000000000000000000000000..93fe8791b631da35b9d03d37e6494cc7c50cb55d
Binary files /dev/null and b/models/research/fivo/fivo/test_data/tiny_speech_dataset.tfrecord differ
diff --git a/models/research/fivo/fivo/test_utils.py b/models/research/fivo/fivo/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..48bbd3d483c45457b82b12ac1587d4c314b79f49
--- /dev/null
+++ b/models/research/fivo/fivo/test_utils.py
@@ -0,0 +1,144 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities for testing FIVO.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from fivo.models import base
+from fivo.models import srnn
+from fivo.models import vrnn
+
+
+def create_vrnn(generative_class=base.ConditionalNormalDistribution,
+ batch_size=2, data_size=3, rnn_hidden_size=4,
+ latent_size=5, fcnet_hidden_size=7, encoded_data_size=9,
+ encoded_latent_size=11, num_timesteps=7, data_lengths=(7, 4),
+ use_tilt=False, random_seed=None):
+ """Creates a VRNN and some dummy data to feed it for testing purposes.
+
+ Args:
+ generative_class: The class of the generative distribution.
+ batch_size: The number of elements per batch.
+ data_size: The dimension of the vectors that make up the data sequences.
+ rnn_hidden_size: The hidden state dimension of the RNN that forms the
+ deterministic part of this VRNN.
+ latent_size: The size of the stochastic latent state of the VRNN.
+ fcnet_hidden_size: The size of the hidden layer of the fully connected
+ networks that parameterize the conditional probability distributions
+ of the VRNN.
+ encoded_data_size: The size of the output of the data encoding network.
+ encoded_latent_size: The size of the output of the latent state encoding
+ network.
+ num_timesteps: The maximum number of timesteps in the data.
+ data_lengths: A tuple of size batch_size that contains the desired lengths
+ of each sequence in the dummy data.
+ use_tilt: Use a tilting function.
+ random_seed: A random seed to feed the VRNN, mainly useful for testing
+ purposes.
+
+ Returns:
+ model: A VRNN object.
+ inputs: A Tensor of shape [num_timesteps, batch_size, data_size], the inputs
+ to the model, also known as the observations.
+ targets: A Tensor of shape [num_timesteps, batch_size, data_size], the
+ desired outputs of the model.
+ lengths: A Tensor of shape [batch_size], the lengths of the sequences in the
+ batch.
+ """
+
+ fcnet_hidden_sizes = [fcnet_hidden_size]
+ initializers = {"w": tf.contrib.layers.xavier_initializer(seed=random_seed),
+ "b": tf.zeros_initializer()}
+ model = vrnn.create_vrnn(
+ data_size,
+ latent_size,
+ generative_class,
+ rnn_hidden_size=rnn_hidden_size,
+ fcnet_hidden_sizes=fcnet_hidden_sizes,
+ encoded_data_size=encoded_data_size,
+ encoded_latent_size=encoded_latent_size,
+ use_tilt=use_tilt,
+ initializers=initializers,
+ random_seed=random_seed)
+ inputs = tf.random_uniform([num_timesteps, batch_size, data_size],
+ seed=random_seed, dtype=tf.float32)
+ targets = tf.random_uniform([num_timesteps, batch_size, data_size],
+ seed=random_seed, dtype=tf.float32)
+ lengths = tf.constant(data_lengths, dtype=tf.int32)
+ return model, inputs, targets, lengths
+
+
+def create_srnn(generative_class=base.ConditionalNormalDistribution,
+ batch_size=2, data_size=3, rnn_hidden_size=4,
+ latent_size=5, fcnet_hidden_size=7, encoded_data_size=3,
+ encoded_latent_size=2, num_timesteps=7, data_lengths=(7, 4),
+ use_tilt=False, random_seed=None):
+ """Creates a SRNN and some dummy data to feed it for testing purposes.
+
+ Args:
+ generative_class: The class of the generative distribution.
+ batch_size: The number of elements per batch.
+ data_size: The dimension of the vectors that make up the data sequences.
+ rnn_hidden_size: The hidden state dimension of the RNN that forms the
+ deterministic part of this SRNN.
+ latent_size: The size of the stochastic latent state of the SRNN.
+ fcnet_hidden_size: The size of the hidden layer of the fully connected
+ networks that parameterize the conditional probability distributions
+ of the SRNN.
+ encoded_data_size: The size of the output of the data encoding network.
+ encoded_latent_size: The size of the output of the latent state encoding
+ network.
+ num_timesteps: The maximum number of timesteps in the data.
+ data_lengths: A tuple of size batch_size that contains the desired lengths
+ of each sequence in the dummy data.
+ use_tilt: Use a tilting function.
+ random_seed: A random seed to feed the SRNN, mainly useful for testing
+ purposes.
+
+ Returns:
+ model: A SRNN object.
+ inputs: A Tensor of shape [num_timesteps, batch_size, data_size], the inputs
+ to the model, also known as the observations.
+ targets: A Tensor of shape [num_timesteps, batch_size, data_size], the
+ desired outputs of the model.
+ lengths: A Tensor of shape [batch_size], the lengths of the sequences in the
+ batch.
+ """
+
+ fcnet_hidden_sizes = [fcnet_hidden_size]
+ initializers = {"w": tf.contrib.layers.xavier_initializer(seed=random_seed),
+ "b": tf.zeros_initializer()}
+ model = srnn.create_srnn(
+ data_size,
+ latent_size,
+ generative_class,
+ rnn_hidden_size=rnn_hidden_size,
+ fcnet_hidden_sizes=fcnet_hidden_sizes,
+ encoded_data_size=encoded_data_size,
+ encoded_latent_size=encoded_latent_size,
+ use_tilt=use_tilt,
+ initializers=initializers,
+ random_seed=random_seed)
+ inputs = tf.random_uniform([num_timesteps, batch_size, data_size],
+ seed=random_seed, dtype=tf.float32)
+ targets = tf.random_uniform([num_timesteps, batch_size, data_size],
+ seed=random_seed, dtype=tf.float32)
+ lengths = tf.constant(data_lengths, dtype=tf.int32)
+ return model, inputs, targets, lengths
diff --git a/models/research/fivo/run_fivo.py b/models/research/fivo/run_fivo.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ca079421f09fb65439dae210b1c3760240b51ad
--- /dev/null
+++ b/models/research/fivo/run_fivo.py
@@ -0,0 +1,142 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A script to run training for sequential latent variable models.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from fivo import ghmm_runners
+from fivo import runners
+
+# Shared flags.
+tf.app.flags.DEFINE_enum("mode", "train",
+ ["train", "eval", "sample"],
+ "The mode of the binary.")
+tf.app.flags.DEFINE_enum("model", "vrnn",
+ ["vrnn", "ghmm", "srnn"],
+ "Model choice.")
+tf.app.flags.DEFINE_integer("latent_size", 64,
+ "The size of the latent state of the model.")
+tf.app.flags.DEFINE_enum("dataset_type", "pianoroll",
+ ["pianoroll", "speech", "pose"],
+ "The type of dataset.")
+tf.app.flags.DEFINE_string("dataset_path", "",
+ "Path to load the dataset from.")
+tf.app.flags.DEFINE_integer("data_dimension", None,
+ "The dimension of each vector in the data sequence. "
+ "Defaults to 88 for pianoroll datasets and 200 for speech "
+ "datasets. Should not need to be changed except for "
+ "testing.")
+tf.app.flags.DEFINE_integer("batch_size", 4,
+ "Batch size.")
+tf.app.flags.DEFINE_integer("num_samples", 4,
+ "The number of samples (or particles) for multisample "
+ "algorithms.")
+tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi",
+ "The directory to keep checkpoints and summaries in.")
+tf.app.flags.DEFINE_integer("random_seed", None,
+ "A random seed for seeding the TensorFlow graph.")
+tf.app.flags.DEFINE_integer("parallel_iterations", 30,
+ "The number of parallel iterations to use for the while "
+ "loop that computes the bounds.")
+
+# Training flags.
+tf.app.flags.DEFINE_enum("bound", "fivo",
+ ["elbo", "iwae", "fivo", "fivo-aux"],
+ "The bound to optimize.")
+tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True,
+ "If true, normalize the loss by the number of timesteps "
+ "per sequence.")
+tf.app.flags.DEFINE_float("learning_rate", 0.0002,
+ "The learning rate for ADAM.")
+tf.app.flags.DEFINE_integer("max_steps", int(1e9),
+ "The number of gradient update steps to train for.")
+tf.app.flags.DEFINE_integer("summarize_every", 50,
+ "The number of steps between summaries.")
+tf.app.flags.DEFINE_enum("resampling_type", "multinomial",
+ ["multinomial", "relaxed"],
+ "The resampling strategy to use for training.")
+tf.app.flags.DEFINE_float("relaxed_resampling_temperature", 0.5,
+ "The relaxation temperature for relaxed resampling.")
+tf.app.flags.DEFINE_enum("proposal_type", "filtering",
+ ["prior", "filtering", "smoothing",
+ "true-filtering", "true-smoothing"],
+ "The type of proposal to use. true-filtering and true-smoothing "
+ "are only available for the GHMM. The specific implementation "
+ "of each proposal type is left to model-writers.")
+
+# Distributed training flags.
+tf.app.flags.DEFINE_string("master", "",
+ "The BNS name of the TensorFlow master to use.")
+tf.app.flags.DEFINE_integer("task", 0,
+ "Task id of the replica running the training.")
+tf.app.flags.DEFINE_integer("ps_tasks", 0,
+ "Number of tasks in the ps job. If 0 no ps job is used.")
+tf.app.flags.DEFINE_boolean("stagger_workers", True,
+ "If true, bring one worker online every 1000 steps.")
+
+# Evaluation flags.
+tf.app.flags.DEFINE_enum("split", "train",
+ ["train", "test", "valid"],
+ "Split to evaluate the model on.")
+
+# Sampling flags.
+tf.app.flags.DEFINE_integer("sample_length", 50,
+ "The number of timesteps to sample for.")
+tf.app.flags.DEFINE_integer("prefix_length", 25,
+ "The number of timesteps to condition the model on "
+ "before sampling.")
+tf.app.flags.DEFINE_string("sample_out_dir", None,
+ "The directory to write the samples to. "
+ "Defaults to logdir.")
+
+# GHMM flags.
+tf.app.flags.DEFINE_float("variance", 0.1,
+ "The variance of the ghmm.")
+tf.app.flags.DEFINE_integer("num_timesteps", 5,
+ "The number of timesteps to run the gmp for.")
+FLAGS = tf.app.flags.FLAGS
+
+PIANOROLL_DEFAULT_DATA_DIMENSION = 88
+SPEECH_DEFAULT_DATA_DIMENSION = 200
+
+
+def main(unused_argv):
+ tf.logging.set_verbosity(tf.logging.INFO)
+ if FLAGS.model in ["vrnn", "srnn"]:
+ if FLAGS.data_dimension is None:
+ if FLAGS.dataset_type == "pianoroll":
+ FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
+ elif FLAGS.dataset_type == "speech":
+ FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
+ if FLAGS.mode == "train":
+ runners.run_train(FLAGS)
+ elif FLAGS.mode == "eval":
+ runners.run_eval(FLAGS)
+ elif FLAGS.mode == "sample":
+ runners.run_sample(FLAGS)
+ elif FLAGS.model == "ghmm":
+ if FLAGS.mode == "train":
+ ghmm_runners.run_train(FLAGS)
+ elif FLAGS.mode == "eval":
+ ghmm_runners.run_eval(FLAGS)
+
+if __name__ == "__main__":
+ tf.app.run(main)
diff --git a/models/research/global_objectives/README.md b/models/research/global_objectives/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f9a778c59d420f9bf5deccf4b2b45147636de582
--- /dev/null
+++ b/models/research/global_objectives/README.md
@@ -0,0 +1,152 @@
+
+
+
+
+# Global Objectives
+The Global Objectives library provides TensorFlow loss functions that optimize
+directly for a variety of objectives including AUC, recall at precision, and
+more. The global objectives losses can be used as drop-in replacements for
+TensorFlow's standard multilabel loss functions:
+`tf.nn.sigmoid_cross_entropy_with_logits` and `tf.losses.sigmoid_cross_entropy`.
+
+Many machine learning classification models are optimized for classification
+accuracy, when the real objective the user cares about is different and can be
+precision at a fixed recall, precision-recall AUC, ROC AUC or similar metrics.
+These are referred to as "global objectives" because they depend on how the
+model classifies the dataset as a whole and do not decouple across data points
+as accuracy does.
+
+Because these objectives are combinatorial, discontinuous, and essentially
+intractable to optimize directly, the functions in this library approximate
+their corresponding objectives. This approximation approach follows the same
+pattern as optimizing for accuracy, where a surrogate objective such as
+cross-entropy or the hinge loss is used as an upper bound on the error rate.
+
+## Getting Started
+For a full example of how to use the loss functions in practice, see
+loss_layers_example.py.
+
+Briefly, global objective losses can be used to replace
+`tf.nn.sigmoid_cross_entropy_with_logits` by providing the relevant
+additional arguments. For example,
+
+``` python
+tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
+```
+
+could be replaced with
+
+``` python
+global_objectives.recall_at_precision_loss(
+ labels=labels,
+ logits=logits,
+ target_precision=0.95)[0]
+```
+
+Just as minimizing the cross-entropy loss will maximize accuracy, the loss
+functions in loss_layers.py were written so that minimizing the loss will
+maximize the corresponding objective.
+
+The global objective losses have two return values -- the loss tensor and
+additional quantities for debugging and customization -- which is why the first
+value is used above. For more information, see
+[Visualization & Debugging](#visualization-debugging).
+
+## Binary Label Format
+Binary classification problems can be represented as a multi-class problem with
+two classes, or as a multi-label problem with one label. (Recall that multiclass
+problems have mutually exclusive classes, e.g. 'cat xor dog', and multilabel
+have classes which are not mutually exclusive, e.g. an image can contain a cat,
+a dog, both, or neither.) The softmax loss
+(`tf.nn.softmax_cross_entropy_with_logits`) is used for multi-class problems,
+while the sigmoid loss (`tf.nn.sigmoid_cross_entropy_with_logits`) is used for
+multi-label problems.
+
+A multiclass label format for binary classification might represent positives
+with the label [1, 0] and negatives with the label [0, 1], while the multilbel
+format for the same problem would use [1] and [0], respectively.
+
+All global objectives loss functions assume that the multilabel format is used.
+Accordingly, if your current loss function is softmax, the labels will have to
+be reformatted for the loss to work properly.
+
+## Dual Variables
+Global objectives losses (except for `roc_auc_loss`) use internal variables
+called dual variables or Lagrange multipliers to enforce the desired constraint
+(e.g. if optimzing for recall at precision, the constraint is on precision).
+
+These dual variables are created and initialized internally by the loss
+functions, and are updated during training by the same optimizer used for the
+model's other variables. To initialize the dual variables to a particular value,
+use the `lambdas_initializer` argument. The dual variables can be found under
+the key `lambdas` in the `other_outputs` dictionary returned by the losses.
+
+## Loss Function Arguments
+The following arguments are common to all loss functions in the library, and are
+either required or very important.
+
+* `labels`: Corresponds directly to the `labels` argument of
+ `tf.nn.sigmoid_cross_entropy_with_logits`.
+* `logits`: Corresponds directly to the `logits` argument of
+ `tf.nn.sigmoid_cross_entropy_with_logits`.
+* `dual_rate_factor`: A floating point value which controls the step size for
+ the Lagrange multipliers. Setting this value less than 1.0 will cause the
+ constraint to be enforced more gradually and will result in more stable
+ training.
+
+In addition, the objectives with a single constraint (e.g.
+`recall_at_precision_loss`) have an argument (e.g. `target_precision`) used to
+specify the value of the constraint. The optional `precision_range` argument to
+`precision_recall_auc_loss` is used to specify the range of precision values
+over which to optimize the AUC, and defaults to the interval [0, 1].
+
+Optional arguments:
+
+* `weights`: A tensor which acts as coefficients for the loss. If a weight of x
+ is provided for a datapoint and that datapoint is a true (false) positive
+ (negative), it will be counted as x true (false) positives (negatives).
+ Defaults to 1.0.
+* `label_priors`: A tensor specifying the fraction of positive datapoints for
+ each label. If not provided, it will be computed inside the loss function.
+* `surrogate_type`: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for indicator functions.
+* `lambdas_initializer`: An initializer for the dual variables (Lagrange
+ multipliers). See also the Dual Variables section.
+* `num_anchors` (precision_recall_auc_loss only): The number of grid points used
+ when approximating the AUC as a Riemann sum.
+
+## Hyperparameters
+While the functional form of the global objectives losses allow them to be
+easily substituted in place of `sigmoid_cross_entropy_with_logits`, model
+hyperparameters such as learning rate, weight decay, etc. may need to be
+fine-tuned to the new loss. Fortunately, the amount of hyperparameter re-tuning
+is usually minor.
+
+The most important hyperparameters to modify are the learning rate and
+dual_rate_factor (see the section on Loss Function Arguments, above).
+
+## Visualization & Debugging
+The global objectives losses return two values. The first is a tensor
+representing the numerical value of the loss, which can be passed to an
+optimizer. The second is a dictionary of tensors created by the loss function
+which are not necessary for optimization but useful in debugging. These vary
+depending on the loss function, but usually include `lambdas` (the Lagrange
+multipliers) as well as the lower bound on true positives and upper bound on
+false positives.
+
+When visualizing the loss during training, note that the global objectives
+losses differ from standard losses in some important ways:
+
+* The global losses may be negative. This is because the value returned by the
+ loss includes terms involving the Lagrange multipliers, which may be negative.
+* The global losses may not decrease over the course of training. To enforce the
+ constraints in the objective, the loss changes over time and may increase.
+
+## More Info
+For more details, see the [Global Objectives paper](https://arxiv.org/abs/1608.04802).
+
+## Maintainers
+
+* Mariano Schain
+* Elad Eban
+* [Alan Mackey](https://github.com/mackeya-google)
diff --git a/models/research/global_objectives/loss_layers.py b/models/research/global_objectives/loss_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaea05398ef3771247060afda63be184ea76cdf0
--- /dev/null
+++ b/models/research/global_objectives/loss_layers.py
@@ -0,0 +1,930 @@
+# Copyright 2018 The TensorFlow Global Objectives Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loss functions for learning global objectives.
+
+These functions have two return values: a Tensor with the value of
+the loss, and a dictionary of internal quantities for customizability.
+"""
+
+# Dependency imports
+import numpy
+import tensorflow as tf
+
+from global_objectives import util
+
+
+def precision_recall_auc_loss(
+ labels,
+ logits,
+ precision_range=(0.0, 1.0),
+ num_anchors=20,
+ weights=1.0,
+ dual_rate_factor=0.1,
+ label_priors=None,
+ surrogate_type='xent',
+ lambdas_initializer=tf.constant_initializer(1.0),
+ reuse=None,
+ variables_collections=None,
+ trainable=True,
+ scope=None):
+ """Computes precision-recall AUC loss.
+
+ The loss is based on a sum of losses for recall at a range of
+ precision values (anchor points). This sum is a Riemann sum that
+ approximates the area under the precision-recall curve.
+
+ The per-example `weights` argument changes not only the coefficients of
+ individual training examples, but how the examples are counted toward the
+ constraint. If `label_priors` is given, it MUST take `weights` into account.
+ That is,
+ label_priors = P / (P + N)
+ where
+ P = sum_i (wt_i on positives)
+ N = sum_i (wt_i on negatives).
+
+ Args:
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ logits: A `Tensor` with the same shape as `labels`.
+ precision_range: A length-two tuple, the range of precision values over
+ which to compute AUC. The entries must be nonnegative, increasing, and
+ less than or equal to 1.0.
+ num_anchors: The number of grid points used to approximate the Riemann sum.
+ weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
+ [batch_size] or [batch_size, num_labels].
+ dual_rate_factor: A floating point value which controls the step size for
+ the Lagrange multipliers.
+ label_priors: None, or a floating point `Tensor` of shape [num_labels]
+ containing the prior probability of each label (i.e. the fraction of the
+ training data consisting of positive examples). If None, the label
+ priors are computed from `labels` with a moving average. See the notes
+ above regarding the interaction with `weights` and do not set this unless
+ you have a good reason to do so.
+ surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for indicator functions.
+ lambdas_initializer: An initializer for the Lagrange multipliers.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for the variables.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ loss: A `Tensor` of the same shape as `logits` with the component-wise
+ loss.
+ other_outputs: A dictionary of useful internal quantities for debugging. For
+ more details, see http://arxiv.org/pdf/1608.04802.pdf.
+ lambdas: A Tensor of shape [1, num_labels, num_anchors] consisting of the
+ Lagrange multipliers.
+ biases: A Tensor of shape [1, num_labels, num_anchors] consisting of the
+ learned bias term for each.
+ label_priors: A Tensor of shape [1, num_labels, 1] consisting of the prior
+ probability of each label learned by the loss, if not provided.
+ true_positives_lower_bound: Lower bound on the number of true positives
+ given `labels` and `logits`. This is the same lower bound which is used
+ in the loss expression to be optimized.
+ false_positives_upper_bound: Upper bound on the number of false positives
+ given `labels` and `logits`. This is the same upper bound which is used
+ in the loss expression to be optimized.
+
+ Raises:
+ ValueError: If `surrogate_type` is not `xent` or `hinge`.
+ """
+ with tf.variable_scope(scope,
+ 'precision_recall_auc',
+ [labels, logits, label_priors],
+ reuse=reuse):
+ labels, logits, weights, original_shape = _prepare_labels_logits_weights(
+ labels, logits, weights)
+ num_labels = util.get_num_labels(logits)
+
+ # Convert other inputs to tensors and standardize dtypes.
+ dual_rate_factor = util.convert_and_cast(
+ dual_rate_factor, 'dual_rate_factor', logits.dtype)
+
+ # Create Tensor of anchor points and distance between anchors.
+ precision_values, delta = _range_to_anchors_and_delta(
+ precision_range, num_anchors, logits.dtype)
+ # Create lambdas with shape [1, num_labels, num_anchors].
+ lambdas, lambdas_variable = _create_dual_variable(
+ 'lambdas',
+ shape=[1, num_labels, num_anchors],
+ dtype=logits.dtype,
+ initializer=lambdas_initializer,
+ collections=variables_collections,
+ trainable=trainable,
+ dual_rate_factor=dual_rate_factor)
+ # Create biases with shape [1, num_labels, num_anchors].
+ biases = tf.contrib.framework.model_variable(
+ name='biases',
+ shape=[1, num_labels, num_anchors],
+ dtype=logits.dtype,
+ initializer=tf.zeros_initializer(),
+ collections=variables_collections,
+ trainable=trainable)
+ # Maybe create label_priors.
+ label_priors = maybe_create_label_priors(
+ label_priors, labels, weights, variables_collections)
+ label_priors = tf.reshape(label_priors, [1, num_labels, 1])
+
+ # Expand logits, labels, and weights to shape [batch_size, num_labels, 1].
+ logits = tf.expand_dims(logits, 2)
+ labels = tf.expand_dims(labels, 2)
+ weights = tf.expand_dims(weights, 2)
+
+ # Calculate weighted loss and other outputs. The log(2.0) term corrects for
+ # logloss not being an upper bound on the indicator function.
+ loss = weights * util.weighted_surrogate_loss(
+ labels,
+ logits + biases,
+ surrogate_type=surrogate_type,
+ positive_weights=1.0 + lambdas * (1.0 - precision_values),
+ negative_weights=lambdas * precision_values)
+ maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
+ maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
+ lambda_term = lambdas * (1.0 - precision_values) * label_priors * maybe_log2
+ per_anchor_loss = loss - lambda_term
+ per_label_loss = delta * tf.reduce_sum(per_anchor_loss, 2)
+ # Normalize the AUC such that a perfect score function will have AUC 1.0.
+ # Because precision_range is discretized into num_anchors + 1 intervals
+ # but only num_anchors terms are included in the Riemann sum, the
+ # effective length of the integration interval is `delta` less than the
+ # length of precision_range.
+ scaled_loss = tf.div(per_label_loss,
+ precision_range[1] - precision_range[0] - delta,
+ name='AUC_Normalize')
+ scaled_loss = tf.reshape(scaled_loss, original_shape)
+
+ other_outputs = {
+ 'lambdas': lambdas_variable,
+ 'biases': biases,
+ 'label_priors': label_priors,
+ 'true_positives_lower_bound': true_positives_lower_bound(
+ labels, logits, weights, surrogate_type),
+ 'false_positives_upper_bound': false_positives_upper_bound(
+ labels, logits, weights, surrogate_type)}
+
+ return scaled_loss, other_outputs
+
+
+def roc_auc_loss(
+ labels,
+ logits,
+ weights=1.0,
+ surrogate_type='xent',
+ scope=None):
+ """Computes ROC AUC loss.
+
+ The area under the ROC curve is the probability p that a randomly chosen
+ positive example will be scored higher than a randomly chosen negative
+ example. This loss approximates 1-p by using a surrogate (either hinge loss or
+ cross entropy) for the indicator function. Specifically, the loss is:
+
+ sum_i sum_j w_i*w_j*loss(logit_i - logit_j)
+
+ where i ranges over the positive datapoints, j ranges over the negative
+ datapoints, logit_k denotes the logit (or score) of the k-th datapoint, and
+ loss is either the hinge or log loss given a positive label.
+
+ Args:
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ logits: A `Tensor` with the same shape and dtype as `labels`.
+ weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
+ [batch_size] or [batch_size, num_labels].
+ surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for the indicator function.
+ scope: Optional scope for `name_scope`.
+
+ Returns:
+ loss: A `Tensor` of the same shape as `logits` with the component-wise loss.
+ other_outputs: An empty dictionary, for consistency.
+
+ Raises:
+ ValueError: If `surrogate_type` is not `xent` or `hinge`.
+ """
+ with tf.name_scope(scope, 'roc_auc', [labels, logits, weights]):
+ # Convert inputs to tensors and standardize dtypes.
+ labels, logits, weights, original_shape = _prepare_labels_logits_weights(
+ labels, logits, weights)
+
+ # Create tensors of pairwise differences for logits and labels, and
+ # pairwise products of weights. These have shape
+ # [batch_size, batch_size, num_labels].
+ logits_difference = tf.expand_dims(logits, 0) - tf.expand_dims(logits, 1)
+ labels_difference = tf.expand_dims(labels, 0) - tf.expand_dims(labels, 1)
+ weights_product = tf.expand_dims(weights, 0) * tf.expand_dims(weights, 1)
+
+ signed_logits_difference = labels_difference * logits_difference
+ raw_loss = util.weighted_surrogate_loss(
+ labels=tf.ones_like(signed_logits_difference),
+ logits=signed_logits_difference,
+ surrogate_type=surrogate_type)
+ weighted_loss = weights_product * raw_loss
+
+ # Zero out entries of the loss where labels_difference zero (so loss is only
+ # computed on pairs with different labels).
+ loss = tf.reduce_mean(tf.abs(labels_difference) * weighted_loss, 0) * 0.5
+ loss = tf.reshape(loss, original_shape)
+ return loss, {}
+
+
+def recall_at_precision_loss(
+ labels,
+ logits,
+ target_precision,
+ weights=1.0,
+ dual_rate_factor=0.1,
+ label_priors=None,
+ surrogate_type='xent',
+ lambdas_initializer=tf.constant_initializer(1.0),
+ reuse=None,
+ variables_collections=None,
+ trainable=True,
+ scope=None):
+ """Computes recall at precision loss.
+
+ The loss is based on a surrogate of the form
+ wt * w(+) * loss(+) + wt * w(-) * loss(-) - c * pi,
+ where:
+ - w(+) = 1 + lambdas * (1 - target_precision)
+ - loss(+) is the cross-entropy loss on the positive examples
+ - w(-) = lambdas * target_precision
+ - loss(-) is the cross-entropy loss on the negative examples
+ - wt is a scalar or tensor of per-example weights
+ - c = lambdas * (1 - target_precision)
+ - pi is the label_priors.
+
+ The per-example weights change not only the coefficients of individual
+ training examples, but how the examples are counted toward the constraint.
+ If `label_priors` is given, it MUST take `weights` into account. That is,
+ label_priors = P / (P + N)
+ where
+ P = sum_i (wt_i on positives)
+ N = sum_i (wt_i on negatives).
+
+ Args:
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ logits: A `Tensor` with the same shape as `labels`.
+ target_precision: The precision at which to compute the loss. Can be a
+ floating point value between 0 and 1 for a single precision value, or a
+ `Tensor` of shape [num_labels], holding each label's target precision
+ value.
+ weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
+ [batch_size] or [batch_size, num_labels].
+ dual_rate_factor: A floating point value which controls the step size for
+ the Lagrange multipliers.
+ label_priors: None, or a floating point `Tensor` of shape [num_labels]
+ containing the prior probability of each label (i.e. the fraction of the
+ training data consisting of positive examples). If None, the label
+ priors are computed from `labels` with a moving average. See the notes
+ above regarding the interaction with `weights` and do not set this unless
+ you have a good reason to do so.
+ surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for indicator functions.
+ lambdas_initializer: An initializer for the Lagrange multipliers.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for the variables.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ loss: A `Tensor` of the same shape as `logits` with the component-wise
+ loss.
+ other_outputs: A dictionary of useful internal quantities for debugging. For
+ more details, see http://arxiv.org/pdf/1608.04802.pdf.
+ lambdas: A Tensor of shape [num_labels] consisting of the Lagrange
+ multipliers.
+ label_priors: A Tensor of shape [num_labels] consisting of the prior
+ probability of each label learned by the loss, if not provided.
+ true_positives_lower_bound: Lower bound on the number of true positives
+ given `labels` and `logits`. This is the same lower bound which is used
+ in the loss expression to be optimized.
+ false_positives_upper_bound: Upper bound on the number of false positives
+ given `labels` and `logits`. This is the same upper bound which is used
+ in the loss expression to be optimized.
+
+ Raises:
+ ValueError: If `logits` and `labels` do not have the same shape.
+ """
+ with tf.variable_scope(scope,
+ 'recall_at_precision',
+ [logits, labels, label_priors],
+ reuse=reuse):
+ labels, logits, weights, original_shape = _prepare_labels_logits_weights(
+ labels, logits, weights)
+ num_labels = util.get_num_labels(logits)
+
+ # Convert other inputs to tensors and standardize dtypes.
+ target_precision = util.convert_and_cast(
+ target_precision, 'target_precision', logits.dtype)
+ dual_rate_factor = util.convert_and_cast(
+ dual_rate_factor, 'dual_rate_factor', logits.dtype)
+
+ # Create lambdas.
+ lambdas, lambdas_variable = _create_dual_variable(
+ 'lambdas',
+ shape=[num_labels],
+ dtype=logits.dtype,
+ initializer=lambdas_initializer,
+ collections=variables_collections,
+ trainable=trainable,
+ dual_rate_factor=dual_rate_factor)
+ # Maybe create label_priors.
+ label_priors = maybe_create_label_priors(
+ label_priors, labels, weights, variables_collections)
+
+ # Calculate weighted loss and other outputs. The log(2.0) term corrects for
+ # logloss not being an upper bound on the indicator function.
+ weighted_loss = weights * util.weighted_surrogate_loss(
+ labels,
+ logits,
+ surrogate_type=surrogate_type,
+ positive_weights=1.0 + lambdas * (1.0 - target_precision),
+ negative_weights=lambdas * target_precision)
+ maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
+ maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
+ lambda_term = lambdas * (1.0 - target_precision) * label_priors * maybe_log2
+ loss = tf.reshape(weighted_loss - lambda_term, original_shape)
+ other_outputs = {
+ 'lambdas': lambdas_variable,
+ 'label_priors': label_priors,
+ 'true_positives_lower_bound': true_positives_lower_bound(
+ labels, logits, weights, surrogate_type),
+ 'false_positives_upper_bound': false_positives_upper_bound(
+ labels, logits, weights, surrogate_type)}
+
+ return loss, other_outputs
+
+
+def precision_at_recall_loss(
+ labels,
+ logits,
+ target_recall,
+ weights=1.0,
+ dual_rate_factor=0.1,
+ label_priors=None,
+ surrogate_type='xent',
+ lambdas_initializer=tf.constant_initializer(1.0),
+ reuse=None,
+ variables_collections=None,
+ trainable=True,
+ scope=None):
+ """Computes precision at recall loss.
+
+ The loss is based on a surrogate of the form
+ wt * loss(-) + lambdas * (pi * (b - 1) + wt * loss(+))
+ where:
+ - loss(-) is the cross-entropy loss on the negative examples
+ - loss(+) is the cross-entropy loss on the positive examples
+ - wt is a scalar or tensor of per-example weights
+ - b is the target recall
+ - pi is the label_priors.
+
+ The per-example weights change not only the coefficients of individual
+ training examples, but how the examples are counted toward the constraint.
+ If `label_priors` is given, it MUST take `weights` into account. That is,
+ label_priors = P / (P + N)
+ where
+ P = sum_i (wt_i on positives)
+ N = sum_i (wt_i on negatives).
+
+ Args:
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ logits: A `Tensor` with the same shape as `labels`.
+ target_recall: The recall at which to compute the loss. Can be a floating
+ point value between 0 and 1 for a single target recall value, or a
+ `Tensor` of shape [num_labels] holding each label's target recall value.
+ weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
+ [batch_size] or [batch_size, num_labels].
+ dual_rate_factor: A floating point value which controls the step size for
+ the Lagrange multipliers.
+ label_priors: None, or a floating point `Tensor` of shape [num_labels]
+ containing the prior probability of each label (i.e. the fraction of the
+ training data consisting of positive examples). If None, the label
+ priors are computed from `labels` with a moving average. See the notes
+ above regarding the interaction with `weights` and do not set this unless
+ you have a good reason to do so.
+ surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for indicator functions.
+ lambdas_initializer: An initializer for the Lagrange multipliers.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for the variables.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ loss: A `Tensor` of the same shape as `logits` with the component-wise
+ loss.
+ other_outputs: A dictionary of useful internal quantities for debugging. For
+ more details, see http://arxiv.org/pdf/1608.04802.pdf.
+ lambdas: A Tensor of shape [num_labels] consisting of the Lagrange
+ multipliers.
+ label_priors: A Tensor of shape [num_labels] consisting of the prior
+ probability of each label learned by the loss, if not provided.
+ true_positives_lower_bound: Lower bound on the number of true positives
+ given `labels` and `logits`. This is the same lower bound which is used
+ in the loss expression to be optimized.
+ false_positives_upper_bound: Upper bound on the number of false positives
+ given `labels` and `logits`. This is the same upper bound which is used
+ in the loss expression to be optimized.
+ """
+ with tf.variable_scope(scope,
+ 'precision_at_recall',
+ [logits, labels, label_priors],
+ reuse=reuse):
+ labels, logits, weights, original_shape = _prepare_labels_logits_weights(
+ labels, logits, weights)
+ num_labels = util.get_num_labels(logits)
+
+ # Convert other inputs to tensors and standardize dtypes.
+ target_recall = util.convert_and_cast(
+ target_recall, 'target_recall', logits.dtype)
+ dual_rate_factor = util.convert_and_cast(
+ dual_rate_factor, 'dual_rate_factor', logits.dtype)
+
+ # Create lambdas.
+ lambdas, lambdas_variable = _create_dual_variable(
+ 'lambdas',
+ shape=[num_labels],
+ dtype=logits.dtype,
+ initializer=lambdas_initializer,
+ collections=variables_collections,
+ trainable=trainable,
+ dual_rate_factor=dual_rate_factor)
+ # Maybe create label_priors.
+ label_priors = maybe_create_label_priors(
+ label_priors, labels, weights, variables_collections)
+
+ # Calculate weighted loss and other outputs. The log(2.0) term corrects for
+ # logloss not being an upper bound on the indicator function.
+ weighted_loss = weights * util.weighted_surrogate_loss(
+ labels,
+ logits,
+ surrogate_type,
+ positive_weights=lambdas,
+ negative_weights=1.0)
+ maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
+ maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
+ lambda_term = lambdas * label_priors * (target_recall - 1.0) * maybe_log2
+ loss = tf.reshape(weighted_loss + lambda_term, original_shape)
+ other_outputs = {
+ 'lambdas': lambdas_variable,
+ 'label_priors': label_priors,
+ 'true_positives_lower_bound': true_positives_lower_bound(
+ labels, logits, weights, surrogate_type),
+ 'false_positives_upper_bound': false_positives_upper_bound(
+ labels, logits, weights, surrogate_type)}
+
+ return loss, other_outputs
+
+
+def false_positive_rate_at_true_positive_rate_loss(
+ labels,
+ logits,
+ target_rate,
+ weights=1.0,
+ dual_rate_factor=0.1,
+ label_priors=None,
+ surrogate_type='xent',
+ lambdas_initializer=tf.constant_initializer(1.0),
+ reuse=None,
+ variables_collections=None,
+ trainable=True,
+ scope=None):
+ """Computes false positive rate at true positive rate loss.
+
+ Note that `true positive rate` is a synonym for Recall, and that minimizing
+ the false positive rate and maximizing precision are equivalent for a fixed
+ Recall. Therefore, this function is identical to precision_at_recall_loss.
+
+ The per-example weights change not only the coefficients of individual
+ training examples, but how the examples are counted toward the constraint.
+ If `label_priors` is given, it MUST take `weights` into account. That is,
+ label_priors = P / (P + N)
+ where
+ P = sum_i (wt_i on positives)
+ N = sum_i (wt_i on negatives).
+
+ Args:
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ logits: A `Tensor` with the same shape as `labels`.
+ target_rate: The true positive rate at which to compute the loss. Can be a
+ floating point value between 0 and 1 for a single true positive rate, or
+ a `Tensor` of shape [num_labels] holding each label's true positive rate.
+ weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
+ [batch_size] or [batch_size, num_labels].
+ dual_rate_factor: A floating point value which controls the step size for
+ the Lagrange multipliers.
+ label_priors: None, or a floating point `Tensor` of shape [num_labels]
+ containing the prior probability of each label (i.e. the fraction of the
+ training data consisting of positive examples). If None, the label
+ priors are computed from `labels` with a moving average. See the notes
+ above regarding the interaction with `weights` and do not set this unless
+ you have a good reason to do so.
+ surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for indicator functions. 'xent' will use the cross-entropy
+ loss surrogate, and 'hinge' will use the hinge loss.
+ lambdas_initializer: An initializer op for the Lagrange multipliers.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for the variables.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ loss: A `Tensor` of the same shape as `logits` with the component-wise
+ loss.
+ other_outputs: A dictionary of useful internal quantities for debugging. For
+ more details, see http://arxiv.org/pdf/1608.04802.pdf.
+ lambdas: A Tensor of shape [num_labels] consisting of the Lagrange
+ multipliers.
+ label_priors: A Tensor of shape [num_labels] consisting of the prior
+ probability of each label learned by the loss, if not provided.
+ true_positives_lower_bound: Lower bound on the number of true positives
+ given `labels` and `logits`. This is the same lower bound which is used
+ in the loss expression to be optimized.
+ false_positives_upper_bound: Upper bound on the number of false positives
+ given `labels` and `logits`. This is the same upper bound which is used
+ in the loss expression to be optimized.
+
+ Raises:
+ ValueError: If `surrogate_type` is not `xent` or `hinge`.
+ """
+ return precision_at_recall_loss(labels=labels,
+ logits=logits,
+ target_recall=target_rate,
+ weights=weights,
+ dual_rate_factor=dual_rate_factor,
+ label_priors=label_priors,
+ surrogate_type=surrogate_type,
+ lambdas_initializer=lambdas_initializer,
+ reuse=reuse,
+ variables_collections=variables_collections,
+ trainable=trainable,
+ scope=scope)
+
+
+def true_positive_rate_at_false_positive_rate_loss(
+ labels,
+ logits,
+ target_rate,
+ weights=1.0,
+ dual_rate_factor=0.1,
+ label_priors=None,
+ surrogate_type='xent',
+ lambdas_initializer=tf.constant_initializer(1.0),
+ reuse=None,
+ variables_collections=None,
+ trainable=True,
+ scope=None):
+ """Computes true positive rate at false positive rate loss.
+
+ The loss is based on a surrogate of the form
+ wt * loss(+) + lambdas * (wt * loss(-) - r * (1 - pi))
+ where:
+ - loss(-) is the loss on the negative examples
+ - loss(+) is the loss on the positive examples
+ - wt is a scalar or tensor of per-example weights
+ - r is the target rate
+ - pi is the label_priors.
+
+ The per-example weights change not only the coefficients of individual
+ training examples, but how the examples are counted toward the constraint.
+ If `label_priors` is given, it MUST take `weights` into account. That is,
+ label_priors = P / (P + N)
+ where
+ P = sum_i (wt_i on positives)
+ N = sum_i (wt_i on negatives).
+
+ Args:
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ logits: A `Tensor` with the same shape as `labels`.
+ target_rate: The false positive rate at which to compute the loss. Can be a
+ floating point value between 0 and 1 for a single false positive rate, or
+ a `Tensor` of shape [num_labels] holding each label's false positive rate.
+ weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
+ [batch_size] or [batch_size, num_labels].
+ dual_rate_factor: A floating point value which controls the step size for
+ the Lagrange multipliers.
+ label_priors: None, or a floating point `Tensor` of shape [num_labels]
+ containing the prior probability of each label (i.e. the fraction of the
+ training data consisting of positive examples). If None, the label
+ priors are computed from `labels` with a moving average. See the notes
+ above regarding the interaction with `weights` and do not set this unless
+ you have a good reason to do so.
+ surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for indicator functions. 'xent' will use the cross-entropy
+ loss surrogate, and 'hinge' will use the hinge loss.
+ lambdas_initializer: An initializer op for the Lagrange multipliers.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for the variables.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ loss: A `Tensor` of the same shape as `logits` with the component-wise
+ loss.
+ other_outputs: A dictionary of useful internal quantities for debugging. For
+ more details, see http://arxiv.org/pdf/1608.04802.pdf.
+ lambdas: A Tensor of shape [num_labels] consisting of the Lagrange
+ multipliers.
+ label_priors: A Tensor of shape [num_labels] consisting of the prior
+ probability of each label learned by the loss, if not provided.
+ true_positives_lower_bound: Lower bound on the number of true positives
+ given `labels` and `logits`. This is the same lower bound which is used
+ in the loss expression to be optimized.
+ false_positives_upper_bound: Upper bound on the number of false positives
+ given `labels` and `logits`. This is the same upper bound which is used
+ in the loss expression to be optimized.
+
+ Raises:
+ ValueError: If `surrogate_type` is not `xent` or `hinge`.
+ """
+ with tf.variable_scope(scope,
+ 'tpr_at_fpr',
+ [labels, logits, label_priors],
+ reuse=reuse):
+ labels, logits, weights, original_shape = _prepare_labels_logits_weights(
+ labels, logits, weights)
+ num_labels = util.get_num_labels(logits)
+
+ # Convert other inputs to tensors and standardize dtypes.
+ target_rate = util.convert_and_cast(
+ target_rate, 'target_rate', logits.dtype)
+ dual_rate_factor = util.convert_and_cast(
+ dual_rate_factor, 'dual_rate_factor', logits.dtype)
+
+ # Create lambdas.
+ lambdas, lambdas_variable = _create_dual_variable(
+ 'lambdas',
+ shape=[num_labels],
+ dtype=logits.dtype,
+ initializer=lambdas_initializer,
+ collections=variables_collections,
+ trainable=trainable,
+ dual_rate_factor=dual_rate_factor)
+ # Maybe create label_priors.
+ label_priors = maybe_create_label_priors(
+ label_priors, labels, weights, variables_collections)
+
+ # Loss op and other outputs. The log(2.0) term corrects for
+ # logloss not being an upper bound on the indicator function.
+ weighted_loss = weights * util.weighted_surrogate_loss(
+ labels,
+ logits,
+ surrogate_type=surrogate_type,
+ positive_weights=1.0,
+ negative_weights=lambdas)
+ maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
+ maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
+ lambda_term = lambdas * target_rate * (1.0 - label_priors) * maybe_log2
+ loss = tf.reshape(weighted_loss - lambda_term, original_shape)
+ other_outputs = {
+ 'lambdas': lambdas_variable,
+ 'label_priors': label_priors,
+ 'true_positives_lower_bound': true_positives_lower_bound(
+ labels, logits, weights, surrogate_type),
+ 'false_positives_upper_bound': false_positives_upper_bound(
+ labels, logits, weights, surrogate_type)}
+
+ return loss, other_outputs
+
+
+def _prepare_labels_logits_weights(labels, logits, weights):
+ """Validates labels, logits, and weights.
+
+ Converts inputs to tensors, checks shape compatibility, and casts dtype if
+ necessary.
+
+ Args:
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ logits: A `Tensor` with the same shape as `labels`.
+ weights: Either `None` or a `Tensor` with shape broadcastable to `logits`.
+
+ Returns:
+ labels: Same as `labels` arg after possible conversion to tensor, cast, and
+ reshape.
+ logits: Same as `logits` arg after possible conversion to tensor and
+ reshape.
+ weights: Same as `weights` arg after possible conversion, cast, and reshape.
+ original_shape: Shape of `labels` and `logits` before reshape.
+
+ Raises:
+ ValueError: If `labels` and `logits` do not have the same shape.
+ """
+ # Convert `labels` and `logits` to Tensors and standardize dtypes.
+ logits = tf.convert_to_tensor(logits, name='logits')
+ labels = util.convert_and_cast(labels, 'labels', logits.dtype.base_dtype)
+ weights = util.convert_and_cast(weights, 'weights', logits.dtype.base_dtype)
+
+ try:
+ labels.get_shape().merge_with(logits.get_shape())
+ except ValueError:
+ raise ValueError('logits and labels must have the same shape (%s vs %s)' %
+ (logits.get_shape(), labels.get_shape()))
+
+ original_shape = labels.get_shape().as_list()
+ if labels.get_shape().ndims > 0:
+ original_shape[0] = -1
+ if labels.get_shape().ndims <= 1:
+ labels = tf.reshape(labels, [-1, 1])
+ logits = tf.reshape(logits, [-1, 1])
+
+ if weights.get_shape().ndims == 1:
+ # Weights has shape [batch_size]. Reshape to [batch_size, 1].
+ weights = tf.reshape(weights, [-1, 1])
+ if weights.get_shape().ndims == 0:
+ # Weights is a scalar. Change shape of weights to match logits.
+ weights *= tf.ones_like(logits)
+
+ return labels, logits, weights, original_shape
+
+
+def _range_to_anchors_and_delta(precision_range, num_anchors, dtype):
+ """Calculates anchor points from precision range.
+
+ Args:
+ precision_range: As required in precision_recall_auc_loss.
+ num_anchors: int, number of equally spaced anchor points.
+ dtype: Data type of returned tensors.
+
+ Returns:
+ precision_values: A `Tensor` of data type dtype with equally spaced values
+ in the interval precision_range.
+ delta: The spacing between the values in precision_values.
+
+ Raises:
+ ValueError: If precision_range is invalid.
+ """
+ # Validate precision_range.
+ if not 0 <= precision_range[0] <= precision_range[-1] <= 1:
+ raise ValueError('precision values must obey 0 <= %f <= %f <= 1' %
+ (precision_range[0], precision_range[-1]))
+ if not 0 < len(precision_range) < 3:
+ raise ValueError('length of precision_range (%d) must be 1 or 2' %
+ len(precision_range))
+
+ # Sets precision_values uniformly between min_precision and max_precision.
+ values = numpy.linspace(start=precision_range[0],
+ stop=precision_range[1],
+ num=num_anchors+2)[1:-1]
+ precision_values = util.convert_and_cast(
+ values, 'precision_values', dtype)
+ delta = util.convert_and_cast(
+ values[0] - precision_range[0], 'delta', dtype)
+ # Makes precision_values [1, 1, num_anchors].
+ precision_values = util.expand_outer(precision_values, 3)
+ return precision_values, delta
+
+
+def _create_dual_variable(name, shape, dtype, initializer, collections,
+ trainable, dual_rate_factor):
+ """Creates a new dual variable.
+
+ Dual variables are required to be nonnegative. If trainable, their gradient
+ is reversed so that they are maximized (rather than minimized) by the
+ optimizer.
+
+ Args:
+ name: A string, the name for the new variable.
+ shape: Shape of the new variable.
+ dtype: Data type for the new variable.
+ initializer: Initializer for the new variable.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ dual_rate_factor: A floating point value or `Tensor`. The learning rate for
+ the dual variable is scaled by this factor.
+
+ Returns:
+ dual_value: An op that computes the absolute value of the dual variable
+ and reverses its gradient.
+ dual_variable: The underlying variable itself.
+ """
+ # We disable partitioning while constructing dual variables because they will
+ # be updated with assign, which is not available for partitioned variables.
+ partitioner = tf.get_variable_scope().partitioner
+ try:
+ tf.get_variable_scope().set_partitioner(None)
+ dual_variable = tf.contrib.framework.model_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ collections=collections,
+ trainable=trainable)
+ finally:
+ tf.get_variable_scope().set_partitioner(partitioner)
+ # Using the absolute value enforces nonnegativity.
+ dual_value = tf.abs(dual_variable)
+
+ if trainable:
+ # To reverse the gradient on the dual variable, multiply the gradient by
+ # -dual_rate_factor
+ dual_value = (tf.stop_gradient((1.0 + dual_rate_factor) * dual_value)
+ - dual_rate_factor * dual_value)
+ return dual_value, dual_variable
+
+
+def maybe_create_label_priors(label_priors,
+ labels,
+ weights,
+ variables_collections):
+ """Creates moving average ops to track label priors, if necessary.
+
+ Args:
+ label_priors: As required in e.g. precision_recall_auc_loss.
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ weights: As required in e.g. precision_recall_auc_loss.
+ variables_collections: Optional list of collections for the variables, if
+ any must be created.
+
+ Returns:
+ label_priors: A Tensor of shape [num_labels] consisting of the
+ weighted label priors, after updating with moving average ops if created.
+ """
+ if label_priors is not None:
+ label_priors = util.convert_and_cast(
+ label_priors, name='label_priors', dtype=labels.dtype.base_dtype)
+ return tf.squeeze(label_priors)
+
+ label_priors = util.build_label_priors(
+ labels,
+ weights,
+ variables_collections=variables_collections)
+ return label_priors
+
+
+def true_positives_lower_bound(labels, logits, weights, surrogate_type):
+ """Calculate a lower bound on the number of true positives.
+
+ This lower bound on the number of true positives given `logits` and `labels`
+ is the same one used in the global objectives loss functions.
+
+ Args:
+ labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
+ logits: A `Tensor` of shape [batch_size, num_labels] or
+ [batch_size, num_labels, num_anchors]. If the third dimension is present,
+ the lower bound is computed on each slice [:, :, k] independently.
+ weights: Per-example loss coefficients, with shape broadcast-compatible with
+ that of `labels`.
+ surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for indicator functions.
+
+ Returns:
+ A `Tensor` of shape [num_labels] or [num_labels, num_anchors].
+ """
+ maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
+ maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
+ if logits.get_shape().ndims == 3 and labels.get_shape().ndims < 3:
+ labels = tf.expand_dims(labels, 2)
+ loss_on_positives = util.weighted_surrogate_loss(
+ labels, logits, surrogate_type, negative_weights=0.0) / maybe_log2
+ return tf.reduce_sum(weights * (labels - loss_on_positives), 0)
+
+
+def false_positives_upper_bound(labels, logits, weights, surrogate_type):
+ """Calculate an upper bound on the number of false positives.
+
+ This upper bound on the number of false positives given `logits` and `labels`
+ is the same one used in the global objectives loss functions.
+
+ Args:
+ labels: A `Tensor` of shape [batch_size, num_labels]
+ logits: A `Tensor` of shape [batch_size, num_labels] or
+ [batch_size, num_labels, num_anchors]. If the third dimension is present,
+ the lower bound is computed on each slice [:, :, k] independently.
+ weights: Per-example loss coefficients, with shape broadcast-compatible with
+ that of `labels`.
+ surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
+ should be used for indicator functions.
+
+ Returns:
+ A `Tensor` of shape [num_labels] or [num_labels, num_anchors].
+ """
+ maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
+ maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
+ loss_on_negatives = util.weighted_surrogate_loss(
+ labels, logits, surrogate_type, positive_weights=0.0) / maybe_log2
+ return tf.reduce_sum(weights * loss_on_negatives, 0)
diff --git a/models/research/global_objectives/loss_layers_example.py b/models/research/global_objectives/loss_layers_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..2323cb0762e7f4eade8f283162be61cc45513d49
--- /dev/null
+++ b/models/research/global_objectives/loss_layers_example.py
@@ -0,0 +1,211 @@
+# Copyright 2018 The TensorFlow Global Objectives Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Example for using global objectives.
+
+Illustrate, using synthetic data, how using the precision_at_recall loss
+significanly improves the performace of a linear classifier.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+import numpy as np
+from sklearn.metrics import precision_score
+import tensorflow as tf
+from global_objectives import loss_layers
+
+# When optimizing using global_objectives, if set to True then the saddle point
+# optimization steps are performed internally by the Tensorflow optimizer,
+# otherwise by dedicated saddle-point steps as part of the optimization loop.
+USE_GO_SADDLE_POINT_OPT = False
+
+TARGET_RECALL = 0.98
+TRAIN_ITERATIONS = 150
+LEARNING_RATE = 1.0
+GO_DUAL_RATE_FACTOR = 15.0
+NUM_CHECKPOINTS = 6
+
+EXPERIMENT_DATA_CONFIG = {
+ 'positives_centers': [[0, 1.0], [1, -0.5]],
+ 'negatives_centers': [[0, -0.5], [1, 1.0]],
+ 'positives_variances': [0.15, 0.1],
+ 'negatives_variances': [0.15, 0.1],
+ 'positives_counts': [500, 50],
+ 'negatives_counts': [3000, 100]
+}
+
+
+def create_training_and_eval_data_for_experiment(**data_config):
+ """Creates train and eval data sets.
+
+ Note: The synthesized binary-labeled data is a mixture of four Gaussians - two
+ positives and two negatives. The centers, variances, and sizes for each of
+ the two positives and negatives mixtures are passed in the respective keys
+ of data_config:
+
+ Args:
+ **data_config: Dictionary with Array entries as follows:
+ positives_centers - float [2,2] two centers of positives data sets.
+ negatives_centers - float [2,2] two centers of negatives data sets.
+ positives_variances - float [2] Variances for the positives sets.
+ negatives_variances - float [2] Variances for the negatives sets.
+ positives_counts - int [2] Counts for each of the two positives sets.
+ negatives_counts - int [2] Counts for each of the two negatives sets.
+
+ Returns:
+ A dictionary with two shuffled data sets created - one for training and one
+ for eval. The dictionary keys are 'train_data', 'train_labels', 'eval_data',
+ and 'eval_labels'. The data points are two-dimentional floats, and the
+ labels are in {0,1}.
+ """
+ def data_points(is_positives, index):
+ variance = data_config['positives_variances'
+ if is_positives else 'negatives_variances'][index]
+ center = data_config['positives_centers'
+ if is_positives else 'negatives_centers'][index]
+ count = data_config['positives_counts'
+ if is_positives else 'negatives_counts'][index]
+ return variance*np.random.randn(count, 2) + np.array([center])
+
+ def create_data():
+ return np.concatenate([data_points(False, 0),
+ data_points(True, 0),
+ data_points(True, 1),
+ data_points(False, 1)], axis=0)
+
+ def create_labels():
+ """Creates an array of 0.0 or 1.0 labels for the data_config batches."""
+ return np.array([0.0]*data_config['negatives_counts'][0] +
+ [1.0]*data_config['positives_counts'][0] +
+ [1.0]*data_config['positives_counts'][1] +
+ [0.0]*data_config['negatives_counts'][1])
+
+ permutation = np.random.permutation(
+ sum(data_config['positives_counts'] + data_config['negatives_counts']))
+
+ train_data = create_data()[permutation, :]
+ eval_data = create_data()[permutation, :]
+ train_labels = create_labels()[permutation]
+ eval_labels = create_labels()[permutation]
+
+ return {
+ 'train_data': train_data,
+ 'train_labels': train_labels,
+ 'eval_data': eval_data,
+ 'eval_labels': eval_labels
+ }
+
+
+def train_model(data, use_global_objectives):
+ """Trains a linear model for maximal accuracy or precision at given recall."""
+
+ def precision_at_recall(scores, labels, target_recall):
+ """Computes precision - at target recall - over data."""
+ positive_scores = scores[labels == 1.0]
+ threshold = np.percentile(positive_scores, 100 - target_recall*100)
+ predicted = scores >= threshold
+ return precision_score(labels, predicted)
+
+ w = tf.Variable(tf.constant([-1.0, -1.0], shape=[2, 1]), trainable=True,
+ name='weights', dtype=tf.float32)
+ b = tf.Variable(tf.zeros([1]), trainable=True, name='biases',
+ dtype=tf.float32)
+
+ logits = tf.matmul(tf.cast(data['train_data'], tf.float32), w) + b
+
+ labels = tf.constant(
+ data['train_labels'],
+ shape=[len(data['train_labels']), 1],
+ dtype=tf.float32)
+
+ if use_global_objectives:
+ loss, other_outputs = loss_layers.precision_at_recall_loss(
+ labels, logits,
+ TARGET_RECALL,
+ dual_rate_factor=GO_DUAL_RATE_FACTOR)
+ loss = tf.reduce_mean(loss)
+ else:
+ loss = tf.reduce_mean(
+ tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
+
+ global_step = tf.Variable(0, trainable=False)
+
+ learning_rate = tf.train.polynomial_decay(
+ LEARNING_RATE,
+ global_step,
+ TRAIN_ITERATIONS, (LEARNING_RATE / TRAIN_ITERATIONS),
+ power=1.0,
+ cycle=False,
+ name='learning_rate')
+
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+
+ if (not use_global_objectives) or USE_GO_SADDLE_POINT_OPT:
+ training_op = optimizer.minimize(loss, global_step=global_step)
+ else:
+ lambdas = other_outputs['lambdas']
+ primal_update_op = optimizer.minimize(loss, var_list=[w, b])
+ dual_update_op = optimizer.minimize(
+ loss, global_step=global_step, var_list=[lambdas])
+
+ # Training loop:
+ with tf.Session() as sess:
+ checkpoint_step = TRAIN_ITERATIONS // NUM_CHECKPOINTS
+ sess.run(tf.global_variables_initializer())
+ step = sess.run(global_step)
+
+ while step <= TRAIN_ITERATIONS:
+ if (not use_global_objectives) or USE_GO_SADDLE_POINT_OPT:
+ _, step, loss_value, w_value, b_value = sess.run(
+ [training_op, global_step, loss, w, b])
+ else:
+ _, w_value, b_value = sess.run([primal_update_op, w, b])
+ _, loss_value, step = sess.run([dual_update_op, loss, global_step])
+
+ if use_global_objectives:
+ go_outputs = sess.run(other_outputs.values())
+
+ if step % checkpoint_step == 0:
+ precision = precision_at_recall(
+ np.dot(data['train_data'], w_value) + b_value,
+ data['train_labels'], TARGET_RECALL)
+
+ tf.logging.info('Loss = %f Precision = %f', loss_value, precision)
+ if use_global_objectives:
+ for i, output_name in enumerate(other_outputs.keys()):
+ tf.logging.info('\t%s = %f', output_name, go_outputs[i])
+
+ w_value, b_value = sess.run([w, b])
+ return precision_at_recall(np.dot(data['eval_data'], w_value) + b_value,
+ data['eval_labels'],
+ TARGET_RECALL)
+
+
+def main(unused_argv):
+ del unused_argv
+ experiment_data = create_training_and_eval_data_for_experiment(
+ **EXPERIMENT_DATA_CONFIG)
+ global_objectives_loss_precision = train_model(experiment_data, True)
+ tf.logging.info('global_objectives precision at requested recall is %f',
+ global_objectives_loss_precision)
+ cross_entropy_loss_precision = train_model(experiment_data, False)
+ tf.logging.info('cross_entropy precision at requested recall is %f',
+ cross_entropy_loss_precision)
+
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run()
diff --git a/models/research/global_objectives/loss_layers_test.py b/models/research/global_objectives/loss_layers_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f91c80deec16a34f5271cdfadbd0d364c3a8cea
--- /dev/null
+++ b/models/research/global_objectives/loss_layers_test.py
@@ -0,0 +1,1379 @@
+# Copyright 2018 The TensorFlow Global Objectives Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for global objectives loss layers."""
+
+# Dependency imports
+from absl.testing import parameterized
+import numpy
+import tensorflow as tf
+
+from global_objectives import loss_layers
+from global_objectives import util
+
+
+# TODO: Include weights in the lagrange multiplier update tests.
+class PrecisionRecallAUCLossTest(parameterized.TestCase, tf.test.TestCase):
+
+ @parameterized.named_parameters(
+ ('_xent', 'xent', 0.7),
+ ('_hinge', 'hinge', 0.7),
+ ('_hinge_2', 'hinge', 0.5)
+ )
+ def testSinglePointAUC(self, surrogate_type, target_precision):
+ # Tests a case with only one anchor point, where the loss should equal
+ # recall_at_precision_loss
+ batch_shape = [10, 2]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ labels = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+
+ auc_loss, _ = loss_layers.precision_recall_auc_loss(
+ labels,
+ logits,
+ precision_range=(target_precision - 0.01, target_precision + 0.01),
+ num_anchors=1,
+ surrogate_type=surrogate_type)
+ point_loss, _ = loss_layers.recall_at_precision_loss(
+ labels, logits, target_precision=target_precision,
+ surrogate_type=surrogate_type)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(auc_loss.eval(), point_loss.eval())
+
+ def testThreePointAUC(self):
+ # Tests a case with three anchor points against a weighted sum of recall
+ # at precision losses.
+ batch_shape = [11, 3]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ labels = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+
+ # TODO: Place the hing/xent loss in a for loop.
+ auc_loss, _ = loss_layers.precision_recall_auc_loss(
+ labels, logits, num_anchors=1)
+ first_point_loss, _ = loss_layers.recall_at_precision_loss(
+ labels, logits, target_precision=0.25)
+ second_point_loss, _ = loss_layers.recall_at_precision_loss(
+ labels, logits, target_precision=0.5)
+ third_point_loss, _ = loss_layers.recall_at_precision_loss(
+ labels, logits, target_precision=0.75)
+ expected_loss = (first_point_loss + second_point_loss +
+ third_point_loss) / 3
+
+ auc_loss_hinge, _ = loss_layers.precision_recall_auc_loss(
+ labels, logits, num_anchors=1, surrogate_type='hinge')
+ first_point_hinge, _ = loss_layers.recall_at_precision_loss(
+ labels, logits, target_precision=0.25, surrogate_type='hinge')
+ second_point_hinge, _ = loss_layers.recall_at_precision_loss(
+ labels, logits, target_precision=0.5, surrogate_type='hinge')
+ third_point_hinge, _ = loss_layers.recall_at_precision_loss(
+ labels, logits, target_precision=0.75, surrogate_type='hinge')
+ expected_hinge = (first_point_hinge + second_point_hinge +
+ third_point_hinge) / 3
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(auc_loss.eval(), expected_loss.eval())
+ self.assertAllClose(auc_loss_hinge.eval(), expected_hinge.eval())
+
+ def testLagrangeMultiplierUpdateDirection(self):
+ for target_precision in [0.35, 0.65]:
+ precision_range = (target_precision - 0.01, target_precision + 0.01)
+
+ for surrogate_type in ['xent', 'hinge']:
+ kwargs = {'precision_range': precision_range,
+ 'num_anchors': 1,
+ 'surrogate_type': surrogate_type,
+ 'scope': 'pr-auc_{}_{}'.format(target_precision,
+ surrogate_type)}
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.precision_recall_auc_loss,
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.precision_recall_auc_loss,
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+
+class ROCAUCLossTest(parameterized.TestCase, tf.test.TestCase):
+
+ def testSimpleScores(self):
+ # Tests the loss on data with only one negative example with score zero.
+ # In this case, the loss should equal the surrogate loss on the scores with
+ # positive labels.
+ num_positives = 10
+ scores_positives = tf.constant(3.0 * numpy.random.randn(num_positives),
+ shape=[num_positives, 1])
+ labels = tf.constant([0.0] + [1.0] * num_positives,
+ shape=[num_positives + 1, 1])
+ scores = tf.concat([[[0.0]], scores_positives], 0)
+
+ loss = tf.reduce_sum(
+ loss_layers.roc_auc_loss(labels, scores, surrogate_type='hinge')[0])
+ expected_loss = tf.reduce_sum(
+ tf.maximum(1.0 - scores_positives, 0)) / (num_positives + 1)
+ with self.test_session():
+ self.assertAllClose(expected_loss.eval(), loss.eval())
+
+ def testRandomROCLoss(self):
+ # Checks that random Bernoulli scores and labels has ~25% swaps.
+ shape = [1000, 30]
+ scores = tf.constant(
+ numpy.random.randint(0, 2, size=shape), shape=shape, dtype=tf.float32)
+ labels = tf.constant(
+ numpy.random.randint(0, 2, size=shape), shape=shape, dtype=tf.float32)
+ loss = tf.reduce_mean(loss_layers.roc_auc_loss(
+ labels, scores, surrogate_type='hinge')[0])
+ with self.test_session():
+ self.assertAllClose(0.25, loss.eval(), 1e-2)
+
+ @parameterized.named_parameters(
+ ('_zero_hinge', 'xent',
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [-5.0, -7.0, -9.0, 8.0, 10.0, 14.0],
+ 0.0),
+ ('_zero_xent', 'hinge',
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [-0.2, 0, -0.1, 1.0, 1.1, 1.0],
+ 0.0),
+ ('_xent', 'xent',
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [0.0, -17.0, -19.0, 1.0, 14.0, 14.0],
+ numpy.log(1.0 + numpy.exp(-1.0)) / 6),
+ ('_hinge', 'hinge',
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [-0.2, -0.05, 0.0, 0.95, 0.8, 1.0],
+ 0.4 / 6)
+ )
+ def testManualROCLoss(self, surrogate_type, labels, logits, expected_value):
+ labels = tf.constant(labels)
+ logits = tf.constant(logits)
+ loss, _ = loss_layers.roc_auc_loss(
+ labels=labels, logits=logits, surrogate_type=surrogate_type)
+
+ with self.test_session():
+ self.assertAllClose(expected_value, tf.reduce_sum(loss).eval())
+
+ def testMultiLabelROCLoss(self):
+ # Tests the loss on multi-label data against manually computed loss.
+ targets = numpy.array([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]])
+ scores = numpy.array([[0.1, 1.0, 1.1, 1.0], [1.0, 0.0, 1.3, 1.1]])
+ class_1_auc = tf.reduce_sum(
+ loss_layers.roc_auc_loss(targets[0], scores[0])[0])
+ class_2_auc = tf.reduce_sum(
+ loss_layers.roc_auc_loss(targets[1], scores[1])[0])
+ total_auc = tf.reduce_sum(loss_layers.roc_auc_loss(
+ targets.transpose(), scores.transpose())[0])
+
+ with self.test_session():
+ self.assertAllClose(total_auc.eval(),
+ class_1_auc.eval() + class_2_auc.eval())
+
+ def testWeights(self):
+ # Test the loss with per-example weights.
+ # The logits_negatives below are repeated, so that setting half their
+ # weights to 2 and the other half to 0 should leave the loss unchanged.
+ logits_positives = tf.constant([2.54321, -0.26, 3.334334], shape=[3, 1])
+ logits_negatives = tf.constant([-0.6, 1, -1.3, -1.3, -0.6, 1], shape=[6, 1])
+ logits = tf.concat([logits_positives, logits_negatives], 0)
+ targets = tf.constant([1, 1, 1, 0, 0, 0, 0, 0, 0],
+ shape=[9, 1], dtype=tf.float32)
+ weights = tf.constant([1, 1, 1, 0, 0, 0, 2, 2, 2],
+ shape=[9, 1], dtype=tf.float32)
+
+ loss = tf.reduce_sum(loss_layers.roc_auc_loss(targets, logits)[0])
+ weighted_loss = tf.reduce_sum(
+ loss_layers.roc_auc_loss(targets, logits, weights)[0])
+
+ with self.test_session():
+ self.assertAllClose(loss.eval(), weighted_loss.eval())
+
+
+class RecallAtPrecisionTest(tf.test.TestCase):
+
+ def testEqualWeightLoss(self):
+ # Tests a special case where the loss should equal cross entropy loss.
+ target_precision = 1.0
+ num_labels = 5
+ batch_shape = [20, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.7)))
+ label_priors = tf.constant(0.34, shape=[num_labels])
+
+ loss, _ = loss_layers.recall_at_precision_loss(
+ targets, logits, target_precision, label_priors=label_priors)
+ expected_loss = (
+ tf.contrib.nn.deprecated_flipped_sigmoid_cross_entropy_with_logits(
+ logits, targets))
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ loss_val, expected_val = session.run([loss, expected_loss])
+ self.assertAllClose(loss_val, expected_val)
+
+ def testEqualWeightLossWithMultiplePrecisions(self):
+ """Tests a case where the loss equals xent loss with multiple precisions."""
+ target_precision = [1.0, 1.0]
+ num_labels = 2
+ batch_size = 20
+ target_shape = [batch_size, num_labels]
+ logits = tf.Variable(tf.random_normal(target_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(target_shape), 0.7)))
+ label_priors = tf.constant([0.34], shape=[num_labels])
+
+ loss, _ = loss_layers.recall_at_precision_loss(
+ targets,
+ logits,
+ target_precision,
+ label_priors=label_priors,
+ surrogate_type='xent',
+ )
+
+ expected_loss = (
+ tf.contrib.nn.deprecated_flipped_sigmoid_cross_entropy_with_logits(
+ logits, targets))
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ loss_val, expected_val = session.run([loss, expected_loss])
+ self.assertAllClose(loss_val, expected_val)
+
+ def testPositivesOnlyLoss(self):
+ # Tests a special case where the loss should equal cross entropy loss
+ # on the negatives only.
+ target_precision = 1.0
+ num_labels = 3
+ batch_shape = [30, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+ label_priors = tf.constant(0.45, shape=[num_labels])
+
+ loss, _ = loss_layers.recall_at_precision_loss(
+ targets, logits, target_precision, label_priors=label_priors,
+ lambdas_initializer=tf.zeros_initializer())
+ expected_loss = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets,
+ logits,
+ positive_weights=1.0,
+ negative_weights=0.0)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ loss_val, expected_val = session.run([loss, expected_loss])
+ self.assertAllClose(loss_val, expected_val)
+
+ def testEquivalenceBetweenSingleAndMultiplePrecisions(self):
+ """Checks recall at precision with different precision values.
+
+ Runs recall at precision with multiple precision values, and runs each label
+ seperately with its own precision value as a scalar. Validates that the
+ returned loss values are the same.
+ """
+ target_precision = [0.2, 0.9, 0.4]
+ num_labels = 3
+ batch_shape = [30, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+ label_priors = tf.constant([0.45, 0.8, 0.3], shape=[num_labels])
+
+ multi_label_loss, _ = loss_layers.recall_at_precision_loss(
+ targets, logits, target_precision, label_priors=label_priors,
+ )
+
+ single_label_losses = [
+ loss_layers.recall_at_precision_loss(
+ tf.expand_dims(targets[:, i], -1),
+ tf.expand_dims(logits[:, i], -1),
+ target_precision[i],
+ label_priors=label_priors[i])[0]
+ for i in range(num_labels)
+ ]
+
+ single_label_losses = tf.concat(single_label_losses, 1)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ multi_label_loss_val, single_label_loss_val = session.run(
+ [multi_label_loss, single_label_losses])
+ self.assertAllClose(multi_label_loss_val, single_label_loss_val)
+
+ def testEquivalenceBetweenSingleAndEqualMultiplePrecisions(self):
+ """Compares single and multiple target precisions with the same value.
+
+ Checks that using a single target precision and multiple target precisions
+ with the same value would result in the same loss value.
+ """
+ num_labels = 2
+ target_shape = [20, num_labels]
+ logits = tf.Variable(tf.random_normal(target_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(target_shape), 0.7)))
+ label_priors = tf.constant([0.34], shape=[num_labels])
+
+ multi_precision_loss, _ = loss_layers.recall_at_precision_loss(
+ targets,
+ logits,
+ [0.75, 0.75],
+ label_priors=label_priors,
+ surrogate_type='xent',
+ )
+
+ single_precision_loss, _ = loss_layers.recall_at_precision_loss(
+ targets,
+ logits,
+ 0.75,
+ label_priors=label_priors,
+ surrogate_type='xent',
+ )
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ multi_precision_loss_val, single_precision_loss_val = session.run(
+ [multi_precision_loss, single_precision_loss])
+ self.assertAllClose(multi_precision_loss_val, single_precision_loss_val)
+
+ def testLagrangeMultiplierUpdateDirection(self):
+ for target_precision in [0.35, 0.65]:
+ for surrogate_type in ['xent', 'hinge']:
+ kwargs = {'target_precision': target_precision,
+ 'surrogate_type': surrogate_type,
+ 'scope': 'r-at-p_{}_{}'.format(target_precision,
+ surrogate_type)}
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.recall_at_precision_loss,
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.recall_at_precision_loss,
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+ def testLagrangeMultiplierUpdateDirectionWithMultiplePrecisions(self):
+ """Runs Lagrange multiplier test with multiple precision values."""
+ target_precision = [0.65, 0.35]
+
+ for surrogate_type in ['xent', 'hinge']:
+ scope_str = 'r-at-p_{}_{}'.format(
+ '_'.join([str(precision) for precision in target_precision]),
+ surrogate_type)
+ kwargs = {
+ 'target_precision': target_precision,
+ 'surrogate_type': surrogate_type,
+ 'scope': scope_str,
+ }
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.recall_at_precision_loss,
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.recall_at_precision_loss,
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+
+class PrecisionAtRecallTest(tf.test.TestCase):
+
+ def testCrossEntropyEquivalence(self):
+ # Checks a special case where the loss should equal cross-entropy loss.
+ target_recall = 1.0
+ num_labels = 3
+ batch_shape = [10, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+
+ loss, _ = loss_layers.precision_at_recall_loss(
+ targets, logits, target_recall,
+ lambdas_initializer=tf.constant_initializer(1.0))
+ expected_loss = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets, logits)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(loss.eval(), expected_loss.eval())
+
+ def testNegativesOnlyLoss(self):
+ # Checks a special case where the loss should equal the loss on
+ # the negative examples only.
+ target_recall = 0.61828
+ num_labels = 4
+ batch_shape = [8, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.6)))
+
+ loss, _ = loss_layers.precision_at_recall_loss(
+ targets,
+ logits,
+ target_recall,
+ surrogate_type='hinge',
+ lambdas_initializer=tf.constant_initializer(0.0),
+ scope='negatives_only_test')
+ expected_loss = util.weighted_hinge_loss(
+ targets, logits, positive_weights=0.0, negative_weights=1.0)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(expected_loss.eval(), loss.eval())
+
+ def testLagrangeMultiplierUpdateDirection(self):
+ for target_recall in [0.34, 0.66]:
+ for surrogate_type in ['xent', 'hinge']:
+ kwargs = {'target_recall': target_recall,
+ 'dual_rate_factor': 1.0,
+ 'surrogate_type': surrogate_type,
+ 'scope': 'p-at-r_{}_{}'.format(target_recall, surrogate_type)}
+
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.precision_at_recall_loss,
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.precision_at_recall_loss,
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+ def testCrossEntropyEquivalenceWithMultipleRecalls(self):
+ """Checks a case where the loss equals xent loss with multiple recalls."""
+ num_labels = 3
+ target_recall = [1.0] * num_labels
+ batch_shape = [10, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+
+ loss, _ = loss_layers.precision_at_recall_loss(
+ targets, logits, target_recall,
+ lambdas_initializer=tf.constant_initializer(1.0))
+ expected_loss = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets, logits)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(loss.eval(), expected_loss.eval())
+
+ def testNegativesOnlyLossWithMultipleRecalls(self):
+ """Tests a case where the loss equals the loss on the negative examples.
+
+ Checks this special case using multiple target recall values.
+ """
+ num_labels = 4
+ target_recall = [0.61828] * num_labels
+ batch_shape = [8, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.6)))
+
+ loss, _ = loss_layers.precision_at_recall_loss(
+ targets,
+ logits,
+ target_recall,
+ surrogate_type='hinge',
+ lambdas_initializer=tf.constant_initializer(0.0),
+ scope='negatives_only_test')
+ expected_loss = util.weighted_hinge_loss(
+ targets, logits, positive_weights=0.0, negative_weights=1.0)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(expected_loss.eval(), loss.eval())
+
+ def testLagrangeMultiplierUpdateDirectionWithMultipleRecalls(self):
+ """Runs Lagrange multiplier test with multiple recall values."""
+ target_recall = [0.34, 0.66]
+ for surrogate_type in ['xent', 'hinge']:
+ scope_str = 'p-at-r_{}_{}'.format(
+ '_'.join([str(recall) for recall in target_recall]),
+ surrogate_type)
+ kwargs = {'target_recall': target_recall,
+ 'dual_rate_factor': 1.0,
+ 'surrogate_type': surrogate_type,
+ 'scope': scope_str}
+
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.precision_at_recall_loss,
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=loss_layers.precision_at_recall_loss,
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+ def testEquivalenceBetweenSingleAndMultipleRecalls(self):
+ """Checks precision at recall with multiple different recall values.
+
+ Runs precision at recall with multiple recall values, and runs each label
+ seperately with its own recall value as a scalar. Validates that the
+ returned loss values are the same.
+ """
+ target_precision = [0.7, 0.9, 0.4]
+ num_labels = 3
+ batch_shape = [30, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+ label_priors = tf.constant(0.45, shape=[num_labels])
+
+ multi_label_loss, _ = loss_layers.precision_at_recall_loss(
+ targets, logits, target_precision, label_priors=label_priors
+ )
+
+ single_label_losses = [
+ loss_layers.precision_at_recall_loss(
+ tf.expand_dims(targets[:, i], -1),
+ tf.expand_dims(logits[:, i], -1),
+ target_precision[i],
+ label_priors=label_priors[i])[0]
+ for i in range(num_labels)
+ ]
+
+ single_label_losses = tf.concat(single_label_losses, 1)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ multi_label_loss_val, single_label_loss_val = session.run(
+ [multi_label_loss, single_label_losses])
+ self.assertAllClose(multi_label_loss_val, single_label_loss_val)
+
+ def testEquivalenceBetweenSingleAndEqualMultipleRecalls(self):
+ """Compares single and multiple target recalls of the same value.
+
+ Checks that using a single target recall and multiple recalls with the
+ same value would result in the same loss value.
+ """
+ num_labels = 2
+ target_shape = [20, num_labels]
+ logits = tf.Variable(tf.random_normal(target_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(target_shape), 0.7)))
+ label_priors = tf.constant([0.34], shape=[num_labels])
+
+ multi_precision_loss, _ = loss_layers.precision_at_recall_loss(
+ targets,
+ logits,
+ [0.75, 0.75],
+ label_priors=label_priors,
+ surrogate_type='xent',
+ )
+
+ single_precision_loss, _ = loss_layers.precision_at_recall_loss(
+ targets,
+ logits,
+ 0.75,
+ label_priors=label_priors,
+ surrogate_type='xent',
+ )
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ multi_precision_loss_val, single_precision_loss_val = session.run(
+ [multi_precision_loss, single_precision_loss])
+ self.assertAllClose(multi_precision_loss_val, single_precision_loss_val)
+
+
+class FalsePositiveRateAtTruePositiveRateTest(tf.test.TestCase):
+
+ def testNegativesOnlyLoss(self):
+ # Checks a special case where the loss returned should be the loss on the
+ # negative examples.
+ target_recall = 0.6
+ num_labels = 3
+ batch_shape = [3, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+ label_priors = tf.constant(numpy.random.uniform(size=[num_labels]),
+ dtype=tf.float32)
+
+ xent_loss, _ = loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ targets, logits, target_recall, label_priors=label_priors,
+ lambdas_initializer=tf.constant_initializer(0.0))
+ xent_expected = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets,
+ logits,
+ positive_weights=0.0,
+ negative_weights=1.0)
+ hinge_loss, _ = loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ targets, logits, target_recall, label_priors=label_priors,
+ lambdas_initializer=tf.constant_initializer(0.0),
+ surrogate_type='hinge')
+ hinge_expected = util.weighted_hinge_loss(
+ targets,
+ logits,
+ positive_weights=0.0,
+ negative_weights=1.0)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ xent_val, xent_expected = session.run([xent_loss, xent_expected])
+ self.assertAllClose(xent_val, xent_expected)
+ hinge_val, hinge_expected = session.run([hinge_loss, hinge_expected])
+ self.assertAllClose(hinge_val, hinge_expected)
+
+ def testPositivesOnlyLoss(self):
+ # Checks a special case where the loss returned should be the loss on the
+ # positive examples only.
+ target_recall = 1.0
+ num_labels = 5
+ batch_shape = [5, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.ones_like(logits)
+ label_priors = tf.constant(numpy.random.uniform(size=[num_labels]),
+ dtype=tf.float32)
+
+ loss, _ = loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ targets, logits, target_recall, label_priors=label_priors)
+ expected_loss = tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=targets, logits=logits)
+ hinge_loss, _ = loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ targets, logits, target_recall, label_priors=label_priors,
+ surrogate_type='hinge')
+ expected_hinge = util.weighted_hinge_loss(
+ targets, logits)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(loss.eval(), expected_loss.eval())
+ self.assertAllClose(hinge_loss.eval(), expected_hinge.eval())
+
+ def testEqualWeightLoss(self):
+ # Checks a special case where the loss returned should be proportional to
+ # the ordinary loss.
+ target_recall = 1.0
+ num_labels = 4
+ batch_shape = [40, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.6)))
+ label_priors = tf.constant(0.5, shape=[num_labels])
+
+ loss, _ = loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ targets, logits, target_recall, label_priors=label_priors)
+ expected_loss = tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=targets, logits=logits)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(loss.eval(), expected_loss.eval())
+
+ def testLagrangeMultiplierUpdateDirection(self):
+ for target_rate in [0.35, 0.65]:
+ for surrogate_type in ['xent', 'hinge']:
+ kwargs = {'target_rate': target_rate,
+ 'surrogate_type': surrogate_type,
+ 'scope': 'fpr-at-tpr_{}_{}'.format(target_rate,
+ surrogate_type)}
+ # True positive rate is a synonym for recall, so we use the
+ # recall constraint data.
+ run_lagrange_multiplier_test(
+ global_objective=(
+ loss_layers.false_positive_rate_at_true_positive_rate_loss),
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=(
+ loss_layers.false_positive_rate_at_true_positive_rate_loss),
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+ def testLagrangeMultiplierUpdateDirectionWithMultipleRates(self):
+ """Runs Lagrange multiplier test with multiple target rates."""
+ target_rate = [0.35, 0.65]
+ for surrogate_type in ['xent', 'hinge']:
+ kwargs = {'target_rate': target_rate,
+ 'surrogate_type': surrogate_type,
+ 'scope': 'fpr-at-tpr_{}_{}'.format(
+ '_'.join([str(target) for target in target_rate]),
+ surrogate_type)}
+ # True positive rate is a synonym for recall, so we use the
+ # recall constraint data.
+ run_lagrange_multiplier_test(
+ global_objective=(
+ loss_layers.false_positive_rate_at_true_positive_rate_loss),
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=(
+ loss_layers.false_positive_rate_at_true_positive_rate_loss),
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+ def testEquivalenceBetweenSingleAndEqualMultipleRates(self):
+ """Compares single and multiple target rates of the same value.
+
+ Checks that using a single target rate and multiple rates with the
+ same value would result in the same loss value.
+ """
+ num_labels = 2
+ target_shape = [20, num_labels]
+ logits = tf.Variable(tf.random_normal(target_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(target_shape), 0.7)))
+ label_priors = tf.constant([0.34], shape=[num_labels])
+
+ multi_label_loss, _ = (
+ loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ targets, logits, [0.75, 0.75], label_priors=label_priors))
+
+ single_label_loss, _ = (
+ loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ targets, logits, 0.75, label_priors=label_priors))
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ multi_label_loss_val, single_label_loss_val = session.run(
+ [multi_label_loss, single_label_loss])
+ self.assertAllClose(multi_label_loss_val, single_label_loss_val)
+
+ def testEquivalenceBetweenSingleAndMultipleRates(self):
+ """Compares single and multiple target rates of different values.
+
+ Runs false_positive_rate_at_true_positive_rate_loss with multiple target
+ rates, and runs each label seperately with its own target rate as a
+ scalar. Validates that the returned loss values are the same.
+ """
+ target_precision = [0.7, 0.9, 0.4]
+ num_labels = 3
+ batch_shape = [30, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+ label_priors = tf.constant(0.45, shape=[num_labels])
+
+ multi_label_loss, _ = (
+ loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ targets, logits, target_precision, label_priors=label_priors))
+
+ single_label_losses = [
+ loss_layers.false_positive_rate_at_true_positive_rate_loss(
+ tf.expand_dims(targets[:, i], -1),
+ tf.expand_dims(logits[:, i], -1),
+ target_precision[i],
+ label_priors=label_priors[i])[0]
+ for i in range(num_labels)
+ ]
+
+ single_label_losses = tf.concat(single_label_losses, 1)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ multi_label_loss_val, single_label_loss_val = session.run(
+ [multi_label_loss, single_label_losses])
+ self.assertAllClose(multi_label_loss_val, single_label_loss_val)
+
+
+class TruePositiveRateAtFalsePositiveRateTest(tf.test.TestCase):
+
+ def testPositivesOnlyLoss(self):
+ # A special case where the loss should equal the loss on the positive
+ # examples.
+ target_rate = numpy.random.uniform()
+ num_labels = 3
+ batch_shape = [20, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.6)))
+ label_priors = tf.constant(numpy.random.uniform(size=[num_labels]),
+ dtype=tf.float32)
+
+ xent_loss, _ = loss_layers.true_positive_rate_at_false_positive_rate_loss(
+ targets, logits, target_rate, label_priors=label_priors,
+ lambdas_initializer=tf.constant_initializer(0.0))
+ xent_expected = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets,
+ logits,
+ positive_weights=1.0,
+ negative_weights=0.0)
+ hinge_loss, _ = loss_layers.true_positive_rate_at_false_positive_rate_loss(
+ targets, logits, target_rate, label_priors=label_priors,
+ lambdas_initializer=tf.constant_initializer(0.0),
+ surrogate_type='hinge')
+ hinge_expected = util.weighted_hinge_loss(
+ targets,
+ logits,
+ positive_weights=1.0,
+ negative_weights=0.0)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(xent_expected.eval(), xent_loss.eval())
+ self.assertAllClose(hinge_expected.eval(), hinge_loss.eval())
+
+ def testNegativesOnlyLoss(self):
+ # A special case where the loss should equal the loss on the negative
+ # examples, minus target_rate * (1 - label_priors) * maybe_log2.
+ target_rate = numpy.random.uniform()
+ num_labels = 3
+ batch_shape = [25, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.zeros_like(logits)
+ label_priors = tf.constant(numpy.random.uniform(size=[num_labels]),
+ dtype=tf.float32)
+
+ xent_loss, _ = loss_layers.true_positive_rate_at_false_positive_rate_loss(
+ targets, logits, target_rate, label_priors=label_priors)
+ xent_expected = tf.subtract(
+ util.weighted_sigmoid_cross_entropy_with_logits(targets,
+ logits,
+ positive_weights=0.0,
+ negative_weights=1.0),
+ target_rate * (1.0 - label_priors) * numpy.log(2))
+ hinge_loss, _ = loss_layers.true_positive_rate_at_false_positive_rate_loss(
+ targets, logits, target_rate, label_priors=label_priors,
+ surrogate_type='hinge')
+ hinge_expected = util.weighted_hinge_loss(
+ targets, logits) - target_rate * (1.0 - label_priors)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(xent_expected.eval(), xent_loss.eval())
+ self.assertAllClose(hinge_expected.eval(), hinge_loss.eval())
+
+ def testLagrangeMultiplierUpdateDirection(self):
+ for target_rate in [0.35, 0.65]:
+ for surrogate_type in ['xent', 'hinge']:
+ kwargs = {'target_rate': target_rate,
+ 'surrogate_type': surrogate_type,
+ 'scope': 'tpr-at-fpr_{}_{}'.format(target_rate,
+ surrogate_type)}
+ run_lagrange_multiplier_test(
+ global_objective=(
+ loss_layers.true_positive_rate_at_false_positive_rate_loss),
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=(
+ loss_layers.true_positive_rate_at_false_positive_rate_loss),
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+ def testLagrangeMultiplierUpdateDirectionWithMultipleRates(self):
+ """Runs Lagrange multiplier test with multiple target rates."""
+ target_rate = [0.35, 0.65]
+ for surrogate_type in ['xent', 'hinge']:
+ kwargs = {'target_rate': target_rate,
+ 'surrogate_type': surrogate_type,
+ 'scope': 'tpr-at-fpr_{}_{}'.format(
+ '_'.join([str(target) for target in target_rate]),
+ surrogate_type)}
+ run_lagrange_multiplier_test(
+ global_objective=(
+ loss_layers.true_positive_rate_at_false_positive_rate_loss),
+ objective_kwargs=kwargs,
+ data_builder=_multilabel_data,
+ test_object=self)
+ kwargs['scope'] = 'other-' + kwargs['scope']
+ run_lagrange_multiplier_test(
+ global_objective=(
+ loss_layers.true_positive_rate_at_false_positive_rate_loss),
+ objective_kwargs=kwargs,
+ data_builder=_other_multilabel_data(surrogate_type),
+ test_object=self)
+
+ def testEquivalenceBetweenSingleAndEqualMultipleRates(self):
+ """Compares single and multiple target rates of the same value.
+
+ Checks that using a single target rate and multiple rates with the
+ same value would result in the same loss value.
+ """
+ num_labels = 2
+ target_shape = [20, num_labels]
+ logits = tf.Variable(tf.random_normal(target_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(target_shape), 0.7)))
+ label_priors = tf.constant([0.34], shape=[num_labels])
+
+ multi_label_loss, _ = (
+ loss_layers.true_positive_rate_at_false_positive_rate_loss(
+ targets, logits, [0.75, 0.75], label_priors=label_priors))
+
+ single_label_loss, _ = (
+ loss_layers.true_positive_rate_at_false_positive_rate_loss(
+ targets, logits, 0.75, label_priors=label_priors))
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ multi_label_loss_val, single_label_loss_val = session.run(
+ [multi_label_loss, single_label_loss])
+ self.assertAllClose(multi_label_loss_val, single_label_loss_val)
+
+ def testEquivalenceBetweenSingleAndMultipleRates(self):
+ """Compares single and multiple target rates of different values.
+
+ Runs true_positive_rate_at_false_positive_rate_loss with multiple target
+ rates, and runs each label seperately with its own target rate as a
+ scalar. Validates that the returned loss values are the same.
+ """
+ target_precision = [0.7, 0.9, 0.4]
+ num_labels = 3
+ batch_shape = [30, num_labels]
+ logits = tf.Variable(tf.random_normal(batch_shape))
+ targets = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+ label_priors = tf.constant(0.45, shape=[num_labels])
+
+ multi_label_loss, _ = (
+ loss_layers.true_positive_rate_at_false_positive_rate_loss(
+ targets, logits, target_precision, label_priors=label_priors))
+
+ single_label_losses = [
+ loss_layers.true_positive_rate_at_false_positive_rate_loss(
+ tf.expand_dims(targets[:, i], -1),
+ tf.expand_dims(logits[:, i], -1),
+ target_precision[i],
+ label_priors=label_priors[i])[0]
+ for i in range(num_labels)
+ ]
+
+ single_label_losses = tf.concat(single_label_losses, 1)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ multi_label_loss_val, single_label_loss_val = session.run(
+ [multi_label_loss, single_label_losses])
+ self.assertAllClose(multi_label_loss_val, single_label_loss_val)
+
+
+class UtilityFunctionsTest(tf.test.TestCase):
+
+ def testTrainableDualVariable(self):
+ # Confirm correct behavior of a trainable dual variable.
+ x = tf.get_variable('primal', dtype=tf.float32, initializer=2.0)
+ y_value, y = loss_layers._create_dual_variable(
+ 'dual', shape=None, dtype=tf.float32, initializer=1.0, collections=None,
+ trainable=True, dual_rate_factor=0.3)
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+ update = optimizer.minimize(0.5 * tf.square(x - y_value))
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ update.run()
+ self.assertAllClose(0.7, y.eval())
+
+ def testUntrainableDualVariable(self):
+ # Confirm correct behavior of dual variable which is not trainable.
+ x = tf.get_variable('primal', dtype=tf.float32, initializer=-2.0)
+ y_value, y = loss_layers._create_dual_variable(
+ 'dual', shape=None, dtype=tf.float32, initializer=1.0, collections=None,
+ trainable=False, dual_rate_factor=0.8)
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+ update = optimizer.minimize(tf.square(x) * y_value + tf.exp(y_value))
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ update.run()
+ self.assertAllClose(1.0, y.eval())
+
+
+class BoundTest(parameterized.TestCase, tf.test.TestCase):
+
+ @parameterized.named_parameters(
+ ('_xent', 'xent', 1.0, [2.0, 1.0]),
+ ('_xent_weighted', 'xent',
+ numpy.array([0, 2, 0.5, 1, 2, 3]).reshape(6, 1), [2.5, 0]),
+ ('_hinge', 'hinge', 1.0, [2.0, 1.0]),
+ ('_hinge_weighted', 'hinge',
+ numpy.array([1.0, 2, 3, 4, 5, 6]).reshape(6, 1), [5.0, 1]))
+ def testLowerBoundMultilabel(self, surrogate_type, weights, expected):
+ labels, logits, _ = _multilabel_data()
+ lower_bound = loss_layers.true_positives_lower_bound(
+ labels, logits, weights, surrogate_type)
+
+ with self.test_session():
+ self.assertAllClose(lower_bound.eval(), expected)
+
+ @parameterized.named_parameters(
+ ('_xent', 'xent'), ('_hinge', 'hinge'))
+ def testLowerBoundOtherMultilabel(self, surrogate_type):
+ labels, logits, _ = _other_multilabel_data(surrogate_type)()
+ lower_bound = loss_layers.true_positives_lower_bound(
+ labels, logits, 1.0, surrogate_type)
+
+ with self.test_session():
+ self.assertAllClose(lower_bound.eval(), [4.0, 2.0], atol=1e-5)
+
+ @parameterized.named_parameters(
+ ('_xent', 'xent', 1.0, [1.0, 2.0]),
+ ('_xent_weighted', 'xent',
+ numpy.array([3.0, 2, 1, 0, 1, 2]).reshape(6, 1), [2.0, 1.0]),
+ ('_hinge', 'hinge', 1.0, [1.0, 2.0]),
+ ('_hinge_weighted', 'hinge',
+ numpy.array([13, 12, 11, 0.5, 0, 0.5]).reshape(6, 1), [0.5, 0.5]))
+ def testUpperBoundMultilabel(self, surrogate_type, weights, expected):
+ labels, logits, _ = _multilabel_data()
+ upper_bound = loss_layers.false_positives_upper_bound(
+ labels, logits, weights, surrogate_type)
+
+ with self.test_session():
+ self.assertAllClose(upper_bound.eval(), expected)
+
+ @parameterized.named_parameters(
+ ('_xent', 'xent'), ('_hinge', 'hinge'))
+ def testUpperBoundOtherMultilabel(self, surrogate_type):
+ labels, logits, _ = _other_multilabel_data(surrogate_type)()
+ upper_bound = loss_layers.false_positives_upper_bound(
+ labels, logits, 1.0, surrogate_type)
+
+ with self.test_session():
+ self.assertAllClose(upper_bound.eval(), [2.0, 4.0], atol=1e-5)
+
+ @parameterized.named_parameters(
+ ('_lower', 'lower'), ('_upper', 'upper'))
+ def testThreeDimensionalLogits(self, bound):
+ bound_function = loss_layers.false_positives_upper_bound
+ if bound == 'lower':
+ bound_function = loss_layers.true_positives_lower_bound
+ random_labels = numpy.float32(numpy.random.uniform(size=[2, 3]) > 0.5)
+ random_logits = numpy.float32(numpy.random.randn(2, 3, 2))
+ first_slice_logits = random_logits[:, :, 0].reshape(2, 3)
+ second_slice_logits = random_logits[:, :, 1].reshape(2, 3)
+
+ full_bound = bound_function(
+ tf.constant(random_labels), tf.constant(random_logits), 1.0, 'xent')
+ first_slice_bound = bound_function(tf.constant(random_labels),
+ tf.constant(first_slice_logits),
+ 1.0,
+ 'xent')
+ second_slice_bound = bound_function(tf.constant(random_labels),
+ tf.constant(second_slice_logits),
+ 1.0,
+ 'xent')
+ stacked_bound = tf.stack([first_slice_bound, second_slice_bound], axis=1)
+
+ with self.test_session():
+ self.assertAllClose(full_bound.eval(), stacked_bound.eval())
+
+
+def run_lagrange_multiplier_test(global_objective,
+ objective_kwargs,
+ data_builder,
+ test_object):
+ """Runs a test for the Lagrange multiplier update of `global_objective`.
+
+ The test checks that the constraint for `global_objective` is satisfied on
+ the first label of the data produced by `data_builder` but not the second.
+
+ Args:
+ global_objective: One of the global objectives.
+ objective_kwargs: A dictionary of keyword arguments to pass to
+ `global_objective`. Must contain an entry for the constraint argument
+ of `global_objective`, e.g. 'target_rate' or 'target_precision'.
+ data_builder: A function which returns tensors corresponding to labels,
+ logits, and label priors.
+ test_object: An instance of tf.test.TestCase.
+ """
+ # Construct global objective kwargs from a copy of `objective_kwargs`.
+ kwargs = dict(objective_kwargs)
+ targets, logits, priors = data_builder()
+ kwargs['labels'] = targets
+ kwargs['logits'] = logits
+ kwargs['label_priors'] = priors
+
+ loss, output_dict = global_objective(**kwargs)
+ lambdas = tf.squeeze(output_dict['lambdas'])
+ opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)
+ update_op = opt.minimize(loss, var_list=[output_dict['lambdas']])
+
+ with test_object.test_session() as session:
+ tf.global_variables_initializer().run()
+ lambdas_before = session.run(lambdas)
+ session.run(update_op)
+ lambdas_after = session.run(lambdas)
+ test_object.assertLess(lambdas_after[0], lambdas_before[0])
+ test_object.assertGreater(lambdas_after[1], lambdas_before[1])
+
+
+class CrossFunctionTest(parameterized.TestCase, tf.test.TestCase):
+
+ @parameterized.named_parameters(
+ ('_auc01xent', loss_layers.precision_recall_auc_loss, {
+ 'precision_range': (0.0, 1.0), 'surrogate_type': 'xent'
+ }),
+ ('_auc051xent', loss_layers.precision_recall_auc_loss, {
+ 'precision_range': (0.5, 1.0), 'surrogate_type': 'xent'
+ }),
+ ('_auc01)hinge', loss_layers.precision_recall_auc_loss, {
+ 'precision_range': (0.0, 1.0), 'surrogate_type': 'hinge'
+ }),
+ ('_ratp04', loss_layers.recall_at_precision_loss, {
+ 'target_precision': 0.4, 'surrogate_type': 'xent'
+ }),
+ ('_ratp066', loss_layers.recall_at_precision_loss, {
+ 'target_precision': 0.66, 'surrogate_type': 'xent'
+ }),
+ ('_ratp07_hinge', loss_layers.recall_at_precision_loss, {
+ 'target_precision': 0.7, 'surrogate_type': 'hinge'
+ }),
+ ('_fpattp066', loss_layers.false_positive_rate_at_true_positive_rate_loss,
+ {'target_rate': 0.66, 'surrogate_type': 'xent'}),
+ ('_fpattp046', loss_layers.false_positive_rate_at_true_positive_rate_loss,
+ {
+ 'target_rate': 0.46, 'surrogate_type': 'xent'
+ }),
+ ('_fpattp076_hinge',
+ loss_layers.false_positive_rate_at_true_positive_rate_loss, {
+ 'target_rate': 0.76, 'surrogate_type': 'hinge'
+ }),
+ ('_fpattp036_hinge',
+ loss_layers.false_positive_rate_at_true_positive_rate_loss, {
+ 'target_rate': 0.36, 'surrogate_type': 'hinge'
+ }),
+ )
+ def testWeigtedGlobalObjective(self,
+ global_objective,
+ objective_kwargs):
+ """Runs a test of `global_objective` with per-example weights.
+
+ Args:
+ global_objective: One of the global objectives.
+ objective_kwargs: A dictionary of keyword arguments to pass to
+ `global_objective`. Must contain keys 'surrogate_type', and the keyword
+ for the constraint argument of `global_objective`, e.g. 'target_rate' or
+ 'target_precision'.
+ """
+ logits_positives = tf.constant([1, -0.5, 3], shape=[3, 1])
+ logits_negatives = tf.constant([-0.5, 1, -1, -1, -0.5, 1], shape=[6, 1])
+
+ # Dummy tensor is used to compute the gradients.
+ dummy = tf.constant(1.0)
+ logits = tf.concat([logits_positives, logits_negatives], 0)
+ logits = tf.multiply(logits, dummy)
+ targets = tf.constant([1, 1, 1, 0, 0, 0, 0, 0, 0],
+ shape=[9, 1], dtype=tf.float32)
+ priors = tf.constant(1.0/3.0, shape=[1])
+ weights = tf.constant([1, 1, 1, 0, 0, 0, 2, 2, 2],
+ shape=[9, 1], dtype=tf.float32)
+
+ # Construct global objective kwargs.
+ objective_kwargs['labels'] = targets
+ objective_kwargs['logits'] = logits
+ objective_kwargs['label_priors'] = priors
+
+ scope = 'weighted_test'
+ # Unweighted loss.
+ objective_kwargs['scope'] = scope + '_plain'
+ raw_loss, update = global_objective(**objective_kwargs)
+ loss = tf.reduce_sum(raw_loss)
+
+ # Weighted loss.
+ objective_kwargs['weights'] = weights
+ objective_kwargs['scope'] = scope + '_weighted'
+ raw_weighted_loss, weighted_update = global_objective(**objective_kwargs)
+ weighted_loss = tf.reduce_sum(raw_weighted_loss)
+
+ lambdas = tf.contrib.framework.get_unique_variable(scope + '_plain/lambdas')
+ weighted_lambdas = tf.contrib.framework.get_unique_variable(
+ scope + '_weighted/lambdas')
+ logits_gradient = tf.gradients(loss, dummy)
+ weighted_logits_gradient = tf.gradients(weighted_loss, dummy)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ self.assertAllClose(loss.eval(), weighted_loss.eval())
+
+ logits_grad, weighted_logits_grad = session.run(
+ [logits_gradient, weighted_logits_gradient])
+ self.assertAllClose(logits_grad, weighted_logits_grad)
+
+ session.run([update, weighted_update])
+ lambdas_value, weighted_lambdas_value = session.run(
+ [lambdas, weighted_lambdas])
+ self.assertAllClose(lambdas_value, weighted_lambdas_value)
+
+ @parameterized.named_parameters(
+ ('_prauc051xent', loss_layers.precision_recall_auc_loss, {
+ 'precision_range': (0.5, 1.0), 'surrogate_type': 'xent'
+ }),
+ ('_prauc01hinge', loss_layers.precision_recall_auc_loss, {
+ 'precision_range': (0.0, 1.0), 'surrogate_type': 'hinge'
+ }),
+ ('_rocxent', loss_layers.roc_auc_loss, {'surrogate_type': 'xent'}),
+ ('_rochinge', loss_layers.roc_auc_loss, {'surrogate_type': 'xent'}),
+ ('_ratp04', loss_layers.recall_at_precision_loss, {
+ 'target_precision': 0.4, 'surrogate_type': 'xent'
+ }),
+ ('_ratp07_hinge', loss_layers.recall_at_precision_loss, {
+ 'target_precision': 0.7, 'surrogate_type': 'hinge'
+ }),
+ ('_patr05', loss_layers.precision_at_recall_loss, {
+ 'target_recall': 0.4, 'surrogate_type': 'xent'
+ }),
+ ('_patr08_hinge', loss_layers.precision_at_recall_loss, {
+ 'target_recall': 0.7, 'surrogate_type': 'hinge'
+ }),
+ ('_fpattp046', loss_layers.false_positive_rate_at_true_positive_rate_loss,
+ {
+ 'target_rate': 0.46, 'surrogate_type': 'xent'
+ }),
+ ('_fpattp036_hinge',
+ loss_layers.false_positive_rate_at_true_positive_rate_loss, {
+ 'target_rate': 0.36, 'surrogate_type': 'hinge'
+ }),
+ ('_tpatfp076', loss_layers.true_positive_rate_at_false_positive_rate_loss,
+ {
+ 'target_rate': 0.76, 'surrogate_type': 'xent'
+ }),
+ ('_tpatfp036_hinge',
+ loss_layers.true_positive_rate_at_false_positive_rate_loss, {
+ 'target_rate': 0.36, 'surrogate_type': 'hinge'
+ }),
+ )
+ def testVectorAndMatrixLabelEquivalence(self,
+ global_objective,
+ objective_kwargs):
+ """Tests equivalence between label shape [batch_size] or [batch_size, 1]."""
+ vector_labels = tf.constant([1.0, 1.0, 0.0, 0.0], shape=[4])
+ vector_logits = tf.constant([1.0, 0.1, 0.1, -1.0], shape=[4])
+
+ # Construct vector global objective kwargs and loss.
+ vector_kwargs = objective_kwargs.copy()
+ vector_kwargs['labels'] = vector_labels
+ vector_kwargs['logits'] = vector_logits
+ vector_loss, _ = global_objective(**vector_kwargs)
+ vector_loss_sum = tf.reduce_sum(vector_loss)
+
+ # Construct matrix global objective kwargs and loss.
+ matrix_kwargs = objective_kwargs.copy()
+ matrix_kwargs['labels'] = tf.expand_dims(vector_labels, 1)
+ matrix_kwargs['logits'] = tf.expand_dims(vector_logits, 1)
+ matrix_loss, _ = global_objective(**matrix_kwargs)
+ matrix_loss_sum = tf.reduce_sum(matrix_loss)
+
+ self.assertEqual(1, vector_loss.get_shape().ndims)
+ self.assertEqual(2, matrix_loss.get_shape().ndims)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(vector_loss_sum.eval(), matrix_loss_sum.eval())
+
+ @parameterized.named_parameters(
+ ('_prauc', loss_layers.precision_recall_auc_loss, None),
+ ('_roc', loss_layers.roc_auc_loss, None),
+ ('_rap', loss_layers.recall_at_precision_loss, {'target_precision': 0.8}),
+ ('_patr', loss_layers.precision_at_recall_loss, {'target_recall': 0.7}),
+ ('_fpattp', loss_layers.false_positive_rate_at_true_positive_rate_loss,
+ {'target_rate': 0.9}),
+ ('_tpatfp', loss_layers.true_positive_rate_at_false_positive_rate_loss,
+ {'target_rate': 0.1})
+ )
+ def testUnknownBatchSize(self, global_objective, objective_kwargs):
+ # Tests that there are no errors when the batch size is not known.
+ batch_shape = [5, 2]
+ logits = tf.placeholder(tf.float32)
+ logits_feed = numpy.random.randn(*batch_shape)
+ labels = tf.placeholder(tf.float32)
+ labels_feed = logits_feed > 0.1
+ logits.set_shape([None, 2])
+ labels.set_shape([None, 2])
+
+ if objective_kwargs is None:
+ objective_kwargs = {}
+
+ placeholder_kwargs = objective_kwargs.copy()
+ placeholder_kwargs['labels'] = labels
+ placeholder_kwargs['logits'] = logits
+ placeholder_loss, _ = global_objective(**placeholder_kwargs)
+
+ kwargs = objective_kwargs.copy()
+ kwargs['labels'] = labels_feed
+ kwargs['logits'] = logits_feed
+ loss, _ = global_objective(**kwargs)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+ feed_loss_val = session.run(placeholder_loss,
+ feed_dict={logits: logits_feed,
+ labels: labels_feed})
+ loss_val = session.run(loss)
+ self.assertAllClose(feed_loss_val, loss_val)
+
+
+# Both sets of logits below are designed so that the surrogate precision and
+# recall (true positive rate) of class 1 is ~ 2/3, and the same surrogates for
+# class 2 are ~ 1/3. The false positive rate surrogates are ~ 1/3 and 2/3.
+def _multilabel_data():
+ targets = tf.constant([1.0, 1.0, 1.0, 0.0, 0.0, 0.0], shape=[6, 1])
+ targets = tf.concat([targets, targets], 1)
+ logits_positives = tf.constant([[0.0, 15],
+ [16, 0.0],
+ [14, 0.0]], shape=[3, 2])
+ logits_negatives = tf.constant([[-17, 0.0],
+ [-15, 0.0],
+ [0.0, -101]], shape=[3, 2])
+ logits = tf.concat([logits_positives, logits_negatives], 0)
+ priors = tf.constant(0.5, shape=[2])
+
+ return targets, logits, priors
+
+
+def _other_multilabel_data(surrogate_type):
+ targets = tf.constant(
+ [1.0] * 6 + [0.0] * 6, shape=[12, 1])
+ targets = tf.concat([targets, targets], 1)
+ logits_positives = tf.constant([[0.0, 13],
+ [12, 0.0],
+ [15, 0.0],
+ [0.0, 30],
+ [13, 0.0],
+ [18, 0.0]], shape=[6, 2])
+ # A score of cost_2 incurs a loss of ~2.0.
+ cost_2 = 1.0 if surrogate_type == 'hinge' else 1.09861229
+ logits_negatives = tf.constant([[-16, cost_2],
+ [-15, cost_2],
+ [cost_2, -111],
+ [-133, -14,],
+ [-14.0100101, -16,],
+ [-19.888828882, -101]], shape=[6, 2])
+ logits = tf.concat([logits_positives, logits_negatives], 0)
+ priors = tf.constant(0.5, shape=[2])
+
+ def builder():
+ return targets, logits, priors
+
+ return builder
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/global_objectives/test_all.py b/models/research/global_objectives/test_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7e439e219840a9ec5c65382c6bc392b1d68b447
--- /dev/null
+++ b/models/research/global_objectives/test_all.py
@@ -0,0 +1,37 @@
+# Copyright 2018 The TensorFlow Global Objectives Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Runs all unit tests in the Global Objectives package.
+
+Requires that TensorFlow and abseil (https://github.com/abseil/abseil-py) be
+installed on your machine. Command to run the tests:
+python test_all.py
+
+"""
+
+import os
+import sys
+import unittest
+
+this_file = os.path.realpath(__file__)
+start_dir = os.path.dirname(this_file)
+parent_dir = os.path.dirname(start_dir)
+
+sys.path.append(parent_dir)
+loader = unittest.TestLoader()
+suite = loader.discover(start_dir, pattern='*_test.py')
+
+runner = unittest.TextTestRunner(verbosity=2)
+runner.run(suite)
diff --git a/models/research/global_objectives/util.py b/models/research/global_objectives/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b287a90bd743e5466b875c933c3872868f4a5f
--- /dev/null
+++ b/models/research/global_objectives/util.py
@@ -0,0 +1,348 @@
+# Copyright 2018 The TensorFlow Global Objectives Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains utility functions for the global objectives library."""
+
+# Dependency imports
+import tensorflow as tf
+
+
+def weighted_sigmoid_cross_entropy_with_logits(labels,
+ logits,
+ positive_weights=1.0,
+ negative_weights=1.0,
+ name=None):
+ """Computes a weighting of sigmoid cross entropy given `logits`.
+
+ Measures the weighted probability error in discrete classification tasks in
+ which classes are independent and not mutually exclusive. For instance, one
+ could perform multilabel classification where a picture can contain both an
+ elephant and a dog at the same time. The class weight multiplies the
+ different types of errors.
+ For brevity, let `x = logits`, `z = labels`, `c = positive_weights`,
+ `d = negative_weights` The
+ weighed logistic loss is
+
+ ```
+ c * z * -log(sigmoid(x)) + d * (1 - z) * -log(1 - sigmoid(x))
+ = c * z * -log(1 / (1 + exp(-x))) - d * (1 - z) * log(exp(-x) / (1 + exp(-x)))
+ = c * z * log(1 + exp(-x)) + d * (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+ = c * z * log(1 + exp(-x)) + d * (1 - z) * (x + log(1 + exp(-x)))
+ = (1 - z) * x * d + (1 - z + c * z ) * log(1 + exp(-x))
+ = - d * x * z + d * x + (d - d * z + c * z ) * log(1 + exp(-x))
+ ```
+
+ To ensure stability and avoid overflow, the implementation uses the identity
+ log(1 + exp(-x)) = max(0,-x) + log(1 + exp(-abs(x)))
+ and the result is computed as
+
+ ```
+ = -d * x * z + d * x
+ + (d - d * z + c * z ) * (max(0,-x) + log(1 + exp(-abs(x))))
+ ```
+
+ Note that the loss is NOT an upper bound on the 0-1 loss, unless it is divided
+ by log(2).
+
+ Args:
+ labels: A `Tensor` of type `float32` or `float64`. `labels` can be a 2D
+ tensor with shape [batch_size, num_labels] or a 3D tensor with shape
+ [batch_size, num_labels, K].
+ logits: A `Tensor` of the same type and shape as `labels`. If `logits` has
+ shape [batch_size, num_labels, K], the loss is computed separately on each
+ slice [:, :, k] of `logits`.
+ positive_weights: A `Tensor` that holds positive weights and has the
+ following semantics according to its shape:
+ scalar - A global positive weight.
+ 1D tensor - must be of size K, a weight for each 'attempt'
+ 2D tensor - of size [num_labels, K'] where K' is either K or 1.
+ The `positive_weights` will be expanded to the left to match the
+ dimensions of logits and labels.
+ negative_weights: A `Tensor` that holds positive weight and has the
+ semantics identical to positive_weights.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of the same shape as `logits` with the componentwise
+ weighted logistic losses.
+ """
+ with tf.name_scope(
+ name,
+ 'weighted_logistic_loss',
+ [logits, labels, positive_weights, negative_weights]) as name:
+ labels, logits, positive_weights, negative_weights = prepare_loss_args(
+ labels, logits, positive_weights, negative_weights)
+
+ softplus_term = tf.add(tf.maximum(-logits, 0.0),
+ tf.log(1.0 + tf.exp(-tf.abs(logits))))
+ weight_dependent_factor = (
+ negative_weights + (positive_weights - negative_weights) * labels)
+ return (negative_weights * (logits - labels * logits) +
+ weight_dependent_factor * softplus_term)
+
+
+def weighted_hinge_loss(labels,
+ logits,
+ positive_weights=1.0,
+ negative_weights=1.0,
+ name=None):
+ """Computes weighted hinge loss given logits `logits`.
+
+ The loss applies to multi-label classification tasks where labels are
+ independent and not mutually exclusive. See also
+ `weighted_sigmoid_cross_entropy_with_logits`.
+
+ Args:
+ labels: A `Tensor` of type `float32` or `float64`. Each entry must be
+ either 0 or 1. `labels` can be a 2D tensor with shape
+ [batch_size, num_labels] or a 3D tensor with shape
+ [batch_size, num_labels, K].
+ logits: A `Tensor` of the same type and shape as `labels`. If `logits` has
+ shape [batch_size, num_labels, K], the loss is computed separately on each
+ slice [:, :, k] of `logits`.
+ positive_weights: A `Tensor` that holds positive weights and has the
+ following semantics according to its shape:
+ scalar - A global positive weight.
+ 1D tensor - must be of size K, a weight for each 'attempt'
+ 2D tensor - of size [num_labels, K'] where K' is either K or 1.
+ The `positive_weights` will be expanded to the left to match the
+ dimensions of logits and labels.
+ negative_weights: A `Tensor` that holds positive weight and has the
+ semantics identical to positive_weights.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of the same shape as `logits` with the componentwise
+ weighted hinge loss.
+ """
+ with tf.name_scope(
+ name, 'weighted_hinge_loss',
+ [logits, labels, positive_weights, negative_weights]) as name:
+ labels, logits, positive_weights, negative_weights = prepare_loss_args(
+ labels, logits, positive_weights, negative_weights)
+
+ positives_term = positive_weights * labels * tf.maximum(1.0 - logits, 0)
+ negatives_term = (negative_weights * (1.0 - labels)
+ * tf.maximum(1.0 + logits, 0))
+ return positives_term + negatives_term
+
+
+def weighted_surrogate_loss(labels,
+ logits,
+ surrogate_type='xent',
+ positive_weights=1.0,
+ negative_weights=1.0,
+ name=None):
+ """Returns either weighted cross-entropy or hinge loss.
+
+ For example `surrogate_type` is 'xent' returns the weighted cross
+ entropy loss.
+
+ Args:
+ labels: A `Tensor` of type `float32` or `float64`. Each entry must be
+ between 0 and 1. `labels` can be a 2D tensor with shape
+ [batch_size, num_labels] or a 3D tensor with shape
+ [batch_size, num_labels, K].
+ logits: A `Tensor` of the same type and shape as `labels`. If `logits` has
+ shape [batch_size, num_labels, K], each slice [:, :, k] represents an
+ 'attempt' to predict `labels` and the loss is computed per slice.
+ surrogate_type: A string that determines which loss to return, supports
+ 'xent' for cross-entropy and 'hinge' for hinge loss.
+ positive_weights: A `Tensor` that holds positive weights and has the
+ following semantics according to its shape:
+ scalar - A global positive weight.
+ 1D tensor - must be of size K, a weight for each 'attempt'
+ 2D tensor - of size [num_labels, K'] where K' is either K or 1.
+ The `positive_weights` will be expanded to the left to match the
+ dimensions of logits and labels.
+ negative_weights: A `Tensor` that holds positive weight and has the
+ semantics identical to positive_weights.
+ name: A name for the operation (optional).
+
+ Returns:
+ The weigthed loss.
+
+ Raises:
+ ValueError: If value of `surrogate_type` is not supported.
+ """
+ with tf.name_scope(
+ name, 'weighted_loss',
+ [logits, labels, surrogate_type, positive_weights,
+ negative_weights]) as name:
+ if surrogate_type == 'xent':
+ return weighted_sigmoid_cross_entropy_with_logits(
+ logits=logits,
+ labels=labels,
+ positive_weights=positive_weights,
+ negative_weights=negative_weights,
+ name=name)
+ elif surrogate_type == 'hinge':
+ return weighted_hinge_loss(
+ logits=logits,
+ labels=labels,
+ positive_weights=positive_weights,
+ negative_weights=negative_weights,
+ name=name)
+ raise ValueError('surrogate_type %s not supported.' % surrogate_type)
+
+
+def expand_outer(tensor, rank):
+ """Expands the given `Tensor` outwards to a target rank.
+
+ For example if rank = 3 and tensor.shape is [3, 4], this function will expand
+ to such that the resulting shape will be [1, 3, 4].
+
+ Args:
+ tensor: The tensor to expand.
+ rank: The target dimension.
+
+ Returns:
+ The expanded tensor.
+
+ Raises:
+ ValueError: If rank of `tensor` is unknown, or if `rank` is smaller than
+ the rank of `tensor`.
+ """
+ if tensor.get_shape().ndims is None:
+ raise ValueError('tensor dimension must be known.')
+ if len(tensor.get_shape()) > rank:
+ raise ValueError(
+ '`rank` must be at least the current tensor dimension: (%s vs %s).' %
+ (rank, len(tensor.get_shape())))
+ while len(tensor.get_shape()) < rank:
+ tensor = tf.expand_dims(tensor, 0)
+ return tensor
+
+
+def build_label_priors(labels,
+ weights=None,
+ positive_pseudocount=1.0,
+ negative_pseudocount=1.0,
+ variables_collections=None):
+ """Creates an op to maintain and update label prior probabilities.
+
+ For each label, the label priors are estimated as
+ (P + sum_i w_i y_i) / (P + N + sum_i w_i),
+ where y_i is the ith label, w_i is the ith weight, P is a pseudo-count of
+ positive labels, and N is a pseudo-count of negative labels. The index i
+ ranges over all labels observed during all evaluations of the returned op.
+
+ Args:
+ labels: A `Tensor` with shape [batch_size, num_labels]. Entries should be
+ in [0, 1].
+ weights: Coefficients representing the weight of each label. Must be either
+ a Tensor of shape [batch_size, num_labels] or `None`, in which case each
+ weight is treated as 1.0.
+ positive_pseudocount: Number of positive labels used to initialize the label
+ priors.
+ negative_pseudocount: Number of negative labels used to initialize the label
+ priors.
+ variables_collections: Optional list of collections for created variables.
+
+ Returns:
+ label_priors: An op to update the weighted label_priors. Gives the
+ current value of the label priors when evaluated.
+ """
+ dtype = labels.dtype.base_dtype
+ num_labels = get_num_labels(labels)
+
+ if weights is None:
+ weights = tf.ones_like(labels)
+
+ # We disable partitioning while constructing dual variables because they will
+ # be updated with assign, which is not available for partitioned variables.
+ partitioner = tf.get_variable_scope().partitioner
+ try:
+ tf.get_variable_scope().set_partitioner(None)
+ # Create variable and update op for weighted label counts.
+ weighted_label_counts = tf.contrib.framework.model_variable(
+ name='weighted_label_counts',
+ shape=[num_labels],
+ dtype=dtype,
+ initializer=tf.constant_initializer(
+ [positive_pseudocount] * num_labels, dtype=dtype),
+ collections=variables_collections,
+ trainable=False)
+ weighted_label_counts_update = weighted_label_counts.assign_add(
+ tf.reduce_sum(weights * labels, 0))
+
+ # Create variable and update op for the sum of the weights.
+ weight_sum = tf.contrib.framework.model_variable(
+ name='weight_sum',
+ shape=[num_labels],
+ dtype=dtype,
+ initializer=tf.constant_initializer(
+ [positive_pseudocount + negative_pseudocount] * num_labels,
+ dtype=dtype),
+ collections=variables_collections,
+ trainable=False)
+ weight_sum_update = weight_sum.assign_add(tf.reduce_sum(weights, 0))
+
+ finally:
+ tf.get_variable_scope().set_partitioner(partitioner)
+
+ label_priors = tf.div(
+ weighted_label_counts_update,
+ weight_sum_update)
+ return label_priors
+
+
+def convert_and_cast(value, name, dtype):
+ """Convert input to tensor and cast to dtype.
+
+ Args:
+ value: An object whose type has a registered Tensor conversion function,
+ e.g. python numerical type or numpy array.
+ name: Name to use for the new Tensor, if one is created.
+ dtype: Optional element type for the returned tensor.
+
+ Returns:
+ A tensor.
+ """
+ return tf.cast(tf.convert_to_tensor(value, name=name), dtype=dtype)
+
+
+def prepare_loss_args(labels, logits, positive_weights, negative_weights):
+ """Prepare arguments for weighted loss functions.
+
+ If needed, will convert given arguments to appropriate type and shape.
+
+ Args:
+ labels: labels or labels of the loss function.
+ logits: Logits of the loss function.
+ positive_weights: Weight on the positive examples.
+ negative_weights: Weight on the negative examples.
+
+ Returns:
+ Converted labels, logits, positive_weights, negative_weights.
+ """
+ logits = tf.convert_to_tensor(logits, name='logits')
+ labels = convert_and_cast(labels, 'labels', logits.dtype)
+ if len(labels.get_shape()) == 2 and len(logits.get_shape()) == 3:
+ labels = tf.expand_dims(labels, [2])
+
+ positive_weights = convert_and_cast(positive_weights, 'positive_weights',
+ logits.dtype)
+ positive_weights = expand_outer(positive_weights, logits.get_shape().ndims)
+ negative_weights = convert_and_cast(negative_weights, 'negative_weights',
+ logits.dtype)
+ negative_weights = expand_outer(negative_weights, logits.get_shape().ndims)
+ return labels, logits, positive_weights, negative_weights
+
+
+def get_num_labels(labels_or_logits):
+ """Returns the number of labels inferred from labels_or_logits."""
+ if labels_or_logits.get_shape().ndims <= 1:
+ return 1
+ return labels_or_logits.get_shape()[1].value
diff --git a/models/research/global_objectives/util_test.py b/models/research/global_objectives/util_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..195252a53eb1d0a50735d2f987b0882681b0544a
--- /dev/null
+++ b/models/research/global_objectives/util_test.py
@@ -0,0 +1,333 @@
+# Copyright 2018 The TensorFlow Global Objectives Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for global objectives util functions."""
+
+# Dependency imports
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+
+from global_objectives import util
+
+
+def weighted_sigmoid_cross_entropy(targets, logits, weight):
+ return (weight * targets * np.log(1.0 + np.exp(-logits)) + (
+ (1.0 - targets) * np.log(1.0 + 1.0 / np.exp(-logits))))
+
+
+def hinge_loss(labels, logits):
+ # Mostly copied from tensorflow.python.ops.losses but with loss per datapoint.
+ labels = tf.to_float(labels)
+ all_ones = tf.ones_like(labels)
+ labels = tf.subtract(2 * labels, all_ones)
+ return tf.nn.relu(tf.subtract(all_ones, tf.multiply(labels, logits)))
+
+
+class WeightedSigmoidCrossEntropyTest(parameterized.TestCase, tf.test.TestCase):
+
+ def testTrivialCompatibilityWithSigmoidCrossEntropy(self):
+ """Tests compatibility with unweighted function with weight 1.0."""
+ x_shape = [300, 10]
+ targets = np.random.random_sample(x_shape).astype(np.float32)
+ logits = np.random.randn(*x_shape).astype(np.float32)
+ weighted_loss = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets,
+ logits)
+ expected_loss = (
+ tf.contrib.nn.deprecated_flipped_sigmoid_cross_entropy_with_logits(
+ logits, targets))
+ with self.test_session():
+ self.assertAllClose(expected_loss.eval(),
+ weighted_loss.eval(),
+ atol=0.000001)
+
+ def testNonTrivialCompatibilityWithSigmoidCrossEntropy(self):
+ """Tests use of an arbitrary weight (4.12)."""
+ x_shape = [300, 10]
+ targets = np.random.random_sample(x_shape).astype(np.float32)
+ logits = np.random.randn(*x_shape).astype(np.float32)
+ weight = 4.12
+ weighted_loss = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets,
+ logits,
+ weight,
+ weight)
+ expected_loss = (
+ weight *
+ tf.contrib.nn.deprecated_flipped_sigmoid_cross_entropy_with_logits(
+ logits, targets))
+ with self.test_session():
+ self.assertAllClose(expected_loss.eval(),
+ weighted_loss.eval(),
+ atol=0.000001)
+
+ def testDifferentSizeWeightedSigmoidCrossEntropy(self):
+ """Tests correctness on 3D tensors.
+
+ Tests that the function works as expected when logits is a 3D tensor and
+ targets is a 2D tensor.
+ """
+ targets_shape = [30, 4]
+ logits_shape = [targets_shape[0], targets_shape[1], 3]
+ targets = np.random.random_sample(targets_shape).astype(np.float32)
+ logits = np.random.randn(*logits_shape).astype(np.float32)
+
+ weight_vector = [2.0, 3.0, 13.0]
+ loss = util.weighted_sigmoid_cross_entropy_with_logits(targets,
+ logits,
+ weight_vector)
+
+ with self.test_session():
+ loss = loss.eval()
+ for i in range(0, len(weight_vector)):
+ expected = weighted_sigmoid_cross_entropy(targets, logits[:, :, i],
+ weight_vector[i])
+ self.assertAllClose(loss[:, :, i], expected, atol=0.000001)
+
+ @parameterized.parameters((300, 10, 0.3), (20, 4, 2.0), (30, 4, 3.9))
+ def testWeightedSigmoidCrossEntropy(self, batch_size, num_labels, weight):
+ """Tests thats the tf and numpy functions agree on many instances."""
+ x_shape = [batch_size, num_labels]
+ targets = np.random.random_sample(x_shape).astype(np.float32)
+ logits = np.random.randn(*x_shape).astype(np.float32)
+
+ with self.test_session():
+ loss = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets,
+ logits,
+ weight,
+ 1.0,
+ name='weighted-loss')
+ expected = weighted_sigmoid_cross_entropy(targets, logits, weight)
+ self.assertAllClose(expected, loss.eval(), atol=0.000001)
+
+ def testGradients(self):
+ """Tests that weighted loss gradients behave as expected."""
+ dummy_tensor = tf.constant(1.0)
+
+ positives_shape = [10, 1]
+ positives_logits = dummy_tensor * tf.Variable(
+ tf.random_normal(positives_shape) + 1.0)
+ positives_targets = tf.ones(positives_shape)
+ positives_weight = 4.6
+ positives_loss = (
+ tf.contrib.nn.deprecated_flipped_sigmoid_cross_entropy_with_logits(
+ positives_logits, positives_targets) * positives_weight)
+
+ negatives_shape = [190, 1]
+ negatives_logits = dummy_tensor * tf.Variable(
+ tf.random_normal(negatives_shape))
+ negatives_targets = tf.zeros(negatives_shape)
+ negatives_weight = 0.9
+ negatives_loss = (
+ tf.contrib.nn.deprecated_flipped_sigmoid_cross_entropy_with_logits(
+ negatives_logits, negatives_targets) * negatives_weight)
+
+ all_logits = tf.concat([positives_logits, negatives_logits], 0)
+ all_targets = tf.concat([positives_targets, negatives_targets], 0)
+ weighted_loss = tf.reduce_sum(
+ util.weighted_sigmoid_cross_entropy_with_logits(
+ all_targets, all_logits, positives_weight, negatives_weight))
+ weighted_gradients = tf.gradients(weighted_loss, dummy_tensor)
+
+ expected_loss = tf.add(
+ tf.reduce_sum(positives_loss),
+ tf.reduce_sum(negatives_loss))
+ expected_gradients = tf.gradients(expected_loss, dummy_tensor)
+
+ with tf.Session() as session:
+ tf.global_variables_initializer().run()
+ grad, expected_grad = session.run(
+ [weighted_gradients, expected_gradients])
+ self.assertAllClose(grad, expected_grad)
+
+ def testDtypeFlexibility(self):
+ """Tests the loss on inputs of varying data types."""
+ shape = [20, 3]
+ logits = np.random.randn(*shape)
+ targets = tf.truncated_normal(shape)
+ positive_weights = tf.constant(3, dtype=tf.int64)
+ negative_weights = 1
+
+ loss = util.weighted_sigmoid_cross_entropy_with_logits(
+ targets, logits, positive_weights, negative_weights)
+
+ with self.test_session():
+ self.assertEqual(loss.eval().dtype, np.float)
+
+
+class WeightedHingeLossTest(tf.test.TestCase):
+
+ def testTrivialCompatibilityWithHinge(self):
+ # Tests compatibility with unweighted hinge loss.
+ x_shape = [55, 10]
+ logits = tf.constant(np.random.randn(*x_shape).astype(np.float32))
+ targets = tf.to_float(tf.constant(np.random.random_sample(x_shape) > 0.3))
+ weighted_loss = util.weighted_hinge_loss(targets, logits)
+ expected_loss = hinge_loss(targets, logits)
+ with self.test_session():
+ self.assertAllClose(expected_loss.eval(), weighted_loss.eval())
+
+ def testLessTrivialCompatibilityWithHinge(self):
+ # Tests compatibility with a constant weight for positives and negatives.
+ x_shape = [56, 11]
+ logits = tf.constant(np.random.randn(*x_shape).astype(np.float32))
+ targets = tf.to_float(tf.constant(np.random.random_sample(x_shape) > 0.7))
+ weight = 1.0 + 1.0/2 + 1.0/3 + 1.0/4 + 1.0/5 + 1.0/6 + 1.0/7
+ weighted_loss = util.weighted_hinge_loss(targets, logits, weight, weight)
+ expected_loss = hinge_loss(targets, logits) * weight
+ with self.test_session():
+ self.assertAllClose(expected_loss.eval(), weighted_loss.eval())
+
+ def testNontrivialCompatibilityWithHinge(self):
+ # Tests compatibility with different positive and negative weights.
+ x_shape = [23, 8]
+ logits_positives = tf.constant(np.random.randn(*x_shape).astype(np.float32))
+ logits_negatives = tf.constant(np.random.randn(*x_shape).astype(np.float32))
+ targets_positives = tf.ones(x_shape)
+ targets_negatives = tf.zeros(x_shape)
+ logits = tf.concat([logits_positives, logits_negatives], 0)
+ targets = tf.concat([targets_positives, targets_negatives], 0)
+
+ raw_loss = util.weighted_hinge_loss(targets,
+ logits,
+ positive_weights=3.4,
+ negative_weights=1.2)
+ loss = tf.reduce_sum(raw_loss, 0)
+ positives_hinge = hinge_loss(targets_positives, logits_positives)
+ negatives_hinge = hinge_loss(targets_negatives, logits_negatives)
+ expected = tf.add(tf.reduce_sum(3.4 * positives_hinge, 0),
+ tf.reduce_sum(1.2 * negatives_hinge, 0))
+
+ with self.test_session():
+ self.assertAllClose(loss.eval(), expected.eval())
+
+ def test3DLogitsAndTargets(self):
+ # Tests correctness when logits is 3D and targets is 2D.
+ targets_shape = [30, 4]
+ logits_shape = [targets_shape[0], targets_shape[1], 3]
+ targets = tf.to_float(
+ tf.constant(np.random.random_sample(targets_shape) > 0.7))
+ logits = tf.constant(np.random.randn(*logits_shape).astype(np.float32))
+ weight_vector = [1.0, 1.0, 1.0]
+ loss = util.weighted_hinge_loss(targets, logits, weight_vector)
+
+ with self.test_session():
+ loss_value = loss.eval()
+ for i in range(len(weight_vector)):
+ expected = hinge_loss(targets, logits[:, :, i]).eval()
+ self.assertAllClose(loss_value[:, :, i], expected)
+
+
+class BuildLabelPriorsTest(tf.test.TestCase):
+
+ def testLabelPriorConsistency(self):
+ # Checks that, with zero pseudocounts, the returned label priors reproduce
+ # label frequencies in the batch.
+ batch_shape = [4, 10]
+ labels = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.678)))
+
+ label_priors_update = util.build_label_priors(
+ labels=labels, positive_pseudocount=0, negative_pseudocount=0)
+ expected_priors = tf.reduce_mean(labels, 0)
+
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ self.assertAllClose(label_priors_update.eval(), expected_priors.eval())
+
+ def testLabelPriorsUpdate(self):
+ # Checks that the update of label priors behaves as expected.
+ batch_shape = [1, 5]
+ labels = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4)))
+ label_priors_update = util.build_label_priors(labels)
+
+ label_sum = np.ones(shape=batch_shape)
+ weight_sum = 2.0 * np.ones(shape=batch_shape)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+
+ for _ in range(3):
+ label_sum += labels.eval()
+ weight_sum += np.ones(shape=batch_shape)
+ expected_posteriors = label_sum / weight_sum
+ label_priors = label_priors_update.eval().reshape(batch_shape)
+ self.assertAllClose(label_priors, expected_posteriors)
+
+ # Re-initialize labels to get a new random sample.
+ session.run(labels.initializer)
+
+ def testLabelPriorsUpdateWithWeights(self):
+ # Checks the update of label priors with per-example weights.
+ batch_size = 6
+ num_labels = 5
+ batch_shape = [batch_size, num_labels]
+ labels = tf.Variable(
+ tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.6)))
+ weights = tf.Variable(tf.random_uniform(batch_shape) * 6.2)
+
+ update_op = util.build_label_priors(labels, weights=weights)
+
+ expected_weighted_label_counts = 1.0 + tf.reduce_sum(weights * labels, 0)
+ expected_weight_sum = 2.0 + tf.reduce_sum(weights, 0)
+ expected_label_posteriors = tf.divide(expected_weighted_label_counts,
+ expected_weight_sum)
+
+ with self.test_session() as session:
+ tf.global_variables_initializer().run()
+
+ updated_priors, expected_posteriors = session.run(
+ [update_op, expected_label_posteriors])
+ self.assertAllClose(updated_priors, expected_posteriors)
+
+
+class WeightedSurrogateLossTest(parameterized.TestCase, tf.test.TestCase):
+
+ @parameterized.parameters(
+ ('hinge', util.weighted_hinge_loss),
+ ('xent', util.weighted_sigmoid_cross_entropy_with_logits))
+ def testCompatibilityLoss(self, loss_name, loss_fn):
+ x_shape = [28, 4]
+ logits = tf.constant(np.random.randn(*x_shape).astype(np.float32))
+ targets = tf.to_float(tf.constant(np.random.random_sample(x_shape) > 0.5))
+ positive_weights = 0.66
+ negative_weights = 11.1
+ expected_loss = loss_fn(
+ targets,
+ logits,
+ positive_weights=positive_weights,
+ negative_weights=negative_weights)
+ computed_loss = util.weighted_surrogate_loss(
+ targets,
+ logits,
+ loss_name,
+ positive_weights=positive_weights,
+ negative_weights=negative_weights)
+ with self.test_session():
+ self.assertAllClose(expected_loss.eval(), computed_loss.eval())
+
+ def testSurrogatgeError(self):
+ x_shape = [7, 3]
+ logits = tf.constant(np.random.randn(*x_shape).astype(np.float32))
+ targets = tf.to_float(tf.constant(np.random.random_sample(x_shape) > 0.5))
+
+ with self.assertRaises(ValueError):
+ util.weighted_surrogate_loss(logits, targets, 'bug')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/im2txt/.gitignore b/models/research/im2txt/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..fb46913cc7a5994c4324de50829c95d7858c30f4
--- /dev/null
+++ b/models/research/im2txt/.gitignore
@@ -0,0 +1,7 @@
+/bazel-bin
+/bazel-ci_build-cache
+/bazel-genfiles
+/bazel-out
+/bazel-im2txt
+/bazel-testlogs
+/bazel-tf
diff --git a/models/research/im2txt/README.md b/models/research/im2txt/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2eb72822a39e3959a5a9370f26a9cc5c12be0fda
--- /dev/null
+++ b/models/research/im2txt/README.md
@@ -0,0 +1,342 @@
+
+
+
+
+# Show and Tell: A Neural Image Caption Generator
+
+A TensorFlow implementation of the image-to-text model described in the paper:
+
+"Show and Tell: Lessons learned from the 2015 MSCOCO Image Captioning
+Challenge."
+
+Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan.
+
+*IEEE transactions on pattern analysis and machine intelligence (2016).*
+
+Full text available at: http://arxiv.org/abs/1609.06647
+
+## Contact
+***Author:*** Chris Shallue
+
+***Pull requests and issues:*** @cshallue
+
+## Contents
+* [Model Overview](#model-overview)
+ * [Introduction](#introduction)
+ * [Architecture](#architecture)
+* [Getting Started](#getting-started)
+ * [A Note on Hardware and Training Time](#a-note-on-hardware-and-training-time)
+ * [Install Required Packages](#install-required-packages)
+ * [Prepare the Training Data](#prepare-the-training-data)
+ * [Download the Inception v3 Checkpoint](#download-the-inception-v3-checkpoint)
+* [Training a Model](#training-a-model)
+ * [Initial Training](#initial-training)
+ * [Fine Tune the Inception v3 Model](#fine-tune-the-inception-v3-model)
+* [Generating Captions](#generating-captions)
+
+## Model Overview
+
+### Introduction
+
+The *Show and Tell* model is a deep neural network that learns how to describe
+the content of images. For example:
+
+
+
+### Architecture
+
+The *Show and Tell* model is an example of an *encoder-decoder* neural network.
+It works by first "encoding" an image into a fixed-length vector representation,
+and then "decoding" the representation into a natural language description.
+
+The image encoder is a deep convolutional neural network. This type of
+network is widely used for image tasks and is currently state-of-the-art for
+object recognition and detection. Our particular choice of network is the
+[*Inception v3*](http://arxiv.org/abs/1512.00567) image recognition model
+pretrained on the
+[ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) image
+classification dataset.
+
+The decoder is a long short-term memory (LSTM) network. This type of network is
+commonly used for sequence modeling tasks such as language modeling and machine
+translation. In the *Show and Tell* model, the LSTM network is trained as a
+language model conditioned on the image encoding.
+
+Words in the captions are represented with an embedding model. Each word in the
+vocabulary is associated with a fixed-length vector representation that is
+learned during training.
+
+The following diagram illustrates the model architecture.
+
+
+
+In this diagram, \{*s*0, *s*1, ..., *s**N*-1\}
+are the words of the caption and \{*w**e**s*0,
+*w**e**s*1, ..., *w**e**s**N*-1\}
+are their corresponding word embedding vectors. The outputs \{*p*1,
+*p*2, ..., *p**N*\} of the LSTM are probability
+distributions generated by the model for the next word in the sentence. The
+terms \{log *p*1(*s*1),
+log *p*2(*s*2), ...,
+log *p**N*(*s**N*)\} are the log-likelihoods of the
+correct word at each step; the negated sum of these terms is the minimization
+objective of the model.
+
+During the first phase of training the parameters of the *Inception v3* model
+are kept fixed: it is simply a static image encoder function. A single trainable
+layer is added on top of the *Inception v3* model to transform the image
+embedding into the word embedding vector space. The model is trained with
+respect to the parameters of the word embeddings, the parameters of the layer on
+top of *Inception v3* and the parameters of the LSTM. In the second phase of
+training, all parameters - including the parameters of *Inception v3* - are
+trained to jointly fine-tune the image encoder and the LSTM.
+
+Given a trained model and an image we use *beam search* to generate captions for
+that image. Captions are generated word-by-word, where at each step *t* we use
+the set of sentences already generated with length *t* - 1 to generate a new set
+of sentences with length *t*. We keep only the top *k* candidates at each step,
+where the hyperparameter *k* is called the *beam size*. We have found the best
+performance with *k* = 3.
+
+## Getting Started
+
+### A Note on Hardware and Training Time
+
+The time required to train the *Show and Tell* model depends on your specific
+hardware and computational capacity. In this guide we assume you will be running
+training on a single machine with a GPU. In our experience on an NVIDIA Tesla
+K20m GPU the initial training phase takes 1-2 weeks. The second training phase
+may take several additional weeks to achieve peak performance (but you can stop
+this phase early and still get reasonable results).
+
+It is possible to achieve a speed-up by implementing distributed training across
+a cluster of machines with GPUs, but that is not covered in this guide.
+
+Whilst it is possible to run this code on a CPU, beware that this may be
+approximately 10 times slower.
+
+### Install Required Packages
+First ensure that you have installed the following required packages:
+
+* **Bazel** ([instructions](http://bazel.io/docs/install.html))
+* **Python 2.7**
+* **TensorFlow** 1.0 or greater ([instructions](https://www.tensorflow.org/install/))
+* **NumPy** ([instructions](http://www.scipy.org/install.html))
+* **Natural Language Toolkit (NLTK)**:
+ * First install NLTK ([instructions](http://www.nltk.org/install.html))
+ * Then install the NLTK data package "punkt" ([instructions](http://www.nltk.org/data.html))
+* **Unzip**
+### Prepare the Training Data
+
+To train the model you will need to provide training data in native TFRecord
+format. The TFRecord format consists of a set of sharded files containing
+serialized `tf.SequenceExample` protocol buffers. Each `tf.SequenceExample`
+proto contains an image (JPEG format), a caption and metadata such as the image
+id.
+
+Each caption is a list of words. During preprocessing, a dictionary is created
+that assigns each word in the vocabulary to an integer-valued id. Each caption
+is encoded as a list of integer word ids in the `tf.SequenceExample` protos.
+
+We have provided a script to download and preprocess the [MSCOCO](http://mscoco.org/) image captioning data set into this format. Downloading
+and preprocessing the data may take several hours depending on your network and
+computer speed. Please be patient.
+
+Before running the script, ensure that your hard disk has at least 150GB of
+available space for storing the downloaded and processed data.
+
+```shell
+# Location to save the MSCOCO data.
+MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
+
+# Build the preprocessing script.
+cd research/im2txt
+bazel build //im2txt:download_and_preprocess_mscoco
+
+# Run the preprocessing script.
+bazel-bin/im2txt/download_and_preprocess_mscoco "${MSCOCO_DIR}"
+```
+
+The final line of the output should read:
+
+```
+2016-09-01 16:47:47.296630: Finished processing all 20267 image-caption pairs in data set 'test'.
+```
+
+When the script finishes you will find 256 training, 4 validation and 8 testing
+files in `DATA_DIR`. The files will match the patterns `train-?????-of-00256`,
+`val-?????-of-00004` and `test-?????-of-00008`, respectively.
+
+### Download the Inception v3 Checkpoint
+
+The *Show and Tell* model requires a pretrained *Inception v3* checkpoint file
+to initialize the parameters of its image encoder submodel.
+
+This checkpoint file is provided by the
+[TensorFlow-Slim image classification library](https://github.com/tensorflow/models/tree/master/research/slim#tensorflow-slim-image-classification-library)
+which provides a suite of pre-trained image classification models. You can read
+more about the models provided by the library
+[here](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models).
+
+
+Run the following commands to download the *Inception v3* checkpoint.
+
+```shell
+# Location to save the Inception v3 checkpoint.
+INCEPTION_DIR="${HOME}/im2txt/data"
+mkdir -p ${INCEPTION_DIR}
+
+wget "http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz"
+tar -xvf "inception_v3_2016_08_28.tar.gz" -C ${INCEPTION_DIR}
+rm "inception_v3_2016_08_28.tar.gz"
+```
+
+Note that the *Inception v3* checkpoint will only be used for initializing the
+parameters of the *Show and Tell* model. Once the *Show and Tell* model starts
+training it will save its own checkpoint files containing the values of all its
+parameters (including copies of the *Inception v3* parameters). If training is
+stopped and restarted, the parameter values will be restored from the latest
+*Show and Tell* checkpoint and the *Inception v3* checkpoint will be ignored. In
+other words, the *Inception v3* checkpoint is only used in the 0-th global step
+(initialization) of training the *Show and Tell* model.
+
+## Training a Model
+
+### Initial Training
+
+Run the training script.
+
+```shell
+# Directory containing preprocessed MSCOCO data.
+MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
+
+# Inception v3 checkpoint file.
+INCEPTION_CHECKPOINT="${HOME}/im2txt/data/inception_v3.ckpt"
+
+# Directory to save the model.
+MODEL_DIR="${HOME}/im2txt/model"
+
+# Build the model.
+cd research/im2txt
+bazel build -c opt //im2txt/...
+
+# Run the training script.
+bazel-bin/im2txt/train \
+ --input_file_pattern="${MSCOCO_DIR}/train-?????-of-00256" \
+ --inception_checkpoint_file="${INCEPTION_CHECKPOINT}" \
+ --train_dir="${MODEL_DIR}/train" \
+ --train_inception=false \
+ --number_of_steps=1000000
+```
+
+Run the evaluation script in a separate process. This will log evaluation
+metrics to TensorBoard which allows training progress to be monitored in
+real-time.
+
+Note that you may run out of memory if you run the evaluation script on the same
+GPU as the training script. You can run the command
+`export CUDA_VISIBLE_DEVICES=""` to force the evaluation script to run on CPU.
+If evaluation runs too slowly on CPU, you can decrease the value of
+`--num_eval_examples`.
+
+```shell
+MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
+MODEL_DIR="${HOME}/im2txt/model"
+
+# Ignore GPU devices (only necessary if your GPU is currently memory
+# constrained, for example, by running the training script).
+export CUDA_VISIBLE_DEVICES=""
+
+# Run the evaluation script. This will run in a loop, periodically loading the
+# latest model checkpoint file and computing evaluation metrics.
+bazel-bin/im2txt/evaluate \
+ --input_file_pattern="${MSCOCO_DIR}/val-?????-of-00004" \
+ --checkpoint_dir="${MODEL_DIR}/train" \
+ --eval_dir="${MODEL_DIR}/eval"
+```
+
+Run a TensorBoard server in a separate process for real-time monitoring of
+training progress and evaluation metrics.
+
+```shell
+MODEL_DIR="${HOME}/im2txt/model"
+
+# Run a TensorBoard server.
+tensorboard --logdir="${MODEL_DIR}"
+```
+
+### Fine Tune the Inception v3 Model
+
+Your model will already be able to generate reasonable captions after the first
+phase of training. Try it out! (See [Generating Captions](#generating-captions)).
+
+You can further improve the performance of the model by running a
+second training phase to jointly fine-tune the parameters of the *Inception v3*
+image submodel and the LSTM.
+
+```shell
+# Restart the training script with --train_inception=true.
+bazel-bin/im2txt/train \
+ --input_file_pattern="${MSCOCO_DIR}/train-?????-of-00256" \
+ --train_dir="${MODEL_DIR}/train" \
+ --train_inception=true \
+ --number_of_steps=3000000 # Additional 2M steps (assuming 1M in initial training).
+```
+
+Note that training will proceed much slower now, and the model will continue to
+improve by a small amount for a long time. We have found that it will improve
+slowly for an additional 2-2.5 million steps before it begins to overfit. This
+may take several weeks on a single GPU. If you don't care about absolutely
+optimal performance then feel free to halt training sooner by stopping the
+training script or passing a smaller value to the flag `--number_of_steps`. Your
+model will still work reasonably well.
+
+## Generating Captions
+
+Your trained *Show and Tell* model can generate captions for any JPEG image! The
+following command line will generate captions for an image from the test set.
+
+```shell
+# Path to checkpoint file or a directory containing checkpoint files. Passing
+# a directory will only work if there is also a file named 'checkpoint' which
+# lists the available checkpoints in the directory. It will not work if you
+# point to a directory with just a copy of a model checkpoint: in that case,
+# you will need to pass the checkpoint path explicitly.
+CHECKPOINT_PATH="${HOME}/im2txt/model/train"
+
+# Vocabulary file generated by the preprocessing script.
+VOCAB_FILE="${HOME}/im2txt/data/mscoco/word_counts.txt"
+
+# JPEG image file to caption.
+IMAGE_FILE="${HOME}/im2txt/data/mscoco/raw-data/val2014/COCO_val2014_000000224477.jpg"
+
+# Build the inference binary.
+cd research/im2txt
+bazel build -c opt //im2txt:run_inference
+
+# Ignore GPU devices (only necessary if your GPU is currently memory
+# constrained, for example, by running the training script).
+export CUDA_VISIBLE_DEVICES=""
+
+# Run inference to generate captions.
+bazel-bin/im2txt/run_inference \
+ --checkpoint_path=${CHECKPOINT_PATH} \
+ --vocab_file=${VOCAB_FILE} \
+ --input_files=${IMAGE_FILE}
+```
+
+Example output:
+
+```
+Captions for image COCO_val2014_000000224477.jpg:
+ 0) a man riding a wave on top of a surfboard . (p=0.040413)
+ 1) a person riding a surf board on a wave (p=0.017452)
+ 2) a man riding a wave on a surfboard in the ocean . (p=0.005743)
+```
+
+Note: you may get different results. Some variation between different models is
+expected.
+
+Here is the image:
+
+
diff --git a/models/research/im2txt/WORKSPACE b/models/research/im2txt/WORKSPACE
new file mode 100644
index 0000000000000000000000000000000000000000..22da718b06f9c61be4ffdf45e48919ed4a5f17ae
--- /dev/null
+++ b/models/research/im2txt/WORKSPACE
@@ -0,0 +1 @@
+workspace(name = "im2txt")
diff --git a/models/research/im2txt/conda-env/ubuntu-18-04-environment.yaml b/models/research/im2txt/conda-env/ubuntu-18-04-environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..332ff2a47f8f49fcdde7b769c29ff84cf5a5ff9d
--- /dev/null
+++ b/models/research/im2txt/conda-env/ubuntu-18-04-environment.yaml
@@ -0,0 +1,142 @@
+name: im2txt
+channels:
+ - defaults
+dependencies:
+ - _tflow_select=2.3.0=mkl
+ - absl-py=0.5.0=py27_0
+ - astor=0.7.1=py27_0
+ - backports=1.0=py27_1
+ - backports.functools_lru_cache=1.5=py27_1
+ - backports.shutil_get_terminal_size=1.0.0=py27_2
+ - backports.weakref=1.0.post1=py27_0
+ - backports_abc=0.5=py27_0
+ - blas=1.0=mkl
+ - bleach=3.0.2=py27_0
+ - ca-certificates=2018.03.07=0
+ - certifi=2018.10.15=py27_0
+ - configparser=3.5.0=py27_0
+ - cycler=0.10.0=py27_0
+ - dbus=1.13.2=h714fa37_1
+ - decorator=4.3.0=py27_0
+ - entrypoints=0.2.3=py27_2
+ - enum34=1.1.6=py27_1
+ - expat=2.2.6=he6710b0_0
+ - fastcache=1.0.2=py27h14c3975_2
+ - fontconfig=2.13.0=h9420a91_0
+ - freetype=2.9.1=h8a8886c_1
+ - funcsigs=1.0.2=py27_0
+ - functools32=3.2.3.2=py27_1
+ - futures=3.2.0=py27_0
+ - gast=0.2.0=py27_0
+ - glib=2.56.2=hd408876_0
+ - gmp=6.1.2=h6c8ec71_1
+ - gmpy2=2.0.8=py27h10f8cd9_2
+ - grpcio=1.12.1=py27hdbcaa40_0
+ - gst-plugins-base=1.14.0=hbbd80ab_1
+ - gstreamer=1.14.0=hb453b48_1
+ - h5py=2.8.0=py27h989c5e5_3
+ - hdf5=1.10.2=hba1933b_1
+ - icu=58.2=h9c2bf20_1
+ - intel-openmp=2019.0=118
+ - ipaddress=1.0.22=py27_0
+ - ipykernel=4.10.0=py27_0
+ - ipython=5.8.0=py27_0
+ - ipython_genutils=0.2.0=py27_0
+ - ipywidgets=7.4.2=py27_0
+ - jinja2=2.10=py27_0
+ - jpeg=9b=h024ee3a_2
+ - jsonschema=2.6.0=py27_0
+ - jupyter=1.0.0=py27_7
+ - jupyter_client=5.2.3=py27_0
+ - jupyter_console=5.2.0=py27_1
+ - jupyter_core=4.4.0=py27_0
+ - keras-applications=1.0.6=py27_0
+ - keras-preprocessing=1.0.5=py27_0
+ - kiwisolver=1.0.1=py27hf484d3e_0
+ - libedit=3.1.20170329=h6b74fdf_2
+ - libffi=3.2.1=hd88cf55_4
+ - libgcc-ng=8.2.0=hdf63c60_1
+ - libgfortran-ng=7.3.0=hdf63c60_0
+ - libpng=1.6.35=hbc83047_0
+ - libprotobuf=3.6.0=hdbcaa40_0
+ - libsodium=1.0.16=h1bed415_0
+ - libstdcxx-ng=8.2.0=hdf63c60_1
+ - libuuid=1.0.3=h1bed415_2
+ - libxcb=1.13=h1bed415_1
+ - libxml2=2.9.8=h26e45fe_1
+ - linecache2=1.0.0=py27_0
+ - markdown=3.0.1=py27_0
+ - markupsafe=1.0=py27h14c3975_1
+ - matplotlib=2.2.3=py27hb69df0a_0
+ - mistune=0.8.4=py27h7b6447c_0
+ - mkl=2019.0=118
+ - mkl_fft=1.0.6=py27h7dd41cf_0
+ - mkl_random=1.0.1=py27h4414c95_1
+ - mock=2.0.0=py27_0
+ - mpc=1.1.0=h10f8cd9_1
+ - mpfr=4.0.1=hdf1c602_3
+ - mpmath=1.0.0=py27_2
+ - nbconvert=5.3.1=py27_0
+ - nbformat=4.4.0=py27_0
+ - ncurses=6.1=hf484d3e_0
+ - nltk=3.3.0=py27_0
+ - nose=1.3.7=py27_2
+ - notebook=5.7.0=py27_0
+ - numpy=1.15.3=py27h1d66e8a_0
+ - numpy-base=1.15.3=py27h81de0dd_0
+ - openssl=1.0.2p=h14c3975_0
+ - pandas=0.23.4=py27h04863e7_0
+ - pandoc=2.2.3.2=0
+ - pandocfilters=1.4.2=py27_1
+ - pathlib2=2.3.2=py27_0
+ - pbr=4.3.0=py27_0
+ - pcre=8.42=h439df22_0
+ - pexpect=4.6.0=py27_0
+ - pickleshare=0.7.5=py27_0
+ - pip=10.0.1=py27_0
+ - prometheus_client=0.4.2=py27_0
+ - prompt_toolkit=1.0.15=py27_0
+ - protobuf=3.6.0=py27hf484d3e_0
+ - ptyprocess=0.6.0=py27_0
+ - pygments=2.2.0=py27_0
+ - pyparsing=2.2.2=py27_0
+ - pyqt=5.9.2=py27h05f1152_2
+ - python=2.7.15=h77bded6_2
+ - python-dateutil=2.7.3=py27_0
+ - pytz=2018.5=py27_0
+ - pyzmq=17.1.2=py27h14c3975_0
+ - qt=5.9.6=h8703b6f_2
+ - qtconsole=4.4.2=py27_0
+ - readline=7.0=h7b6447c_5
+ - scandir=1.9.0=py27h14c3975_0
+ - scipy=1.1.0=py27hfa4b5c9_1
+ - send2trash=1.5.0=py27_0
+ - setuptools=40.4.3=py27_0
+ - simplegeneric=0.8.1=py27_2
+ - singledispatch=3.4.0.3=py27_0
+ - sip=4.19.8=py27hf484d3e_0
+ - six=1.11.0=py27_1
+ - sqlite=3.25.2=h7b6447c_0
+ - subprocess32=3.5.3=py27h7b6447c_0
+ - sympy=1.3=py27_0
+ - tensorboard=1.11.0=py27hf484d3e_0
+ - tensorflow=1.11.0=mkl_py27h25e0b76_0
+ - tensorflow-base=1.11.0=mkl_py27h3c3e929_0
+ - termcolor=1.1.0=py27_1
+ - terminado=0.8.1=py27_1
+ - testpath=0.4.2=py27_0
+ - tk=8.6.8=hbc83047_0
+ - tornado=5.1.1=py27h7b6447c_0
+ - traceback2=1.4.0=py27_0
+ - traitlets=4.3.2=py27_0
+ - unittest2=1.1.0=py27_0
+ - wcwidth=0.1.7=py27_0
+ - webencodings=0.5.1=py27_1
+ - werkzeug=0.14.1=py27_0
+ - wheel=0.32.2=py27_0
+ - widgetsnbextension=3.4.2=py27_0
+ - xz=5.2.4=h14c3975_4
+ - zeromq=4.2.5=hf484d3e_1
+ - zlib=1.2.11=ha838bed_2
+prefix: /home/arinto_murdopo/anaconda3/envs/im2txt
+
diff --git a/models/research/im2txt/g3doc/COCO_val2014_000000224477.jpg b/models/research/im2txt/g3doc/COCO_val2014_000000224477.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8976fa84b40b04c5bf1205a49c8d236b747f8f9b
Binary files /dev/null and b/models/research/im2txt/g3doc/COCO_val2014_000000224477.jpg differ
diff --git a/models/research/im2txt/g3doc/example_captions.jpg b/models/research/im2txt/g3doc/example_captions.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b3a8f43247e5c9c39a3f93daaf1ad34837959ec5
Binary files /dev/null and b/models/research/im2txt/g3doc/example_captions.jpg differ
diff --git a/models/research/im2txt/g3doc/show_and_tell_architecture.png b/models/research/im2txt/g3doc/show_and_tell_architecture.png
new file mode 100644
index 0000000000000000000000000000000000000000..984590d54ba4aa089b5740fd69f6dc6216b9047f
Binary files /dev/null and b/models/research/im2txt/g3doc/show_and_tell_architecture.png differ
diff --git a/models/research/im2txt/im2txt/BUILD b/models/research/im2txt/im2txt/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..8c403171153c36ee43cde2788dbfcaf9c7bf4293
--- /dev/null
+++ b/models/research/im2txt/im2txt/BUILD
@@ -0,0 +1,96 @@
+package(default_visibility = [":internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+ name = "internal",
+ packages = [
+ "//im2txt/...",
+ ],
+)
+
+py_binary(
+ name = "build_mscoco_data",
+ srcs = [
+ "data/build_mscoco_data.py",
+ ],
+)
+
+sh_binary(
+ name = "download_and_preprocess_mscoco",
+ srcs = ["data/download_and_preprocess_mscoco.sh"],
+ data = [
+ ":build_mscoco_data",
+ ],
+)
+
+py_library(
+ name = "configuration",
+ srcs = ["configuration.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "show_and_tell_model",
+ srcs = ["show_and_tell_model.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//im2txt/ops:image_embedding",
+ "//im2txt/ops:image_processing",
+ "//im2txt/ops:inputs",
+ ],
+)
+
+py_test(
+ name = "show_and_tell_model_test",
+ size = "large",
+ srcs = ["show_and_tell_model_test.py"],
+ deps = [
+ ":configuration",
+ ":show_and_tell_model",
+ ],
+)
+
+py_library(
+ name = "inference_wrapper",
+ srcs = ["inference_wrapper.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":show_and_tell_model",
+ "//im2txt/inference_utils:inference_wrapper_base",
+ ],
+)
+
+py_binary(
+ name = "train",
+ srcs = ["train.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":configuration",
+ ":show_and_tell_model",
+ ],
+)
+
+py_binary(
+ name = "evaluate",
+ srcs = ["evaluate.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":configuration",
+ ":show_and_tell_model",
+ ],
+)
+
+py_binary(
+ name = "run_inference",
+ srcs = ["run_inference.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":configuration",
+ ":inference_wrapper",
+ "//im2txt/inference_utils:caption_generator",
+ "//im2txt/inference_utils:vocabulary",
+ ],
+)
diff --git a/models/research/im2txt/im2txt/configuration.py b/models/research/im2txt/im2txt/configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b664eb9f0cd963fb26929d019ec9cdb3282d0a8
--- /dev/null
+++ b/models/research/im2txt/im2txt/configuration.py
@@ -0,0 +1,104 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Image-to-text model and training configurations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class ModelConfig(object):
+ """Wrapper class for model hyperparameters."""
+
+ def __init__(self):
+ """Sets the default model hyperparameters."""
+ # File pattern of sharded TFRecord file containing SequenceExample protos.
+ # Must be provided in training and evaluation modes.
+ self.input_file_pattern = None
+
+ # Image format ("jpeg" or "png").
+ self.image_format = "jpeg"
+
+ # Approximate number of values per input shard. Used to ensure sufficient
+ # mixing between shards in training.
+ self.values_per_input_shard = 2300
+ # Minimum number of shards to keep in the input queue.
+ self.input_queue_capacity_factor = 2
+ # Number of threads for prefetching SequenceExample protos.
+ self.num_input_reader_threads = 1
+
+ # Name of the SequenceExample context feature containing image data.
+ self.image_feature_name = "image/data"
+ # Name of the SequenceExample feature list containing integer captions.
+ self.caption_feature_name = "image/caption_ids"
+
+ # Number of unique words in the vocab (plus 1, for ).
+ # The default value is larger than the expected actual vocab size to allow
+ # for differences between tokenizer versions used in preprocessing. There is
+ # no harm in using a value greater than the actual vocab size, but using a
+ # value less than the actual vocab size will result in an error.
+ self.vocab_size = 12000
+
+ # Number of threads for image preprocessing. Should be a multiple of 2.
+ self.num_preprocess_threads = 4
+
+ # Batch size.
+ self.batch_size = 32
+
+ # File containing an Inception v3 checkpoint to initialize the variables
+ # of the Inception model. Must be provided when starting training for the
+ # first time.
+ self.inception_checkpoint_file = None
+
+ # Dimensions of Inception v3 input images.
+ self.image_height = 299
+ self.image_width = 299
+
+ # Scale used to initialize model variables.
+ self.initializer_scale = 0.08
+
+ # LSTM input and output dimensionality, respectively.
+ self.embedding_size = 512
+ self.num_lstm_units = 512
+
+ # If < 1.0, the dropout keep probability applied to LSTM variables.
+ self.lstm_dropout_keep_prob = 0.7
+
+
+class TrainingConfig(object):
+ """Wrapper class for training hyperparameters."""
+
+ def __init__(self):
+ """Sets the default training hyperparameters."""
+ # Number of examples per epoch of training data.
+ self.num_examples_per_epoch = 586363
+
+ # Optimizer for training the model.
+ self.optimizer = "SGD"
+
+ # Learning rate for the initial phase of training.
+ self.initial_learning_rate = 2.0
+ self.learning_rate_decay_factor = 0.5
+ self.num_epochs_per_decay = 8.0
+
+ # Learning rate when fine tuning the Inception v3 parameters.
+ self.train_inception_learning_rate = 0.0005
+
+ # If not None, clip gradients to this value.
+ self.clip_gradients = 5.0
+
+ # How many model checkpoints to keep.
+ self.max_checkpoints_to_keep = 5
diff --git a/models/research/im2txt/im2txt/data/build_mscoco_data.py b/models/research/im2txt/im2txt/data/build_mscoco_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c3e9d977669bf63d8e39128336319b48c0432dd
--- /dev/null
+++ b/models/research/im2txt/im2txt/data/build_mscoco_data.py
@@ -0,0 +1,483 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts MSCOCO data to TFRecord file format with SequenceExample protos.
+
+The MSCOCO images are expected to reside in JPEG files located in the following
+directory structure:
+
+ train_image_dir/COCO_train2014_000000000151.jpg
+ train_image_dir/COCO_train2014_000000000260.jpg
+ ...
+
+and
+
+ val_image_dir/COCO_val2014_000000000042.jpg
+ val_image_dir/COCO_val2014_000000000073.jpg
+ ...
+
+The MSCOCO annotations JSON files are expected to reside in train_captions_file
+and val_captions_file respectively.
+
+This script converts the combined MSCOCO data into sharded data files consisting
+of 256, 4 and 8 TFRecord files, respectively:
+
+ output_dir/train-00000-of-00256
+ output_dir/train-00001-of-00256
+ ...
+ output_dir/train-00255-of-00256
+
+and
+
+ output_dir/val-00000-of-00004
+ ...
+ output_dir/val-00003-of-00004
+
+and
+
+ output_dir/test-00000-of-00008
+ ...
+ output_dir/test-00007-of-00008
+
+Each TFRecord file contains ~2300 records. Each record within the TFRecord file
+is a serialized SequenceExample proto consisting of precisely one image-caption
+pair. Note that each image has multiple captions (usually 5) and therefore each
+image is replicated multiple times in the TFRecord files.
+
+The SequenceExample proto contains the following fields:
+
+ context:
+ image/image_id: integer MSCOCO image identifier
+ image/data: string containing JPEG encoded image in RGB colorspace
+
+ feature_lists:
+ image/caption: list of strings containing the (tokenized) caption words
+ image/caption_ids: list of integer ids corresponding to the caption words
+
+The captions are tokenized using the NLTK (http://www.nltk.org/) word tokenizer.
+The vocabulary of word identifiers is constructed from the sorted list (by
+descending frequency) of word tokens in the training set. Only tokens appearing
+at least 4 times are considered; all other words get the "unknown" word id.
+
+NOTE: This script will consume around 100GB of disk space because each image
+in the MSCOCO dataset is replicated ~5 times (once per caption) in the output.
+This is done for two reasons:
+ 1. In order to better shuffle the training data.
+ 2. It makes it easier to perform asynchronous preprocessing of each image in
+ TensorFlow.
+
+Running this script using 16 threads may take around 1 hour on a HP Z420.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import Counter
+from collections import namedtuple
+from datetime import datetime
+import json
+import os.path
+import random
+import sys
+import threading
+
+
+
+import nltk.tokenize
+import numpy as np
+from six.moves import xrange
+import tensorflow as tf
+
+tf.flags.DEFINE_string("train_image_dir", "/tmp/train2014/",
+ "Training image directory.")
+tf.flags.DEFINE_string("val_image_dir", "/tmp/val2014",
+ "Validation image directory.")
+
+tf.flags.DEFINE_string("train_captions_file", "/tmp/captions_train2014.json",
+ "Training captions JSON file.")
+tf.flags.DEFINE_string("val_captions_file", "/tmp/captions_val2014.json",
+ "Validation captions JSON file.")
+
+tf.flags.DEFINE_string("output_dir", "/tmp/", "Output data directory.")
+
+tf.flags.DEFINE_integer("train_shards", 256,
+ "Number of shards in training TFRecord files.")
+tf.flags.DEFINE_integer("val_shards", 4,
+ "Number of shards in validation TFRecord files.")
+tf.flags.DEFINE_integer("test_shards", 8,
+ "Number of shards in testing TFRecord files.")
+
+tf.flags.DEFINE_string("start_word", "",
+ "Special word added to the beginning of each sentence.")
+tf.flags.DEFINE_string("end_word", "",
+ "Special word added to the end of each sentence.")
+tf.flags.DEFINE_string("unknown_word", "",
+ "Special word meaning 'unknown'.")
+tf.flags.DEFINE_integer("min_word_count", 4,
+ "The minimum number of occurrences of each word in the "
+ "training set for inclusion in the vocabulary.")
+tf.flags.DEFINE_string("word_counts_output_file", "/tmp/word_counts.txt",
+ "Output vocabulary file of word counts.")
+
+tf.flags.DEFINE_integer("num_threads", 8,
+ "Number of threads to preprocess the images.")
+
+FLAGS = tf.flags.FLAGS
+
+ImageMetadata = namedtuple("ImageMetadata",
+ ["image_id", "filename", "captions"])
+
+
+class Vocabulary(object):
+ """Simple vocabulary wrapper."""
+
+ def __init__(self, vocab, unk_id):
+ """Initializes the vocabulary.
+
+ Args:
+ vocab: A dictionary of word to word_id.
+ unk_id: Id of the special 'unknown' word.
+ """
+ self._vocab = vocab
+ self._unk_id = unk_id
+
+ def word_to_id(self, word):
+ """Returns the integer id of a word string."""
+ if word in self._vocab:
+ return self._vocab[word]
+ else:
+ return self._unk_id
+
+
+class ImageDecoder(object):
+ """Helper class for decoding images in TensorFlow."""
+
+ def __init__(self):
+ # Create a single TensorFlow Session for all image decoding calls.
+ self._sess = tf.Session()
+
+ # TensorFlow ops for JPEG decoding.
+ self._encoded_jpeg = tf.placeholder(dtype=tf.string)
+ self._decode_jpeg = tf.image.decode_jpeg(self._encoded_jpeg, channels=3)
+
+ def decode_jpeg(self, encoded_jpeg):
+ image = self._sess.run(self._decode_jpeg,
+ feed_dict={self._encoded_jpeg: encoded_jpeg})
+ assert len(image.shape) == 3
+ assert image.shape[2] == 3
+ return image
+
+
+def _int64_feature(value):
+ """Wrapper for inserting an int64 Feature into a SequenceExample proto."""
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+ """Wrapper for inserting a bytes Feature into a SequenceExample proto."""
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value)]))
+
+
+def _int64_feature_list(values):
+ """Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
+ return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])
+
+
+def _bytes_feature_list(values):
+ """Wrapper for inserting a bytes FeatureList into a SequenceExample proto."""
+ return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values])
+
+
+def _to_sequence_example(image, decoder, vocab):
+ """Builds a SequenceExample proto for an image-caption pair.
+
+ Args:
+ image: An ImageMetadata object.
+ decoder: An ImageDecoder object.
+ vocab: A Vocabulary object.
+
+ Returns:
+ A SequenceExample proto.
+ """
+ with tf.gfile.FastGFile(image.filename, "r") as f:
+ encoded_image = f.read()
+
+ try:
+ decoder.decode_jpeg(encoded_image)
+ except (tf.errors.InvalidArgumentError, AssertionError):
+ print("Skipping file with invalid JPEG data: %s" % image.filename)
+ return
+
+ context = tf.train.Features(feature={
+ "image/image_id": _int64_feature(image.image_id),
+ "image/data": _bytes_feature(encoded_image),
+ })
+
+ assert len(image.captions) == 1
+ caption = image.captions[0]
+ caption_ids = [vocab.word_to_id(word) for word in caption]
+ feature_lists = tf.train.FeatureLists(feature_list={
+ "image/caption": _bytes_feature_list(caption),
+ "image/caption_ids": _int64_feature_list(caption_ids)
+ })
+ sequence_example = tf.train.SequenceExample(
+ context=context, feature_lists=feature_lists)
+
+ return sequence_example
+
+
+def _process_image_files(thread_index, ranges, name, images, decoder, vocab,
+ num_shards):
+ """Processes and saves a subset of images as TFRecord files in one thread.
+
+ Args:
+ thread_index: Integer thread identifier within [0, len(ranges)].
+ ranges: A list of pairs of integers specifying the ranges of the dataset to
+ process in parallel.
+ name: Unique identifier specifying the dataset.
+ images: List of ImageMetadata.
+ decoder: An ImageDecoder object.
+ vocab: A Vocabulary object.
+ num_shards: Integer number of shards for the output files.
+ """
+ # Each thread produces N shards where N = num_shards / num_threads. For
+ # instance, if num_shards = 128, and num_threads = 2, then the first thread
+ # would produce shards [0, 64).
+ num_threads = len(ranges)
+ assert not num_shards % num_threads
+ num_shards_per_batch = int(num_shards / num_threads)
+
+ shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1],
+ num_shards_per_batch + 1).astype(int)
+ num_images_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
+
+ counter = 0
+ for s in xrange(num_shards_per_batch):
+ # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
+ shard = thread_index * num_shards_per_batch + s
+ output_filename = "%s-%.5d-of-%.5d" % (name, shard, num_shards)
+ output_file = os.path.join(FLAGS.output_dir, output_filename)
+ writer = tf.python_io.TFRecordWriter(output_file)
+
+ shard_counter = 0
+ images_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
+ for i in images_in_shard:
+ image = images[i]
+
+ sequence_example = _to_sequence_example(image, decoder, vocab)
+ if sequence_example is not None:
+ writer.write(sequence_example.SerializeToString())
+ shard_counter += 1
+ counter += 1
+
+ if not counter % 1000:
+ print("%s [thread %d]: Processed %d of %d items in thread batch." %
+ (datetime.now(), thread_index, counter, num_images_in_thread))
+ sys.stdout.flush()
+
+ writer.close()
+ print("%s [thread %d]: Wrote %d image-caption pairs to %s" %
+ (datetime.now(), thread_index, shard_counter, output_file))
+ sys.stdout.flush()
+ shard_counter = 0
+ print("%s [thread %d]: Wrote %d image-caption pairs to %d shards." %
+ (datetime.now(), thread_index, counter, num_shards_per_batch))
+ sys.stdout.flush()
+
+
+def _process_dataset(name, images, vocab, num_shards):
+ """Processes a complete data set and saves it as a TFRecord.
+
+ Args:
+ name: Unique identifier specifying the dataset.
+ images: List of ImageMetadata.
+ vocab: A Vocabulary object.
+ num_shards: Integer number of shards for the output files.
+ """
+ # Break up each image into a separate entity for each caption.
+ images = [ImageMetadata(image.image_id, image.filename, [caption])
+ for image in images for caption in image.captions]
+
+ # Shuffle the ordering of images. Make the randomization repeatable.
+ random.seed(12345)
+ random.shuffle(images)
+
+ # Break the images into num_threads batches. Batch i is defined as
+ # images[ranges[i][0]:ranges[i][1]].
+ num_threads = min(num_shards, FLAGS.num_threads)
+ spacing = np.linspace(0, len(images), num_threads + 1).astype(np.int)
+ ranges = []
+ threads = []
+ for i in xrange(len(spacing) - 1):
+ ranges.append([spacing[i], spacing[i + 1]])
+
+ # Create a mechanism for monitoring when all threads are finished.
+ coord = tf.train.Coordinator()
+
+ # Create a utility for decoding JPEG images to run sanity checks.
+ decoder = ImageDecoder()
+
+ # Launch a thread for each batch.
+ print("Launching %d threads for spacings: %s" % (num_threads, ranges))
+ for thread_index in xrange(len(ranges)):
+ args = (thread_index, ranges, name, images, decoder, vocab, num_shards)
+ t = threading.Thread(target=_process_image_files, args=args)
+ t.start()
+ threads.append(t)
+
+ # Wait for all the threads to terminate.
+ coord.join(threads)
+ print("%s: Finished processing all %d image-caption pairs in data set '%s'." %
+ (datetime.now(), len(images), name))
+
+
+def _create_vocab(captions):
+ """Creates the vocabulary of word to word_id.
+
+ The vocabulary is saved to disk in a text file of word counts. The id of each
+ word in the file is its corresponding 0-based line number.
+
+ Args:
+ captions: A list of lists of strings.
+
+ Returns:
+ A Vocabulary object.
+ """
+ print("Creating vocabulary.")
+ counter = Counter()
+ for c in captions:
+ counter.update(c)
+ print("Total words:", len(counter))
+
+ # Filter uncommon words and sort by descending count.
+ word_counts = [x for x in counter.items() if x[1] >= FLAGS.min_word_count]
+ word_counts.sort(key=lambda x: x[1], reverse=True)
+ print("Words in vocabulary:", len(word_counts))
+
+ # Write out the word counts file.
+ with tf.gfile.FastGFile(FLAGS.word_counts_output_file, "w") as f:
+ f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts]))
+ print("Wrote vocabulary file:", FLAGS.word_counts_output_file)
+
+ # Create the vocabulary dictionary.
+ reverse_vocab = [x[0] for x in word_counts]
+ unk_id = len(reverse_vocab)
+ vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
+ vocab = Vocabulary(vocab_dict, unk_id)
+
+ return vocab
+
+
+def _process_caption(caption):
+ """Processes a caption string into a list of tonenized words.
+
+ Args:
+ caption: A string caption.
+
+ Returns:
+ A list of strings; the tokenized caption.
+ """
+ tokenized_caption = [FLAGS.start_word]
+ tokenized_caption.extend(nltk.tokenize.word_tokenize(caption.lower()))
+ tokenized_caption.append(FLAGS.end_word)
+ return tokenized_caption
+
+
+def _load_and_process_metadata(captions_file, image_dir):
+ """Loads image metadata from a JSON file and processes the captions.
+
+ Args:
+ captions_file: JSON file containing caption annotations.
+ image_dir: Directory containing the image files.
+
+ Returns:
+ A list of ImageMetadata.
+ """
+ with tf.gfile.FastGFile(captions_file, "r") as f:
+ caption_data = json.load(f)
+
+ # Extract the filenames.
+ id_to_filename = [(x["id"], x["file_name"]) for x in caption_data["images"]]
+
+ # Extract the captions. Each image_id is associated with multiple captions.
+ id_to_captions = {}
+ for annotation in caption_data["annotations"]:
+ image_id = annotation["image_id"]
+ caption = annotation["caption"]
+ id_to_captions.setdefault(image_id, [])
+ id_to_captions[image_id].append(caption)
+
+ assert len(id_to_filename) == len(id_to_captions)
+ assert set([x[0] for x in id_to_filename]) == set(id_to_captions.keys())
+ print("Loaded caption metadata for %d images from %s" %
+ (len(id_to_filename), captions_file))
+
+ # Process the captions and combine the data into a list of ImageMetadata.
+ print("Processing captions.")
+ image_metadata = []
+ num_captions = 0
+ for image_id, base_filename in id_to_filename:
+ filename = os.path.join(image_dir, base_filename)
+ captions = [_process_caption(c) for c in id_to_captions[image_id]]
+ image_metadata.append(ImageMetadata(image_id, filename, captions))
+ num_captions += len(captions)
+ print("Finished processing %d captions for %d images in %s" %
+ (num_captions, len(id_to_filename), captions_file))
+
+ return image_metadata
+
+
+def main(unused_argv):
+ def _is_valid_num_shards(num_shards):
+ """Returns True if num_shards is compatible with FLAGS.num_threads."""
+ return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads
+
+ assert _is_valid_num_shards(FLAGS.train_shards), (
+ "Please make the FLAGS.num_threads commensurate with FLAGS.train_shards")
+ assert _is_valid_num_shards(FLAGS.val_shards), (
+ "Please make the FLAGS.num_threads commensurate with FLAGS.val_shards")
+ assert _is_valid_num_shards(FLAGS.test_shards), (
+ "Please make the FLAGS.num_threads commensurate with FLAGS.test_shards")
+
+ if not tf.gfile.IsDirectory(FLAGS.output_dir):
+ tf.gfile.MakeDirs(FLAGS.output_dir)
+
+ # Load image metadata from caption files.
+ mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file,
+ FLAGS.train_image_dir)
+ mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file,
+ FLAGS.val_image_dir)
+
+ # Redistribute the MSCOCO data as follows:
+ # train_dataset = 100% of mscoco_train_dataset + 85% of mscoco_val_dataset.
+ # val_dataset = 5% of mscoco_val_dataset (for validation during training).
+ # test_dataset = 10% of mscoco_val_dataset (for final evaluation).
+ train_cutoff = int(0.85 * len(mscoco_val_dataset))
+ val_cutoff = int(0.90 * len(mscoco_val_dataset))
+ train_dataset = mscoco_train_dataset + mscoco_val_dataset[0:train_cutoff]
+ val_dataset = mscoco_val_dataset[train_cutoff:val_cutoff]
+ test_dataset = mscoco_val_dataset[val_cutoff:]
+
+ # Create vocabulary from the training captions.
+ train_captions = [c for image in train_dataset for c in image.captions]
+ vocab = _create_vocab(train_captions)
+
+ _process_dataset("train", train_dataset, vocab, FLAGS.train_shards)
+ _process_dataset("val", val_dataset, vocab, FLAGS.val_shards)
+ _process_dataset("test", test_dataset, vocab, FLAGS.test_shards)
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/im2txt/im2txt/data/download_and_preprocess_mscoco.sh b/models/research/im2txt/im2txt/data/download_and_preprocess_mscoco.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ab3ff28d576adcbf1992de4c00dfa350dd93b1c3
--- /dev/null
+++ b/models/research/im2txt/im2txt/data/download_and_preprocess_mscoco.sh
@@ -0,0 +1,90 @@
+#!/bin/bash
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Script to download and preprocess the MSCOCO data set.
+#
+# The outputs of this script are sharded TFRecord files containing serialized
+# SequenceExample protocol buffers. See build_mscoco_data.py for details of how
+# the SequenceExample protocol buffers are constructed.
+#
+# usage:
+# ./download_and_preprocess_mscoco.sh
+set -e
+
+if [ -z "$1" ]; then
+ echo "usage download_and_preproces_mscoco.sh [data dir]"
+ exit
+fi
+
+if [ "$(uname)" == "Darwin" ]; then
+ UNZIP="tar -xf"
+else
+ UNZIP="unzip -nq"
+fi
+
+# Create the output directories.
+OUTPUT_DIR="${1%/}"
+SCRATCH_DIR="${OUTPUT_DIR}/raw-data"
+mkdir -p "${OUTPUT_DIR}"
+mkdir -p "${SCRATCH_DIR}"
+CURRENT_DIR=$(pwd)
+WORK_DIR="$0.runfiles/im2txt/im2txt"
+
+# Helper function to download and unpack a .zip file.
+function download_and_unzip() {
+ local BASE_URL=${1}
+ local FILENAME=${2}
+
+ if [ ! -f ${FILENAME} ]; then
+ echo "Downloading ${FILENAME} to $(pwd)"
+ wget -nd -c "${BASE_URL}/${FILENAME}"
+ else
+ echo "Skipping download of ${FILENAME}"
+ fi
+ echo "Unzipping ${FILENAME}"
+ ${UNZIP} ${FILENAME}
+}
+
+cd ${SCRATCH_DIR}
+
+# Download the images.
+BASE_IMAGE_URL="http://msvocds.blob.core.windows.net/coco2014"
+
+TRAIN_IMAGE_FILE="train2014.zip"
+download_and_unzip ${BASE_IMAGE_URL} ${TRAIN_IMAGE_FILE}
+TRAIN_IMAGE_DIR="${SCRATCH_DIR}/train2014"
+
+VAL_IMAGE_FILE="val2014.zip"
+download_and_unzip ${BASE_IMAGE_URL} ${VAL_IMAGE_FILE}
+VAL_IMAGE_DIR="${SCRATCH_DIR}/val2014"
+
+# Download the captions.
+BASE_CAPTIONS_URL="http://msvocds.blob.core.windows.net/annotations-1-0-3"
+CAPTIONS_FILE="captions_train-val2014.zip"
+download_and_unzip ${BASE_CAPTIONS_URL} ${CAPTIONS_FILE}
+TRAIN_CAPTIONS_FILE="${SCRATCH_DIR}/annotations/captions_train2014.json"
+VAL_CAPTIONS_FILE="${SCRATCH_DIR}/annotations/captions_val2014.json"
+
+# Build TFRecords of the image data.
+cd "${CURRENT_DIR}"
+BUILD_SCRIPT="${WORK_DIR}/build_mscoco_data"
+"${BUILD_SCRIPT}" \
+ --train_image_dir="${TRAIN_IMAGE_DIR}" \
+ --val_image_dir="${VAL_IMAGE_DIR}" \
+ --train_captions_file="${TRAIN_CAPTIONS_FILE}" \
+ --val_captions_file="${VAL_CAPTIONS_FILE}" \
+ --output_dir="${OUTPUT_DIR}" \
+ --word_counts_output_file="${OUTPUT_DIR}/word_counts.txt" \
diff --git a/models/research/im2txt/im2txt/evaluate.py b/models/research/im2txt/im2txt/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c81a59dab56626cb2c6a19433544f4d239cbd9d
--- /dev/null
+++ b/models/research/im2txt/im2txt/evaluate.py
@@ -0,0 +1,198 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Evaluate the model.
+
+This script should be run concurrently with training so that summaries show up
+in TensorBoard.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os.path
+import time
+
+
+import numpy as np
+import tensorflow as tf
+
+from im2txt import configuration
+from im2txt import show_and_tell_model
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("input_file_pattern", "",
+ "File pattern of sharded TFRecord input files.")
+tf.flags.DEFINE_string("checkpoint_dir", "",
+ "Directory containing model checkpoints.")
+tf.flags.DEFINE_string("eval_dir", "", "Directory to write event logs.")
+
+tf.flags.DEFINE_integer("eval_interval_secs", 600,
+ "Interval between evaluation runs.")
+tf.flags.DEFINE_integer("num_eval_examples", 10132,
+ "Number of examples for evaluation.")
+
+tf.flags.DEFINE_integer("min_global_step", 5000,
+ "Minimum global step to run evaluation.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def evaluate_model(sess, model, global_step, summary_writer, summary_op):
+ """Computes perplexity-per-word over the evaluation dataset.
+
+ Summaries and perplexity-per-word are written out to the eval directory.
+
+ Args:
+ sess: Session object.
+ model: Instance of ShowAndTellModel; the model to evaluate.
+ global_step: Integer; global step of the model checkpoint.
+ summary_writer: Instance of FileWriter.
+ summary_op: Op for generating model summaries.
+ """
+ # Log model summaries on a single batch.
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, global_step)
+
+ # Compute perplexity over the entire dataset.
+ num_eval_batches = int(
+ math.ceil(FLAGS.num_eval_examples / model.config.batch_size))
+
+ start_time = time.time()
+ sum_losses = 0.
+ sum_weights = 0.
+ for i in range(num_eval_batches):
+ cross_entropy_losses, weights = sess.run([
+ model.target_cross_entropy_losses,
+ model.target_cross_entropy_loss_weights
+ ])
+ sum_losses += np.sum(cross_entropy_losses * weights)
+ sum_weights += np.sum(weights)
+ if not i % 100:
+ tf.logging.info("Computed losses for %d of %d batches.", i + 1,
+ num_eval_batches)
+ eval_time = time.time() - start_time
+
+ perplexity = math.exp(sum_losses / sum_weights)
+ tf.logging.info("Perplexity = %f (%.2g sec)", perplexity, eval_time)
+
+ # Log perplexity to the FileWriter.
+ summary = tf.Summary()
+ value = summary.value.add()
+ value.simple_value = perplexity
+ value.tag = "Perplexity"
+ summary_writer.add_summary(summary, global_step)
+
+ # Write the Events file to the eval directory.
+ summary_writer.flush()
+ tf.logging.info("Finished processing evaluation at global step %d.",
+ global_step)
+
+
+def run_once(model, saver, summary_writer, summary_op):
+ """Evaluates the latest model checkpoint.
+
+ Args:
+ model: Instance of ShowAndTellModel; the model to evaluate.
+ saver: Instance of tf.train.Saver for restoring model Variables.
+ summary_writer: Instance of FileWriter.
+ summary_op: Op for generating model summaries.
+ """
+ model_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
+ if not model_path:
+ tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
+ FLAGS.checkpoint_dir)
+ return
+
+ with tf.Session() as sess:
+ # Load model from checkpoint.
+ tf.logging.info("Loading model from checkpoint: %s", model_path)
+ saver.restore(sess, model_path)
+ global_step = tf.train.global_step(sess, model.global_step.name)
+ tf.logging.info("Successfully loaded %s at global step = %d.",
+ os.path.basename(model_path), global_step)
+ if global_step < FLAGS.min_global_step:
+ tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step,
+ FLAGS.min_global_step)
+ return
+
+ # Start the queue runners.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(coord=coord)
+
+ # Run evaluation on the latest checkpoint.
+ try:
+ evaluate_model(
+ sess=sess,
+ model=model,
+ global_step=global_step,
+ summary_writer=summary_writer,
+ summary_op=summary_op)
+ except Exception as e: # pylint: disable=broad-except
+ tf.logging.error("Evaluation failed.")
+ coord.request_stop(e)
+
+ coord.request_stop()
+ coord.join(threads, stop_grace_period_secs=10)
+
+
+def run():
+ """Runs evaluation in a loop, and logs summaries to TensorBoard."""
+ # Create the evaluation directory if it doesn't exist.
+ eval_dir = FLAGS.eval_dir
+ if not tf.gfile.IsDirectory(eval_dir):
+ tf.logging.info("Creating eval directory: %s", eval_dir)
+ tf.gfile.MakeDirs(eval_dir)
+
+ g = tf.Graph()
+ with g.as_default():
+ # Build the model for evaluation.
+ model_config = configuration.ModelConfig()
+ model_config.input_file_pattern = FLAGS.input_file_pattern
+ model = show_and_tell_model.ShowAndTellModel(model_config, mode="eval")
+ model.build()
+
+ # Create the Saver to restore model Variables.
+ saver = tf.train.Saver()
+
+ # Create the summary operation and the summary writer.
+ summary_op = tf.summary.merge_all()
+ summary_writer = tf.summary.FileWriter(eval_dir)
+
+ g.finalize()
+
+ # Run a new evaluation run every eval_interval_secs.
+ while True:
+ start = time.time()
+ tf.logging.info("Starting evaluation at " + time.strftime(
+ "%Y-%m-%d-%H:%M:%S", time.localtime()))
+ run_once(model, saver, summary_writer, summary_op)
+ time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
+ if time_to_next_eval > 0:
+ time.sleep(time_to_next_eval)
+
+
+def main(unused_argv):
+ assert FLAGS.input_file_pattern, "--input_file_pattern is required"
+ assert FLAGS.checkpoint_dir, "--checkpoint_dir is required"
+ assert FLAGS.eval_dir, "--eval_dir is required"
+ run()
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/im2txt/im2txt/inference_utils/BUILD b/models/research/im2txt/im2txt/inference_utils/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..82a15fd3ca487e542c41ab337404f8caa63b8c63
--- /dev/null
+++ b/models/research/im2txt/im2txt/inference_utils/BUILD
@@ -0,0 +1,31 @@
+package(default_visibility = ["//im2txt:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "inference_wrapper_base",
+ srcs = ["inference_wrapper_base.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "vocabulary",
+ srcs = ["vocabulary.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "caption_generator",
+ srcs = ["caption_generator.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "caption_generator_test",
+ srcs = ["caption_generator_test.py"],
+ deps = [
+ ":caption_generator",
+ ],
+)
diff --git a/models/research/im2txt/im2txt/inference_utils/caption_generator.py b/models/research/im2txt/im2txt/inference_utils/caption_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f158d3d2330e8f839efdad4cbc4d38811b58d826
--- /dev/null
+++ b/models/research/im2txt/im2txt/inference_utils/caption_generator.py
@@ -0,0 +1,213 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class for generating captions from an image-to-text model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import heapq
+import math
+
+
+import numpy as np
+
+
+class Caption(object):
+ """Represents a complete or partial caption."""
+
+ def __init__(self, sentence, state, logprob, score, metadata=None):
+ """Initializes the Caption.
+
+ Args:
+ sentence: List of word ids in the caption.
+ state: Model state after generating the previous word.
+ logprob: Log-probability of the caption.
+ score: Score of the caption.
+ metadata: Optional metadata associated with the partial sentence. If not
+ None, a list of strings with the same length as 'sentence'.
+ """
+ self.sentence = sentence
+ self.state = state
+ self.logprob = logprob
+ self.score = score
+ self.metadata = metadata
+
+ def __cmp__(self, other):
+ """Compares Captions by score."""
+ assert isinstance(other, Caption)
+ if self.score == other.score:
+ return 0
+ elif self.score < other.score:
+ return -1
+ else:
+ return 1
+
+ # For Python 3 compatibility (__cmp__ is deprecated).
+ def __lt__(self, other):
+ assert isinstance(other, Caption)
+ return self.score < other.score
+
+ # Also for Python 3 compatibility.
+ def __eq__(self, other):
+ assert isinstance(other, Caption)
+ return self.score == other.score
+
+
+class TopN(object):
+ """Maintains the top n elements of an incrementally provided set."""
+
+ def __init__(self, n):
+ self._n = n
+ self._data = []
+
+ def size(self):
+ assert self._data is not None
+ return len(self._data)
+
+ def push(self, x):
+ """Pushes a new element."""
+ assert self._data is not None
+ if len(self._data) < self._n:
+ heapq.heappush(self._data, x)
+ else:
+ heapq.heappushpop(self._data, x)
+
+ def extract(self, sort=False):
+ """Extracts all elements from the TopN. This is a destructive operation.
+
+ The only method that can be called immediately after extract() is reset().
+
+ Args:
+ sort: Whether to return the elements in descending sorted order.
+
+ Returns:
+ A list of data; the top n elements provided to the set.
+ """
+ assert self._data is not None
+ data = self._data
+ self._data = None
+ if sort:
+ data.sort(reverse=True)
+ return data
+
+ def reset(self):
+ """Returns the TopN to an empty state."""
+ self._data = []
+
+
+class CaptionGenerator(object):
+ """Class to generate captions from an image-to-text model."""
+
+ def __init__(self,
+ model,
+ vocab,
+ beam_size=3,
+ max_caption_length=20,
+ length_normalization_factor=0.0):
+ """Initializes the generator.
+
+ Args:
+ model: Object encapsulating a trained image-to-text model. Must have
+ methods feed_image() and inference_step(). For example, an instance of
+ InferenceWrapperBase.
+ vocab: A Vocabulary object.
+ beam_size: Beam size to use when generating captions.
+ max_caption_length: The maximum caption length before stopping the search.
+ length_normalization_factor: If != 0, a number x such that captions are
+ scored by logprob/length^x, rather than logprob. This changes the
+ relative scores of captions depending on their lengths. For example, if
+ x > 0 then longer captions will be favored.
+ """
+ self.vocab = vocab
+ self.model = model
+
+ self.beam_size = beam_size
+ self.max_caption_length = max_caption_length
+ self.length_normalization_factor = length_normalization_factor
+
+ def beam_search(self, sess, encoded_image):
+ """Runs beam search caption generation on a single image.
+
+ Args:
+ sess: TensorFlow Session object.
+ encoded_image: An encoded image string.
+
+ Returns:
+ A list of Caption sorted by descending score.
+ """
+ # Feed in the image to get the initial state.
+ initial_state = self.model.feed_image(sess, encoded_image)
+
+ initial_beam = Caption(
+ sentence=[self.vocab.start_id],
+ state=initial_state[0],
+ logprob=0.0,
+ score=0.0,
+ metadata=[""])
+ partial_captions = TopN(self.beam_size)
+ partial_captions.push(initial_beam)
+ complete_captions = TopN(self.beam_size)
+
+ # Run beam search.
+ for _ in range(self.max_caption_length - 1):
+ partial_captions_list = partial_captions.extract()
+ partial_captions.reset()
+ input_feed = np.array([c.sentence[-1] for c in partial_captions_list])
+ state_feed = np.array([c.state for c in partial_captions_list])
+
+ softmax, new_states, metadata = self.model.inference_step(sess,
+ input_feed,
+ state_feed)
+
+ for i, partial_caption in enumerate(partial_captions_list):
+ word_probabilities = softmax[i]
+ state = new_states[i]
+ # For this partial caption, get the beam_size most probable next words.
+ # Sort the indexes with numpy, select the last self.beam_size
+ # (3 by default) (ie, the most likely) and then reverse the sorted
+ # indexes with [::-1] to sort them from higher to lower.
+ most_likely_words = np.argsort(word_probabilities)[:-self.beam_size][::-1]
+
+ for w in most_likely_words:
+ p = word_probabilities[w]
+ if p < 1e-12:
+ continue # Avoid log(0).
+ sentence = partial_caption.sentence + [w]
+ logprob = partial_caption.logprob + math.log(p)
+ score = logprob
+ if metadata:
+ metadata_list = partial_caption.metadata + [metadata[i]]
+ else:
+ metadata_list = None
+ if w == self.vocab.end_id:
+ if self.length_normalization_factor > 0:
+ score /= len(sentence)**self.length_normalization_factor
+ beam = Caption(sentence, state, logprob, score, metadata_list)
+ complete_captions.push(beam)
+ else:
+ beam = Caption(sentence, state, logprob, score, metadata_list)
+ partial_captions.push(beam)
+ if partial_captions.size() == 0:
+ # We have run out of partial candidates; happens when beam_size = 1.
+ break
+
+ # If we have no complete captions then fall back to the partial captions.
+ # But never output a mixture of complete and partial captions because a
+ # partial caption could have a higher score than all the complete captions.
+ if not complete_captions.size():
+ complete_captions = partial_captions
+
+ return complete_captions.extract(sort=True)
diff --git a/models/research/im2txt/im2txt/inference_utils/caption_generator_test.py b/models/research/im2txt/im2txt/inference_utils/caption_generator_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbd069313ac4ddb10a8463d166ab282b68b2e24d
--- /dev/null
+++ b/models/research/im2txt/im2txt/inference_utils/caption_generator_test.py
@@ -0,0 +1,178 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for CaptionGenerator."""
+
+import math
+
+
+
+import numpy as np
+import tensorflow as tf
+
+from im2txt.inference_utils import caption_generator
+
+
+class FakeVocab(object):
+ """Fake Vocabulary for testing purposes."""
+
+ def __init__(self):
+ self.start_id = 0 # Word id denoting sentence start.
+ self.end_id = 1 # Word id denoting sentence end.
+
+
+class FakeModel(object):
+ """Fake model for testing purposes."""
+
+ def __init__(self):
+ # Number of words in the vocab.
+ self._vocab_size = 12
+
+ # Dimensionality of the nominal model state.
+ self._state_size = 1
+
+ # Map of previous word to the probability distribution of the next word.
+ self._probabilities = {
+ 0: {1: 0.1,
+ 2: 0.2,
+ 3: 0.3,
+ 4: 0.4},
+ 2: {5: 0.1,
+ 6: 0.9},
+ 3: {1: 0.1,
+ 7: 0.4,
+ 8: 0.5},
+ 4: {1: 0.3,
+ 9: 0.3,
+ 10: 0.4},
+ 5: {1: 1.0},
+ 6: {1: 1.0},
+ 7: {1: 1.0},
+ 8: {1: 1.0},
+ 9: {1: 0.5,
+ 11: 0.5},
+ 10: {1: 1.0},
+ 11: {1: 1.0},
+ }
+
+ # pylint: disable=unused-argument
+
+ def feed_image(self, sess, encoded_image):
+ # Return a nominal model state.
+ return np.zeros([1, self._state_size])
+
+ def inference_step(self, sess, input_feed, state_feed):
+ # Compute the matrix of softmax distributions for the next batch of words.
+ batch_size = input_feed.shape[0]
+ softmax_output = np.zeros([batch_size, self._vocab_size])
+ for batch_index, word_id in enumerate(input_feed):
+ for next_word, probability in self._probabilities[word_id].items():
+ softmax_output[batch_index, next_word] = probability
+
+ # Nominal state and metadata.
+ new_state = np.zeros([batch_size, self._state_size])
+ metadata = None
+
+ return softmax_output, new_state, metadata
+
+ # pylint: enable=unused-argument
+
+
+class CaptionGeneratorTest(tf.test.TestCase):
+
+ def _assertExpectedCaptions(self,
+ expected_captions,
+ beam_size=3,
+ max_caption_length=20,
+ length_normalization_factor=0):
+ """Tests that beam search generates the expected captions.
+
+ Args:
+ expected_captions: A sequence of pairs (sentence, probability), where
+ sentence is a list of integer ids and probability is a float in [0, 1].
+ beam_size: Parameter passed to beam_search().
+ max_caption_length: Parameter passed to beam_search().
+ length_normalization_factor: Parameter passed to beam_search().
+ """
+ expected_sentences = [c[0] for c in expected_captions]
+ expected_probabilities = [c[1] for c in expected_captions]
+
+ # Generate captions.
+ generator = caption_generator.CaptionGenerator(
+ model=FakeModel(),
+ vocab=FakeVocab(),
+ beam_size=beam_size,
+ max_caption_length=max_caption_length,
+ length_normalization_factor=length_normalization_factor)
+ actual_captions = generator.beam_search(sess=None, encoded_image=None)
+
+ actual_sentences = [c.sentence for c in actual_captions]
+ actual_probabilities = [math.exp(c.logprob) for c in actual_captions]
+
+ self.assertEqual(expected_sentences, actual_sentences)
+ self.assertAllClose(expected_probabilities, actual_probabilities)
+
+ def testBeamSize(self):
+ # Beam size = 1.
+ expected = [([0, 4, 10, 1], 0.16)]
+ self._assertExpectedCaptions(expected, beam_size=1)
+
+ # Beam size = 2.
+ expected = [([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)]
+ self._assertExpectedCaptions(expected, beam_size=2)
+
+ # Beam size = 3.
+ expected = [
+ ([0, 2, 6, 1], 0.18), ([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)
+ ]
+ self._assertExpectedCaptions(expected, beam_size=3)
+
+ def testMaxLength(self):
+ # Max length = 1.
+ expected = [([0], 1.0)]
+ self._assertExpectedCaptions(expected, max_caption_length=1)
+
+ # Max length = 2.
+ # There are no complete sentences, so partial sentences are returned.
+ expected = [([0, 4], 0.4), ([0, 3], 0.3), ([0, 2], 0.2)]
+ self._assertExpectedCaptions(expected, max_caption_length=2)
+
+ # Max length = 3.
+ # There is at least one complete sentence, so only complete sentences are
+ # returned.
+ expected = [([0, 4, 1], 0.12), ([0, 3, 1], 0.03)]
+ self._assertExpectedCaptions(expected, max_caption_length=3)
+
+ # Max length = 4.
+ expected = [
+ ([0, 2, 6, 1], 0.18), ([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)
+ ]
+ self._assertExpectedCaptions(expected, max_caption_length=4)
+
+ def testLengthNormalization(self):
+ # Length normalization factor = 3.
+ # The longest caption is returned first, despite having low probability,
+ # because it has the highest log(probability)/length**3.
+ expected = [
+ ([0, 4, 9, 11, 1], 0.06),
+ ([0, 2, 6, 1], 0.18),
+ ([0, 4, 10, 1], 0.16),
+ ([0, 3, 8, 1], 0.15),
+ ]
+ self._assertExpectedCaptions(
+ expected, beam_size=4, length_normalization_factor=3)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/im2txt/im2txt/inference_utils/inference_wrapper_base.py b/models/research/im2txt/im2txt/inference_utils/inference_wrapper_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e94cd6af474488e4b8175fc959e1dbe33cca18c9
--- /dev/null
+++ b/models/research/im2txt/im2txt/inference_utils/inference_wrapper_base.py
@@ -0,0 +1,181 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base wrapper class for performing inference with an image-to-text model.
+
+Subclasses must implement the following methods:
+
+ build_model():
+ Builds the model for inference and returns the model object.
+
+ feed_image():
+ Takes an encoded image and returns the initial model state, where "state"
+ is a numpy array whose specifics are defined by the subclass, e.g.
+ concatenated LSTM state. It's assumed that feed_image() will be called
+ precisely once at the start of inference for each image. Subclasses may
+ compute and/or save per-image internal context in this method.
+
+ inference_step():
+ Takes a batch of inputs and states at a single time-step. Returns the
+ softmax output corresponding to the inputs, and the new states of the batch.
+ Optionally also returns metadata about the current inference step, e.g. a
+ serialized numpy array containing activations from a particular model layer.
+
+Client usage:
+ 1. Build the model inference graph via build_graph_from_config() or
+ build_graph_from_proto().
+ 2. Call the resulting restore_fn to load the model checkpoint.
+ 3. For each image in a batch of images:
+ a) Call feed_image() once to get the initial state.
+ b) For each step of caption generation, call inference_step().
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+
+import tensorflow as tf
+
+# pylint: disable=unused-argument
+
+
+class InferenceWrapperBase(object):
+ """Base wrapper class for performing inference with an image-to-text model."""
+
+ def __init__(self):
+ pass
+
+ def build_model(self, model_config):
+ """Builds the model for inference.
+
+ Args:
+ model_config: Object containing configuration for building the model.
+
+ Returns:
+ model: The model object.
+ """
+ tf.logging.fatal("Please implement build_model in subclass")
+
+ def _create_restore_fn(self, checkpoint_path, saver):
+ """Creates a function that restores a model from checkpoint.
+
+ Args:
+ checkpoint_path: Checkpoint file or a directory containing a checkpoint
+ file.
+ saver: Saver for restoring variables from the checkpoint file.
+
+ Returns:
+ restore_fn: A function such that restore_fn(sess) loads model variables
+ from the checkpoint file.
+
+ Raises:
+ ValueError: If checkpoint_path does not refer to a checkpoint file or a
+ directory containing a checkpoint file.
+ """
+ if tf.gfile.IsDirectory(checkpoint_path):
+ checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
+ if not checkpoint_path:
+ raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
+
+ def _restore_fn(sess):
+ tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
+ saver.restore(sess, checkpoint_path)
+ tf.logging.info("Successfully loaded checkpoint: %s",
+ os.path.basename(checkpoint_path))
+
+ return _restore_fn
+
+ def build_graph_from_config(self, model_config, checkpoint_path):
+ """Builds the inference graph from a configuration object.
+
+ Args:
+ model_config: Object containing configuration for building the model.
+ checkpoint_path: Checkpoint file or a directory containing a checkpoint
+ file.
+
+ Returns:
+ restore_fn: A function such that restore_fn(sess) loads model variables
+ from the checkpoint file.
+ """
+ tf.logging.info("Building model.")
+ self.build_model(model_config)
+ saver = tf.train.Saver()
+
+ return self._create_restore_fn(checkpoint_path, saver)
+
+ def build_graph_from_proto(self, graph_def_file, saver_def_file,
+ checkpoint_path):
+ """Builds the inference graph from serialized GraphDef and SaverDef protos.
+
+ Args:
+ graph_def_file: File containing a serialized GraphDef proto.
+ saver_def_file: File containing a serialized SaverDef proto.
+ checkpoint_path: Checkpoint file or a directory containing a checkpoint
+ file.
+
+ Returns:
+ restore_fn: A function such that restore_fn(sess) loads model variables
+ from the checkpoint file.
+ """
+ # Load the Graph.
+ tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
+ graph_def = tf.GraphDef()
+ with tf.gfile.FastGFile(graph_def_file, "rb") as f:
+ graph_def.ParseFromString(f.read())
+ tf.import_graph_def(graph_def, name="")
+
+ # Load the Saver.
+ tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
+ saver_def = tf.train.SaverDef()
+ with tf.gfile.FastGFile(saver_def_file, "rb") as f:
+ saver_def.ParseFromString(f.read())
+ saver = tf.train.Saver(saver_def=saver_def)
+
+ return self._create_restore_fn(checkpoint_path, saver)
+
+ def feed_image(self, sess, encoded_image):
+ """Feeds an image and returns the initial model state.
+
+ See comments at the top of file.
+
+ Args:
+ sess: TensorFlow Session object.
+ encoded_image: An encoded image string.
+
+ Returns:
+ state: A numpy array of shape [1, state_size].
+ """
+ tf.logging.fatal("Please implement feed_image in subclass")
+
+ def inference_step(self, sess, input_feed, state_feed):
+ """Runs one step of inference.
+
+ Args:
+ sess: TensorFlow Session object.
+ input_feed: A numpy array of shape [batch_size].
+ state_feed: A numpy array of shape [batch_size, state_size].
+
+ Returns:
+ softmax_output: A numpy array of shape [batch_size, vocab_size].
+ new_state: A numpy array of shape [batch_size, state_size].
+ metadata: Optional. If not None, a string containing metadata about the
+ current inference step (e.g. serialized numpy array containing
+ activations from a particular model layer.).
+ """
+ tf.logging.fatal("Please implement inference_step in subclass")
+
+# pylint: enable=unused-argument
diff --git a/models/research/im2txt/im2txt/inference_utils/vocabulary.py b/models/research/im2txt/im2txt/inference_utils/vocabulary.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecf0ada9c2242cb32c2ea9a300d16411f5e83fab
--- /dev/null
+++ b/models/research/im2txt/im2txt/inference_utils/vocabulary.py
@@ -0,0 +1,78 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Vocabulary class for an image-to-text model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+
+class Vocabulary(object):
+ """Vocabulary class for an image-to-text model."""
+
+ def __init__(self,
+ vocab_file,
+ start_word="",
+ end_word="",
+ unk_word=""):
+ """Initializes the vocabulary.
+
+ Args:
+ vocab_file: File containing the vocabulary, where the words are the first
+ whitespace-separated token on each line (other tokens are ignored) and
+ the word ids are the corresponding line numbers.
+ start_word: Special word denoting sentence start.
+ end_word: Special word denoting sentence end.
+ unk_word: Special word denoting unknown words.
+ """
+ if not tf.gfile.Exists(vocab_file):
+ tf.logging.fatal("Vocab file %s not found.", vocab_file)
+ tf.logging.info("Initializing vocabulary from file: %s", vocab_file)
+
+ with tf.gfile.GFile(vocab_file, mode="r") as f:
+ reverse_vocab = list(f.readlines())
+ reverse_vocab = [line.split()[0] for line in reverse_vocab]
+ assert start_word in reverse_vocab
+ assert end_word in reverse_vocab
+ if unk_word not in reverse_vocab:
+ reverse_vocab.append(unk_word)
+ vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
+
+ tf.logging.info("Created vocabulary with %d words" % len(vocab))
+
+ self.vocab = vocab # vocab[word] = id
+ self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word
+
+ # Save special word ids.
+ self.start_id = vocab[start_word]
+ self.end_id = vocab[end_word]
+ self.unk_id = vocab[unk_word]
+
+ def word_to_id(self, word):
+ """Returns the integer word id of a word string."""
+ if word in self.vocab:
+ return self.vocab[word]
+ else:
+ return self.unk_id
+
+ def id_to_word(self, word_id):
+ """Returns the word string of an integer word id."""
+ if word_id >= len(self.reverse_vocab):
+ return self.reverse_vocab[self.unk_id]
+ else:
+ return self.reverse_vocab[word_id]
diff --git a/models/research/im2txt/im2txt/inference_wrapper.py b/models/research/im2txt/im2txt/inference_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..a047a9c8d084fd9e69c937915cea8553c2d51817
--- /dev/null
+++ b/models/research/im2txt/im2txt/inference_wrapper.py
@@ -0,0 +1,51 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Model wrapper class for performing inference with a ShowAndTellModel."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+from im2txt import show_and_tell_model
+from im2txt.inference_utils import inference_wrapper_base
+
+
+class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase):
+ """Model wrapper class for performing inference with a ShowAndTellModel."""
+
+ def __init__(self):
+ super(InferenceWrapper, self).__init__()
+
+ def build_model(self, model_config):
+ model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference")
+ model.build()
+ return model
+
+ def feed_image(self, sess, encoded_image):
+ initial_state = sess.run(fetches="lstm/initial_state:0",
+ feed_dict={"image_feed:0": encoded_image})
+ return initial_state
+
+ def inference_step(self, sess, input_feed, state_feed):
+ softmax_output, state_output = sess.run(
+ fetches=["softmax:0", "lstm/state:0"],
+ feed_dict={
+ "input_feed:0": input_feed,
+ "lstm/state_feed:0": state_feed,
+ })
+ return softmax_output, state_output, None
diff --git a/models/research/im2txt/im2txt/ops/BUILD b/models/research/im2txt/im2txt/ops/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..7d48bf3938c7ecfc94ac6498386e7ce214b8be92
--- /dev/null
+++ b/models/research/im2txt/im2txt/ops/BUILD
@@ -0,0 +1,32 @@
+package(default_visibility = ["//im2txt:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "image_processing",
+ srcs = ["image_processing.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "image_embedding",
+ srcs = ["image_embedding.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "image_embedding_test",
+ size = "small",
+ srcs = ["image_embedding_test.py"],
+ deps = [
+ ":image_embedding",
+ ],
+)
+
+py_library(
+ name = "inputs",
+ srcs = ["inputs.py"],
+ srcs_version = "PY2AND3",
+)
diff --git a/models/research/im2txt/im2txt/ops/image_embedding.py b/models/research/im2txt/im2txt/ops/image_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..58e3ddaa95fa799f245fe2a46f2e948be7d9ebf2
--- /dev/null
+++ b/models/research/im2txt/im2txt/ops/image_embedding.py
@@ -0,0 +1,114 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Image embedding ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base
+
+slim = tf.contrib.slim
+
+
+def inception_v3(images,
+ trainable=True,
+ is_training=True,
+ weight_decay=0.00004,
+ stddev=0.1,
+ dropout_keep_prob=0.8,
+ use_batch_norm=True,
+ batch_norm_params=None,
+ add_summaries=True,
+ scope="InceptionV3"):
+ """Builds an Inception V3 subgraph for image embeddings.
+
+ Args:
+ images: A float32 Tensor of shape [batch, height, width, channels].
+ trainable: Whether the inception submodel should be trainable or not.
+ is_training: Boolean indicating training mode or not.
+ weight_decay: Coefficient for weight regularization.
+ stddev: The standard deviation of the trunctated normal weight initializer.
+ dropout_keep_prob: Dropout keep probability.
+ use_batch_norm: Whether to use batch normalization.
+ batch_norm_params: Parameters for batch normalization. See
+ tf.contrib.layers.batch_norm for details.
+ add_summaries: Whether to add activation summaries.
+ scope: Optional Variable scope.
+
+ Returns:
+ end_points: A dictionary of activations from inception_v3 layers.
+ """
+ # Only consider the inception model to be in training mode if it's trainable.
+ is_inception_model_training = trainable and is_training
+
+ if use_batch_norm:
+ # Default parameters for batch normalization.
+ if not batch_norm_params:
+ batch_norm_params = {
+ "is_training": is_inception_model_training,
+ "trainable": trainable,
+ # Decay for the moving averages.
+ "decay": 0.9997,
+ # Epsilon to prevent 0s in variance.
+ "epsilon": 0.001,
+ # Collection containing the moving mean and moving variance.
+ "variables_collections": {
+ "beta": None,
+ "gamma": None,
+ "moving_mean": ["moving_vars"],
+ "moving_variance": ["moving_vars"],
+ }
+ }
+ else:
+ batch_norm_params = None
+
+ if trainable:
+ weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
+ else:
+ weights_regularizer = None
+
+ with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ weights_regularizer=weights_regularizer,
+ trainable=trainable):
+ with slim.arg_scope(
+ [slim.conv2d],
+ weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
+ activation_fn=tf.nn.relu,
+ normalizer_fn=slim.batch_norm,
+ normalizer_params=batch_norm_params):
+ net, end_points = inception_v3_base(images, scope=scope)
+ with tf.variable_scope("logits"):
+ shape = net.get_shape()
+ net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
+ net = slim.dropout(
+ net,
+ keep_prob=dropout_keep_prob,
+ is_training=is_inception_model_training,
+ scope="dropout")
+ net = slim.flatten(net, scope="flatten")
+
+ # Add summaries.
+ if add_summaries:
+ for v in end_points.values():
+ tf.contrib.layers.summaries.summarize_activation(v)
+
+ return net
diff --git a/models/research/im2txt/im2txt/ops/image_embedding_test.py b/models/research/im2txt/im2txt/ops/image_embedding_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66324d68eee0ec9c450375c25229d80283fc909f
--- /dev/null
+++ b/models/research/im2txt/im2txt/ops/image_embedding_test.py
@@ -0,0 +1,136 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tensorflow_models.im2txt.ops.image_embedding."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from im2txt.ops import image_embedding
+
+
+class InceptionV3Test(tf.test.TestCase):
+
+ def setUp(self):
+ super(InceptionV3Test, self).setUp()
+
+ batch_size = 4
+ height = 299
+ width = 299
+ num_channels = 3
+ self._images = tf.placeholder(tf.float32,
+ [batch_size, height, width, num_channels])
+ self._batch_size = batch_size
+
+ def _countInceptionParameters(self):
+ """Counts the number of parameters in the inception model at top scope."""
+ counter = {}
+ for v in tf.global_variables():
+ name_tokens = v.op.name.split("/")
+ if name_tokens[0] == "InceptionV3":
+ name = "InceptionV3/" + name_tokens[1]
+ num_params = v.get_shape().num_elements()
+ assert num_params
+ counter[name] = counter.get(name, 0) + num_params
+ return counter
+
+ def _verifyParameterCounts(self):
+ """Verifies the number of parameters in the inception model."""
+ param_counts = self._countInceptionParameters()
+ expected_param_counts = {
+ "InceptionV3/Conv2d_1a_3x3": 960,
+ "InceptionV3/Conv2d_2a_3x3": 9312,
+ "InceptionV3/Conv2d_2b_3x3": 18624,
+ "InceptionV3/Conv2d_3b_1x1": 5360,
+ "InceptionV3/Conv2d_4a_3x3": 138816,
+ "InceptionV3/Mixed_5b": 256368,
+ "InceptionV3/Mixed_5c": 277968,
+ "InceptionV3/Mixed_5d": 285648,
+ "InceptionV3/Mixed_6a": 1153920,
+ "InceptionV3/Mixed_6b": 1298944,
+ "InceptionV3/Mixed_6c": 1692736,
+ "InceptionV3/Mixed_6d": 1692736,
+ "InceptionV3/Mixed_6e": 2143872,
+ "InceptionV3/Mixed_7a": 1699584,
+ "InceptionV3/Mixed_7b": 5047872,
+ "InceptionV3/Mixed_7c": 6080064,
+ }
+ self.assertDictEqual(expected_param_counts, param_counts)
+
+ def _assertCollectionSize(self, expected_size, collection):
+ actual_size = len(tf.get_collection(collection))
+ if expected_size != actual_size:
+ self.fail("Found %d items in collection %s (expected %d)." %
+ (actual_size, collection, expected_size))
+
+ def testTrainableTrueIsTrainingTrue(self):
+ embeddings = image_embedding.inception_v3(
+ self._images, trainable=True, is_training=True)
+ self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
+
+ self._verifyParameterCounts()
+ self._assertCollectionSize(376, tf.GraphKeys.GLOBAL_VARIABLES)
+ self._assertCollectionSize(188, tf.GraphKeys.TRAINABLE_VARIABLES)
+ self._assertCollectionSize(188, tf.GraphKeys.UPDATE_OPS)
+ self._assertCollectionSize(94, tf.GraphKeys.REGULARIZATION_LOSSES)
+ self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
+ self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
+
+ def testTrainableTrueIsTrainingFalse(self):
+ embeddings = image_embedding.inception_v3(
+ self._images, trainable=True, is_training=False)
+ self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
+
+ self._verifyParameterCounts()
+ self._assertCollectionSize(376, tf.GraphKeys.GLOBAL_VARIABLES)
+ self._assertCollectionSize(188, tf.GraphKeys.TRAINABLE_VARIABLES)
+ self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
+ self._assertCollectionSize(94, tf.GraphKeys.REGULARIZATION_LOSSES)
+ self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
+ self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
+
+ def testTrainableFalseIsTrainingTrue(self):
+ embeddings = image_embedding.inception_v3(
+ self._images, trainable=False, is_training=True)
+ self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
+
+ self._verifyParameterCounts()
+ self._assertCollectionSize(376, tf.GraphKeys.GLOBAL_VARIABLES)
+ self._assertCollectionSize(0, tf.GraphKeys.TRAINABLE_VARIABLES)
+ self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
+ self._assertCollectionSize(0, tf.GraphKeys.REGULARIZATION_LOSSES)
+ self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
+ self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
+
+ def testTrainableFalseIsTrainingFalse(self):
+ embeddings = image_embedding.inception_v3(
+ self._images, trainable=False, is_training=False)
+ self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
+
+ self._verifyParameterCounts()
+ self._assertCollectionSize(376, tf.GraphKeys.GLOBAL_VARIABLES)
+ self._assertCollectionSize(0, tf.GraphKeys.TRAINABLE_VARIABLES)
+ self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
+ self._assertCollectionSize(0, tf.GraphKeys.REGULARIZATION_LOSSES)
+ self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
+ self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/research/im2txt/im2txt/ops/image_processing.py b/models/research/im2txt/im2txt/ops/image_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a7545547d5507febaabebf642ee81b6f94319f6
--- /dev/null
+++ b/models/research/im2txt/im2txt/ops/image_processing.py
@@ -0,0 +1,133 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Helper functions for image preprocessing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+
+def distort_image(image, thread_id):
+ """Perform random distortions on an image.
+
+ Args:
+ image: A float32 Tensor of shape [height, width, 3] with values in [0, 1).
+ thread_id: Preprocessing thread id used to select the ordering of color
+ distortions. There should be a multiple of 2 preprocessing threads.
+
+ Returns:
+ distorted_image: A float32 Tensor of shape [height, width, 3] with values in
+ [0, 1].
+ """
+ # Randomly flip horizontally.
+ with tf.name_scope("flip_horizontal", values=[image]):
+ image = tf.image.random_flip_left_right(image)
+
+ # Randomly distort the colors based on thread id.
+ color_ordering = thread_id % 2
+ with tf.name_scope("distort_color", values=[image]):
+ if color_ordering == 0:
+ image = tf.image.random_brightness(image, max_delta=32. / 255.)
+ image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+ image = tf.image.random_hue(image, max_delta=0.032)
+ image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+ elif color_ordering == 1:
+ image = tf.image.random_brightness(image, max_delta=32. / 255.)
+ image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+ image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+ image = tf.image.random_hue(image, max_delta=0.032)
+
+ # The random_* ops do not necessarily clamp.
+ image = tf.clip_by_value(image, 0.0, 1.0)
+
+ return image
+
+
+def process_image(encoded_image,
+ is_training,
+ height,
+ width,
+ resize_height=346,
+ resize_width=346,
+ thread_id=0,
+ image_format="jpeg"):
+ """Decode an image, resize and apply random distortions.
+
+ In training, images are distorted slightly differently depending on thread_id.
+
+ Args:
+ encoded_image: String Tensor containing the image.
+ is_training: Boolean; whether preprocessing for training or eval.
+ height: Height of the output image.
+ width: Width of the output image.
+ resize_height: If > 0, resize height before crop to final dimensions.
+ resize_width: If > 0, resize width before crop to final dimensions.
+ thread_id: Preprocessing thread id used to select the ordering of color
+ distortions. There should be a multiple of 2 preprocessing threads.
+ image_format: "jpeg" or "png".
+
+ Returns:
+ A float32 Tensor of shape [height, width, 3] with values in [-1, 1].
+
+ Raises:
+ ValueError: If image_format is invalid.
+ """
+ # Helper function to log an image summary to the visualizer. Summaries are
+ # only logged in thread 0.
+ def image_summary(name, image):
+ if not thread_id:
+ tf.summary.image(name, tf.expand_dims(image, 0))
+
+ # Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1).
+ with tf.name_scope("decode", values=[encoded_image]):
+ if image_format == "jpeg":
+ image = tf.image.decode_jpeg(encoded_image, channels=3)
+ elif image_format == "png":
+ image = tf.image.decode_png(encoded_image, channels=3)
+ else:
+ raise ValueError("Invalid image format: %s" % image_format)
+ image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+ image_summary("original_image", image)
+
+ # Resize image.
+ assert (resize_height > 0) == (resize_width > 0)
+ if resize_height:
+ image = tf.image.resize_images(image,
+ size=[resize_height, resize_width],
+ method=tf.image.ResizeMethod.BILINEAR)
+
+ # Crop to final dimensions.
+ if is_training:
+ image = tf.random_crop(image, [height, width, 3])
+ else:
+ # Central crop, assuming resize_height > height, resize_width > width.
+ image = tf.image.resize_image_with_crop_or_pad(image, height, width)
+
+ image_summary("resized_image", image)
+
+ # Randomly distort the image.
+ if is_training:
+ image = distort_image(image, thread_id)
+
+ image_summary("final_image", image)
+
+ # Rescale to [-1,1] instead of [0, 1]
+ image = tf.subtract(image, 0.5)
+ image = tf.multiply(image, 2.0)
+ return image
diff --git a/models/research/im2txt/im2txt/ops/inputs.py b/models/research/im2txt/im2txt/ops/inputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dc90c0ce5dfd5c30fe0e0e543999bb15cc13a8c
--- /dev/null
+++ b/models/research/im2txt/im2txt/ops/inputs.py
@@ -0,0 +1,204 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Input ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+
+def parse_sequence_example(serialized, image_feature, caption_feature):
+ """Parses a tensorflow.SequenceExample into an image and caption.
+
+ Args:
+ serialized: A scalar string Tensor; a single serialized SequenceExample.
+ image_feature: Name of SequenceExample context feature containing image
+ data.
+ caption_feature: Name of SequenceExample feature list containing integer
+ captions.
+
+ Returns:
+ encoded_image: A scalar string Tensor containing a JPEG encoded image.
+ caption: A 1-D uint64 Tensor with dynamically specified length.
+ """
+ context, sequence = tf.parse_single_sequence_example(
+ serialized,
+ context_features={
+ image_feature: tf.FixedLenFeature([], dtype=tf.string)
+ },
+ sequence_features={
+ caption_feature: tf.FixedLenSequenceFeature([], dtype=tf.int64),
+ })
+
+ encoded_image = context[image_feature]
+ caption = sequence[caption_feature]
+ return encoded_image, caption
+
+
+def prefetch_input_data(reader,
+ file_pattern,
+ is_training,
+ batch_size,
+ values_per_shard,
+ input_queue_capacity_factor=16,
+ num_reader_threads=1,
+ shard_queue_name="filename_queue",
+ value_queue_name="input_queue"):
+ """Prefetches string values from disk into an input queue.
+
+ In training the capacity of the queue is important because a larger queue
+ means better mixing of training examples between shards. The minimum number of
+ values kept in the queue is values_per_shard * input_queue_capacity_factor,
+ where input_queue_memory factor should be chosen to trade-off better mixing
+ with memory usage.
+
+ Args:
+ reader: Instance of tf.ReaderBase.
+ file_pattern: Comma-separated list of file patterns (e.g.
+ /tmp/train_data-?????-of-00100).
+ is_training: Boolean; whether prefetching for training or eval.
+ batch_size: Model batch size used to determine queue capacity.
+ values_per_shard: Approximate number of values per shard.
+ input_queue_capacity_factor: Minimum number of values to keep in the queue
+ in multiples of values_per_shard. See comments above.
+ num_reader_threads: Number of reader threads to fill the queue.
+ shard_queue_name: Name for the shards filename queue.
+ value_queue_name: Name for the values input queue.
+
+ Returns:
+ A Queue containing prefetched string values.
+ """
+ data_files = []
+ for pattern in file_pattern.split(","):
+ data_files.extend(tf.gfile.Glob(pattern))
+ if not data_files:
+ tf.logging.fatal("Found no input files matching %s", file_pattern)
+ else:
+ tf.logging.info("Prefetching values from %d files matching %s",
+ len(data_files), file_pattern)
+
+ if is_training:
+ filename_queue = tf.train.string_input_producer(
+ data_files, shuffle=True, capacity=16, name=shard_queue_name)
+ min_queue_examples = values_per_shard * input_queue_capacity_factor
+ capacity = min_queue_examples + 100 * batch_size
+ values_queue = tf.RandomShuffleQueue(
+ capacity=capacity,
+ min_after_dequeue=min_queue_examples,
+ dtypes=[tf.string],
+ name="random_" + value_queue_name)
+ else:
+ filename_queue = tf.train.string_input_producer(
+ data_files, shuffle=False, capacity=1, name=shard_queue_name)
+ capacity = values_per_shard + 3 * batch_size
+ values_queue = tf.FIFOQueue(
+ capacity=capacity, dtypes=[tf.string], name="fifo_" + value_queue_name)
+
+ enqueue_ops = []
+ for _ in range(num_reader_threads):
+ _, value = reader.read(filename_queue)
+ enqueue_ops.append(values_queue.enqueue([value]))
+ tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
+ values_queue, enqueue_ops))
+ tf.summary.scalar(
+ "queue/%s/fraction_of_%d_full" % (values_queue.name, capacity),
+ tf.cast(values_queue.size(), tf.float32) * (1. / capacity))
+
+ return values_queue
+
+
+def batch_with_dynamic_pad(images_and_captions,
+ batch_size,
+ queue_capacity,
+ add_summaries=True):
+ """Batches input images and captions.
+
+ This function splits the caption into an input sequence and a target sequence,
+ where the target sequence is the input sequence right-shifted by 1. Input and
+ target sequences are batched and padded up to the maximum length of sequences
+ in the batch. A mask is created to distinguish real words from padding words.
+
+ Example:
+ Actual captions in the batch ('-' denotes padded character):
+ [
+ [ 1 2 3 4 5 ],
+ [ 1 2 3 4 - ],
+ [ 1 2 3 - - ],
+ ]
+
+ input_seqs:
+ [
+ [ 1 2 3 4 ],
+ [ 1 2 3 - ],
+ [ 1 2 - - ],
+ ]
+
+ target_seqs:
+ [
+ [ 2 3 4 5 ],
+ [ 2 3 4 - ],
+ [ 2 3 - - ],
+ ]
+
+ mask:
+ [
+ [ 1 1 1 1 ],
+ [ 1 1 1 0 ],
+ [ 1 1 0 0 ],
+ ]
+
+ Args:
+ images_and_captions: A list of pairs [image, caption], where image is a
+ Tensor of shape [height, width, channels] and caption is a 1-D Tensor of
+ any length. Each pair will be processed and added to the queue in a
+ separate thread.
+ batch_size: Batch size.
+ queue_capacity: Queue capacity.
+ add_summaries: If true, add caption length summaries.
+
+ Returns:
+ images: A Tensor of shape [batch_size, height, width, channels].
+ input_seqs: An int32 Tensor of shape [batch_size, padded_length].
+ target_seqs: An int32 Tensor of shape [batch_size, padded_length].
+ mask: An int32 0/1 Tensor of shape [batch_size, padded_length].
+ """
+ enqueue_list = []
+ for image, caption in images_and_captions:
+ caption_length = tf.shape(caption)[0]
+ input_length = tf.expand_dims(tf.subtract(caption_length, 1), 0)
+
+ input_seq = tf.slice(caption, [0], input_length)
+ target_seq = tf.slice(caption, [1], input_length)
+ indicator = tf.ones(input_length, dtype=tf.int32)
+ enqueue_list.append([image, input_seq, target_seq, indicator])
+
+ images, input_seqs, target_seqs, mask = tf.train.batch_join(
+ enqueue_list,
+ batch_size=batch_size,
+ capacity=queue_capacity,
+ dynamic_pad=True,
+ name="batch_and_pad")
+
+ if add_summaries:
+ lengths = tf.add(tf.reduce_sum(mask, 1), 1)
+ tf.summary.scalar("caption_length/batch_min", tf.reduce_min(lengths))
+ tf.summary.scalar("caption_length/batch_max", tf.reduce_max(lengths))
+ tf.summary.scalar("caption_length/batch_mean", tf.reduce_mean(lengths))
+
+ return images, input_seqs, target_seqs, mask
diff --git a/models/research/im2txt/im2txt/run_inference.py b/models/research/im2txt/im2txt/run_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..9848522df162e52394ee8349dab1f5220aeb88f6
--- /dev/null
+++ b/models/research/im2txt/im2txt/run_inference.py
@@ -0,0 +1,85 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Generate captions for images using default beam search parameters."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os
+
+
+import tensorflow as tf
+
+from im2txt import configuration
+from im2txt import inference_wrapper
+from im2txt.inference_utils import caption_generator
+from im2txt.inference_utils import vocabulary
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("checkpoint_path", "",
+ "Model checkpoint file or directory containing a "
+ "model checkpoint file.")
+tf.flags.DEFINE_string("vocab_file", "", "Text file containing the vocabulary.")
+tf.flags.DEFINE_string("input_files", "",
+ "File pattern or comma-separated list of file patterns "
+ "of image files.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def main(_):
+ # Build the inference graph.
+ g = tf.Graph()
+ with g.as_default():
+ model = inference_wrapper.InferenceWrapper()
+ restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
+ FLAGS.checkpoint_path)
+ g.finalize()
+
+ # Create the vocabulary.
+ vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
+
+ filenames = []
+ for file_pattern in FLAGS.input_files.split(","):
+ filenames.extend(tf.gfile.Glob(file_pattern))
+ tf.logging.info("Running caption generation on %d files matching %s",
+ len(filenames), FLAGS.input_files)
+
+ with tf.Session(graph=g) as sess:
+ # Load the model from checkpoint.
+ restore_fn(sess)
+
+ # Prepare the caption generator. Here we are implicitly using the default
+ # beam search parameters. See caption_generator.py for a description of the
+ # available beam search parameters.
+ generator = caption_generator.CaptionGenerator(model, vocab)
+
+ for filename in filenames:
+ with tf.gfile.GFile(filename, "rb") as f:
+ image = f.read()
+ captions = generator.beam_search(sess, image)
+ print("Captions for image %s:" % os.path.basename(filename))
+ for i, caption in enumerate(captions):
+ # Ignore begin and end words.
+ sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
+ sentence = " ".join(sentence)
+ print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/im2txt/im2txt/show_and_tell_model.py b/models/research/im2txt/im2txt/show_and_tell_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac29e7fdb80fbefe3594eabc972648a3fb32312
--- /dev/null
+++ b/models/research/im2txt/im2txt/show_and_tell_model.py
@@ -0,0 +1,358 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
+
+"Show and Tell: A Neural Image Caption Generator"
+Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from im2txt.ops import image_embedding
+from im2txt.ops import image_processing
+from im2txt.ops import inputs as input_ops
+
+
+class ShowAndTellModel(object):
+ """Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
+
+ "Show and Tell: A Neural Image Caption Generator"
+ Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
+ """
+
+ def __init__(self, config, mode, train_inception=False):
+ """Basic setup.
+
+ Args:
+ config: Object containing configuration parameters.
+ mode: "train", "eval" or "inference".
+ train_inception: Whether the inception submodel variables are trainable.
+ """
+ assert mode in ["train", "eval", "inference"]
+ self.config = config
+ self.mode = mode
+ self.train_inception = train_inception
+
+ # Reader for the input data.
+ self.reader = tf.TFRecordReader()
+
+ # To match the "Show and Tell" paper we initialize all variables with a
+ # random uniform initializer.
+ self.initializer = tf.random_uniform_initializer(
+ minval=-self.config.initializer_scale,
+ maxval=self.config.initializer_scale)
+
+ # A float32 Tensor with shape [batch_size, height, width, channels].
+ self.images = None
+
+ # An int32 Tensor with shape [batch_size, padded_length].
+ self.input_seqs = None
+
+ # An int32 Tensor with shape [batch_size, padded_length].
+ self.target_seqs = None
+
+ # An int32 0/1 Tensor with shape [batch_size, padded_length].
+ self.input_mask = None
+
+ # A float32 Tensor with shape [batch_size, embedding_size].
+ self.image_embeddings = None
+
+ # A float32 Tensor with shape [batch_size, padded_length, embedding_size].
+ self.seq_embeddings = None
+
+ # A float32 scalar Tensor; the total loss for the trainer to optimize.
+ self.total_loss = None
+
+ # A float32 Tensor with shape [batch_size * padded_length].
+ self.target_cross_entropy_losses = None
+
+ # A float32 Tensor with shape [batch_size * padded_length].
+ self.target_cross_entropy_loss_weights = None
+
+ # Collection of variables from the inception submodel.
+ self.inception_variables = []
+
+ # Function to restore the inception submodel from checkpoint.
+ self.init_fn = None
+
+ # Global step Tensor.
+ self.global_step = None
+
+ def is_training(self):
+ """Returns true if the model is built for training mode."""
+ return self.mode == "train"
+
+ def process_image(self, encoded_image, thread_id=0):
+ """Decodes and processes an image string.
+
+ Args:
+ encoded_image: A scalar string Tensor; the encoded image.
+ thread_id: Preprocessing thread id used to select the ordering of color
+ distortions.
+
+ Returns:
+ A float32 Tensor of shape [height, width, 3]; the processed image.
+ """
+ return image_processing.process_image(encoded_image,
+ is_training=self.is_training(),
+ height=self.config.image_height,
+ width=self.config.image_width,
+ thread_id=thread_id,
+ image_format=self.config.image_format)
+
+ def build_inputs(self):
+ """Input prefetching, preprocessing and batching.
+
+ Outputs:
+ self.images
+ self.input_seqs
+ self.target_seqs (training and eval only)
+ self.input_mask (training and eval only)
+ """
+ if self.mode == "inference":
+ # In inference mode, images and inputs are fed via placeholders.
+ image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
+ input_feed = tf.placeholder(dtype=tf.int64,
+ shape=[None], # batch_size
+ name="input_feed")
+
+ # Process image and insert batch dimensions.
+ images = tf.expand_dims(self.process_image(image_feed), 0)
+ input_seqs = tf.expand_dims(input_feed, 1)
+
+ # No target sequences or input mask in inference mode.
+ target_seqs = None
+ input_mask = None
+ else:
+ # Prefetch serialized SequenceExample protos.
+ input_queue = input_ops.prefetch_input_data(
+ self.reader,
+ self.config.input_file_pattern,
+ is_training=self.is_training(),
+ batch_size=self.config.batch_size,
+ values_per_shard=self.config.values_per_input_shard,
+ input_queue_capacity_factor=self.config.input_queue_capacity_factor,
+ num_reader_threads=self.config.num_input_reader_threads)
+
+ # Image processing and random distortion. Split across multiple threads
+ # with each thread applying a slightly different distortion.
+ assert self.config.num_preprocess_threads % 2 == 0
+ images_and_captions = []
+ for thread_id in range(self.config.num_preprocess_threads):
+ serialized_sequence_example = input_queue.dequeue()
+ encoded_image, caption = input_ops.parse_sequence_example(
+ serialized_sequence_example,
+ image_feature=self.config.image_feature_name,
+ caption_feature=self.config.caption_feature_name)
+ image = self.process_image(encoded_image, thread_id=thread_id)
+ images_and_captions.append([image, caption])
+
+ # Batch inputs.
+ queue_capacity = (2 * self.config.num_preprocess_threads *
+ self.config.batch_size)
+ images, input_seqs, target_seqs, input_mask = (
+ input_ops.batch_with_dynamic_pad(images_and_captions,
+ batch_size=self.config.batch_size,
+ queue_capacity=queue_capacity))
+
+ self.images = images
+ self.input_seqs = input_seqs
+ self.target_seqs = target_seqs
+ self.input_mask = input_mask
+
+ def build_image_embeddings(self):
+ """Builds the image model subgraph and generates image embeddings.
+
+ Inputs:
+ self.images
+
+ Outputs:
+ self.image_embeddings
+ """
+ inception_output = image_embedding.inception_v3(
+ self.images,
+ trainable=self.train_inception,
+ is_training=self.is_training())
+ self.inception_variables = tf.get_collection(
+ tf.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3")
+
+ # Map inception output into embedding space.
+ with tf.variable_scope("image_embedding") as scope:
+ image_embeddings = tf.contrib.layers.fully_connected(
+ inputs=inception_output,
+ num_outputs=self.config.embedding_size,
+ activation_fn=None,
+ weights_initializer=self.initializer,
+ biases_initializer=None,
+ scope=scope)
+
+ # Save the embedding size in the graph.
+ tf.constant(self.config.embedding_size, name="embedding_size")
+
+ self.image_embeddings = image_embeddings
+
+ def build_seq_embeddings(self):
+ """Builds the input sequence embeddings.
+
+ Inputs:
+ self.input_seqs
+
+ Outputs:
+ self.seq_embeddings
+ """
+ with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
+ embedding_map = tf.get_variable(
+ name="map",
+ shape=[self.config.vocab_size, self.config.embedding_size],
+ initializer=self.initializer)
+ seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs)
+
+ self.seq_embeddings = seq_embeddings
+
+ def build_model(self):
+ """Builds the model.
+
+ Inputs:
+ self.image_embeddings
+ self.seq_embeddings
+ self.target_seqs (training and eval only)
+ self.input_mask (training and eval only)
+
+ Outputs:
+ self.total_loss (training and eval only)
+ self.target_cross_entropy_losses (training and eval only)
+ self.target_cross_entropy_loss_weights (training and eval only)
+ """
+ # This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
+ # modified LSTM in the "Show and Tell" paper has no biases and outputs
+ # new_c * sigmoid(o).
+ lstm_cell = tf.contrib.rnn.BasicLSTMCell(
+ num_units=self.config.num_lstm_units, state_is_tuple=True)
+ if self.mode == "train":
+ lstm_cell = tf.contrib.rnn.DropoutWrapper(
+ lstm_cell,
+ input_keep_prob=self.config.lstm_dropout_keep_prob,
+ output_keep_prob=self.config.lstm_dropout_keep_prob)
+
+ with tf.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
+ # Feed the image embeddings to set the initial LSTM state.
+ zero_state = lstm_cell.zero_state(
+ batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32)
+ _, initial_state = lstm_cell(self.image_embeddings, zero_state)
+
+ # Allow the LSTM variables to be reused.
+ lstm_scope.reuse_variables()
+
+ if self.mode == "inference":
+ # In inference mode, use concatenated states for convenient feeding and
+ # fetching.
+ tf.concat(axis=1, values=initial_state, name="initial_state")
+
+ # Placeholder for feeding a batch of concatenated states.
+ state_feed = tf.placeholder(dtype=tf.float32,
+ shape=[None, sum(lstm_cell.state_size)],
+ name="state_feed")
+ state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)
+
+ # Run a single LSTM step.
+ lstm_outputs, state_tuple = lstm_cell(
+ inputs=tf.squeeze(self.seq_embeddings, axis=[1]),
+ state=state_tuple)
+
+ # Concatentate the resulting state.
+ tf.concat(axis=1, values=state_tuple, name="state")
+ else:
+ # Run the batch of sequence embeddings through the LSTM.
+ sequence_length = tf.reduce_sum(self.input_mask, 1)
+ lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell,
+ inputs=self.seq_embeddings,
+ sequence_length=sequence_length,
+ initial_state=initial_state,
+ dtype=tf.float32,
+ scope=lstm_scope)
+
+ # Stack batches vertically.
+ lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
+
+ with tf.variable_scope("logits") as logits_scope:
+ logits = tf.contrib.layers.fully_connected(
+ inputs=lstm_outputs,
+ num_outputs=self.config.vocab_size,
+ activation_fn=None,
+ weights_initializer=self.initializer,
+ scope=logits_scope)
+
+ if self.mode == "inference":
+ tf.nn.softmax(logits, name="softmax")
+ else:
+ targets = tf.reshape(self.target_seqs, [-1])
+ weights = tf.to_float(tf.reshape(self.input_mask, [-1]))
+
+ # Compute losses.
+ losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets,
+ logits=logits)
+ batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)),
+ tf.reduce_sum(weights),
+ name="batch_loss")
+ tf.losses.add_loss(batch_loss)
+ total_loss = tf.losses.get_total_loss()
+
+ # Add summaries.
+ tf.summary.scalar("losses/batch_loss", batch_loss)
+ tf.summary.scalar("losses/total_loss", total_loss)
+ for var in tf.trainable_variables():
+ tf.summary.histogram("parameters/" + var.op.name, var)
+
+ self.total_loss = total_loss
+ self.target_cross_entropy_losses = losses # Used in evaluation.
+ self.target_cross_entropy_loss_weights = weights # Used in evaluation.
+
+ def setup_inception_initializer(self):
+ """Sets up the function to restore inception variables from checkpoint."""
+ if self.mode != "inference":
+ # Restore inception variables only.
+ saver = tf.train.Saver(self.inception_variables)
+
+ def restore_fn(sess):
+ tf.logging.info("Restoring Inception variables from checkpoint file %s",
+ self.config.inception_checkpoint_file)
+ saver.restore(sess, self.config.inception_checkpoint_file)
+
+ self.init_fn = restore_fn
+
+ def setup_global_step(self):
+ """Sets up the global step Tensor."""
+ global_step = tf.Variable(
+ initial_value=0,
+ name="global_step",
+ trainable=False,
+ collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
+
+ self.global_step = global_step
+
+ def build(self):
+ """Creates all ops for training and evaluation."""
+ self.build_inputs()
+ self.build_image_embeddings()
+ self.build_seq_embeddings()
+ self.build_model()
+ self.setup_inception_initializer()
+ self.setup_global_step()
diff --git a/models/research/im2txt/im2txt/show_and_tell_model_test.py b/models/research/im2txt/im2txt/show_and_tell_model_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bdfb6e1a3ae3c15bd1c8daf005fe2542436ca8e
--- /dev/null
+++ b/models/research/im2txt/im2txt/show_and_tell_model_test.py
@@ -0,0 +1,200 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tensorflow_models.im2txt.show_and_tell_model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import numpy as np
+import tensorflow as tf
+
+from im2txt import configuration
+from im2txt import show_and_tell_model
+
+
+class ShowAndTellModel(show_and_tell_model.ShowAndTellModel):
+ """Subclass of ShowAndTellModel without the disk I/O."""
+
+ def build_inputs(self):
+ if self.mode == "inference":
+ # Inference mode doesn't read from disk, so defer to parent.
+ return super(ShowAndTellModel, self).build_inputs()
+ else:
+ # Replace disk I/O with random Tensors.
+ self.images = tf.random_uniform(
+ shape=[self.config.batch_size, self.config.image_height,
+ self.config.image_width, 3],
+ minval=-1,
+ maxval=1)
+ self.input_seqs = tf.random_uniform(
+ [self.config.batch_size, 15],
+ minval=0,
+ maxval=self.config.vocab_size,
+ dtype=tf.int64)
+ self.target_seqs = tf.random_uniform(
+ [self.config.batch_size, 15],
+ minval=0,
+ maxval=self.config.vocab_size,
+ dtype=tf.int64)
+ self.input_mask = tf.ones_like(self.input_seqs)
+
+
+class ShowAndTellModelTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(ShowAndTellModelTest, self).setUp()
+ self._model_config = configuration.ModelConfig()
+
+ def _countModelParameters(self):
+ """Counts the number of parameters in the model at top level scope."""
+ counter = {}
+ for v in tf.global_variables():
+ name = v.op.name.split("/")[0]
+ num_params = v.get_shape().num_elements()
+ assert num_params
+ counter[name] = counter.get(name, 0) + num_params
+ return counter
+
+ def _checkModelParameters(self):
+ """Verifies the number of parameters in the model."""
+ param_counts = self._countModelParameters()
+ expected_param_counts = {
+ "InceptionV3": 21802784,
+ # inception_output_size * embedding_size
+ "image_embedding": 1048576,
+ # vocab_size * embedding_size
+ "seq_embedding": 6144000,
+ # (embedding_size + num_lstm_units + 1) * 4 * num_lstm_units
+ "lstm": 2099200,
+ # (num_lstm_units + 1) * vocab_size
+ "logits": 6156000,
+ "global_step": 1,
+ }
+ self.assertDictEqual(expected_param_counts, param_counts)
+
+ def _checkOutputs(self, expected_shapes, feed_dict=None):
+ """Verifies that the model produces expected outputs.
+
+ Args:
+ expected_shapes: A dict mapping Tensor or Tensor name to expected output
+ shape.
+ feed_dict: Values of Tensors to feed into Session.run().
+ """
+ fetches = expected_shapes.keys()
+
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ outputs = sess.run(fetches, feed_dict)
+
+ for index, output in enumerate(outputs):
+ tensor = fetches[index]
+ expected = expected_shapes[tensor]
+ actual = output.shape
+ if expected != actual:
+ self.fail("Tensor %s has shape %s (expected %s)." %
+ (tensor, actual, expected))
+
+ def testBuildForTraining(self):
+ model = ShowAndTellModel(self._model_config, mode="train")
+ model.build()
+
+ self._checkModelParameters()
+
+ expected_shapes = {
+ # [batch_size, image_height, image_width, 3]
+ model.images: (32, 299, 299, 3),
+ # [batch_size, sequence_length]
+ model.input_seqs: (32, 15),
+ # [batch_size, sequence_length]
+ model.target_seqs: (32, 15),
+ # [batch_size, sequence_length]
+ model.input_mask: (32, 15),
+ # [batch_size, embedding_size]
+ model.image_embeddings: (32, 512),
+ # [batch_size, sequence_length, embedding_size]
+ model.seq_embeddings: (32, 15, 512),
+ # Scalar
+ model.total_loss: (),
+ # [batch_size * sequence_length]
+ model.target_cross_entropy_losses: (480,),
+ # [batch_size * sequence_length]
+ model.target_cross_entropy_loss_weights: (480,),
+ }
+ self._checkOutputs(expected_shapes)
+
+ def testBuildForEval(self):
+ model = ShowAndTellModel(self._model_config, mode="eval")
+ model.build()
+
+ self._checkModelParameters()
+
+ expected_shapes = {
+ # [batch_size, image_height, image_width, 3]
+ model.images: (32, 299, 299, 3),
+ # [batch_size, sequence_length]
+ model.input_seqs: (32, 15),
+ # [batch_size, sequence_length]
+ model.target_seqs: (32, 15),
+ # [batch_size, sequence_length]
+ model.input_mask: (32, 15),
+ # [batch_size, embedding_size]
+ model.image_embeddings: (32, 512),
+ # [batch_size, sequence_length, embedding_size]
+ model.seq_embeddings: (32, 15, 512),
+ # Scalar
+ model.total_loss: (),
+ # [batch_size * sequence_length]
+ model.target_cross_entropy_losses: (480,),
+ # [batch_size * sequence_length]
+ model.target_cross_entropy_loss_weights: (480,),
+ }
+ self._checkOutputs(expected_shapes)
+
+ def testBuildForInference(self):
+ model = ShowAndTellModel(self._model_config, mode="inference")
+ model.build()
+
+ self._checkModelParameters()
+
+ # Test feeding an image to get the initial LSTM state.
+ images_feed = np.random.rand(1, 299, 299, 3)
+ feed_dict = {model.images: images_feed}
+ expected_shapes = {
+ # [batch_size, embedding_size]
+ model.image_embeddings: (1, 512),
+ # [batch_size, 2 * num_lstm_units]
+ "lstm/initial_state:0": (1, 1024),
+ }
+ self._checkOutputs(expected_shapes, feed_dict)
+
+ # Test feeding a batch of inputs and LSTM states to get softmax output and
+ # LSTM states.
+ input_feed = np.random.randint(0, 10, size=3)
+ state_feed = np.random.rand(3, 1024)
+ feed_dict = {"input_feed:0": input_feed, "lstm/state_feed:0": state_feed}
+ expected_shapes = {
+ # [batch_size, 2 * num_lstm_units]
+ "lstm/state:0": (3, 1024),
+ # [batch_size, vocab_size]
+ "softmax:0": (3, 12000),
+ }
+ self._checkOutputs(expected_shapes, feed_dict)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/models/research/im2txt/im2txt/train.py b/models/research/im2txt/im2txt/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..db602735ba11e7f540a4e985333d8a457512c977
--- /dev/null
+++ b/models/research/im2txt/im2txt/train.py
@@ -0,0 +1,114 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Train the model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from im2txt import configuration
+from im2txt import show_and_tell_model
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.flags.DEFINE_string("input_file_pattern", "",
+ "File pattern of sharded TFRecord input files.")
+tf.flags.DEFINE_string("inception_checkpoint_file", "",
+ "Path to a pretrained inception_v3 model.")
+tf.flags.DEFINE_string("train_dir", "",
+ "Directory for saving and loading model checkpoints.")
+tf.flags.DEFINE_boolean("train_inception", False,
+ "Whether to train inception submodel variables.")
+tf.flags.DEFINE_integer("number_of_steps", 1000000, "Number of training steps.")
+tf.flags.DEFINE_integer("log_every_n_steps", 1,
+ "Frequency at which loss and global step are logged.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def main(unused_argv):
+ assert FLAGS.input_file_pattern, "--input_file_pattern is required"
+ assert FLAGS.train_dir, "--train_dir is required"
+
+ model_config = configuration.ModelConfig()
+ model_config.input_file_pattern = FLAGS.input_file_pattern
+ model_config.inception_checkpoint_file = FLAGS.inception_checkpoint_file
+ training_config = configuration.TrainingConfig()
+
+ # Create training directory.
+ train_dir = FLAGS.train_dir
+ if not tf.gfile.IsDirectory(train_dir):
+ tf.logging.info("Creating training directory: %s", train_dir)
+ tf.gfile.MakeDirs(train_dir)
+
+ # Build the TensorFlow graph.
+ g = tf.Graph()
+ with g.as_default():
+ # Build the model.
+ model = show_and_tell_model.ShowAndTellModel(
+ model_config, mode="train", train_inception=FLAGS.train_inception)
+ model.build()
+
+ # Set up the learning rate.
+ learning_rate_decay_fn = None
+ if FLAGS.train_inception:
+ learning_rate = tf.constant(training_config.train_inception_learning_rate)
+ else:
+ learning_rate = tf.constant(training_config.initial_learning_rate)
+ if training_config.learning_rate_decay_factor > 0:
+ num_batches_per_epoch = (training_config.num_examples_per_epoch /
+ model_config.batch_size)
+ decay_steps = int(num_batches_per_epoch *
+ training_config.num_epochs_per_decay)
+
+ def _learning_rate_decay_fn(learning_rate, global_step):
+ return tf.train.exponential_decay(
+ learning_rate,
+ global_step,
+ decay_steps=decay_steps,
+ decay_rate=training_config.learning_rate_decay_factor,
+ staircase=True)
+
+ learning_rate_decay_fn = _learning_rate_decay_fn
+
+ # Set up the training ops.
+ train_op = tf.contrib.layers.optimize_loss(
+ loss=model.total_loss,
+ global_step=model.global_step,
+ learning_rate=learning_rate,
+ optimizer=training_config.optimizer,
+ clip_gradients=training_config.clip_gradients,
+ learning_rate_decay_fn=learning_rate_decay_fn)
+
+ # Set up the Saver for saving and restoring model checkpoints.
+ saver = tf.train.Saver(max_to_keep=training_config.max_checkpoints_to_keep)
+
+ # Run training.
+ tf.contrib.slim.learning.train(
+ train_op,
+ train_dir,
+ log_every_n_steps=FLAGS.log_every_n_steps,
+ graph=g,
+ global_step=model.global_step,
+ number_of_steps=FLAGS.number_of_steps,
+ init_fn=model.init_fn,
+ saver=saver)
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/inception/.gitignore b/models/research/inception/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..58cbf2f4e0d5d39a0e3910d6993508546dad429f
--- /dev/null
+++ b/models/research/inception/.gitignore
@@ -0,0 +1,7 @@
+/bazel-bin
+/bazel-ci_build-cache
+/bazel-genfiles
+/bazel-out
+/bazel-inception
+/bazel-testlogs
+/bazel-tf
diff --git a/models/research/inception/README.md b/models/research/inception/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..beed66cf5cd83a6843ec39b28b5dbd88f1c0d3d0
--- /dev/null
+++ b/models/research/inception/README.md
@@ -0,0 +1,858 @@
+
+
+
+
+**NOTE: For the most part, you will find a newer version of this code at [models/research/slim](https://github.com/tensorflow/models/tree/master/research/slim).** In particular:
+
+* `inception_train.py` and `imagenet_train.py` should no longer be used. The slim editions for running on multiple GPUs are the current best examples.
+* `inception_distributed_train.py` and `imagenet_distributed_train.py` are still valid examples of distributed training.
+
+For performance benchmarking, please see https://www.tensorflow.org/performance/benchmarks.
+
+---
+
+# Inception in TensorFlow
+
+[ImageNet](http://www.image-net.org/) is a common academic data set in machine
+learning for training an image recognition system. Code in this directory
+demonstrates how to use TensorFlow to train and evaluate a type of convolutional
+neural network (CNN) on this academic data set. In particular, we demonstrate
+how to train the Inception v3 architecture as specified in:
+
+_Rethinking the Inception Architecture for Computer Vision_
+
+Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew
+Wojna
+
+http://arxiv.org/abs/1512.00567
+
+This network achieves 21.2% top-1 and 5.6% top-5 error for single frame
+evaluation with a computational cost of 5 billion multiply-adds per inference
+and with using less than 25 million parameters. Below is a visualization of the
+model architecture.
+
+
+
+## Description of Code
+
+The code base provides three core binaries for:
+
+* Training an Inception v3 network from scratch across multiple GPUs and/or
+ multiple machines using the ImageNet 2012 Challenge training data set.
+* Evaluating an Inception v3 network using the ImageNet 2012 Challenge
+ validation data set.
+* Retraining an Inception v3 network on a novel task and back-propagating the
+ errors to fine tune the network weights.
+
+The training procedure employs synchronous stochastic gradient descent across
+multiple GPUs. The user may specify the number of GPUs they wish to harness. The
+synchronous training performs *batch-splitting* by dividing a given batch across
+multiple GPUs.
+
+The training set up is nearly identical to the section [Training a Model Using
+Multiple GPU Cards](https://www.tensorflow.org/tutorials/deep_cnn/index.html#launching_and_training_the_model_on_multiple_gpu_cards)
+where we have substituted the CIFAR-10 model architecture with Inception v3. The
+primary differences with that setup are:
+
+* Calculate and update the batch-norm statistics during training so that they
+ may be substituted in during evaluation.
+* Specify the model architecture using a (still experimental) higher level
+ language called TensorFlow-Slim.
+
+For more details about TensorFlow-Slim, please see the [Slim README](inception/slim/README.md). Please note that this higher-level language is still
+*experimental* and the API may change over time depending on usage and
+subsequent research.
+
+## Getting Started
+
+Before you run the training script for the first time, you will need to download
+and convert the ImageNet data to native TFRecord format. The TFRecord format
+consists of a set of sharded files where each entry is a serialized `tf.Example`
+proto. Each `tf.Example` proto contains the ImageNet image (JPEG encoded) as
+well as metadata such as label and bounding box information. See
+[`parse_example_proto`](inception/image_processing.py) for details.
+
+We provide a single [script](inception/data/download_and_preprocess_imagenet.sh) for
+downloading and converting ImageNet data to TFRecord format. Downloading and
+preprocessing the data may take several hours (up to half a day) depending on
+your network and computer speed. Please be patient.
+
+To begin, you will need to sign up for an account with [ImageNet](http://image-net.org) to gain access to the data. Look for the sign up page,
+create an account and request an access key to download the data.
+
+After you have `USERNAME` and `PASSWORD`, you are ready to run our script. Make
+sure that your hard disk has at least 500 GB of free space for downloading and
+storing the data. Here we select `DATA_DIR=$HOME/imagenet-data` as such a
+location but feel free to edit accordingly.
+
+When you run the below script, please enter *USERNAME* and *PASSWORD* when
+prompted. This will occur at the very beginning. Once these values are entered,
+you will not need to interact with the script again.
+
+```shell
+# location of where to place the ImageNet data
+DATA_DIR=$HOME/imagenet-data
+
+# build the preprocessing script.
+cd tensorflow-models/inception
+bazel build //inception:download_and_preprocess_imagenet
+
+# run it
+bazel-bin/inception/download_and_preprocess_imagenet "${DATA_DIR}"
+```
+
+The final line of the output script should read:
+
+```shell
+2016-02-17 14:30:17.287989: Finished writing all 1281167 images in data set.
+```
+
+When the script finishes, you will find 1024 training files and 128 validation
+files in the `DATA_DIR`. The files will match the patterns
+`train-?????-of-01024` and `validation-?????-of-00128`, respectively.
+
+[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You are now
+ready to train or evaluate with the ImageNet data set.
+
+## How to Train from Scratch
+
+**WARNING** Training an Inception v3 network from scratch is a computationally
+intensive task and depending on your compute setup may take several days or even
+weeks.
+
+*Before proceeding* please read the [Convolutional Neural Networks](https://www.tensorflow.org/tutorials/deep_cnn/index.html) tutorial; in
+particular, focus on [Training a Model Using Multiple GPU Cards](https://www.tensorflow.org/tutorials/deep_cnn/index.html#launching_and_training_the_model_on_multiple_gpu_cards). The model training method is nearly identical to that described in the
+CIFAR-10 multi-GPU model training. Briefly, the model training
+
+* Places an individual model replica on each GPU.
+* Splits the batch across the GPUs.
+* Updates model parameters synchronously by waiting for all GPUs to finish
+ processing a batch of data.
+
+The training procedure is encapsulated by this diagram of how operations and
+variables are placed on CPU and GPUs respectively.
+
+
+
+
+
+Each tower computes the gradients for a portion of the batch and the gradients
+are combined and averaged across the multiple towers in order to provide a
+single update of the Variables stored on the CPU.
+
+A crucial aspect of training a network of this size is *training speed* in terms
+of wall-clock time. The training speed is dictated by many factors -- most
+importantly the batch size and the learning rate schedule. Both of these
+parameters are heavily coupled to the hardware set up.
+
+Generally speaking, a batch size is a difficult parameter to tune as it requires
+balancing memory demands of the model, memory available on the GPU and speed of
+computation. Generally speaking, employing larger batch sizes leads to more
+efficient computation and potentially more efficient training steps.
+
+We have tested several hardware setups for training this model from scratch but
+we emphasize that depending your hardware set up, you may need to adapt the
+batch size and learning rate schedule.
+
+Please see the comments in `inception_train.py` for a few selected learning rate
+plans based on some selected hardware setups.
+
+To train this model, you simply need to specify the following:
+
+```shell
+# Build the model. Note that we need to make sure the TensorFlow is ready to
+# use before this as this command will not build TensorFlow.
+cd tensorflow-models/inception
+bazel build //inception:imagenet_train
+
+# run it
+bazel-bin/inception/imagenet_train --num_gpus=1 --batch_size=32 --train_dir=/tmp/imagenet_train --data_dir=/tmp/imagenet_data
+```
+
+The model reads in the ImageNet training data from `--data_dir`. If you followed
+the instructions in [Getting Started](#getting-started), then set
+`--data_dir="${DATA_DIR}"`. The script assumes that there exists a set of
+sharded TFRecord files containing the ImageNet data. If you have not created
+TFRecord files, please refer to [Getting Started](#getting-started)
+
+Here is the output of the above command line when running on a Tesla K40c:
+
+```shell
+2016-03-07 12:24:59.922898: step 0, loss = 13.11 (5.3 examples/sec; 6.064 sec/batch)
+2016-03-07 12:25:55.206783: step 10, loss = 13.71 (9.4 examples/sec; 3.394 sec/batch)
+2016-03-07 12:26:28.905231: step 20, loss = 14.81 (9.5 examples/sec; 3.380 sec/batch)
+2016-03-07 12:27:02.699719: step 30, loss = 14.45 (9.5 examples/sec; 3.378 sec/batch)
+2016-03-07 12:27:36.515699: step 40, loss = 13.98 (9.5 examples/sec; 3.376 sec/batch)
+2016-03-07 12:28:10.220956: step 50, loss = 13.92 (9.6 examples/sec; 3.327 sec/batch)
+2016-03-07 12:28:43.658223: step 60, loss = 13.28 (9.6 examples/sec; 3.350 sec/batch)
+...
+```
+
+In this example, a log entry is printed every 10 step and the line includes the
+total loss (starts around 13.0-14.0) and the speed of processing in terms of
+throughput (examples / sec) and batch speed (sec/batch).
+
+The number of GPU devices is specified by `--num_gpus` (which defaults to 1).
+Specifying `--num_gpus` greater then 1 splits the batch evenly split across the
+GPU cards.
+
+```shell
+# Build the model. Note that we need to make sure the TensorFlow is ready to
+# use before this as this command will not build TensorFlow.
+cd tensorflow-models/inception
+bazel build //inception:imagenet_train
+
+# run it
+bazel-bin/inception/imagenet_train --num_gpus=2 --batch_size=64 --train_dir=/tmp/imagenet_train
+```
+
+This model splits the batch of 64 images across 2 GPUs and calculates the
+average gradient by waiting for both GPUs to finish calculating the gradients
+from their respective data (See diagram above). Generally speaking, using larger
+numbers of GPUs leads to higher throughput as well as the opportunity to use
+larger batch sizes. In turn, larger batch sizes imply better estimates of the
+gradient enabling the usage of higher learning rates. In summary, using more
+GPUs results in simply faster training speed.
+
+Note that selecting a batch size is a difficult parameter to tune as it requires
+balancing memory demands of the model, memory available on the GPU and speed of
+computation. Generally speaking, employing larger batch sizes leads to more
+efficient computation and potentially more efficient training steps.
+
+Note that there is considerable noise in the loss function on individual steps
+in the previous log. Because of this noise, it is difficult to discern how well
+a model is learning. The solution to the last problem is to launch TensorBoard
+pointing to the directory containing the events log.
+
+```shell
+tensorboard --logdir=/tmp/imagenet_train
+```
+
+TensorBoard has access to the many Summaries produced by the model that describe
+multitudes of statistics tracking the model behavior and the quality of the
+learned model. In particular, TensorBoard tracks a exponentially smoothed
+version of the loss. In practice, it is far easier to judge how well a model
+learns by monitoring the smoothed version of the loss.
+
+## How to Train from Scratch in a Distributed Setting
+
+**NOTE** Distributed TensorFlow requires version 0.8 or later.
+
+Distributed TensorFlow lets us use multiple machines to train a model faster.
+This is quite different from the training with multiple GPU towers on a single
+machine where all parameters and gradients computation are in the same place. We
+coordinate the computation across multiple machines by employing a centralized
+repository for parameters that maintains a unified, single copy of model
+parameters. Each individual machine sends gradient updates to the centralized
+parameter repository which coordinates these updates and sends back updated
+parameters to the individual machines running the model training.
+
+We term each machine that runs a copy of the training a `worker` or `replica`.
+We term each machine that maintains model parameters a `ps`, short for
+`parameter server`. Note that we might have more than one machine acting as a
+`ps` as the model parameters may be sharded across multiple machines.
+
+Variables may be updated with synchronous or asynchronous gradient updates. One
+may construct a an [`Optimizer`](https://www.tensorflow.org/api_docs/python/train.html#optimizers) in TensorFlow
+that constructs the necessary graph for either case diagrammed below from the
+TensorFlow [Whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf):
+
+
+
+
+
+In [a recent paper](https://arxiv.org/abs/1604.00981), synchronous gradient
+updates have demonstrated to reach higher accuracy in a shorter amount of time.
+In this distributed Inception example we employ synchronous gradient updates.
+
+Note that in this example each replica has a single tower that uses one GPU.
+
+The command-line flags `worker_hosts` and `ps_hosts` specify available servers.
+The same binary will be used for both the `worker` jobs and the `ps` jobs.
+Command line flag `job_name` will be used to specify what role a task will be
+playing and `task_id` will be used to identify which one of the jobs it is
+running. Several things to note here:
+
+* The numbers of `ps` and `worker` tasks are inferred from the lists of hosts
+ specified in the flags. The `task_id` should be within the range `[0,
+ num_ps_tasks)` for `ps` tasks and `[0, num_worker_tasks)` for `worker`
+ tasks.
+* `ps` and `worker` tasks can run on the same machine, as long as that machine
+ has sufficient resources to handle both tasks. Note that the `ps` task does
+ not benefit from a GPU, so it should not attempt to use one (see below).
+* Multiple `worker` tasks can run on the same machine with multiple GPUs so
+ machine_A with 2 GPUs may have 2 workers while machine_B with 1 GPU just has
+ 1 worker.
+* The default learning rate schedule works well for a wide range of number of
+ replicas [25, 50, 100] but feel free to tune it for even better results.
+* The command line of both `ps` and `worker` tasks should include the complete
+ list of `ps_hosts` and `worker_hosts`.
+* There is a chief `worker` among all workers which defaults to `worker` 0.
+ The chief will be in charge of initializing all the parameters, writing out
+ the summaries and the checkpoint. The checkpoint and summary will be in the
+ `train_dir` of the host for `worker` 0.
+* Each worker processes a batch_size number of examples but each gradient
+ update is computed from all replicas. Hence, the effective batch size of
+ this model is batch_size * num_workers.
+
+```shell
+# Build the model. Note that we need to make sure the TensorFlow is ready to
+# use before this as this command will not build TensorFlow.
+cd tensorflow-models/inception
+bazel build //inception:imagenet_distributed_train
+
+# To start worker 0, go to the worker0 host and run the following (Note that
+# task_id should be in the range [0, num_worker_tasks):
+bazel-bin/inception/imagenet_distributed_train \
+--batch_size=32 \
+--data_dir=$HOME/imagenet-data \
+--job_name='worker' \
+--task_id=0 \
+--ps_hosts='ps0.example.com:2222' \
+--worker_hosts='worker0.example.com:2222,worker1.example.com:2222'
+
+# To start worker 1, go to the worker1 host and run the following (Note that
+# task_id should be in the range [0, num_worker_tasks):
+bazel-bin/inception/imagenet_distributed_train \
+--batch_size=32 \
+--data_dir=$HOME/imagenet-data \
+--job_name='worker' \
+--task_id=1 \
+--ps_hosts='ps0.example.com:2222' \
+--worker_hosts='worker0.example.com:2222,worker1.example.com:2222'
+
+# To start the parameter server (ps), go to the ps host and run the following (Note
+# that task_id should be in the range [0, num_ps_tasks):
+bazel-bin/inception/imagenet_distributed_train \
+--job_name='ps' \
+--task_id=0 \
+--ps_hosts='ps0.example.com:2222' \
+--worker_hosts='worker0.example.com:2222,worker1.example.com:2222'
+```
+
+If you have installed a GPU-compatible version of TensorFlow, the `ps` will also
+try to allocate GPU memory although it is not helpful. This could potentially
+crash the worker on the same machine as it has little to no GPU memory to
+allocate. To avoid this, you can prepend the previous command to start `ps`
+with: `CUDA_VISIBLE_DEVICES=''`
+
+```shell
+CUDA_VISIBLE_DEVICES='' bazel-bin/inception/imagenet_distributed_train \
+--job_name='ps' \
+--task_id=0 \
+--ps_hosts='ps0.example.com:2222' \
+--worker_hosts='worker0.example.com:2222,worker1.example.com:2222'
+```
+
+If you have run everything correctly, you should see a log in each `worker` job
+that looks like the following. Note the training speed varies depending on your
+hardware and the first several steps could take much longer.
+
+```shell
+INFO:tensorflow:PS hosts are: ['ps0.example.com:2222', 'ps1.example.com:2222']
+INFO:tensorflow:Worker hosts are: ['worker0.example.com:2222', 'worker1.example.com:2222']
+I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job ps -> {ps0.example.com:2222, ps1.example.com:2222}
+I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job worker -> {localhost:2222, worker1.example.com:2222}
+I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:202] Started server with target: grpc://localhost:2222
+INFO:tensorflow:Created variable global_step:0 with shape () and init
+
+...
+
+INFO:tensorflow:Created variable logits/logits/biases:0 with shape (1001,) and init
+INFO:tensorflow:SyncReplicas enabled: replicas_to_aggregate=2; total_num_replicas=2
+INFO:tensorflow:2016-04-13 01:56:26.405639 Supervisor
+INFO:tensorflow:Started 2 queues for processing input data.
+INFO:tensorflow:global_step/sec: 0
+INFO:tensorflow:Worker 0: 2016-04-13 01:58:40.342404: step 0, loss = 12.97(0.0 examples/sec; 65.428 sec/batch)
+INFO:tensorflow:global_step/sec: 0.0172907
+...
+```
+
+and a log in each `ps` job that looks like the following:
+
+```shell
+INFO:tensorflow:PS hosts are: ['ps0.example.com:2222', 'ps1.example.com:2222']
+INFO:tensorflow:Worker hosts are: ['worker0.example.com:2222', 'worker1.example.com:2222']
+I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job ps -> {localhost:2222, ps1.example.com:2222}
+I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job worker -> {worker0.example.com:2222, worker1.example.com:2222}
+I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:202] Started server with target: grpc://localhost:2222
+```
+
+If you compiled TensorFlow (from v1.1-rc3) with VERBS support and you have the
+required device and IB verbs SW stack, you can specify --protocol='grpc+verbs'
+In order to use Verbs RDMA for Tensor passing between workers and ps.
+Need to add the the --protocol flag in all tasks (ps and workers).
+The default protocol is the TensorFlow default protocol of grpc.
+
+
+[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You are now
+training Inception in a distributed manner.
+
+## How to Evaluate
+
+Evaluating an Inception v3 model on the ImageNet 2012 validation data set
+requires running a separate binary.
+
+The evaluation procedure is nearly identical to [Evaluating a Model](https://www.tensorflow.org/tutorials/deep_cnn/index.html#evaluating_a_model)
+described in the [Convolutional Neural Network](https://www.tensorflow.org/tutorials/deep_cnn/index.html) tutorial.
+
+**WARNING** Be careful not to run the evaluation and training binary on the same
+GPU or else you might run out of memory. Consider running the evaluation on a
+separate GPU if available or suspending the training binary while running the
+evaluation on the same GPU.
+
+Briefly, one can evaluate the model by running:
+
+```shell
+# Build the model. Note that we need to make sure the TensorFlow is ready to
+# use before this as this command will not build TensorFlow.
+cd tensorflow-models/inception
+bazel build //inception:imagenet_eval
+
+# run it
+bazel-bin/inception/imagenet_eval --checkpoint_dir=/tmp/imagenet_train --eval_dir=/tmp/imagenet_eval
+```
+
+Note that we point `--checkpoint_dir` to the location of the checkpoints saved
+by `inception_train.py` above. Running the above command results in the
+following output:
+
+```shell
+2016-02-17 22:32:50.391206: precision @ 1 = 0.735
+...
+```
+
+The script calculates the precision @ 1 over the entire validation data
+periodically. The precision @ 1 measures the how often the highest scoring
+prediction from the model matched the ImageNet label -- in this case, 73.5%. If
+you wish to run the eval just once and not periodically, append the `--run_once`
+option.
+
+Much like the training script, `imagenet_eval.py` also exports summaries that
+may be visualized in TensorBoard. These summaries calculate additional
+statistics on the predictions (e.g. recall @ 5) as well as monitor the
+statistics of the model activations and weights during evaluation.
+
+## How to Fine-Tune a Pre-Trained Model on a New Task
+
+### Getting Started
+
+Much like training the ImageNet model we must first convert a new data set to
+the sharded TFRecord format which each entry is a serialized `tf.Example` proto.
+
+We have provided a script demonstrating how to do this for small data set of of
+a few thousand flower images spread across 5 labels:
+
+```shell
+daisy, dandelion, roses, sunflowers, tulips
+```
+
+There is a single automated script that downloads the data set and converts it
+to the TFRecord format. Much like the ImageNet data set, each record in the
+TFRecord format is a serialized `tf.Example` proto whose entries include a
+JPEG-encoded string and an integer label. Please see [`parse_example_proto`](inception/image_processing.py) for details.
+
+The script just takes a few minutes to run depending your network connection
+speed for downloading and processing the images. Your hard disk requires 200MB
+of free storage. Here we select `DATA_DIR=/tmp/flowers-data/` as such a location
+but feel free to edit accordingly.
+
+```shell
+# location of where to place the flowers data
+FLOWERS_DATA_DIR=/tmp/flowers-data/
+
+# build the preprocessing script.
+cd tensorflow-models/inception
+bazel build //inception:download_and_preprocess_flowers
+
+# run it
+bazel-bin/inception/download_and_preprocess_flowers "${FLOWERS_DATA_DIR}"
+```
+
+If the script runs successfully, the final line of the terminal output should
+look like:
+
+```shell
+2016-02-24 20:42:25.067551: Finished writing all 3170 images in data set.
+```
+
+When the script finishes you will find 2 shards for the training and validation
+files in the `DATA_DIR`. The files will match the patterns `train-?????-of-00002`
+and `validation-?????-of-00002`, respectively.
+
+**NOTE** If you wish to prepare a custom image data set for transfer learning,
+you will need to invoke [`build_image_data.py`](inception/data/build_image_data.py) on
+your custom data set. Please see the associated options and assumptions behind
+this script by reading the comments section of [`build_image_data.py`](inception/data/build_image_data.py). Also, if your custom data has a different
+number of examples or classes, you need to change the appropriate values in
+[`imagenet_data.py`](inception/imagenet_data.py).
+
+The second piece you will need is a trained Inception v3 image model. You have
+the option of either training one yourself (See [How to Train from Scratch](#how-to-train-from-scratch) for details) or you can download a pre-trained
+model like so:
+
+```shell
+# location of where to place the Inception v3 model
+INCEPTION_MODEL_DIR=$HOME/inception-v3-model
+mkdir -p ${INCEPTION_MODEL_DIR}
+cd ${INCEPTION_MODEL_DIR}
+
+# download the Inception v3 model
+curl -O http://download.tensorflow.org/models/image/imagenet/inception-v3-2016-03-01.tar.gz
+tar xzf inception-v3-2016-03-01.tar.gz
+
+# this will create a directory called inception-v3 which contains the following files.
+> ls inception-v3
+README.txt
+checkpoint
+model.ckpt-157585
+```
+
+[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You are now
+ready to fine-tune your pre-trained Inception v3 model with the flower data set.
+
+### How to Retrain a Trained Model on the Flowers Data
+
+We are now ready to fine-tune a pre-trained Inception-v3 model on the flowers
+data set. This requires two distinct changes to our training procedure:
+
+1. Build the exact same model as previously except we change the number of
+ labels in the final classification layer.
+
+2. Restore all weights from the pre-trained Inception-v3 except for the final
+ classification layer; this will get randomly initialized instead.
+
+We can perform these two operations by specifying two flags:
+`--pretrained_model_checkpoint_path` and `--fine_tune`. The first flag is a
+string that points to the path of a pre-trained Inception-v3 model. If this flag
+is specified, it will load the entire model from the checkpoint before the
+script begins training.
+
+The second flag `--fine_tune` is a boolean that indicates whether the last
+classification layer should be randomly initialized or restored. You may set
+this flag to false if you wish to continue training a pre-trained model from a
+checkpoint. If you set this flag to true, you can train a new classification
+layer from scratch.
+
+In order to understand how `--fine_tune` works, please see the discussion on
+`Variables` in the TensorFlow-Slim [`README.md`](inception/slim/README.md).
+
+Putting this all together you can retrain a pre-trained Inception-v3 model on
+the flowers data set with the following command.
+
+```shell
+# Build the model. Note that we need to make sure the TensorFlow is ready to
+# use before this as this command will not build TensorFlow.
+cd tensorflow-models/inception
+bazel build //inception:flowers_train
+
+# Path to the downloaded Inception-v3 model.
+MODEL_PATH="${INCEPTION_MODEL_DIR}/inception-v3/model.ckpt-157585"
+
+# Directory where the flowers data resides.
+FLOWERS_DATA_DIR=/tmp/flowers-data/
+
+# Directory where to save the checkpoint and events files.
+TRAIN_DIR=/tmp/flowers_train/
+
+# Run the fine-tuning on the flowers data set starting from the pre-trained
+# Imagenet-v3 model.
+bazel-bin/inception/flowers_train \
+ --train_dir="${TRAIN_DIR}" \
+ --data_dir="${FLOWERS_DATA_DIR}" \
+ --pretrained_model_checkpoint_path="${MODEL_PATH}" \
+ --fine_tune=True \
+ --initial_learning_rate=0.001 \
+ --input_queue_memory_factor=1
+```
+
+We have added a few extra options to the training procedure.
+
+* Fine-tuning a model a separate data set requires significantly lowering the
+ initial learning rate. We set the initial learning rate to 0.001.
+* The flowers data set is quite small so we shrink the size of the shuffling
+ queue of examples. See [Adjusting Memory Demands](#adjusting-memory-demands)
+ for more details.
+
+The training script will only reports the loss. To evaluate the quality of the
+fine-tuned model, you will need to run `flowers_eval`:
+
+```shell
+# Build the model. Note that we need to make sure the TensorFlow is ready to
+# use before this as this command will not build TensorFlow.
+cd tensorflow-models/inception
+bazel build //inception:flowers_eval
+
+# Directory where we saved the fine-tuned checkpoint and events files.
+TRAIN_DIR=/tmp/flowers_train/
+
+# Directory where the flowers data resides.
+FLOWERS_DATA_DIR=/tmp/flowers-data/
+
+# Directory where to save the evaluation events files.
+EVAL_DIR=/tmp/flowers_eval/
+
+# Evaluate the fine-tuned model on a hold-out of the flower data set.
+bazel-bin/inception/flowers_eval \
+ --eval_dir="${EVAL_DIR}" \
+ --data_dir="${FLOWERS_DATA_DIR}" \
+ --subset=validation \
+ --num_examples=500 \
+ --checkpoint_dir="${TRAIN_DIR}" \
+ --input_queue_memory_factor=1 \
+ --run_once
+```
+
+We find that the evaluation arrives at roughly 93.4% precision@1 after the model
+has been running for 2000 steps.
+
+```shell
+Successfully loaded model from /tmp/flowers/model.ckpt-1999 at step=1999.
+2016-03-01 16:52:51.761219: starting evaluation on (validation).
+2016-03-01 16:53:05.450419: [20 batches out of 20] (36.5 examples/sec; 0.684sec/batch)
+2016-03-01 16:53:05.450471: precision @ 1 = 0.9340 recall @ 5 = 0.9960 [500 examples]
+```
+
+## How to Construct a New Dataset for Retraining
+
+One can use the existing scripts supplied with this model to build a new dataset
+for training or fine-tuning. The main script to employ is
+[`build_image_data.py`](inception/data/build_image_data.py). Briefly, this script takes a
+structured directory of images and converts it to a sharded `TFRecord` that can
+be read by the Inception model.
+
+In particular, you will need to create a directory of training images that
+reside within `$TRAIN_DIR` and `$VALIDATION_DIR` arranged as such:
+
+```shell
+ $TRAIN_DIR/dog/image0.jpeg
+ $TRAIN_DIR/dog/image1.jpg
+ $TRAIN_DIR/dog/image2.png
+ ...
+ $TRAIN_DIR/cat/weird-image.jpeg
+ $TRAIN_DIR/cat/my-image.jpeg
+ $TRAIN_DIR/cat/my-image.JPG
+ ...
+ $VALIDATION_DIR/dog/imageA.jpeg
+ $VALIDATION_DIR/dog/imageB.jpg
+ $VALIDATION_DIR/dog/imageC.png
+ ...
+ $VALIDATION_DIR/cat/weird-image.PNG
+ $VALIDATION_DIR/cat/that-image.jpg
+ $VALIDATION_DIR/cat/cat.JPG
+ ...
+```
+**NOTE**: This script will append an extra background class indexed at 0, so
+your class labels will range from 0 to num_labels. Using the example above, the
+corresponding class labels generated from `build_image_data.py` will be as
+follows:
+```shell
+0
+1 dog
+2 cat
+```
+
+Each sub-directory in `$TRAIN_DIR` and `$VALIDATION_DIR` corresponds to a unique
+label for the images that reside within that sub-directory. The images may be
+JPEG or PNG images. We do not support other images types currently.
+
+Once the data is arranged in this directory structure, we can run
+`build_image_data.py` on the data to generate the sharded `TFRecord` dataset.
+Each entry of the `TFRecord` is a serialized `tf.Example` protocol buffer. A
+complete list of information contained in the `tf.Example` is described in the
+comments of `build_image_data.py`.
+
+To run `build_image_data.py`, you can run the following command line:
+
+```shell
+# location to where to save the TFRecord data.
+OUTPUT_DIRECTORY=$HOME/my-custom-data/
+
+# build the preprocessing script.
+cd tensorflow-models/inception
+bazel build //inception:build_image_data
+
+# convert the data.
+bazel-bin/inception/build_image_data \
+ --train_directory="${TRAIN_DIR}" \
+ --validation_directory="${VALIDATION_DIR}" \
+ --output_directory="${OUTPUT_DIRECTORY}" \
+ --labels_file="${LABELS_FILE}" \
+ --train_shards=128 \
+ --validation_shards=24 \
+ --num_threads=8
+```
+
+where the `$OUTPUT_DIRECTORY` is the location of the sharded `TFRecords`. The
+`$LABELS_FILE` will be a text file that is read by the script that provides
+a list of all of the labels. For instance, in the case flowers data set, the
+`$LABELS_FILE` contained the following data:
+
+```shell
+daisy
+dandelion
+roses
+sunflowers
+tulips
+```
+
+Note that each row of each label corresponds with the entry in the final
+classifier in the model. That is, the `daisy` corresponds to the classifier for
+entry `1`; `dandelion` is entry `2`, etc. We skip label `0` as a background
+class.
+
+After running this script produces files that look like the following:
+
+```shell
+ $TRAIN_DIR/train-00000-of-00128
+ $TRAIN_DIR/train-00001-of-00128
+ ...
+ $TRAIN_DIR/train-00127-of-00128
+
+and
+
+ $VALIDATION_DIR/validation-00000-of-00024
+ $VALIDATION_DIR/validation-00001-of-00024
+ ...
+ $VALIDATION_DIR/validation-00023-of-00024
+```
+
+where 128 and 24 are the number of shards specified for each dataset,
+respectively. Generally speaking, we aim for selecting the number of shards such
+that roughly 1024 images reside in each shard. Once this data set is built, you
+are ready to train or fine-tune an Inception model on this data set.
+
+Note, if you are piggy backing on the flowers retraining scripts, be sure to
+update `num_classes()` and `num_examples_per_epoch()` in `flowers_data.py`
+to correspond with your data.
+
+## Practical Considerations for Training a Model
+
+The model architecture and training procedure is heavily dependent on the
+hardware used to train the model. If you wish to train or fine-tune this model
+on your machine **you will need to adjust and empirically determine a good set
+of training hyper-parameters for your setup**. What follows are some general
+considerations for novices.
+
+### Finding Good Hyperparameters
+
+Roughly 5-10 hyper-parameters govern the speed at which a network is trained. In
+addition to `--batch_size` and `--num_gpus`, there are several constants defined
+in [inception_train.py](inception/inception_train.py) which dictate the learning
+schedule.
+
+```shell
+RMSPROP_DECAY = 0.9 # Decay term for RMSProp.
+MOMENTUM = 0.9 # Momentum in RMSProp.
+RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp.
+INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
+NUM_EPOCHS_PER_DECAY = 30.0 # Epochs after which learning rate decays.
+LEARNING_RATE_DECAY_FACTOR = 0.16 # Learning rate decay factor.
+```
+
+There are many papers that discuss the various tricks and trade-offs associated
+with training a model with stochastic gradient descent. For those new to the
+field, some great references are:
+
+* Y Bengio, [Practical recommendations for gradient-based training of deep
+ architectures](http://arxiv.org/abs/1206.5533)
+* I Goodfellow, Y Bengio and A Courville, [Deep Learning]
+ (http://www.deeplearningbook.org/)
+
+What follows is a summary of some general advice for identifying appropriate
+model hyper-parameters in the context of this particular model training setup.
+Namely, this library provides *synchronous* updates to model parameters based on
+batch-splitting the model across multiple GPUs.
+
+* Higher learning rates leads to faster training. Too high of learning rate
+ leads to instability and will cause model parameters to diverge to infinity
+ or NaN.
+
+* Larger batch sizes lead to higher quality estimates of the gradient and
+ permit training the model with higher learning rates.
+
+* Often the GPU memory is a bottleneck that prevents employing larger batch
+ sizes. Employing more GPUs allows one to use larger batch sizes because
+ this model splits the batch across the GPUs.
+
+**NOTE** If one wishes to train this model with *asynchronous* gradient updates,
+one will need to substantially alter this model and new considerations need to
+be factored into hyperparameter tuning. See [Large Scale Distributed Deep
+Networks](http://research.google.com/archive/large_deep_networks_nips2012.html)
+for a discussion in this domain.
+
+### Adjusting Memory Demands
+
+Training this model has large memory demands in terms of the CPU and GPU. Let's
+discuss each item in turn.
+
+GPU memory is relatively small compared to CPU memory. Two items dictate the
+amount of GPU memory employed -- model architecture and batch size. Assuming
+that you keep the model architecture fixed, the sole parameter governing the GPU
+demand is the batch size. A good rule of thumb is to try employ as large of
+batch size as will fit on the GPU.
+
+If you run out of GPU memory, either lower the `--batch_size` or employ more
+GPUs on your desktop. The model performs batch-splitting across GPUs, thus N
+GPUs can handle N times the batch size of 1 GPU.
+
+The model requires a large amount of CPU memory as well. We have tuned the model
+to employ about ~20GB of CPU memory. Thus, having access to about 40 GB of CPU
+memory would be ideal.
+
+If that is not possible, you can tune down the memory demands of the model via
+lowering `--input_queue_memory_factor`. Images are preprocessed asynchronously
+with respect to the main training across `--num_preprocess_threads` threads. The
+preprocessed images are stored in shuffling queue in which each GPU performs a
+dequeue operation in order to receive a `batch_size` worth of images.
+
+In order to guarantee good shuffling across the data, we maintain a large
+shuffling queue of 1024 x `input_queue_memory_factor` images. For the current
+model architecture, this corresponds to about 4GB of CPU memory. You may lower
+`input_queue_memory_factor` in order to decrease the memory footprint. Keep in
+mind though that lowering this value drastically may result in a model with
+slightly lower predictive accuracy when training from scratch. Please see
+comments in [`image_processing.py`](inception/image_processing.py) for more details.
+
+## Troubleshooting
+
+#### The model runs out of CPU memory.
+
+In lieu of buying more CPU memory, an easy fix is to decrease
+`--input_queue_memory_factor`. See [Adjusting Memory Demands](#adjusting-memory-demands).
+
+#### The model runs out of GPU memory.
+
+The data is not able to fit on the GPU card. The simplest solution is to
+decrease the batch size of the model. Otherwise, you will need to think about a
+more sophisticated method for specifying the training which cuts up the model
+across multiple `session.run()` calls or partitions the model across multiple
+GPUs. See [Using GPUs](https://www.tensorflow.org/how_tos/using_gpu/index.html)
+and [Adjusting Memory Demands](#adjusting-memory-demands) for more information.
+
+#### The model training results in NaN's.
+
+The learning rate of the model is too high. Turn down your learning rate.
+
+#### I wish to train a model with a different image size.
+
+The simplest solution is to artificially resize your images to `299x299` pixels.
+See [Images](https://www.tensorflow.org/api_docs/python/image.html) section for
+many resizing, cropping and padding methods. Note that the entire model
+architecture is predicated on a `299x299` image, thus if you wish to change the
+input image size, then you may need to redesign the entire model architecture.
+
+#### What hardware specification are these hyper-parameters targeted for?
+
+We targeted a desktop with 128GB of CPU ram connected to 8 NVIDIA Tesla K40 GPU
+cards but we have run this on desktops with 32GB of CPU ram and 1 NVIDIA Tesla
+K40. You can get a sense of the various training configurations we tested by
+reading the comments in [`inception_train.py`](inception/inception_train.py).
+
+#### How do I continue training from a checkpoint in distributed setting?
+
+You only need to make sure that the checkpoint is in a location that can be
+reached by all of the `ps` tasks. By specifying the checkpoint location with
+`--train_dir` , the `ps` servers will load the checkpoint before commencing
+training.
diff --git a/models/research/inception/WORKSPACE b/models/research/inception/WORKSPACE
new file mode 100644
index 0000000000000000000000000000000000000000..2d7b4fb254a0fcebe695cb3fd3685af29a02e0b0
--- /dev/null
+++ b/models/research/inception/WORKSPACE
@@ -0,0 +1 @@
+workspace(name = "inception")
diff --git a/models/research/inception/g3doc/inception_v3_architecture.png b/models/research/inception/g3doc/inception_v3_architecture.png
new file mode 100644
index 0000000000000000000000000000000000000000..91fb734a104b2f63114ade7c8f9b2f95ce6334a6
Binary files /dev/null and b/models/research/inception/g3doc/inception_v3_architecture.png differ
diff --git a/models/research/inception/inception/BUILD b/models/research/inception/inception/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..21fc27aa57c14f6a72359cf15d446787c8ea6c2e
--- /dev/null
+++ b/models/research/inception/inception/BUILD
@@ -0,0 +1,198 @@
+# Description:
+# Example TensorFlow models for ImageNet.
+
+package(default_visibility = [":internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+ name = "internal",
+ packages = ["//inception/..."],
+)
+
+py_library(
+ name = "dataset",
+ srcs = [
+ "dataset.py",
+ ],
+)
+
+py_library(
+ name = "imagenet_data",
+ srcs = [
+ "imagenet_data.py",
+ ],
+ deps = [
+ ":dataset",
+ ],
+)
+
+py_library(
+ name = "flowers_data",
+ srcs = [
+ "flowers_data.py",
+ ],
+ deps = [
+ ":dataset",
+ ],
+)
+
+py_library(
+ name = "image_processing",
+ srcs = [
+ "image_processing.py",
+ ],
+)
+
+py_library(
+ name = "inception",
+ srcs = [
+ "inception_model.py",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":dataset",
+ "//inception/slim",
+ ],
+)
+
+py_binary(
+ name = "imagenet_eval",
+ srcs = [
+ "imagenet_eval.py",
+ ],
+ deps = [
+ ":imagenet_data",
+ ":inception_eval",
+ ],
+)
+
+py_binary(
+ name = "flowers_eval",
+ srcs = [
+ "flowers_eval.py",
+ ],
+ deps = [
+ ":flowers_data",
+ ":inception_eval",
+ ],
+)
+
+py_library(
+ name = "inception_eval",
+ srcs = [
+ "inception_eval.py",
+ ],
+ deps = [
+ ":image_processing",
+ ":inception",
+ ],
+)
+
+py_binary(
+ name = "imagenet_train",
+ srcs = [
+ "imagenet_train.py",
+ ],
+ deps = [
+ ":imagenet_data",
+ ":inception_train",
+ ],
+)
+
+py_binary(
+ name = "imagenet_distributed_train",
+ srcs = [
+ "imagenet_distributed_train.py",
+ ],
+ deps = [
+ ":imagenet_data",
+ ":inception_distributed_train",
+ ],
+)
+
+py_binary(
+ name = "flowers_train",
+ srcs = [
+ "flowers_train.py",
+ ],
+ deps = [
+ ":flowers_data",
+ ":inception_train",
+ ],
+)
+
+py_library(
+ name = "inception_train",
+ srcs = [
+ "inception_train.py",
+ ],
+ deps = [
+ ":image_processing",
+ ":inception",
+ ],
+)
+
+py_library(
+ name = "inception_distributed_train",
+ srcs = [
+ "inception_distributed_train.py",
+ ],
+ deps = [
+ ":image_processing",
+ ":inception",
+ ],
+)
+
+py_binary(
+ name = "build_image_data",
+ srcs = ["data/build_image_data.py"],
+)
+
+sh_binary(
+ name = "download_and_preprocess_flowers",
+ srcs = ["data/download_and_preprocess_flowers.sh"],
+ data = [
+ ":build_image_data",
+ ],
+)
+
+sh_binary(
+ name = "download_and_preprocess_imagenet",
+ srcs = ["data/download_and_preprocess_imagenet.sh"],
+ data = [
+ "data/download_imagenet.sh",
+ "data/imagenet_2012_validation_synset_labels.txt",
+ "data/imagenet_lsvrc_2015_synsets.txt",
+ "data/imagenet_metadata.txt",
+ "data/preprocess_imagenet_validation_data.py",
+ "data/process_bounding_boxes.py",
+ ":build_imagenet_data",
+ ],
+)
+
+py_binary(
+ name = "build_imagenet_data",
+ srcs = ["data/build_imagenet_data.py"],
+)
+
+filegroup(
+ name = "srcs",
+ srcs = glob(
+ [
+ "**/*.py",
+ "BUILD",
+ ],
+ ),
+)
+
+filegroup(
+ name = "imagenet_metadata",
+ srcs = [
+ "data/imagenet_lsvrc_2015_synsets.txt",
+ "data/imagenet_metadata.txt",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/models/research/inception/inception/data/build_image_data.py b/models/research/inception/inception/data/build_image_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..894388b7f758a46746870f2f0d55d1df7d3fe29b
--- /dev/null
+++ b/models/research/inception/inception/data/build_image_data.py
@@ -0,0 +1,436 @@
+#!/usr/bin/python
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts image data to TFRecords file format with Example protos.
+
+The image data set is expected to reside in JPEG files located in the
+following directory structure.
+
+ data_dir/label_0/image0.jpeg
+ data_dir/label_0/image1.jpg
+ ...
+ data_dir/label_1/weird-image.jpeg
+ data_dir/label_1/my-image.jpeg
+ ...
+
+where the sub-directory is the unique label associated with these images.
+
+This TensorFlow script converts the training and evaluation data into
+a sharded data set consisting of TFRecord files
+
+ train_directory/train-00000-of-01024
+ train_directory/train-00001-of-01024
+ ...
+ train_directory/train-01023-of-01024
+
+and
+
+ validation_directory/validation-00000-of-00128
+ validation_directory/validation-00001-of-00128
+ ...
+ validation_directory/validation-00127-of-00128
+
+where we have selected 1024 and 128 shards for each data set. Each record
+within the TFRecord file is a serialized Example proto. The Example proto
+contains the following fields:
+
+ image/encoded: string containing JPEG encoded image in RGB colorspace
+ image/height: integer, image height in pixels
+ image/width: integer, image width in pixels
+ image/colorspace: string, specifying the colorspace, always 'RGB'
+ image/channels: integer, specifying the number of channels, always 3
+ image/format: string, specifying the format, always 'JPEG'
+
+ image/filename: string containing the basename of the image file
+ e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG'
+ image/class/label: integer specifying the index in a classification layer.
+ The label ranges from [0, num_labels] where 0 is unused and left as
+ the background class.
+ image/class/text: string specifying the human-readable version of the label
+ e.g. 'dog'
+
+If your data set involves bounding boxes, please look at build_imagenet_data.py.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import os
+import random
+import sys
+import threading
+
+import numpy as np
+import tensorflow as tf
+
+tf.app.flags.DEFINE_string('train_directory', '/tmp/',
+ 'Training data directory')
+tf.app.flags.DEFINE_string('validation_directory', '/tmp/',
+ 'Validation data directory')
+tf.app.flags.DEFINE_string('output_directory', '/tmp/',
+ 'Output data directory')
+
+tf.app.flags.DEFINE_integer('train_shards', 2,
+ 'Number of shards in training TFRecord files.')
+tf.app.flags.DEFINE_integer('validation_shards', 2,
+ 'Number of shards in validation TFRecord files.')
+
+tf.app.flags.DEFINE_integer('num_threads', 2,
+ 'Number of threads to preprocess the images.')
+
+# The labels file contains a list of valid labels are held in this file.
+# Assumes that the file contains entries as such:
+# dog
+# cat
+# flower
+# where each line corresponds to a label. We map each label contained in
+# the file to an integer corresponding to the line number starting from 0.
+tf.app.flags.DEFINE_string('labels_file', '', 'Labels file')
+
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def _int64_feature(value):
+ """Wrapper for inserting int64 features into Example proto."""
+ if not isinstance(value, list):
+ value = [value]
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
+
+
+def _bytes_feature(value):
+ """Wrapper for inserting bytes features into Example proto."""
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _convert_to_example(filename, image_buffer, label, text, height, width):
+ """Build an Example proto for an example.
+
+ Args:
+ filename: string, path to an image file, e.g., '/path/to/example.JPG'
+ image_buffer: string, JPEG encoding of RGB image
+ label: integer, identifier for the ground truth for the network
+ text: string, unique human-readable, e.g. 'dog'
+ height: integer, image height in pixels
+ width: integer, image width in pixels
+ Returns:
+ Example proto
+ """
+
+ colorspace = 'RGB'
+ channels = 3
+ image_format = 'JPEG'
+
+ example = tf.train.Example(features=tf.train.Features(feature={
+ 'image/height': _int64_feature(height),
+ 'image/width': _int64_feature(width),
+ 'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)),
+ 'image/channels': _int64_feature(channels),
+ 'image/class/label': _int64_feature(label),
+ 'image/class/text': _bytes_feature(tf.compat.as_bytes(text)),
+ 'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
+ 'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
+ 'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))}))
+ return example
+
+
+class ImageCoder(object):
+ """Helper class that provides TensorFlow image coding utilities."""
+
+ def __init__(self):
+ # Create a single Session to run all image coding calls.
+ self._sess = tf.Session()
+
+ # Initializes function that converts PNG to JPEG data.
+ self._png_data = tf.placeholder(dtype=tf.string)
+ image = tf.image.decode_png(self._png_data, channels=3)
+ self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
+
+ # Initializes function that decodes RGB JPEG data.
+ self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
+ self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
+
+ def png_to_jpeg(self, image_data):
+ return self._sess.run(self._png_to_jpeg,
+ feed_dict={self._png_data: image_data})
+
+ def decode_jpeg(self, image_data):
+ image = self._sess.run(self._decode_jpeg,
+ feed_dict={self._decode_jpeg_data: image_data})
+ assert len(image.shape) == 3
+ assert image.shape[2] == 3
+ return image
+
+
+def _is_png(filename):
+ """Determine if a file contains a PNG format image.
+
+ Args:
+ filename: string, path of the image file.
+
+ Returns:
+ boolean indicating if the image is a PNG.
+ """
+ return filename.endswith('.png')
+
+
+def _process_image(filename, coder):
+ """Process a single image file.
+
+ Args:
+ filename: string, path to an image file e.g., '/path/to/example.JPG'.
+ coder: instance of ImageCoder to provide TensorFlow image coding utils.
+ Returns:
+ image_buffer: string, JPEG encoding of RGB image.
+ height: integer, image height in pixels.
+ width: integer, image width in pixels.
+ """
+ # Read the image file.
+ with tf.gfile.FastGFile(filename, 'rb') as f:
+ image_data = f.read()
+
+ # Convert any PNG to JPEG's for consistency.
+ if _is_png(filename):
+ print('Converting PNG to JPEG for %s' % filename)
+ image_data = coder.png_to_jpeg(image_data)
+
+ # Decode the RGB JPEG.
+ image = coder.decode_jpeg(image_data)
+
+ # Check that image converted to RGB
+ assert len(image.shape) == 3
+ height = image.shape[0]
+ width = image.shape[1]
+ assert image.shape[2] == 3
+
+ return image_data, height, width
+
+
+def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
+ texts, labels, num_shards):
+ """Processes and saves list of images as TFRecord in 1 thread.
+
+ Args:
+ coder: instance of ImageCoder to provide TensorFlow image coding utils.
+ thread_index: integer, unique batch to run index is within [0, len(ranges)).
+ ranges: list of pairs of integers specifying ranges of each batches to
+ analyze in parallel.
+ name: string, unique identifier specifying the data set
+ filenames: list of strings; each string is a path to an image file
+ texts: list of strings; each string is human readable, e.g. 'dog'
+ labels: list of integer; each integer identifies the ground truth
+ num_shards: integer number of shards for this data set.
+ """
+ # Each thread produces N shards where N = int(num_shards / num_threads).
+ # For instance, if num_shards = 128, and the num_threads = 2, then the first
+ # thread would produce shards [0, 64).
+ num_threads = len(ranges)
+ assert not num_shards % num_threads
+ num_shards_per_batch = int(num_shards / num_threads)
+
+ shard_ranges = np.linspace(ranges[thread_index][0],
+ ranges[thread_index][1],
+ num_shards_per_batch + 1).astype(int)
+ num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
+
+ counter = 0
+ for s in range(num_shards_per_batch):
+ # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
+ shard = thread_index * num_shards_per_batch + s
+ output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards)
+ output_file = os.path.join(FLAGS.output_directory, output_filename)
+ writer = tf.python_io.TFRecordWriter(output_file)
+
+ shard_counter = 0
+ files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
+ for i in files_in_shard:
+ filename = filenames[i]
+ label = labels[i]
+ text = texts[i]
+
+ try:
+ image_buffer, height, width = _process_image(filename, coder)
+ except Exception as e:
+ print(e)
+ print('SKIPPED: Unexpected error while decoding %s.' % filename)
+ continue
+
+ example = _convert_to_example(filename, image_buffer, label,
+ text, height, width)
+ writer.write(example.SerializeToString())
+ shard_counter += 1
+ counter += 1
+
+ if not counter % 1000:
+ print('%s [thread %d]: Processed %d of %d images in thread batch.' %
+ (datetime.now(), thread_index, counter, num_files_in_thread))
+ sys.stdout.flush()
+
+ writer.close()
+ print('%s [thread %d]: Wrote %d images to %s' %
+ (datetime.now(), thread_index, shard_counter, output_file))
+ sys.stdout.flush()
+ shard_counter = 0
+ print('%s [thread %d]: Wrote %d images to %d shards.' %
+ (datetime.now(), thread_index, counter, num_files_in_thread))
+ sys.stdout.flush()
+
+
+def _process_image_files(name, filenames, texts, labels, num_shards):
+ """Process and save list of images as TFRecord of Example protos.
+
+ Args:
+ name: string, unique identifier specifying the data set
+ filenames: list of strings; each string is a path to an image file
+ texts: list of strings; each string is human readable, e.g. 'dog'
+ labels: list of integer; each integer identifies the ground truth
+ num_shards: integer number of shards for this data set.
+ """
+ assert len(filenames) == len(texts)
+ assert len(filenames) == len(labels)
+
+ # Break all images into batches with a [ranges[i][0], ranges[i][1]].
+ spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
+ ranges = []
+ for i in range(len(spacing) - 1):
+ ranges.append([spacing[i], spacing[i + 1]])
+
+ # Launch a thread for each batch.
+ print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
+ sys.stdout.flush()
+
+ # Create a mechanism for monitoring when all threads are finished.
+ coord = tf.train.Coordinator()
+
+ # Create a generic TensorFlow-based utility for converting all image codings.
+ coder = ImageCoder()
+
+ threads = []
+ for thread_index in range(len(ranges)):
+ args = (coder, thread_index, ranges, name, filenames,
+ texts, labels, num_shards)
+ t = threading.Thread(target=_process_image_files_batch, args=args)
+ t.start()
+ threads.append(t)
+
+ # Wait for all the threads to terminate.
+ coord.join(threads)
+ print('%s: Finished writing all %d images in data set.' %
+ (datetime.now(), len(filenames)))
+ sys.stdout.flush()
+
+
+def _find_image_files(data_dir, labels_file):
+ """Build a list of all images files and labels in the data set.
+
+ Args:
+ data_dir: string, path to the root directory of images.
+
+ Assumes that the image data set resides in JPEG files located in
+ the following directory structure.
+
+ data_dir/dog/another-image.JPEG
+ data_dir/dog/my-image.jpg
+
+ where 'dog' is the label associated with these images.
+
+ labels_file: string, path to the labels file.
+
+ The list of valid labels are held in this file. Assumes that the file
+ contains entries as such:
+ dog
+ cat
+ flower
+ where each line corresponds to a label. We map each label contained in
+ the file to an integer starting with the integer 0 corresponding to the
+ label contained in the first line.
+
+ Returns:
+ filenames: list of strings; each string is a path to an image file.
+ texts: list of strings; each string is the class, e.g. 'dog'
+ labels: list of integer; each integer identifies the ground truth.
+ """
+ print('Determining list of input files and labels from %s.' % data_dir)
+ unique_labels = [l.strip() for l in tf.gfile.FastGFile(
+ labels_file, 'r').readlines()]
+
+ labels = []
+ filenames = []
+ texts = []
+
+ # Leave label index 0 empty as a background class.
+ label_index = 1
+
+ # Construct the list of JPEG files and labels.
+ for text in unique_labels:
+ jpeg_file_path = '%s/%s/*' % (data_dir, text)
+ matching_files = tf.gfile.Glob(jpeg_file_path)
+
+ labels.extend([label_index] * len(matching_files))
+ texts.extend([text] * len(matching_files))
+ filenames.extend(matching_files)
+
+ if not label_index % 100:
+ print('Finished finding files in %d of %d classes.' % (
+ label_index, len(labels)))
+ label_index += 1
+
+ # Shuffle the ordering of all image files in order to guarantee
+ # random ordering of the images with respect to label in the
+ # saved TFRecord files. Make the randomization repeatable.
+ shuffled_index = list(range(len(filenames)))
+ random.seed(12345)
+ random.shuffle(shuffled_index)
+
+ filenames = [filenames[i] for i in shuffled_index]
+ texts = [texts[i] for i in shuffled_index]
+ labels = [labels[i] for i in shuffled_index]
+
+ print('Found %d JPEG files across %d labels inside %s.' %
+ (len(filenames), len(unique_labels), data_dir))
+ return filenames, texts, labels
+
+
+def _process_dataset(name, directory, num_shards, labels_file):
+ """Process a complete data set and save it as a TFRecord.
+
+ Args:
+ name: string, unique identifier specifying the data set.
+ directory: string, root path to the data set.
+ num_shards: integer number of shards for this data set.
+ labels_file: string, path to the labels file.
+ """
+ filenames, texts, labels = _find_image_files(directory, labels_file)
+ _process_image_files(name, filenames, texts, labels, num_shards)
+
+
+def main(unused_argv):
+ assert not FLAGS.train_shards % FLAGS.num_threads, (
+ 'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
+ assert not FLAGS.validation_shards % FLAGS.num_threads, (
+ 'Please make the FLAGS.num_threads commensurate with '
+ 'FLAGS.validation_shards')
+ print('Saving results to %s' % FLAGS.output_directory)
+
+ # Run it!
+ _process_dataset('validation', FLAGS.validation_directory,
+ FLAGS.validation_shards, FLAGS.labels_file)
+ _process_dataset('train', FLAGS.train_directory,
+ FLAGS.train_shards, FLAGS.labels_file)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/inception/inception/data/build_imagenet_data.py b/models/research/inception/inception/data/build_imagenet_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c054735e782297f990451e29ff4383af24bbe802
--- /dev/null
+++ b/models/research/inception/inception/data/build_imagenet_data.py
@@ -0,0 +1,707 @@
+#!/usr/bin/python
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts ImageNet data to TFRecords file format with Example protos.
+
+The raw ImageNet data set is expected to reside in JPEG files located in the
+following directory structure.
+
+ data_dir/n01440764/ILSVRC2012_val_00000293.JPEG
+ data_dir/n01440764/ILSVRC2012_val_00000543.JPEG
+ ...
+
+where 'n01440764' is the unique synset label associated with
+these images.
+
+The training data set consists of 1000 sub-directories (i.e. labels)
+each containing 1200 JPEG images for a total of 1.2M JPEG images.
+
+The evaluation data set consists of 1000 sub-directories (i.e. labels)
+each containing 50 JPEG images for a total of 50K JPEG images.
+
+This TensorFlow script converts the training and evaluation data into
+a sharded data set consisting of 1024 and 128 TFRecord files, respectively.
+
+ train_directory/train-00000-of-01024
+ train_directory/train-00001-of-01024
+ ...
+ train_directory/train-01023-of-01024
+
+and
+
+ validation_directory/validation-00000-of-00128
+ validation_directory/validation-00001-of-00128
+ ...
+ validation_directory/validation-00127-of-00128
+
+Each validation TFRecord file contains ~390 records. Each training TFREcord
+file contains ~1250 records. Each record within the TFRecord file is a
+serialized Example proto. The Example proto contains the following fields:
+
+ image/encoded: string containing JPEG encoded image in RGB colorspace
+ image/height: integer, image height in pixels
+ image/width: integer, image width in pixels
+ image/colorspace: string, specifying the colorspace, always 'RGB'
+ image/channels: integer, specifying the number of channels, always 3
+ image/format: string, specifying the format, always 'JPEG'
+
+ image/filename: string containing the basename of the image file
+ e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG'
+ image/class/label: integer specifying the index in a classification layer.
+ The label ranges from [1, 1000] where 0 is not used.
+ image/class/synset: string specifying the unique ID of the label,
+ e.g. 'n01440764'
+ image/class/text: string specifying the human-readable version of the label
+ e.g. 'red fox, Vulpes vulpes'
+
+ image/object/bbox/xmin: list of integers specifying the 0+ human annotated
+ bounding boxes
+ image/object/bbox/xmax: list of integers specifying the 0+ human annotated
+ bounding boxes
+ image/object/bbox/ymin: list of integers specifying the 0+ human annotated
+ bounding boxes
+ image/object/bbox/ymax: list of integers specifying the 0+ human annotated
+ bounding boxes
+ image/object/bbox/label: integer specifying the index in a classification
+ layer. The label ranges from [1, 1000] where 0 is not used. Note this is
+ always identical to the image label.
+
+Note that the length of xmin is identical to the length of xmax, ymin and ymax
+for each example.
+
+Running this script using 16 threads may take around ~2.5 hours on an HP Z420.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import os
+import random
+import sys
+import threading
+
+import numpy as np
+import six
+import tensorflow as tf
+
+tf.app.flags.DEFINE_string('train_directory', '/tmp/',
+ 'Training data directory')
+tf.app.flags.DEFINE_string('validation_directory', '/tmp/',
+ 'Validation data directory')
+tf.app.flags.DEFINE_string('output_directory', '/tmp/',
+ 'Output data directory')
+
+tf.app.flags.DEFINE_integer('train_shards', 1024,
+ 'Number of shards in training TFRecord files.')
+tf.app.flags.DEFINE_integer('validation_shards', 128,
+ 'Number of shards in validation TFRecord files.')
+
+tf.app.flags.DEFINE_integer('num_threads', 8,
+ 'Number of threads to preprocess the images.')
+
+# The labels file contains a list of valid labels are held in this file.
+# Assumes that the file contains entries as such:
+# n01440764
+# n01443537
+# n01484850
+# where each line corresponds to a label expressed as a synset. We map
+# each synset contained in the file to an integer (based on the alphabetical
+# ordering). See below for details.
+tf.app.flags.DEFINE_string('labels_file',
+ 'imagenet_lsvrc_2015_synsets.txt',
+ 'Labels file')
+
+# This file containing mapping from synset to human-readable label.
+# Assumes each line of the file looks like:
+#
+# n02119247 black fox
+# n02119359 silver fox
+# n02119477 red fox, Vulpes fulva
+#
+# where each line corresponds to a unique mapping. Note that each line is
+# formatted as \t.
+tf.app.flags.DEFINE_string('imagenet_metadata_file',
+ 'imagenet_metadata.txt',
+ 'ImageNet metadata file')
+
+# This file is the output of process_bounding_box.py
+# Assumes each line of the file looks like:
+#
+# n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940
+#
+# where each line corresponds to one bounding box annotation associated
+# with an image. Each line can be parsed as:
+#
+# , , , ,
+#
+# Note that there might exist mulitple bounding box annotations associated
+# with an image file.
+tf.app.flags.DEFINE_string('bounding_box_file',
+ './imagenet_2012_bounding_boxes.csv',
+ 'Bounding box file')
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def _int64_feature(value):
+ """Wrapper for inserting int64 features into Example proto."""
+ if not isinstance(value, list):
+ value = [value]
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
+
+
+def _float_feature(value):
+ """Wrapper for inserting float features into Example proto."""
+ if not isinstance(value, list):
+ value = [value]
+ return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+
+def _bytes_feature(value):
+ """Wrapper for inserting bytes features into Example proto."""
+ if six.PY3 and isinstance(value, six.text_type):
+ value = six.binary_type(value, encoding='utf-8')
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _convert_to_example(filename, image_buffer, label, synset, human, bbox,
+ height, width):
+ """Build an Example proto for an example.
+
+ Args:
+ filename: string, path to an image file, e.g., '/path/to/example.JPG'
+ image_buffer: string, JPEG encoding of RGB image
+ label: integer, identifier for the ground truth for the network
+ synset: string, unique WordNet ID specifying the label, e.g., 'n02323233'
+ human: string, human-readable label, e.g., 'red fox, Vulpes vulpes'
+ bbox: list of bounding boxes; each box is a list of integers
+ specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong to
+ the same label as the image label.
+ height: integer, image height in pixels
+ width: integer, image width in pixels
+ Returns:
+ Example proto
+ """
+ xmin = []
+ ymin = []
+ xmax = []
+ ymax = []
+ for b in bbox:
+ assert len(b) == 4
+ # pylint: disable=expression-not-assigned
+ [l.append(point) for l, point in zip([xmin, ymin, xmax, ymax], b)]
+ # pylint: enable=expression-not-assigned
+
+ colorspace = 'RGB'
+ channels = 3
+ image_format = 'JPEG'
+
+ example = tf.train.Example(features=tf.train.Features(feature={
+ 'image/height': _int64_feature(height),
+ 'image/width': _int64_feature(width),
+ 'image/colorspace': _bytes_feature(colorspace),
+ 'image/channels': _int64_feature(channels),
+ 'image/class/label': _int64_feature(label),
+ 'image/class/synset': _bytes_feature(synset),
+ 'image/class/text': _bytes_feature(human),
+ 'image/object/bbox/xmin': _float_feature(xmin),
+ 'image/object/bbox/xmax': _float_feature(xmax),
+ 'image/object/bbox/ymin': _float_feature(ymin),
+ 'image/object/bbox/ymax': _float_feature(ymax),
+ 'image/object/bbox/label': _int64_feature([label] * len(xmin)),
+ 'image/format': _bytes_feature(image_format),
+ 'image/filename': _bytes_feature(os.path.basename(filename)),
+ 'image/encoded': _bytes_feature(image_buffer)}))
+ return example
+
+
+class ImageCoder(object):
+ """Helper class that provides TensorFlow image coding utilities."""
+
+ def __init__(self):
+ # Create a single Session to run all image coding calls.
+ self._sess = tf.Session()
+
+ # Initializes function that converts PNG to JPEG data.
+ self._png_data = tf.placeholder(dtype=tf.string)
+ image = tf.image.decode_png(self._png_data, channels=3)
+ self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
+
+ # Initializes function that converts CMYK JPEG data to RGB JPEG data.
+ self._cmyk_data = tf.placeholder(dtype=tf.string)
+ image = tf.image.decode_jpeg(self._cmyk_data, channels=0)
+ self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100)
+
+ # Initializes function that decodes RGB JPEG data.
+ self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
+ self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
+
+ def png_to_jpeg(self, image_data):
+ return self._sess.run(self._png_to_jpeg,
+ feed_dict={self._png_data: image_data})
+
+ def cmyk_to_rgb(self, image_data):
+ return self._sess.run(self._cmyk_to_rgb,
+ feed_dict={self._cmyk_data: image_data})
+
+ def decode_jpeg(self, image_data):
+ image = self._sess.run(self._decode_jpeg,
+ feed_dict={self._decode_jpeg_data: image_data})
+ assert len(image.shape) == 3
+ assert image.shape[2] == 3
+ return image
+
+
+def _is_png(filename):
+ """Determine if a file contains a PNG format image.
+
+ Args:
+ filename: string, path of the image file.
+
+ Returns:
+ boolean indicating if the image is a PNG.
+ """
+ # File list from:
+ # https://groups.google.com/forum/embed/?place=forum/torch7#!topic/torch7/fOSTXHIESSU
+ return 'n02105855_2933.JPEG' in filename
+
+
+def _is_cmyk(filename):
+ """Determine if file contains a CMYK JPEG format image.
+
+ Args:
+ filename: string, path of the image file.
+
+ Returns:
+ boolean indicating if the image is a JPEG encoded with CMYK color space.
+ """
+ # File list from:
+ # https://github.com/cytsai/ilsvrc-cmyk-image-list
+ blacklist = ['n01739381_1309.JPEG', 'n02077923_14822.JPEG',
+ 'n02447366_23489.JPEG', 'n02492035_15739.JPEG',
+ 'n02747177_10752.JPEG', 'n03018349_4028.JPEG',
+ 'n03062245_4620.JPEG', 'n03347037_9675.JPEG',
+ 'n03467068_12171.JPEG', 'n03529860_11437.JPEG',
+ 'n03544143_17228.JPEG', 'n03633091_5218.JPEG',
+ 'n03710637_5125.JPEG', 'n03961711_5286.JPEG',
+ 'n04033995_2932.JPEG', 'n04258138_17003.JPEG',
+ 'n04264628_27969.JPEG', 'n04336792_7448.JPEG',
+ 'n04371774_5854.JPEG', 'n04596742_4225.JPEG',
+ 'n07583066_647.JPEG', 'n13037406_4650.JPEG']
+ return filename.split('/')[-1] in blacklist
+
+
+def _process_image(filename, coder):
+ """Process a single image file.
+
+ Args:
+ filename: string, path to an image file e.g., '/path/to/example.JPG'.
+ coder: instance of ImageCoder to provide TensorFlow image coding utils.
+ Returns:
+ image_buffer: string, JPEG encoding of RGB image.
+ height: integer, image height in pixels.
+ width: integer, image width in pixels.
+ """
+ # Read the image file.
+ with tf.gfile.FastGFile(filename, 'rb') as f:
+ image_data = f.read()
+
+ # Clean the dirty data.
+ if _is_png(filename):
+ # 1 image is a PNG.
+ print('Converting PNG to JPEG for %s' % filename)
+ image_data = coder.png_to_jpeg(image_data)
+ elif _is_cmyk(filename):
+ # 22 JPEG images are in CMYK colorspace.
+ print('Converting CMYK to RGB for %s' % filename)
+ image_data = coder.cmyk_to_rgb(image_data)
+
+ # Decode the RGB JPEG.
+ image = coder.decode_jpeg(image_data)
+
+ # Check that image converted to RGB
+ assert len(image.shape) == 3
+ height = image.shape[0]
+ width = image.shape[1]
+ assert image.shape[2] == 3
+
+ return image_data, height, width
+
+
+def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
+ synsets, labels, humans, bboxes, num_shards):
+ """Processes and saves list of images as TFRecord in 1 thread.
+
+ Args:
+ coder: instance of ImageCoder to provide TensorFlow image coding utils.
+ thread_index: integer, unique batch to run index is within [0, len(ranges)).
+ ranges: list of pairs of integers specifying ranges of each batches to
+ analyze in parallel.
+ name: string, unique identifier specifying the data set
+ filenames: list of strings; each string is a path to an image file
+ synsets: list of strings; each string is a unique WordNet ID
+ labels: list of integer; each integer identifies the ground truth
+ humans: list of strings; each string is a human-readable label
+ bboxes: list of bounding boxes for each image. Note that each entry in this
+ list might contain from 0+ entries corresponding to the number of bounding
+ box annotations for the image.
+ num_shards: integer number of shards for this data set.
+ """
+ # Each thread produces N shards where N = int(num_shards / num_threads).
+ # For instance, if num_shards = 128, and the num_threads = 2, then the first
+ # thread would produce shards [0, 64).
+ num_threads = len(ranges)
+ assert not num_shards % num_threads
+ num_shards_per_batch = int(num_shards / num_threads)
+
+ shard_ranges = np.linspace(ranges[thread_index][0],
+ ranges[thread_index][1],
+ num_shards_per_batch + 1).astype(int)
+ num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
+
+ counter = 0
+ for s in range(num_shards_per_batch):
+ # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
+ shard = thread_index * num_shards_per_batch + s
+ output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards)
+ output_file = os.path.join(FLAGS.output_directory, output_filename)
+ writer = tf.python_io.TFRecordWriter(output_file)
+
+ shard_counter = 0
+ files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
+ for i in files_in_shard:
+ filename = filenames[i]
+ label = labels[i]
+ synset = synsets[i]
+ human = humans[i]
+ bbox = bboxes[i]
+
+ image_buffer, height, width = _process_image(filename, coder)
+
+ example = _convert_to_example(filename, image_buffer, label,
+ synset, human, bbox,
+ height, width)
+ writer.write(example.SerializeToString())
+ shard_counter += 1
+ counter += 1
+
+ if not counter % 1000:
+ print('%s [thread %d]: Processed %d of %d images in thread batch.' %
+ (datetime.now(), thread_index, counter, num_files_in_thread))
+ sys.stdout.flush()
+
+ writer.close()
+ print('%s [thread %d]: Wrote %d images to %s' %
+ (datetime.now(), thread_index, shard_counter, output_file))
+ sys.stdout.flush()
+ shard_counter = 0
+ print('%s [thread %d]: Wrote %d images to %d shards.' %
+ (datetime.now(), thread_index, counter, num_files_in_thread))
+ sys.stdout.flush()
+
+
+def _process_image_files(name, filenames, synsets, labels, humans,
+ bboxes, num_shards):
+ """Process and save list of images as TFRecord of Example protos.
+
+ Args:
+ name: string, unique identifier specifying the data set
+ filenames: list of strings; each string is a path to an image file
+ synsets: list of strings; each string is a unique WordNet ID
+ labels: list of integer; each integer identifies the ground truth
+ humans: list of strings; each string is a human-readable label
+ bboxes: list of bounding boxes for each image. Note that each entry in this
+ list might contain from 0+ entries corresponding to the number of bounding
+ box annotations for the image.
+ num_shards: integer number of shards for this data set.
+ """
+ assert len(filenames) == len(synsets)
+ assert len(filenames) == len(labels)
+ assert len(filenames) == len(humans)
+ assert len(filenames) == len(bboxes)
+
+ # Break all images into batches with a [ranges[i][0], ranges[i][1]].
+ spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
+ ranges = []
+ threads = []
+ for i in range(len(spacing) - 1):
+ ranges.append([spacing[i], spacing[i + 1]])
+
+ # Launch a thread for each batch.
+ print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
+ sys.stdout.flush()
+
+ # Create a mechanism for monitoring when all threads are finished.
+ coord = tf.train.Coordinator()
+
+ # Create a generic TensorFlow-based utility for converting all image codings.
+ coder = ImageCoder()
+
+ threads = []
+ for thread_index in range(len(ranges)):
+ args = (coder, thread_index, ranges, name, filenames,
+ synsets, labels, humans, bboxes, num_shards)
+ t = threading.Thread(target=_process_image_files_batch, args=args)
+ t.start()
+ threads.append(t)
+
+ # Wait for all the threads to terminate.
+ coord.join(threads)
+ print('%s: Finished writing all %d images in data set.' %
+ (datetime.now(), len(filenames)))
+ sys.stdout.flush()
+
+
+def _find_image_files(data_dir, labels_file):
+ """Build a list of all images files and labels in the data set.
+
+ Args:
+ data_dir: string, path to the root directory of images.
+
+ Assumes that the ImageNet data set resides in JPEG files located in
+ the following directory structure.
+
+ data_dir/n01440764/ILSVRC2012_val_00000293.JPEG
+ data_dir/n01440764/ILSVRC2012_val_00000543.JPEG
+
+ where 'n01440764' is the unique synset label associated with these images.
+
+ labels_file: string, path to the labels file.
+
+ The list of valid labels are held in this file. Assumes that the file
+ contains entries as such:
+ n01440764
+ n01443537
+ n01484850
+ where each line corresponds to a label expressed as a synset. We map
+ each synset contained in the file to an integer (based on the alphabetical
+ ordering) starting with the integer 1 corresponding to the synset
+ contained in the first line.
+
+ The reason we start the integer labels at 1 is to reserve label 0 as an
+ unused background class.
+
+ Returns:
+ filenames: list of strings; each string is a path to an image file.
+ synsets: list of strings; each string is a unique WordNet ID.
+ labels: list of integer; each integer identifies the ground truth.
+ """
+ print('Determining list of input files and labels from %s.' % data_dir)
+ challenge_synsets = [l.strip() for l in
+ tf.gfile.FastGFile(labels_file, 'r').readlines()]
+
+ labels = []
+ filenames = []
+ synsets = []
+
+ # Leave label index 0 empty as a background class.
+ label_index = 1
+
+ # Construct the list of JPEG files and labels.
+ for synset in challenge_synsets:
+ jpeg_file_path = '%s/%s/*.JPEG' % (data_dir, synset)
+ matching_files = tf.gfile.Glob(jpeg_file_path)
+
+ labels.extend([label_index] * len(matching_files))
+ synsets.extend([synset] * len(matching_files))
+ filenames.extend(matching_files)
+
+ if not label_index % 100:
+ print('Finished finding files in %d of %d classes.' % (
+ label_index, len(challenge_synsets)))
+ label_index += 1
+
+ # Shuffle the ordering of all image files in order to guarantee
+ # random ordering of the images with respect to label in the
+ # saved TFRecord files. Make the randomization repeatable.
+ shuffled_index = list(range(len(filenames)))
+ random.seed(12345)
+ random.shuffle(shuffled_index)
+
+ filenames = [filenames[i] for i in shuffled_index]
+ synsets = [synsets[i] for i in shuffled_index]
+ labels = [labels[i] for i in shuffled_index]
+
+ print('Found %d JPEG files across %d labels inside %s.' %
+ (len(filenames), len(challenge_synsets), data_dir))
+ return filenames, synsets, labels
+
+
+def _find_human_readable_labels(synsets, synset_to_human):
+ """Build a list of human-readable labels.
+
+ Args:
+ synsets: list of strings; each string is a unique WordNet ID.
+ synset_to_human: dict of synset to human labels, e.g.,
+ 'n02119022' --> 'red fox, Vulpes vulpes'
+
+ Returns:
+ List of human-readable strings corresponding to each synset.
+ """
+ humans = []
+ for s in synsets:
+ assert s in synset_to_human, ('Failed to find: %s' % s)
+ humans.append(synset_to_human[s])
+ return humans
+
+
+def _find_image_bounding_boxes(filenames, image_to_bboxes):
+ """Find the bounding boxes for a given image file.
+
+ Args:
+ filenames: list of strings; each string is a path to an image file.
+ image_to_bboxes: dictionary mapping image file names to a list of
+ bounding boxes. This list contains 0+ bounding boxes.
+ Returns:
+ List of bounding boxes for each image. Note that each entry in this
+ list might contain from 0+ entries corresponding to the number of bounding
+ box annotations for the image.
+ """
+ num_image_bbox = 0
+ bboxes = []
+ for f in filenames:
+ basename = os.path.basename(f)
+ if basename in image_to_bboxes:
+ bboxes.append(image_to_bboxes[basename])
+ num_image_bbox += 1
+ else:
+ bboxes.append([])
+ print('Found %d images with bboxes out of %d images' % (
+ num_image_bbox, len(filenames)))
+ return bboxes
+
+
+def _process_dataset(name, directory, num_shards, synset_to_human,
+ image_to_bboxes):
+ """Process a complete data set and save it as a TFRecord.
+
+ Args:
+ name: string, unique identifier specifying the data set.
+ directory: string, root path to the data set.
+ num_shards: integer number of shards for this data set.
+ synset_to_human: dict of synset to human labels, e.g.,
+ 'n02119022' --> 'red fox, Vulpes vulpes'
+ image_to_bboxes: dictionary mapping image file names to a list of
+ bounding boxes. This list contains 0+ bounding boxes.
+ """
+ filenames, synsets, labels = _find_image_files(directory, FLAGS.labels_file)
+ humans = _find_human_readable_labels(synsets, synset_to_human)
+ bboxes = _find_image_bounding_boxes(filenames, image_to_bboxes)
+ _process_image_files(name, filenames, synsets, labels,
+ humans, bboxes, num_shards)
+
+
+def _build_synset_lookup(imagenet_metadata_file):
+ """Build lookup for synset to human-readable label.
+
+ Args:
+ imagenet_metadata_file: string, path to file containing mapping from
+ synset to human-readable label.
+
+ Assumes each line of the file looks like:
+
+ n02119247 black fox
+ n02119359 silver fox
+ n02119477 red fox, Vulpes fulva
+
+ where each line corresponds to a unique mapping. Note that each line is
+ formatted as \t.
+
+ Returns:
+ Dictionary of synset to human labels, such as:
+ 'n02119022' --> 'red fox, Vulpes vulpes'
+ """
+ lines = tf.gfile.FastGFile(imagenet_metadata_file, 'r').readlines()
+ synset_to_human = {}
+ for l in lines:
+ if l:
+ parts = l.strip().split('\t')
+ assert len(parts) == 2
+ synset = parts[0]
+ human = parts[1]
+ synset_to_human[synset] = human
+ return synset_to_human
+
+
+def _build_bounding_box_lookup(bounding_box_file):
+ """Build a lookup from image file to bounding boxes.
+
+ Args:
+ bounding_box_file: string, path to file with bounding boxes annotations.
+
+ Assumes each line of the file looks like:
+
+ n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940
+
+ where each line corresponds to one bounding box annotation associated
+ with an image. Each line can be parsed as:
+
+ , , , ,
+
+ Note that there might exist mulitple bounding box annotations associated
+ with an image file. This file is the output of process_bounding_boxes.py.
+
+ Returns:
+ Dictionary mapping image file names to a list of bounding boxes. This list
+ contains 0+ bounding boxes.
+ """
+ lines = tf.gfile.FastGFile(bounding_box_file, 'r').readlines()
+ images_to_bboxes = {}
+ num_bbox = 0
+ num_image = 0
+ for l in lines:
+ if l:
+ parts = l.split(',')
+ assert len(parts) == 5, ('Failed to parse: %s' % l)
+ filename = parts[0]
+ xmin = float(parts[1])
+ ymin = float(parts[2])
+ xmax = float(parts[3])
+ ymax = float(parts[4])
+ box = [xmin, ymin, xmax, ymax]
+
+ if filename not in images_to_bboxes:
+ images_to_bboxes[filename] = []
+ num_image += 1
+ images_to_bboxes[filename].append(box)
+ num_bbox += 1
+
+ print('Successfully read %d bounding boxes '
+ 'across %d images.' % (num_bbox, num_image))
+ return images_to_bboxes
+
+
+def main(unused_argv):
+ assert not FLAGS.train_shards % FLAGS.num_threads, (
+ 'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
+ assert not FLAGS.validation_shards % FLAGS.num_threads, (
+ 'Please make the FLAGS.num_threads commensurate with '
+ 'FLAGS.validation_shards')
+ print('Saving results to %s' % FLAGS.output_directory)
+
+ # Build a map from synset to human-readable label.
+ synset_to_human = _build_synset_lookup(FLAGS.imagenet_metadata_file)
+ image_to_bboxes = _build_bounding_box_lookup(FLAGS.bounding_box_file)
+
+ # Run it!
+ _process_dataset('validation', FLAGS.validation_directory,
+ FLAGS.validation_shards, synset_to_human, image_to_bboxes)
+ _process_dataset('train', FLAGS.train_directory, FLAGS.train_shards,
+ synset_to_human, image_to_bboxes)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/inception/inception/data/download_and_preprocess_flowers.sh b/models/research/inception/inception/data/download_and_preprocess_flowers.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ee045c164e803ab38be69fb1933134e7f37f1793
--- /dev/null
+++ b/models/research/inception/inception/data/download_and_preprocess_flowers.sh
@@ -0,0 +1,96 @@
+#!/bin/bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Script to download and preprocess the flowers data set. This data set
+# provides a demonstration for how to perform fine-tuning (i.e. tranfer
+# learning) from one model to a new data set.
+#
+# This script provides a demonstration for how to prepare an arbitrary
+# data set for training an Inception v3 model.
+#
+# We demonstrate this with the flowers data set which consists of images
+# of labeled flower images from 5 classes:
+#
+# daisy, dandelion, roses, sunflowers, tulips
+#
+# The final output of this script are sharded TFRecord files containing
+# serialized Example protocol buffers. See build_image_data.py for
+# details of how the Example protocol buffer contains image data.
+#
+# usage:
+# ./download_and_preprocess_flowers.sh [data-dir]
+set -e
+
+if [ -z "$1" ]; then
+ echo "Usage: download_and_preprocess_flowers.sh [data dir]"
+ exit
+fi
+
+# Create the output and temporary directories.
+DATA_DIR="${1%/}"
+SCRATCH_DIR="${DATA_DIR}/raw-data"
+WORK_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
+mkdir -p "${DATA_DIR}"
+mkdir -p "${SCRATCH_DIR}"
+
+# Download the flowers data.
+DATA_URL="http://download.tensorflow.org/example_images/flower_photos.tgz"
+CURRENT_DIR=$(pwd)
+cd "${DATA_DIR}"
+TARBALL="flower_photos.tgz"
+if [ ! -f ${TARBALL} ]; then
+ echo "Downloading flower data set."
+ curl -o ${TARBALL} "${DATA_URL}"
+else
+ echo "Skipping download of flower data."
+fi
+
+# Note the locations of the train and validation data.
+TRAIN_DIRECTORY="${SCRATCH_DIR}/train"
+VALIDATION_DIRECTORY="${SCRATCH_DIR}/validation"
+
+# Expands the data into the flower_photos/ directory and rename it as the
+# train directory.
+tar xf flower_photos.tgz
+rm -rf "${TRAIN_DIRECTORY}" "${VALIDATION_DIRECTORY}"
+mv flower_photos "${TRAIN_DIRECTORY}"
+
+# Generate a list of 5 labels: daisy, dandelion, roses, sunflowers, tulips
+LABELS_FILE="${SCRATCH_DIR}/labels.txt"
+ls -1 "${TRAIN_DIRECTORY}" | grep -v 'LICENSE' | sed 's/\///' | sort > "${LABELS_FILE}"
+
+# Generate the validation data set.
+while read LABEL; do
+ VALIDATION_DIR_FOR_LABEL="${VALIDATION_DIRECTORY}/${LABEL}"
+ TRAIN_DIR_FOR_LABEL="${TRAIN_DIRECTORY}/${LABEL}"
+
+ # Move the first randomly selected 100 images to the validation set.
+ mkdir -p "${VALIDATION_DIR_FOR_LABEL}"
+ VALIDATION_IMAGES=$(ls -1 "${TRAIN_DIR_FOR_LABEL}" | shuf | head -100)
+ for IMAGE in ${VALIDATION_IMAGES}; do
+ mv -f "${TRAIN_DIRECTORY}/${LABEL}/${IMAGE}" "${VALIDATION_DIR_FOR_LABEL}"
+ done
+done < "${LABELS_FILE}"
+
+# Build the TFRecords version of the image data.
+cd "${CURRENT_DIR}"
+BUILD_SCRIPT="${WORK_DIR}/build_image_data.py"
+OUTPUT_DIRECTORY="${DATA_DIR}"
+"${BUILD_SCRIPT}" \
+ --train_directory="${TRAIN_DIRECTORY}" \
+ --validation_directory="${VALIDATION_DIRECTORY}" \
+ --output_directory="${OUTPUT_DIRECTORY}" \
+ --labels_file="${LABELS_FILE}"
diff --git a/models/research/inception/inception/data/download_and_preprocess_flowers_mac.sh b/models/research/inception/inception/data/download_and_preprocess_flowers_mac.sh
new file mode 100644
index 0000000000000000000000000000000000000000..154905635b19aeaaea087a8e76afda9b8c624d59
--- /dev/null
+++ b/models/research/inception/inception/data/download_and_preprocess_flowers_mac.sh
@@ -0,0 +1,96 @@
+#!/bin/bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Script to download and preprocess the flowers data set. This data set
+# provides a demonstration for how to perform fine-tuning (i.e. tranfer
+# learning) from one model to a new data set.
+#
+# This script provides a demonstration for how to prepare an arbitrary
+# data set for training an Inception v3 model.
+#
+# We demonstrate this with the flowers data set which consists of images
+# of labeled flower images from 5 classes:
+#
+# daisy, dandelion, roses, sunflowers, tulips
+#
+# The final output of this script are sharded TFRecord files containing
+# serialized Example protocol buffers. See build_image_data.py for
+# details of how the Example protocol buffer contains image data.
+#
+# usage:
+# ./download_and_preprocess_flowers.sh [data-dir]
+set -e
+
+if [ -z "$1" ]; then
+ echo "Usage: download_and_preprocess_flowers.sh [data dir]"
+ exit
+fi
+
+# Create the output and temporary directories.
+DATA_DIR="${1%/}"
+SCRATCH_DIR="${DATA_DIR}/raw-data/"
+mkdir -p "${DATA_DIR}"
+mkdir -p "${SCRATCH_DIR}"
+WORK_DIR="$0.runfiles/inception/inception"
+
+# Download the flowers data.
+DATA_URL="http://download.tensorflow.org/example_images/flower_photos.tgz"
+CURRENT_DIR=$(pwd)
+cd "${DATA_DIR}"
+TARBALL="flower_photos.tgz"
+if [ ! -f ${TARBALL} ]; then
+ echo "Downloading flower data set."
+ curl -o ${TARBALL} "${DATA_URL}"
+else
+ echo "Skipping download of flower data."
+fi
+
+# Note the locations of the train and validation data.
+TRAIN_DIRECTORY="${SCRATCH_DIR}train/"
+VALIDATION_DIRECTORY="${SCRATCH_DIR}validation/"
+
+# Expands the data into the flower_photos/ directory and rename it as the
+# train directory.
+tar xf flower_photos.tgz
+rm -rf "${TRAIN_DIRECTORY}" "${VALIDATION_DIRECTORY}"
+mv flower_photos "${TRAIN_DIRECTORY}"
+
+# Generate a list of 5 labels: daisy, dandelion, roses, sunflowers, tulips
+LABELS_FILE="${SCRATCH_DIR}/labels.txt"
+ls -1 "${TRAIN_DIRECTORY}" | grep -v 'LICENSE' | sed 's/\///' | sort > "${LABELS_FILE}"
+
+# Generate the validation data set.
+while read LABEL; do
+ VALIDATION_DIR_FOR_LABEL="${VALIDATION_DIRECTORY}${LABEL}"
+ TRAIN_DIR_FOR_LABEL="${TRAIN_DIRECTORY}${LABEL}"
+
+ # Move the first randomly selected 100 images to the validation set.
+ mkdir -p "${VALIDATION_DIR_FOR_LABEL}"
+ VALIDATION_IMAGES=$(ls -1 "${TRAIN_DIR_FOR_LABEL}" | gshuf | head -100)
+ for IMAGE in ${VALIDATION_IMAGES}; do
+ mv -f "${TRAIN_DIRECTORY}${LABEL}/${IMAGE}" "${VALIDATION_DIR_FOR_LABEL}"
+ done
+done < "${LABELS_FILE}"
+
+# Build the TFRecords version of the image data.
+cd "${CURRENT_DIR}"
+BUILD_SCRIPT="${WORK_DIR}/build_image_data"
+OUTPUT_DIRECTORY="${DATA_DIR}"
+"${BUILD_SCRIPT}" \
+ --train_directory="${TRAIN_DIRECTORY}" \
+ --validation_directory="${VALIDATION_DIRECTORY}" \
+ --output_directory="${OUTPUT_DIRECTORY}" \
+ --labels_file="${LABELS_FILE}"
diff --git a/models/research/inception/inception/data/download_and_preprocess_imagenet.sh b/models/research/inception/inception/data/download_and_preprocess_imagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6faae831075d4f6bfdc8bf8797219f7a0e4c1797
--- /dev/null
+++ b/models/research/inception/inception/data/download_and_preprocess_imagenet.sh
@@ -0,0 +1,101 @@
+#!/bin/bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Script to download and preprocess ImageNet Challenge 2012
+# training and validation data set.
+#
+# The final output of this script are sharded TFRecord files containing
+# serialized Example protocol buffers. See build_imagenet_data.py for
+# details of how the Example protocol buffers contain the ImageNet data.
+#
+# The final output of this script appears as such:
+#
+# data_dir/train-00000-of-01024
+# data_dir/train-00001-of-01024
+# ...
+# data_dir/train-01023-of-01024
+#
+# and
+#
+# data_dir/validation-00000-of-00128
+# data_dir/validation-00001-of-00128
+# ...
+# data_dir/validation-00127-of-00128
+#
+# Note that this script may take several hours to run to completion. The
+# conversion of the ImageNet data to TFRecords alone takes 2-3 hours depending
+# on the speed of your machine. Please be patient.
+#
+# **IMPORTANT**
+# To download the raw images, the user must create an account with image-net.org
+# and generate a username and access_key. The latter two are required for
+# downloading the raw images.
+#
+# usage:
+# ./download_and_preprocess_imagenet.sh [data-dir]
+set -e
+
+if [ -z "$1" ]; then
+ echo "Usage: download_and_preprocess_imagenet.sh [data dir]"
+ exit
+fi
+
+# Create the output and temporary directories.
+DATA_DIR="${1%/}"
+SCRATCH_DIR="${DATA_DIR}/raw-data/"
+mkdir -p "${DATA_DIR}"
+mkdir -p "${SCRATCH_DIR}"
+WORK_DIR="$0.runfiles/inception/inception"
+
+# Download the ImageNet data.
+LABELS_FILE="${WORK_DIR}/data/imagenet_lsvrc_2015_synsets.txt"
+DOWNLOAD_SCRIPT="${WORK_DIR}/data/download_imagenet.sh"
+"${DOWNLOAD_SCRIPT}" "${SCRATCH_DIR}" "${LABELS_FILE}"
+
+# Note the locations of the train and validation data.
+TRAIN_DIRECTORY="${SCRATCH_DIR}train/"
+VALIDATION_DIRECTORY="${SCRATCH_DIR}validation/"
+
+# Preprocess the validation data by moving the images into the appropriate
+# sub-directory based on the label (synset) of the image.
+echo "Organizing the validation data into sub-directories."
+PREPROCESS_VAL_SCRIPT="${WORK_DIR}/data/preprocess_imagenet_validation_data.py"
+VAL_LABELS_FILE="${WORK_DIR}/data/imagenet_2012_validation_synset_labels.txt"
+
+"${PREPROCESS_VAL_SCRIPT}" "${VALIDATION_DIRECTORY}" "${VAL_LABELS_FILE}"
+
+# Convert the XML files for bounding box annotations into a single CSV.
+echo "Extracting bounding box information from XML."
+BOUNDING_BOX_SCRIPT="${WORK_DIR}/data/process_bounding_boxes.py"
+BOUNDING_BOX_FILE="${SCRATCH_DIR}/imagenet_2012_bounding_boxes.csv"
+BOUNDING_BOX_DIR="${SCRATCH_DIR}bounding_boxes/"
+
+"${BOUNDING_BOX_SCRIPT}" "${BOUNDING_BOX_DIR}" "${LABELS_FILE}" \
+ | sort > "${BOUNDING_BOX_FILE}"
+echo "Finished downloading and preprocessing the ImageNet data."
+
+# Build the TFRecords version of the ImageNet data.
+BUILD_SCRIPT="${WORK_DIR}/build_imagenet_data"
+OUTPUT_DIRECTORY="${DATA_DIR}"
+IMAGENET_METADATA_FILE="${WORK_DIR}/data/imagenet_metadata.txt"
+
+"${BUILD_SCRIPT}" \
+ --train_directory="${TRAIN_DIRECTORY}" \
+ --validation_directory="${VALIDATION_DIRECTORY}" \
+ --output_directory="${OUTPUT_DIRECTORY}" \
+ --imagenet_metadata_file="${IMAGENET_METADATA_FILE}" \
+ --labels_file="${LABELS_FILE}" \
+ --bounding_box_file="${BOUNDING_BOX_FILE}"
diff --git a/models/research/inception/inception/data/download_imagenet.sh b/models/research/inception/inception/data/download_imagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f6c77781c0bcaad642ec7a38a7ba00693ef8ef83
--- /dev/null
+++ b/models/research/inception/inception/data/download_imagenet.sh
@@ -0,0 +1,104 @@
+#!/bin/bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Script to download ImageNet Challenge 2012 training and validation data set.
+#
+# Downloads and decompresses raw images and bounding boxes.
+#
+# **IMPORTANT**
+# To download the raw images, the user must create an account with image-net.org
+# and generate a username and access_key. The latter two are required for
+# downloading the raw images.
+#
+# usage:
+# ./download_imagenet.sh [dir name] [synsets file]
+set -e
+
+if [ "x$IMAGENET_ACCESS_KEY" == x -o "x$IMAGENET_USERNAME" == x ]; then
+ cat <')
+ sys.exit(-1)
+ data_dir = sys.argv[1]
+ validation_labels_file = sys.argv[2]
+
+ # Read in the 50000 synsets associated with the validation data set.
+ labels = [l.strip() for l in open(validation_labels_file).readlines()]
+ unique_labels = set(labels)
+
+ # Make all sub-directories in the validation data dir.
+ for label in unique_labels:
+ labeled_data_dir = os.path.join(data_dir, label)
+ # Catch error if sub-directory exists
+ try:
+ os.makedirs(labeled_data_dir)
+ except OSError as e:
+ # Raise all errors but 'EEXIST'
+ if e.errno != errno.EEXIST:
+ raise
+
+ # Move all of the image to the appropriate sub-directory.
+ for i in range(len(labels)):
+ basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1)
+ original_filename = os.path.join(data_dir, basename)
+ if not os.path.exists(original_filename):
+ print('Failed to find: %s' % original_filename)
+ sys.exit(-1)
+ new_filename = os.path.join(data_dir, labels[i], basename)
+ os.rename(original_filename, new_filename)
diff --git a/models/research/inception/inception/data/process_bounding_boxes.py b/models/research/inception/inception/data/process_bounding_boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e9fd786e40b6d95b89fcc9f9774aa7f132c1a6f
--- /dev/null
+++ b/models/research/inception/inception/data/process_bounding_boxes.py
@@ -0,0 +1,254 @@
+#!/usr/bin/python
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Process the ImageNet Challenge bounding boxes for TensorFlow model training.
+
+This script is called as
+
+process_bounding_boxes.py [synsets-file]
+
+Where is a directory containing the downloaded and unpacked bounding box
+data. If [synsets-file] is supplied, then only the bounding boxes whose
+synstes are contained within this file are returned. Note that the
+[synsets-file] file contains synset ids, one per line.
+
+The script dumps out a CSV text file in which each line contains an entry.
+ n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940
+
+The entry can be read as:
+ , , , ,
+
+The bounding box for contains two points (xmin, ymin) and
+(xmax, ymax) specifying the lower-left corner and upper-right corner of a
+bounding box in *relative* coordinates.
+
+The user supplies a directory where the XML files reside. The directory
+structure in the directory is assumed to look like this:
+
+/nXXXXXXXX/nXXXXXXXX_YYYY.xml
+
+Each XML file contains a bounding box annotation. The script:
+
+ (1) Parses the XML file and extracts the filename, label and bounding box info.
+
+ (2) The bounding box is specified in the XML files as integer (xmin, ymin) and
+ (xmax, ymax) *relative* to image size displayed to the human annotator. The
+ size of the image displayed to the human annotator is stored in the XML file
+ as integer (height, width).
+
+ Note that the displayed size will differ from the actual size of the image
+ downloaded from image-net.org. To make the bounding box annotation useable,
+ we convert bounding box to floating point numbers relative to displayed
+ height and width of the image.
+
+ Note that each XML file might contain N bounding box annotations.
+
+ Note that the points are all clamped at a range of [0.0, 1.0] because some
+ human annotations extend outside the range of the supplied image.
+
+ See details here: http://image-net.org/download-bboxes
+
+(3) By default, the script outputs all valid bounding boxes. If a
+ [synsets-file] is supplied, only the subset of bounding boxes associated
+ with those synsets are outputted. Importantly, one can supply a list of
+ synsets in the ImageNet Challenge and output the list of bounding boxes
+ associated with the training images of the ILSVRC.
+
+ We use these bounding boxes to inform the random distortion of images
+ supplied to the network.
+
+If you run this script successfully, you will see the following output
+to stderr:
+> Finished processing 544546 XML files.
+> Skipped 0 XML files not in ImageNet Challenge.
+> Skipped 0 bounding boxes not in ImageNet Challenge.
+> Wrote 615299 bounding boxes from 544546 annotated images.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import os.path
+import sys
+import xml.etree.ElementTree as ET
+
+
+class BoundingBox(object):
+ pass
+
+
+def GetItem(name, root, index=0):
+ count = 0
+ for item in root.iter(name):
+ if count == index:
+ return item.text
+ count += 1
+ # Failed to find "index" occurrence of item.
+ return -1
+
+
+def GetInt(name, root, index=0):
+ # In some XML annotation files, the point values are not integers, but floats.
+ # So we add a float function to avoid ValueError.
+ return int(float(GetItem(name, root, index)))
+
+
+def FindNumberBoundingBoxes(root):
+ index = 0
+ while True:
+ if GetInt('xmin', root, index) == -1:
+ break
+ index += 1
+ return index
+
+
+def ProcessXMLAnnotation(xml_file):
+ """Process a single XML file containing a bounding box."""
+ # pylint: disable=broad-except
+ try:
+ tree = ET.parse(xml_file)
+ except Exception:
+ print('Failed to parse: ' + xml_file, file=sys.stderr)
+ return None
+ # pylint: enable=broad-except
+ root = tree.getroot()
+
+ num_boxes = FindNumberBoundingBoxes(root)
+ boxes = []
+
+ for index in range(num_boxes):
+ box = BoundingBox()
+ # Grab the 'index' annotation.
+ box.xmin = GetInt('xmin', root, index)
+ box.ymin = GetInt('ymin', root, index)
+ box.xmax = GetInt('xmax', root, index)
+ box.ymax = GetInt('ymax', root, index)
+
+ box.width = GetInt('width', root)
+ box.height = GetInt('height', root)
+ box.filename = GetItem('filename', root) + '.JPEG'
+ box.label = GetItem('name', root)
+
+ xmin = float(box.xmin) / float(box.width)
+ xmax = float(box.xmax) / float(box.width)
+ ymin = float(box.ymin) / float(box.height)
+ ymax = float(box.ymax) / float(box.height)
+
+ # Some images contain bounding box annotations that
+ # extend outside of the supplied image. See, e.g.
+ # n03127925/n03127925_147.xml
+ # Additionally, for some bounding boxes, the min > max
+ # or the box is entirely outside of the image.
+ min_x = min(xmin, xmax)
+ max_x = max(xmin, xmax)
+ box.xmin_scaled = min(max(min_x, 0.0), 1.0)
+ box.xmax_scaled = min(max(max_x, 0.0), 1.0)
+
+ min_y = min(ymin, ymax)
+ max_y = max(ymin, ymax)
+ box.ymin_scaled = min(max(min_y, 0.0), 1.0)
+ box.ymax_scaled = min(max(max_y, 0.0), 1.0)
+
+ boxes.append(box)
+
+ return boxes
+
+if __name__ == '__main__':
+ if len(sys.argv) < 2 or len(sys.argv) > 3:
+ print('Invalid usage\n'
+ 'usage: process_bounding_boxes.py [synsets-file]',
+ file=sys.stderr)
+ sys.exit(-1)
+
+ xml_files = glob.glob(sys.argv[1] + '/*/*.xml')
+ print('Identified %d XML files in %s' % (len(xml_files), sys.argv[1]),
+ file=sys.stderr)
+
+ if len(sys.argv) == 3:
+ labels = set([l.strip() for l in open(sys.argv[2]).readlines()])
+ print('Identified %d synset IDs in %s' % (len(labels), sys.argv[2]),
+ file=sys.stderr)
+ else:
+ labels = None
+
+ skipped_boxes = 0
+ skipped_files = 0
+ saved_boxes = 0
+ saved_files = 0
+ for file_index, one_file in enumerate(xml_files):
+ # Example: <...>/n06470073/n00141669_6790.xml
+ label = os.path.basename(os.path.dirname(one_file))
+
+ # Determine if the annotation is from an ImageNet Challenge label.
+ if labels is not None and label not in labels:
+ skipped_files += 1
+ continue
+
+ bboxes = ProcessXMLAnnotation(one_file)
+ assert bboxes is not None, 'No bounding boxes found in ' + one_file
+
+ found_box = False
+ for bbox in bboxes:
+ if labels is not None:
+ if bbox.label != label:
+ # Note: There is a slight bug in the bounding box annotation data.
+ # Many of the dog labels have the human label 'Scottish_deerhound'
+ # instead of the synset ID 'n02092002' in the bbox.label field. As a
+ # simple hack to overcome this issue, we only exclude bbox labels
+ # *which are synset ID's* that do not match original synset label for
+ # the XML file.
+ if bbox.label in labels:
+ skipped_boxes += 1
+ continue
+
+ # Guard against improperly specified boxes.
+ if (bbox.xmin_scaled >= bbox.xmax_scaled or
+ bbox.ymin_scaled >= bbox.ymax_scaled):
+ skipped_boxes += 1
+ continue
+
+ # Note bbox.filename occasionally contains '%s' in the name. This is
+ # data set noise that is fixed by just using the basename of the XML file.
+ image_filename = os.path.splitext(os.path.basename(one_file))[0]
+ print('%s.JPEG,%.4f,%.4f,%.4f,%.4f' %
+ (image_filename,
+ bbox.xmin_scaled, bbox.ymin_scaled,
+ bbox.xmax_scaled, bbox.ymax_scaled))
+
+ saved_boxes += 1
+ found_box = True
+ if found_box:
+ saved_files += 1
+ else:
+ skipped_files += 1
+
+ if not file_index % 5000:
+ print('--> processed %d of %d XML files.' %
+ (file_index + 1, len(xml_files)),
+ file=sys.stderr)
+ print('--> skipped %d boxes and %d XML files.' %
+ (skipped_boxes, skipped_files), file=sys.stderr)
+
+ print('Finished processing %d XML files.' % len(xml_files), file=sys.stderr)
+ print('Skipped %d XML files not in ImageNet Challenge.' % skipped_files,
+ file=sys.stderr)
+ print('Skipped %d bounding boxes not in ImageNet Challenge.' % skipped_boxes,
+ file=sys.stderr)
+ print('Wrote %d bounding boxes from %d annotated images.' %
+ (saved_boxes, saved_files),
+ file=sys.stderr)
+ print('Finished.', file=sys.stderr)
diff --git a/models/research/inception/inception/dataset.py b/models/research/inception/inception/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..752c97e03b0361975d64b72892cc94333e353dfb
--- /dev/null
+++ b/models/research/inception/inception/dataset.py
@@ -0,0 +1,103 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Small library that points to a data set.
+
+Methods of Data class:
+ data_files: Returns a python list of all (sharded) data set files.
+ num_examples_per_epoch: Returns the number of examples in the data set.
+ num_classes: Returns the number of classes in the data set.
+ reader: Return a reader for a single entry from the data set.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from abc import ABCMeta
+from abc import abstractmethod
+import os
+
+
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+# Basic model parameters.
+tf.app.flags.DEFINE_string('data_dir', '/tmp/mydata',
+ """Path to the processed data, i.e. """
+ """TFRecord of Example protos.""")
+
+
+class Dataset(object):
+ """A simple class for handling data sets."""
+ __metaclass__ = ABCMeta
+
+ def __init__(self, name, subset):
+ """Initialize dataset using a subset and the path to the data."""
+ assert subset in self.available_subsets(), self.available_subsets()
+ self.name = name
+ self.subset = subset
+
+ @abstractmethod
+ def num_classes(self):
+ """Returns the number of classes in the data set."""
+ pass
+ # return 10
+
+ @abstractmethod
+ def num_examples_per_epoch(self):
+ """Returns the number of examples in the data subset."""
+ pass
+ # if self.subset == 'train':
+ # return 10000
+ # if self.subset == 'validation':
+ # return 1000
+
+ @abstractmethod
+ def download_message(self):
+ """Prints a download message for the Dataset."""
+ pass
+
+ def available_subsets(self):
+ """Returns the list of available subsets."""
+ return ['train', 'validation']
+
+ def data_files(self):
+ """Returns a python list of all (sharded) data subset files.
+
+ Returns:
+ python list of all (sharded) data set files.
+ Raises:
+ ValueError: if there are not data_files matching the subset.
+ """
+ tf_record_pattern = os.path.join(FLAGS.data_dir, '%s-*' % self.subset)
+ data_files = tf.gfile.Glob(tf_record_pattern)
+ if not data_files:
+ print('No files found for dataset %s/%s at %s' % (self.name,
+ self.subset,
+ FLAGS.data_dir))
+
+ self.download_message()
+ exit(-1)
+ return data_files
+
+ def reader(self):
+ """Return a reader for a single entry from the data set.
+
+ See io_ops.py for details of Reader class.
+
+ Returns:
+ Reader object that reads the data set.
+ """
+ return tf.TFRecordReader()
diff --git a/models/research/inception/inception/flowers_data.py b/models/research/inception/inception/flowers_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..022b5234deef035a6150a54ed74445b510f1b148
--- /dev/null
+++ b/models/research/inception/inception/flowers_data.py
@@ -0,0 +1,52 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Small library that points to the flowers data set.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+from inception.dataset import Dataset
+
+
+class FlowersData(Dataset):
+ """Flowers data set."""
+
+ def __init__(self, subset):
+ super(FlowersData, self).__init__('Flowers', subset)
+
+ def num_classes(self):
+ """Returns the number of classes in the data set."""
+ return 5
+
+ def num_examples_per_epoch(self):
+ """Returns the number of examples in the data subset."""
+ if self.subset == 'train':
+ return 3170
+ if self.subset == 'validation':
+ return 500
+
+ def download_message(self):
+ """Instruction to download and extract the tarball from Flowers website."""
+
+ print('Failed to find any Flowers %s files'% self.subset)
+ print('')
+ print('If you have already downloaded and processed the data, then make '
+ 'sure to set --data_dir to point to the directory containing the '
+ 'location of the sharded TFRecords.\n')
+ print('Please see README.md for instructions on how to build '
+ 'the flowers dataset using download_and_preprocess_flowers.\n')
diff --git a/models/research/inception/inception/flowers_eval.py b/models/research/inception/inception/flowers_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae3e9dc14c8dc83368aa83f523ade92e12113554
--- /dev/null
+++ b/models/research/inception/inception/flowers_eval.py
@@ -0,0 +1,40 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A binary to evaluate Inception on the flowers data set.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from inception import inception_eval
+from inception.flowers_data import FlowersData
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(unused_argv=None):
+ dataset = FlowersData(subset=FLAGS.subset)
+ assert dataset.data_files()
+ if tf.gfile.Exists(FLAGS.eval_dir):
+ tf.gfile.DeleteRecursively(FLAGS.eval_dir)
+ tf.gfile.MakeDirs(FLAGS.eval_dir)
+ inception_eval.evaluate(dataset)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/inception/inception/flowers_train.py b/models/research/inception/inception/flowers_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f044a539d48ef6ce011831210b4bc31eba278f3
--- /dev/null
+++ b/models/research/inception/inception/flowers_train.py
@@ -0,0 +1,41 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A binary to train Inception on the flowers data set.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+import tensorflow as tf
+
+from inception import inception_train
+from inception.flowers_data import FlowersData
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(_):
+ dataset = FlowersData(subset=FLAGS.subset)
+ assert dataset.data_files()
+ if tf.gfile.Exists(FLAGS.train_dir):
+ tf.gfile.DeleteRecursively(FLAGS.train_dir)
+ tf.gfile.MakeDirs(FLAGS.train_dir)
+ inception_train.train(dataset)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/inception/inception/image_processing.py b/models/research/inception/inception/image_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe74f1b3c9958060b15f52df80b11606c7ccf343
--- /dev/null
+++ b/models/research/inception/inception/image_processing.py
@@ -0,0 +1,513 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Read and preprocess image data.
+
+ Image processing occurs on a single image at a time. Image are read and
+ preprocessed in parallel across multiple threads. The resulting images
+ are concatenated together to form a single batch for training or evaluation.
+
+ -- Provide processed image data for a network:
+ inputs: Construct batches of evaluation examples of images.
+ distorted_inputs: Construct batches of training examples of images.
+ batch_inputs: Construct batches of training or evaluation examples of images.
+
+ -- Data processing:
+ parse_example_proto: Parses an Example proto containing a training example
+ of an image.
+
+ -- Image decoding:
+ decode_jpeg: Decode a JPEG encoded string into a 3-D float32 Tensor.
+
+ -- Image preprocessing:
+ image_preprocessing: Decode and preprocess one image for evaluation or training
+ distort_image: Distort one image for training a network.
+ eval_image: Prepare one image for evaluation.
+ distort_color: Distort the color in one image for training.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+ """Number of images to process in a batch.""")
+tf.app.flags.DEFINE_integer('image_size', 299,
+ """Provide square images of this size.""")
+tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
+ """Number of preprocessing threads per tower. """
+ """Please make this a multiple of 4.""")
+tf.app.flags.DEFINE_integer('num_readers', 4,
+ """Number of parallel readers during train.""")
+
+# Images are preprocessed asynchronously using multiple threads specified by
+# --num_preprocss_threads and the resulting processed images are stored in a
+# random shuffling queue. The shuffling queue dequeues --batch_size images
+# for processing on a given Inception tower. A larger shuffling queue guarantees
+# better mixing across examples within a batch and results in slightly higher
+# predictive performance in a trained model. Empirically,
+# --input_queue_memory_factor=16 works well. A value of 16 implies a queue size
+# of 1024*16 images. Assuming RGB 299x299 images, this implies a queue size of
+# 16GB. If the machine is memory limited, then decrease this factor to
+# decrease the CPU memory footprint, accordingly.
+tf.app.flags.DEFINE_integer('input_queue_memory_factor', 16,
+ """Size of the queue of preprocessed images. """
+ """Default is ideal but try smaller values, e.g. """
+ """4, 2 or 1, if host memory is constrained. See """
+ """comments in code for more details.""")
+
+
+def inputs(dataset, batch_size=None, num_preprocess_threads=None):
+ """Generate batches of ImageNet images for evaluation.
+
+ Use this function as the inputs for evaluating a network.
+
+ Note that some (minimal) image preprocessing occurs during evaluation
+ including central cropping and resizing of the image to fit the network.
+
+ Args:
+ dataset: instance of Dataset class specifying the dataset.
+ batch_size: integer, number of examples in batch
+ num_preprocess_threads: integer, total number of preprocessing threads but
+ None defaults to FLAGS.num_preprocess_threads.
+
+ Returns:
+ images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
+ image_size, 3].
+ labels: 1-D integer Tensor of [FLAGS.batch_size].
+ """
+ if not batch_size:
+ batch_size = FLAGS.batch_size
+
+ # Force all input processing onto CPU in order to reserve the GPU for
+ # the forward inference and back-propagation.
+ with tf.device('/cpu:0'):
+ images, labels = batch_inputs(
+ dataset, batch_size, train=False,
+ num_preprocess_threads=num_preprocess_threads,
+ num_readers=1)
+
+ return images, labels
+
+
+def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
+ """Generate batches of distorted versions of ImageNet images.
+
+ Use this function as the inputs for training a network.
+
+ Distorting images provides a useful technique for augmenting the data
+ set during training in order to make the network invariant to aspects
+ of the image that do not effect the label.
+
+ Args:
+ dataset: instance of Dataset class specifying the dataset.
+ batch_size: integer, number of examples in batch
+ num_preprocess_threads: integer, total number of preprocessing threads but
+ None defaults to FLAGS.num_preprocess_threads.
+
+ Returns:
+ images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
+ FLAGS.image_size, 3].
+ labels: 1-D integer Tensor of [batch_size].
+ """
+ if not batch_size:
+ batch_size = FLAGS.batch_size
+
+ # Force all input processing onto CPU in order to reserve the GPU for
+ # the forward inference and back-propagation.
+ with tf.device('/cpu:0'):
+ images, labels = batch_inputs(
+ dataset, batch_size, train=True,
+ num_preprocess_threads=num_preprocess_threads,
+ num_readers=FLAGS.num_readers)
+ return images, labels
+
+
+def decode_jpeg(image_buffer, scope=None):
+ """Decode a JPEG string into one 3-D float image Tensor.
+
+ Args:
+ image_buffer: scalar string Tensor.
+ scope: Optional scope for name_scope.
+ Returns:
+ 3-D float Tensor with values ranging from [0, 1).
+ """
+ with tf.name_scope(values=[image_buffer], name=scope,
+ default_name='decode_jpeg'):
+ # Decode the string as an RGB JPEG.
+ # Note that the resulting image contains an unknown height and width
+ # that is set dynamically by decode_jpeg. In other words, the height
+ # and width of image is unknown at compile-time.
+ image = tf.image.decode_jpeg(image_buffer, channels=3)
+
+ # After this point, all image pixels reside in [0,1)
+ # until the very end, when they're rescaled to (-1, 1). The various
+ # adjust_* ops all require this range for dtype float.
+ image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+ return image
+
+
+def distort_color(image, thread_id=0, scope=None):
+ """Distort the color of the image.
+
+ Each color distortion is non-commutative and thus ordering of the color ops
+ matters. Ideally we would randomly permute the ordering of the color ops.
+ Rather than adding that level of complication, we select a distinct ordering
+ of color ops for each preprocessing thread.
+
+ Args:
+ image: Tensor containing single image.
+ thread_id: preprocessing thread ID.
+ scope: Optional scope for name_scope.
+ Returns:
+ color-distorted image
+ """
+ with tf.name_scope(values=[image], name=scope, default_name='distort_color'):
+ color_ordering = thread_id % 2
+
+ if color_ordering == 0:
+ image = tf.image.random_brightness(image, max_delta=32. / 255.)
+ image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+ image = tf.image.random_hue(image, max_delta=0.2)
+ image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+ elif color_ordering == 1:
+ image = tf.image.random_brightness(image, max_delta=32. / 255.)
+ image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+ image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+ image = tf.image.random_hue(image, max_delta=0.2)
+
+ # The random_* ops do not necessarily clamp.
+ image = tf.clip_by_value(image, 0.0, 1.0)
+ return image
+
+
+def distort_image(image, height, width, bbox, thread_id=0, scope=None):
+ """Distort one image for training a network.
+
+ Distorting images provides a useful technique for augmenting the data
+ set during training in order to make the network invariant to aspects
+ of the image that do not effect the label.
+
+ Args:
+ image: 3-D float Tensor of image
+ height: integer
+ width: integer
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged
+ as [ymin, xmin, ymax, xmax].
+ thread_id: integer indicating the preprocessing thread.
+ scope: Optional scope for name_scope.
+ Returns:
+ 3-D float Tensor of distorted image used for training.
+ """
+ with tf.name_scope(values=[image, height, width, bbox], name=scope,
+ default_name='distort_image'):
+ # Each bounding box has shape [1, num_boxes, box coords] and
+ # the coordinates are ordered [ymin, xmin, ymax, xmax].
+
+ # Display the bounding box in the first thread only.
+ if not thread_id:
+ image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
+ bbox)
+ tf.summary.image('image_with_bounding_boxes', image_with_box)
+
+ # A large fraction of image datasets contain a human-annotated bounding
+ # box delineating the region of the image containing the object of interest.
+ # We choose to create a new bounding box for the object which is a randomly
+ # distorted version of the human-annotated bounding box that obeys an allowed
+ # range of aspect ratios, sizes and overlap with the human-annotated
+ # bounding box. If no box is supplied, then we assume the bounding box is
+ # the entire image.
+ sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
+ tf.shape(image),
+ bounding_boxes=bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=[0.75, 1.33],
+ area_range=[0.05, 1.0],
+ max_attempts=100,
+ use_image_if_no_bounding_boxes=True)
+ bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
+ if not thread_id:
+ image_with_distorted_box = tf.image.draw_bounding_boxes(
+ tf.expand_dims(image, 0), distort_bbox)
+ tf.summary.image('images_with_distorted_bounding_box',
+ image_with_distorted_box)
+
+ # Crop the image to the specified bounding box.
+ distorted_image = tf.slice(image, bbox_begin, bbox_size)
+
+ # This resizing operation may distort the images because the aspect
+ # ratio is not respected. We select a resize method in a round robin
+ # fashion based on the thread number.
+ # Note that ResizeMethod contains 4 enumerated resizing methods.
+ resize_method = thread_id % 4
+ distorted_image = tf.image.resize_images(distorted_image, [height, width],
+ method=resize_method)
+ # Restore the shape since the dynamic slice based upon the bbox_size loses
+ # the third dimension.
+ distorted_image.set_shape([height, width, 3])
+ if not thread_id:
+ tf.summary.image('cropped_resized_image',
+ tf.expand_dims(distorted_image, 0))
+
+ # Randomly flip the image horizontally.
+ distorted_image = tf.image.random_flip_left_right(distorted_image)
+
+ # Randomly distort the colors.
+ distorted_image = distort_color(distorted_image, thread_id)
+
+ if not thread_id:
+ tf.summary.image('final_distorted_image',
+ tf.expand_dims(distorted_image, 0))
+ return distorted_image
+
+
+def eval_image(image, height, width, scope=None):
+ """Prepare one image for evaluation.
+
+ Args:
+ image: 3-D float Tensor
+ height: integer
+ width: integer
+ scope: Optional scope for name_scope.
+ Returns:
+ 3-D float Tensor of prepared image.
+ """
+ with tf.name_scope(values=[image, height, width], name=scope,
+ default_name='eval_image'):
+ # Crop the central region of the image with an area containing 87.5% of
+ # the original image.
+ image = tf.image.central_crop(image, central_fraction=0.875)
+
+ # Resize the image to the original height and width.
+ image = tf.expand_dims(image, 0)
+ image = tf.image.resize_bilinear(image, [height, width],
+ align_corners=False)
+ image = tf.squeeze(image, [0])
+ return image
+
+
+def image_preprocessing(image_buffer, bbox, train, thread_id=0):
+ """Decode and preprocess one image for evaluation or training.
+
+ Args:
+ image_buffer: JPEG encoded string Tensor
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ train: boolean
+ thread_id: integer indicating preprocessing thread
+
+ Returns:
+ 3-D float Tensor containing an appropriately scaled image
+
+ Raises:
+ ValueError: if user does not provide bounding box
+ """
+ if bbox is None:
+ raise ValueError('Please supply a bounding box.')
+
+ image = decode_jpeg(image_buffer)
+ height = FLAGS.image_size
+ width = FLAGS.image_size
+
+ if train:
+ image = distort_image(image, height, width, bbox, thread_id)
+ else:
+ image = eval_image(image, height, width)
+
+ # Finally, rescale to [-1,1] instead of [0, 1)
+ image = tf.subtract(image, 0.5)
+ image = tf.multiply(image, 2.0)
+ return image
+
+
+def parse_example_proto(example_serialized):
+ """Parses an Example proto containing a training example of an image.
+
+ The output of the build_image_data.py image preprocessing script is a dataset
+ containing serialized Example protocol buffers. Each Example proto contains
+ the following fields:
+
+ image/height: 462
+ image/width: 581
+ image/colorspace: 'RGB'
+ image/channels: 3
+ image/class/label: 615
+ image/class/synset: 'n03623198'
+ image/class/text: 'knee pad'
+ image/object/bbox/xmin: 0.1
+ image/object/bbox/xmax: 0.9
+ image/object/bbox/ymin: 0.2
+ image/object/bbox/ymax: 0.6
+ image/object/bbox/label: 615
+ image/format: 'JPEG'
+ image/filename: 'ILSVRC2012_val_00041207.JPEG'
+ image/encoded:
+
+ Args:
+ example_serialized: scalar Tensor tf.string containing a serialized
+ Example protocol buffer.
+
+ Returns:
+ image_buffer: Tensor tf.string containing the contents of a JPEG file.
+ label: Tensor tf.int32 containing the label.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ text: Tensor tf.string containing the human-readable label.
+ """
+ # Dense features in Example proto.
+ feature_map = {
+ 'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
+ default_value=''),
+ 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
+ default_value=-1),
+ 'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
+ default_value=''),
+ }
+ sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
+ # Sparse features in Example proto.
+ feature_map.update(
+ {k: sparse_float32 for k in ['image/object/bbox/xmin',
+ 'image/object/bbox/ymin',
+ 'image/object/bbox/xmax',
+ 'image/object/bbox/ymax']})
+
+ features = tf.parse_single_example(example_serialized, feature_map)
+ label = tf.cast(features['image/class/label'], dtype=tf.int32)
+
+ xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
+ ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
+ xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
+ ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
+
+ # Note that we impose an ordering of (y, x) just to make life difficult.
+ bbox = tf.concat(axis=0, values=[ymin, xmin, ymax, xmax])
+
+ # Force the variable number of bounding boxes into the shape
+ # [1, num_boxes, coords].
+ bbox = tf.expand_dims(bbox, 0)
+ bbox = tf.transpose(bbox, [0, 2, 1])
+
+ return features['image/encoded'], label, bbox, features['image/class/text']
+
+
+def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None,
+ num_readers=1):
+ """Contruct batches of training or evaluation examples from the image dataset.
+
+ Args:
+ dataset: instance of Dataset class specifying the dataset.
+ See dataset.py for details.
+ batch_size: integer
+ train: boolean
+ num_preprocess_threads: integer, total number of preprocessing threads
+ num_readers: integer, number of parallel readers
+
+ Returns:
+ images: 4-D float Tensor of a batch of images
+ labels: 1-D integer Tensor of [batch_size].
+
+ Raises:
+ ValueError: if data is not found
+ """
+ with tf.name_scope('batch_processing'):
+ data_files = dataset.data_files()
+ if data_files is None:
+ raise ValueError('No data files found for this dataset')
+
+ # Create filename_queue
+ if train:
+ filename_queue = tf.train.string_input_producer(data_files,
+ shuffle=True,
+ capacity=16)
+ else:
+ filename_queue = tf.train.string_input_producer(data_files,
+ shuffle=False,
+ capacity=1)
+ if num_preprocess_threads is None:
+ num_preprocess_threads = FLAGS.num_preprocess_threads
+
+ if num_preprocess_threads % 4:
+ raise ValueError('Please make num_preprocess_threads a multiple '
+ 'of 4 (%d % 4 != 0).', num_preprocess_threads)
+
+ if num_readers is None:
+ num_readers = FLAGS.num_readers
+
+ if num_readers < 1:
+ raise ValueError('Please make num_readers at least 1')
+
+ # Approximate number of examples per shard.
+ examples_per_shard = 1024
+ # Size the random shuffle queue to balance between good global
+ # mixing (more examples) and memory use (fewer examples).
+ # 1 image uses 299*299*3*4 bytes = 1MB
+ # The default input_queue_memory_factor is 16 implying a shuffling queue
+ # size: examples_per_shard * 16 * 1MB = 17.6GB
+ min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
+ if train:
+ examples_queue = tf.RandomShuffleQueue(
+ capacity=min_queue_examples + 3 * batch_size,
+ min_after_dequeue=min_queue_examples,
+ dtypes=[tf.string])
+ else:
+ examples_queue = tf.FIFOQueue(
+ capacity=examples_per_shard + 3 * batch_size,
+ dtypes=[tf.string])
+
+ # Create multiple readers to populate the queue of examples.
+ if num_readers > 1:
+ enqueue_ops = []
+ for _ in range(num_readers):
+ reader = dataset.reader()
+ _, value = reader.read(filename_queue)
+ enqueue_ops.append(examples_queue.enqueue([value]))
+
+ tf.train.queue_runner.add_queue_runner(
+ tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
+ example_serialized = examples_queue.dequeue()
+ else:
+ reader = dataset.reader()
+ _, example_serialized = reader.read(filename_queue)
+
+ images_and_labels = []
+ for thread_id in range(num_preprocess_threads):
+ # Parse a serialized Example proto to extract the image and metadata.
+ image_buffer, label_index, bbox, _ = parse_example_proto(
+ example_serialized)
+ image = image_preprocessing(image_buffer, bbox, train, thread_id)
+ images_and_labels.append([image, label_index])
+
+ images, label_index_batch = tf.train.batch_join(
+ images_and_labels,
+ batch_size=batch_size,
+ capacity=2 * num_preprocess_threads * batch_size)
+
+ # Reshape images into these desired dimensions.
+ height = FLAGS.image_size
+ width = FLAGS.image_size
+ depth = 3
+
+ images = tf.cast(images, tf.float32)
+ images = tf.reshape(images, shape=[batch_size, height, width, depth])
+
+ # Display the training images in the visualizer.
+ tf.summary.image('images', images)
+
+ return images, tf.reshape(label_index_batch, [batch_size])
diff --git a/models/research/inception/inception/imagenet_data.py b/models/research/inception/inception/imagenet_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6d22e1292632f0899355d5aa7183c3f5f33b2c
--- /dev/null
+++ b/models/research/inception/inception/imagenet_data.py
@@ -0,0 +1,59 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Small library that points to the ImageNet data set.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+from inception.dataset import Dataset
+
+
+class ImagenetData(Dataset):
+ """ImageNet data set."""
+
+ def __init__(self, subset):
+ super(ImagenetData, self).__init__('ImageNet', subset)
+
+ def num_classes(self):
+ """Returns the number of classes in the data set."""
+ return 1000
+
+ def num_examples_per_epoch(self):
+ """Returns the number of examples in the data set."""
+ # Bounding box data consists of 615299 bounding boxes for 544546 images.
+ if self.subset == 'train':
+ return 1281167
+ if self.subset == 'validation':
+ return 50000
+
+ def download_message(self):
+ """Instruction to download and extract the tarball from Flowers website."""
+
+ print('Failed to find any ImageNet %s files'% self.subset)
+ print('')
+ print('If you have already downloaded and processed the data, then make '
+ 'sure to set --data_dir to point to the directory containing the '
+ 'location of the sharded TFRecords.\n')
+ print('If you have not downloaded and prepared the ImageNet data in the '
+ 'TFRecord format, you will need to do this at least once. This '
+ 'process could take several hours depending on the speed of your '
+ 'computer and network connection\n')
+ print('Please see README.md for instructions on how to build '
+ 'the ImageNet dataset using download_and_preprocess_imagenet.\n')
+ print('Note that the raw data size is 300 GB and the processed data size '
+ 'is 150 GB. Please ensure you have at least 500GB disk space.')
diff --git a/models/research/inception/inception/imagenet_distributed_train.py b/models/research/inception/inception/imagenet_distributed_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3615e012f042649b52e37aeaeeb2c3efc07f92c
--- /dev/null
+++ b/models/research/inception/inception/imagenet_distributed_train.py
@@ -0,0 +1,66 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=line-too-long
+"""A binary to train Inception in a distributed manner using multiple systems.
+
+Please see accompanying README.md for details and instructions.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from inception import inception_distributed_train
+from inception.imagenet_data import ImagenetData
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(unused_args):
+ assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker'
+
+ # Extract all the hostnames for the ps and worker jobs to construct the
+ # cluster spec.
+ ps_hosts = FLAGS.ps_hosts.split(',')
+ worker_hosts = FLAGS.worker_hosts.split(',')
+ tf.logging.info('PS hosts are: %s' % ps_hosts)
+ tf.logging.info('Worker hosts are: %s' % worker_hosts)
+
+ cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,
+ 'worker': worker_hosts})
+ server = tf.train.Server(
+ {'ps': ps_hosts,
+ 'worker': worker_hosts},
+ job_name=FLAGS.job_name,
+ task_index=FLAGS.task_id,
+ protocol=FLAGS.protocol)
+
+ if FLAGS.job_name == 'ps':
+ # `ps` jobs wait for incoming connections from the workers.
+ server.join()
+ else:
+ # `worker` jobs will actually do the work.
+ dataset = ImagenetData(subset=FLAGS.subset)
+ assert dataset.data_files()
+ # Only the chief checks for or creates train_dir.
+ if FLAGS.task_id == 0:
+ if not tf.gfile.Exists(FLAGS.train_dir):
+ tf.gfile.MakeDirs(FLAGS.train_dir)
+ inception_distributed_train.train(server.target, dataset, cluster_spec)
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run()
diff --git a/models/research/inception/inception/imagenet_eval.py b/models/research/inception/inception/imagenet_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6f8bac2ee71021914715172296d63dd56b5a6f9
--- /dev/null
+++ b/models/research/inception/inception/imagenet_eval.py
@@ -0,0 +1,46 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A binary to evaluate Inception on the ImageNet data set.
+
+Note that using the supplied pre-trained inception checkpoint, the eval should
+achieve:
+ precision @ 1 = 0.7874 recall @ 5 = 0.9436 [50000 examples]
+
+See the README.md for more details.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from inception import inception_eval
+from inception.imagenet_data import ImagenetData
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(unused_argv=None):
+ dataset = ImagenetData(subset=FLAGS.subset)
+ assert dataset.data_files()
+ if tf.gfile.Exists(FLAGS.eval_dir):
+ tf.gfile.DeleteRecursively(FLAGS.eval_dir)
+ tf.gfile.MakeDirs(FLAGS.eval_dir)
+ inception_eval.evaluate(dataset)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/inception/inception/imagenet_train.py b/models/research/inception/inception/imagenet_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ffb55ee963e5b9f8e31915a78eef518324642aa
--- /dev/null
+++ b/models/research/inception/inception/imagenet_train.py
@@ -0,0 +1,41 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A binary to train Inception on the ImageNet data set.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+import tensorflow as tf
+
+from inception import inception_train
+from inception.imagenet_data import ImagenetData
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(_):
+ dataset = ImagenetData(subset=FLAGS.subset)
+ assert dataset.data_files()
+ if tf.gfile.Exists(FLAGS.train_dir):
+ tf.gfile.DeleteRecursively(FLAGS.train_dir)
+ tf.gfile.MakeDirs(FLAGS.train_dir)
+ inception_train.train(dataset)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/inception/inception/inception_distributed_train.py b/models/research/inception/inception/inception_distributed_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1a589acb5fe386fd648ae3fae926ee927c0ca79
--- /dev/null
+++ b/models/research/inception/inception/inception_distributed_train.py
@@ -0,0 +1,314 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A library to train Inception using multiple replicas with synchronous update.
+
+Please see accompanying README.md for details and instructions.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import os.path
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from inception import image_processing
+from inception import inception_model as inception
+from inception.slim import slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('job_name', '', 'One of "ps", "worker"')
+tf.app.flags.DEFINE_string('ps_hosts', '',
+ """Comma-separated list of hostname:port for the """
+ """parameter server jobs. e.g. """
+ """'machine1:2222,machine2:1111,machine2:2222'""")
+tf.app.flags.DEFINE_string('worker_hosts', '',
+ """Comma-separated list of hostname:port for the """
+ """worker jobs. e.g. """
+ """'machine1:2222,machine2:1111,machine2:2222'""")
+tf.app.flags.DEFINE_string('protocol', 'grpc',
+ """Communication protocol to use in distributed """
+ """execution (default grpc) """)
+
+tf.app.flags.DEFINE_string('train_dir', '/tmp/imagenet_train',
+ """Directory where to write event logs """
+ """and checkpoint.""")
+tf.app.flags.DEFINE_integer('max_steps', 1000000, 'Number of batches to run.')
+tf.app.flags.DEFINE_string('subset', 'train', 'Either "train" or "validation".')
+tf.app.flags.DEFINE_boolean('log_device_placement', False,
+ 'Whether to log device placement.')
+
+# Task ID is used to select the chief and also to access the local_step for
+# each replica to check staleness of the gradients in SyncReplicasOptimizer.
+tf.app.flags.DEFINE_integer(
+ 'task_id', 0, 'Task ID of the worker/replica running the training.')
+
+# More details can be found in the SyncReplicasOptimizer class:
+# tensorflow/python/training/sync_replicas_optimizer.py
+tf.app.flags.DEFINE_integer('num_replicas_to_aggregate', -1,
+ """Number of gradients to collect before """
+ """updating the parameters.""")
+tf.app.flags.DEFINE_integer('save_interval_secs', 10 * 60,
+ 'Save interval seconds.')
+tf.app.flags.DEFINE_integer('save_summaries_secs', 180,
+ 'Save summaries interval seconds.')
+
+# **IMPORTANT**
+# Please note that this learning rate schedule is heavily dependent on the
+# hardware architecture, batch size and any changes to the model architecture
+# specification. Selecting a finely tuned learning rate schedule is an
+# empirical process that requires some experimentation. Please see README.md
+# more guidance and discussion.
+#
+# Learning rate decay factor selected from https://arxiv.org/abs/1604.00981
+tf.app.flags.DEFINE_float('initial_learning_rate', 0.045,
+ 'Initial learning rate.')
+tf.app.flags.DEFINE_float('num_epochs_per_decay', 2.0,
+ 'Epochs after which learning rate decays.')
+tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.94,
+ 'Learning rate decay factor.')
+
+# Constants dictating the learning rate schedule.
+RMSPROP_DECAY = 0.9 # Decay term for RMSProp.
+RMSPROP_MOMENTUM = 0.9 # Momentum in RMSProp.
+RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp.
+
+
+def train(target, dataset, cluster_spec):
+ """Train Inception on a dataset for a number of steps."""
+ # Number of workers and parameter servers are inferred from the workers and ps
+ # hosts string.
+ num_workers = len(cluster_spec.as_dict()['worker'])
+ num_parameter_servers = len(cluster_spec.as_dict()['ps'])
+ # If no value is given, num_replicas_to_aggregate defaults to be the number of
+ # workers.
+ if FLAGS.num_replicas_to_aggregate == -1:
+ num_replicas_to_aggregate = num_workers
+ else:
+ num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate
+
+ # Both should be greater than 0 in a distributed training.
+ assert num_workers > 0 and num_parameter_servers > 0, (' num_workers and '
+ 'num_parameter_servers'
+ ' must be > 0.')
+
+ # Choose worker 0 as the chief. Note that any worker could be the chief
+ # but there should be only one chief.
+ is_chief = (FLAGS.task_id == 0)
+
+ # Ops are assigned to worker by default.
+ with tf.device('/job:worker/task:%d' % FLAGS.task_id):
+ # Variables and its related init/assign ops are assigned to ps.
+ with slim.scopes.arg_scope(
+ [slim.variables.variable, slim.variables.global_step],
+ device=slim.variables.VariableDeviceChooser(num_parameter_servers)):
+ # Create a variable to count the number of train() calls. This equals the
+ # number of updates applied to the variables.
+ global_step = slim.variables.global_step()
+
+ # Calculate the learning rate schedule.
+ num_batches_per_epoch = (dataset.num_examples_per_epoch() /
+ FLAGS.batch_size)
+ # Decay steps need to be divided by the number of replicas to aggregate.
+ decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay /
+ num_replicas_to_aggregate)
+
+ # Decay the learning rate exponentially based on the number of steps.
+ lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
+ global_step,
+ decay_steps,
+ FLAGS.learning_rate_decay_factor,
+ staircase=True)
+ # Add a summary to track the learning rate.
+ tf.summary.scalar('learning_rate', lr)
+
+ # Create an optimizer that performs gradient descent.
+ opt = tf.train.RMSPropOptimizer(lr,
+ RMSPROP_DECAY,
+ momentum=RMSPROP_MOMENTUM,
+ epsilon=RMSPROP_EPSILON)
+
+ images, labels = image_processing.distorted_inputs(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ num_preprocess_threads=FLAGS.num_preprocess_threads)
+
+ # Number of classes in the Dataset label set plus 1.
+ # Label 0 is reserved for an (unused) background class.
+ num_classes = dataset.num_classes() + 1
+ logits = inception.inference(images, num_classes, for_training=True)
+ # Add classification loss.
+ inception.loss(logits, labels)
+
+ # Gather all of the losses including regularization losses.
+ losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
+ losses += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
+
+ total_loss = tf.add_n(losses, name='total_loss')
+
+ if is_chief:
+ # Compute the moving average of all individual losses and the
+ # total loss.
+ loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
+ loss_averages_op = loss_averages.apply(losses + [total_loss])
+
+ # Attach a scalar summmary to all individual losses and the total loss;
+ # do the same for the averaged version of the losses.
+ for l in losses + [total_loss]:
+ loss_name = l.op.name
+ # Name each loss as '(raw)' and name the moving average version of the
+ # loss as the original loss name.
+ tf.summary.scalar(loss_name + ' (raw)', l)
+ tf.summary.scalar(loss_name, loss_averages.average(l))
+
+ # Add dependency to compute loss_averages.
+ with tf.control_dependencies([loss_averages_op]):
+ total_loss = tf.identity(total_loss)
+
+ # Track the moving averages of all trainable variables.
+ # Note that we maintain a 'double-average' of the BatchNormalization
+ # global statistics.
+ # This is not needed when the number of replicas are small but important
+ # for synchronous distributed training with tens of workers/replicas.
+ exp_moving_averager = tf.train.ExponentialMovingAverage(
+ inception.MOVING_AVERAGE_DECAY, global_step)
+
+ variables_to_average = (
+ tf.trainable_variables() + tf.moving_average_variables())
+
+ # Add histograms for model variables.
+ for var in variables_to_average:
+ tf.summary.histogram(var.op.name, var)
+
+ # Create synchronous replica optimizer.
+ opt = tf.train.SyncReplicasOptimizer(
+ opt,
+ replicas_to_aggregate=num_replicas_to_aggregate,
+ total_num_replicas=num_workers,
+ variable_averages=exp_moving_averager,
+ variables_to_average=variables_to_average)
+
+ batchnorm_updates = tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION)
+ assert batchnorm_updates, 'Batchnorm updates are missing'
+ batchnorm_updates_op = tf.group(*batchnorm_updates)
+ # Add dependency to compute batchnorm_updates.
+ with tf.control_dependencies([batchnorm_updates_op]):
+ total_loss = tf.identity(total_loss)
+
+ # Compute gradients with respect to the loss.
+ grads = opt.compute_gradients(total_loss)
+
+ # Add histograms for gradients.
+ for grad, var in grads:
+ if grad is not None:
+ tf.summary.histogram(var.op.name + '/gradients', grad)
+
+ apply_gradients_op = opt.apply_gradients(grads, global_step=global_step)
+
+ with tf.control_dependencies([apply_gradients_op]):
+ train_op = tf.identity(total_loss, name='train_op')
+
+ # Get chief queue_runners and init_tokens, which is used to synchronize
+ # replicas. More details can be found in SyncReplicasOptimizer.
+ chief_queue_runners = [opt.get_chief_queue_runner()]
+ init_tokens_op = opt.get_init_tokens_op()
+
+ # Create a saver.
+ saver = tf.train.Saver()
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.summary.merge_all()
+
+ # Build an initialization operation to run below.
+ init_op = tf.global_variables_initializer()
+
+ # We run the summaries in the same thread as the training operations by
+ # passing in None for summary_op to avoid a summary_thread being started.
+ # Running summaries and training operations in parallel could run out of
+ # GPU memory.
+ sv = tf.train.Supervisor(is_chief=is_chief,
+ logdir=FLAGS.train_dir,
+ init_op=init_op,
+ summary_op=None,
+ global_step=global_step,
+ saver=saver,
+ save_model_secs=FLAGS.save_interval_secs)
+
+ tf.logging.info('%s Supervisor' % datetime.now())
+
+ sess_config = tf.ConfigProto(
+ allow_soft_placement=True,
+ log_device_placement=FLAGS.log_device_placement)
+
+ # Get a session.
+ sess = sv.prepare_or_wait_for_session(target, config=sess_config)
+
+ # Start the queue runners.
+ queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
+ sv.start_queue_runners(sess, queue_runners)
+ tf.logging.info('Started %d queues for processing input data.',
+ len(queue_runners))
+
+ if is_chief:
+ sv.start_queue_runners(sess, chief_queue_runners)
+ sess.run(init_tokens_op)
+
+ # Train, checking for Nans. Concurrently run the summary operation at a
+ # specified interval. Note that the summary_op and train_op never run
+ # simultaneously in order to prevent running out of GPU memory.
+ next_summary_time = time.time() + FLAGS.save_summaries_secs
+ while not sv.should_stop():
+ try:
+ start_time = time.time()
+ loss_value, step = sess.run([train_op, global_step])
+ assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
+ if step > FLAGS.max_steps:
+ break
+ duration = time.time() - start_time
+
+ if step % 30 == 0:
+ examples_per_sec = FLAGS.batch_size / float(duration)
+ format_str = ('Worker %d: %s: step %d, loss = %.2f'
+ '(%.1f examples/sec; %.3f sec/batch)')
+ tf.logging.info(format_str %
+ (FLAGS.task_id, datetime.now(), step, loss_value,
+ examples_per_sec, duration))
+
+ # Determine if the summary_op should be run on the chief worker.
+ if is_chief and next_summary_time < time.time():
+ tf.logging.info('Running Summary operation on the chief.')
+ summary_str = sess.run(summary_op)
+ sv.summary_computed(sess, summary_str)
+ tf.logging.info('Finished running Summary operation.')
+
+ # Determine the next time for running the summary.
+ next_summary_time += FLAGS.save_summaries_secs
+ except:
+ if is_chief:
+ tf.logging.info('Chief got exception while running!')
+ raise
+
+ # Stop the supervisor. This also waits for service threads to finish.
+ sv.stop()
+
+ # Save after the training ends.
+ if is_chief:
+ saver.save(sess,
+ os.path.join(FLAGS.train_dir, 'model.ckpt'),
+ global_step=global_step)
diff --git a/models/research/inception/inception/inception_eval.py b/models/research/inception/inception/inception_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7cfc3c399dd82a915b3a49c7ddd4a8565292f69
--- /dev/null
+++ b/models/research/inception/inception/inception_eval.py
@@ -0,0 +1,171 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A library to evaluate Inception on a single GPU.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import math
+import os.path
+import time
+
+
+import numpy as np
+import tensorflow as tf
+
+from inception import image_processing
+from inception import inception_model as inception
+
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('eval_dir', '/tmp/imagenet_eval',
+ """Directory where to write event logs.""")
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/imagenet_train',
+ """Directory where to read model checkpoints.""")
+
+# Flags governing the frequency of the eval.
+tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
+ """How often to run the eval.""")
+tf.app.flags.DEFINE_boolean('run_once', False,
+ """Whether to run eval only once.""")
+
+# Flags governing the data used for the eval.
+tf.app.flags.DEFINE_integer('num_examples', 50000,
+ """Number of examples to run. Note that the eval """
+ """ImageNet dataset contains 50000 examples.""")
+tf.app.flags.DEFINE_string('subset', 'validation',
+ """Either 'validation' or 'train'.""")
+
+
+def _eval_once(saver, summary_writer, top_1_op, top_5_op, summary_op):
+ """Runs Eval once.
+
+ Args:
+ saver: Saver.
+ summary_writer: Summary writer.
+ top_1_op: Top 1 op.
+ top_5_op: Top 5 op.
+ summary_op: Summary op.
+ """
+ with tf.Session() as sess:
+ ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
+ if ckpt and ckpt.model_checkpoint_path:
+ if os.path.isabs(ckpt.model_checkpoint_path):
+ # Restores from checkpoint with absolute path.
+ saver.restore(sess, ckpt.model_checkpoint_path)
+ else:
+ # Restores from checkpoint with relative path.
+ saver.restore(sess, os.path.join(FLAGS.checkpoint_dir,
+ ckpt.model_checkpoint_path))
+
+ # Assuming model_checkpoint_path looks something like:
+ # /my-favorite-path/imagenet_train/model.ckpt-0,
+ # extract global_step from it.
+ global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
+ print('Successfully loaded model from %s at step=%s.' %
+ (ckpt.model_checkpoint_path, global_step))
+ else:
+ print('No checkpoint file found')
+ return
+
+ # Start the queue runners.
+ coord = tf.train.Coordinator()
+ try:
+ threads = []
+ for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
+ threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
+ start=True))
+
+ num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
+ # Counts the number of correct predictions.
+ count_top_1 = 0.0
+ count_top_5 = 0.0
+ total_sample_count = num_iter * FLAGS.batch_size
+ step = 0
+
+ print('%s: starting evaluation on (%s).' % (datetime.now(), FLAGS.subset))
+ start_time = time.time()
+ while step < num_iter and not coord.should_stop():
+ top_1, top_5 = sess.run([top_1_op, top_5_op])
+ count_top_1 += np.sum(top_1)
+ count_top_5 += np.sum(top_5)
+ step += 1
+ if step % 20 == 0:
+ duration = time.time() - start_time
+ sec_per_batch = duration / 20.0
+ examples_per_sec = FLAGS.batch_size / sec_per_batch
+ print('%s: [%d batches out of %d] (%.1f examples/sec; %.3f'
+ 'sec/batch)' % (datetime.now(), step, num_iter,
+ examples_per_sec, sec_per_batch))
+ start_time = time.time()
+
+ # Compute precision @ 1.
+ precision_at_1 = count_top_1 / total_sample_count
+ recall_at_5 = count_top_5 / total_sample_count
+ print('%s: precision @ 1 = %.4f recall @ 5 = %.4f [%d examples]' %
+ (datetime.now(), precision_at_1, recall_at_5, total_sample_count))
+
+ summary = tf.Summary()
+ summary.ParseFromString(sess.run(summary_op))
+ summary.value.add(tag='Precision @ 1', simple_value=precision_at_1)
+ summary.value.add(tag='Recall @ 5', simple_value=recall_at_5)
+ summary_writer.add_summary(summary, global_step)
+
+ except Exception as e: # pylint: disable=broad-except
+ coord.request_stop(e)
+
+ coord.request_stop()
+ coord.join(threads, stop_grace_period_secs=10)
+
+
+def evaluate(dataset):
+ """Evaluate model on Dataset for a number of steps."""
+ with tf.Graph().as_default():
+ # Get images and labels from the dataset.
+ images, labels = image_processing.inputs(dataset)
+
+ # Number of classes in the Dataset label set plus 1.
+ # Label 0 is reserved for an (unused) background class.
+ num_classes = dataset.num_classes() + 1
+
+ # Build a Graph that computes the logits predictions from the
+ # inference model.
+ logits, _ = inception.inference(images, num_classes)
+
+ # Calculate predictions.
+ top_1_op = tf.nn.in_top_k(logits, labels, 1)
+ top_5_op = tf.nn.in_top_k(logits, labels, 5)
+
+ # Restore the moving average version of the learned variables for eval.
+ variable_averages = tf.train.ExponentialMovingAverage(
+ inception.MOVING_AVERAGE_DECAY)
+ variables_to_restore = variable_averages.variables_to_restore()
+ saver = tf.train.Saver(variables_to_restore)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.summary.merge_all()
+
+ graph_def = tf.get_default_graph().as_graph_def()
+ summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
+ graph_def=graph_def)
+
+ while True:
+ _eval_once(saver, summary_writer, top_1_op, top_5_op, summary_op)
+ if FLAGS.run_once:
+ break
+ time.sleep(FLAGS.eval_interval_secs)
diff --git a/models/research/inception/inception/inception_model.py b/models/research/inception/inception/inception_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fedae13ae712f09d23ff020b161d86e87ee46e95
--- /dev/null
+++ b/models/research/inception/inception/inception_model.py
@@ -0,0 +1,157 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Build the Inception v3 network on ImageNet data set.
+
+The Inception v3 architecture is described in http://arxiv.org/abs/1512.00567
+
+Summary of available functions:
+ inference: Compute inference on the model inputs to make a prediction
+ loss: Compute the loss of the prediction with respect to the labels
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+import tensorflow as tf
+
+from inception.slim import slim
+
+FLAGS = tf.app.flags.FLAGS
+
+# If a model is trained using multiple GPUs, prefix all Op names with tower_name
+# to differentiate the operations. Note that this prefix is removed from the
+# names of the summaries when visualizing a model.
+TOWER_NAME = 'tower'
+
+# Batch normalization. Constant governing the exponential moving average of
+# the 'global' mean and variance for all activations.
+BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997
+
+# The decay to use for the moving average.
+MOVING_AVERAGE_DECAY = 0.9999
+
+
+def inference(images, num_classes, for_training=False, restore_logits=True,
+ scope=None):
+ """Build Inception v3 model architecture.
+
+ See here for reference: http://arxiv.org/abs/1512.00567
+
+ Args:
+ images: Images returned from inputs() or distorted_inputs().
+ num_classes: number of classes
+ for_training: If set to `True`, build the inference model for training.
+ Kernels that operate differently for inference during training
+ e.g. dropout, are appropriately configured.
+ restore_logits: whether or not the logits layers should be restored.
+ Useful for fine-tuning a model with different num_classes.
+ scope: optional prefix string identifying the ImageNet tower.
+
+ Returns:
+ Logits. 2-D float Tensor.
+ Auxiliary Logits. 2-D float Tensor of side-head. Used for training only.
+ """
+ # Parameters for BatchNorm.
+ batch_norm_params = {
+ # Decay for the moving averages.
+ 'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
+ # epsilon to prevent 0s in variance.
+ 'epsilon': 0.001,
+ }
+ # Set weight_decay for weights in Conv and FC layers.
+ with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
+ with slim.arg_scope([slim.ops.conv2d],
+ stddev=0.1,
+ activation=tf.nn.relu,
+ batch_norm_params=batch_norm_params):
+ logits, endpoints = slim.inception.inception_v3(
+ images,
+ dropout_keep_prob=0.8,
+ num_classes=num_classes,
+ is_training=for_training,
+ restore_logits=restore_logits,
+ scope=scope)
+
+ # Add summaries for viewing model statistics on TensorBoard.
+ _activation_summaries(endpoints)
+
+ # Grab the logits associated with the side head. Employed during training.
+ auxiliary_logits = endpoints['aux_logits']
+
+ return logits, auxiliary_logits
+
+
+def loss(logits, labels, batch_size=None):
+ """Adds all losses for the model.
+
+ Note the final loss is not returned. Instead, the list of losses are collected
+ by slim.losses. The losses are accumulated in tower_loss() and summed to
+ calculate the total loss.
+
+ Args:
+ logits: List of logits from inference(). Each entry is a 2-D float Tensor.
+ labels: Labels from distorted_inputs or inputs(). 1-D tensor
+ of shape [batch_size]
+ batch_size: integer
+ """
+ if not batch_size:
+ batch_size = FLAGS.batch_size
+
+ # Reshape the labels into a dense Tensor of
+ # shape [FLAGS.batch_size, num_classes].
+ sparse_labels = tf.reshape(labels, [batch_size, 1])
+ indices = tf.reshape(tf.range(batch_size), [batch_size, 1])
+ concated = tf.concat(axis=1, values=[indices, sparse_labels])
+ num_classes = logits[0].get_shape()[-1].value
+ dense_labels = tf.sparse_to_dense(concated,
+ [batch_size, num_classes],
+ 1.0, 0.0)
+
+ # Cross entropy loss for the main softmax prediction.
+ slim.losses.cross_entropy_loss(logits[0],
+ dense_labels,
+ label_smoothing=0.1,
+ weight=1.0)
+
+ # Cross entropy loss for the auxiliary softmax head.
+ slim.losses.cross_entropy_loss(logits[1],
+ dense_labels,
+ label_smoothing=0.1,
+ weight=0.4,
+ scope='aux_loss')
+
+
+def _activation_summary(x):
+ """Helper to create summaries for activations.
+
+ Creates a summary that provides a histogram of activations.
+ Creates a summary that measure the sparsity of activations.
+
+ Args:
+ x: Tensor
+ """
+ # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
+ # session. This helps the clarity of presentation on tensorboard.
+ tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
+ tf.summary.histogram(tensor_name + '/activations', x)
+ tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
+
+
+def _activation_summaries(endpoints):
+ with tf.name_scope('summaries'):
+ for act in endpoints.values():
+ _activation_summary(act)
diff --git a/models/research/inception/inception/inception_train.py b/models/research/inception/inception/inception_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c32713b2012aec8a18637ec5dd79a1cc84d90f
--- /dev/null
+++ b/models/research/inception/inception/inception_train.py
@@ -0,0 +1,357 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A library to train Inception using multiple GPUs with synchronous updates.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+from datetime import datetime
+import os.path
+import re
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from inception import image_processing
+from inception import inception_model as inception
+from inception.slim import slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('train_dir', '/tmp/imagenet_train',
+ """Directory where to write event logs """
+ """and checkpoint.""")
+tf.app.flags.DEFINE_integer('max_steps', 10000000,
+ """Number of batches to run.""")
+tf.app.flags.DEFINE_string('subset', 'train',
+ """Either 'train' or 'validation'.""")
+
+# Flags governing the hardware employed for running TensorFlow.
+tf.app.flags.DEFINE_integer('num_gpus', 1,
+ """How many GPUs to use.""")
+tf.app.flags.DEFINE_boolean('log_device_placement', False,
+ """Whether to log device placement.""")
+
+# Flags governing the type of training.
+tf.app.flags.DEFINE_boolean('fine_tune', False,
+ """If set, randomly initialize the final layer """
+ """of weights in order to train the network on a """
+ """new task.""")
+tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', '',
+ """If specified, restore this pretrained model """
+ """before beginning any training.""")
+
+# **IMPORTANT**
+# Please note that this learning rate schedule is heavily dependent on the
+# hardware architecture, batch size and any changes to the model architecture
+# specification. Selecting a finely tuned learning rate schedule is an
+# empirical process that requires some experimentation. Please see README.md
+# more guidance and discussion.
+#
+# With 8 Tesla K40's and a batch size = 256, the following setup achieves
+# precision@1 = 73.5% after 100 hours and 100K steps (20 epochs).
+# Learning rate decay factor selected from http://arxiv.org/abs/1404.5997.
+tf.app.flags.DEFINE_float('initial_learning_rate', 0.1,
+ """Initial learning rate.""")
+tf.app.flags.DEFINE_float('num_epochs_per_decay', 30.0,
+ """Epochs after which learning rate decays.""")
+tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.16,
+ """Learning rate decay factor.""")
+
+# Constants dictating the learning rate schedule.
+RMSPROP_DECAY = 0.9 # Decay term for RMSProp.
+RMSPROP_MOMENTUM = 0.9 # Momentum in RMSProp.
+RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp.
+
+
+def _tower_loss(images, labels, num_classes, scope, reuse_variables=None):
+ """Calculate the total loss on a single tower running the ImageNet model.
+
+ We perform 'batch splitting'. This means that we cut up a batch across
+ multiple GPUs. For instance, if the batch size = 32 and num_gpus = 2,
+ then each tower will operate on an batch of 16 images.
+
+ Args:
+ images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
+ FLAGS.image_size, 3].
+ labels: 1-D integer Tensor of [batch_size].
+ num_classes: number of classes
+ scope: unique prefix string identifying the ImageNet tower, e.g.
+ 'tower_0'.
+
+ Returns:
+ Tensor of shape [] containing the total loss for a batch of data
+ """
+ # When fine-tuning a model, we do not restore the logits but instead we
+ # randomly initialize the logits. The number of classes in the output of the
+ # logit is the number of classes in specified Dataset.
+ restore_logits = not FLAGS.fine_tune
+
+ # Build inference Graph.
+ with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
+ logits = inception.inference(images, num_classes, for_training=True,
+ restore_logits=restore_logits,
+ scope=scope)
+
+ # Build the portion of the Graph calculating the losses. Note that we will
+ # assemble the total_loss using a custom function below.
+ split_batch_size = images.get_shape().as_list()[0]
+ inception.loss(logits, labels, batch_size=split_batch_size)
+
+ # Assemble all of the losses for the current tower only.
+ losses = tf.get_collection(slim.losses.LOSSES_COLLECTION, scope)
+
+ # Calculate the total loss for the current tower.
+ regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
+ total_loss = tf.add_n(losses + regularization_losses, name='total_loss')
+
+ # Compute the moving average of all individual losses and the total loss.
+ loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
+ loss_averages_op = loss_averages.apply(losses + [total_loss])
+
+ # Attach a scalar summmary to all individual losses and the total loss; do the
+ # same for the averaged version of the losses.
+ for l in losses + [total_loss]:
+ # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
+ # session. This helps the clarity of presentation on TensorBoard.
+ loss_name = re.sub('%s_[0-9]*/' % inception.TOWER_NAME, '', l.op.name)
+ # Name each loss as '(raw)' and name the moving average version of the loss
+ # as the original loss name.
+ tf.summary.scalar(loss_name +' (raw)', l)
+ tf.summary.scalar(loss_name, loss_averages.average(l))
+
+ with tf.control_dependencies([loss_averages_op]):
+ total_loss = tf.identity(total_loss)
+ return total_loss
+
+
+def _average_gradients(tower_grads):
+ """Calculate the average gradient for each shared variable across all towers.
+
+ Note that this function provides a synchronization point across all towers.
+
+ Args:
+ tower_grads: List of lists of (gradient, variable) tuples. The outer list
+ is over individual gradients. The inner list is over the gradient
+ calculation for each tower.
+ Returns:
+ List of pairs of (gradient, variable) where the gradient has been averaged
+ across all towers.
+ """
+ average_grads = []
+ for grad_and_vars in zip(*tower_grads):
+ # Note that each grad_and_vars looks like the following:
+ # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
+ grads = []
+ for g, _ in grad_and_vars:
+ # Add 0 dimension to the gradients to represent the tower.
+ expanded_g = tf.expand_dims(g, 0)
+
+ # Append on a 'tower' dimension which we will average over below.
+ grads.append(expanded_g)
+
+ # Average over the 'tower' dimension.
+ grad = tf.concat(axis=0, values=grads)
+ grad = tf.reduce_mean(grad, 0)
+
+ # Keep in mind that the Variables are redundant because they are shared
+ # across towers. So .. we will just return the first tower's pointer to
+ # the Variable.
+ v = grad_and_vars[0][1]
+ grad_and_var = (grad, v)
+ average_grads.append(grad_and_var)
+ return average_grads
+
+
+def train(dataset):
+ """Train on dataset for a number of steps."""
+ with tf.Graph().as_default(), tf.device('/cpu:0'):
+ # Create a variable to count the number of train() calls. This equals the
+ # number of batches processed * FLAGS.num_gpus.
+ global_step = tf.get_variable(
+ 'global_step', [],
+ initializer=tf.constant_initializer(0), trainable=False)
+
+ # Calculate the learning rate schedule.
+ num_batches_per_epoch = (dataset.num_examples_per_epoch() /
+ FLAGS.batch_size)
+ decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)
+
+ # Decay the learning rate exponentially based on the number of steps.
+ lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
+ global_step,
+ decay_steps,
+ FLAGS.learning_rate_decay_factor,
+ staircase=True)
+
+ # Create an optimizer that performs gradient descent.
+ opt = tf.train.RMSPropOptimizer(lr, RMSPROP_DECAY,
+ momentum=RMSPROP_MOMENTUM,
+ epsilon=RMSPROP_EPSILON)
+
+ # Get images and labels for ImageNet and split the batch across GPUs.
+ assert FLAGS.batch_size % FLAGS.num_gpus == 0, (
+ 'Batch size must be divisible by number of GPUs')
+ split_batch_size = int(FLAGS.batch_size / FLAGS.num_gpus)
+
+ # Override the number of preprocessing threads to account for the increased
+ # number of GPU towers.
+ num_preprocess_threads = FLAGS.num_preprocess_threads * FLAGS.num_gpus
+ images, labels = image_processing.distorted_inputs(
+ dataset,
+ num_preprocess_threads=num_preprocess_threads)
+
+ input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES))
+
+ # Number of classes in the Dataset label set plus 1.
+ # Label 0 is reserved for an (unused) background class.
+ num_classes = dataset.num_classes() + 1
+
+ # Split the batch of images and labels for towers.
+ images_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=images)
+ labels_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=labels)
+
+ # Calculate the gradients for each model tower.
+ tower_grads = []
+ reuse_variables = None
+ for i in range(FLAGS.num_gpus):
+ with tf.device('/gpu:%d' % i):
+ with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
+ # Force all Variables to reside on the CPU.
+ with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
+ # Calculate the loss for one tower of the ImageNet model. This
+ # function constructs the entire ImageNet model but shares the
+ # variables across all towers.
+ loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
+ scope, reuse_variables)
+
+ # Reuse variables for the next tower.
+ reuse_variables = True
+
+ # Retain the summaries from the final tower.
+ summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
+
+ # Retain the Batch Normalization updates operations only from the
+ # final tower. Ideally, we should grab the updates from all towers
+ # but these stats accumulate extremely fast so we can ignore the
+ # other stats from the other towers without significant detriment.
+ batchnorm_updates = tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION,
+ scope)
+
+ # Calculate the gradients for the batch of data on this ImageNet
+ # tower.
+ grads = opt.compute_gradients(loss)
+
+ # Keep track of the gradients across all towers.
+ tower_grads.append(grads)
+
+ # We must calculate the mean of each gradient. Note that this is the
+ # synchronization point across all towers.
+ grads = _average_gradients(tower_grads)
+
+ # Add a summaries for the input processing and global_step.
+ summaries.extend(input_summaries)
+
+ # Add a summary to track the learning rate.
+ summaries.append(tf.summary.scalar('learning_rate', lr))
+
+ # Add histograms for gradients.
+ for grad, var in grads:
+ if grad is not None:
+ summaries.append(
+ tf.summary.histogram(var.op.name + '/gradients', grad))
+
+ # Apply the gradients to adjust the shared variables.
+ apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+ # Add histograms for trainable variables.
+ for var in tf.trainable_variables():
+ summaries.append(tf.summary.histogram(var.op.name, var))
+
+ # Track the moving averages of all trainable variables.
+ # Note that we maintain a "double-average" of the BatchNormalization
+ # global statistics. This is more complicated then need be but we employ
+ # this for backward-compatibility with our previous models.
+ variable_averages = tf.train.ExponentialMovingAverage(
+ inception.MOVING_AVERAGE_DECAY, global_step)
+
+ # Another possibility is to use tf.slim.get_variables().
+ variables_to_average = (tf.trainable_variables() +
+ tf.moving_average_variables())
+ variables_averages_op = variable_averages.apply(variables_to_average)
+
+ # Group all updates to into a single train op.
+ batchnorm_updates_op = tf.group(*batchnorm_updates)
+ train_op = tf.group(apply_gradient_op, variables_averages_op,
+ batchnorm_updates_op)
+
+ # Create a saver.
+ saver = tf.train.Saver(tf.global_variables())
+
+ # Build the summary operation from the last tower summaries.
+ summary_op = tf.summary.merge(summaries)
+
+ # Build an initialization operation to run below.
+ init = tf.global_variables_initializer()
+
+ # Start running operations on the Graph. allow_soft_placement must be set to
+ # True to build towers on GPU, as some of the ops do not have GPU
+ # implementations.
+ sess = tf.Session(config=tf.ConfigProto(
+ allow_soft_placement=True,
+ log_device_placement=FLAGS.log_device_placement))
+ sess.run(init)
+
+ if FLAGS.pretrained_model_checkpoint_path:
+ assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
+ variables_to_restore = tf.get_collection(
+ slim.variables.VARIABLES_TO_RESTORE)
+ restorer = tf.train.Saver(variables_to_restore)
+ restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
+ print('%s: Pre-trained model restored from %s' %
+ (datetime.now(), FLAGS.pretrained_model_checkpoint_path))
+
+ # Start the queue runners.
+ tf.train.start_queue_runners(sess=sess)
+
+ summary_writer = tf.summary.FileWriter(
+ FLAGS.train_dir,
+ graph=sess.graph)
+
+ for step in range(FLAGS.max_steps):
+ start_time = time.time()
+ _, loss_value = sess.run([train_op, loss])
+ duration = time.time() - start_time
+
+ assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
+
+ if step % 10 == 0:
+ examples_per_sec = FLAGS.batch_size / float(duration)
+ format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+ 'sec/batch)')
+ print(format_str % (datetime.now(), step, loss_value,
+ examples_per_sec, duration))
+
+ if step % 100 == 0:
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+
+ # Save the model checkpoint periodically.
+ if step % 5000 == 0 or (step + 1) == FLAGS.max_steps:
+ checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
+ saver.save(sess, checkpoint_path, global_step=step)
diff --git a/models/research/inception/inception/slim/BUILD b/models/research/inception/inception/slim/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..174e77d5c2654380232174a2bb8b29c6b9affc5d
--- /dev/null
+++ b/models/research/inception/inception/slim/BUILD
@@ -0,0 +1,112 @@
+# Description:
+# Contains the operations and nets for building TensorFlow-Slim models.
+
+package(default_visibility = ["//inception:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "scopes",
+ srcs = ["scopes.py"],
+)
+
+py_test(
+ name = "scopes_test",
+ size = "small",
+ srcs = ["scopes_test.py"],
+ deps = [
+ ":scopes",
+ ],
+)
+
+py_library(
+ name = "variables",
+ srcs = ["variables.py"],
+ deps = [
+ ":scopes",
+ ],
+)
+
+py_test(
+ name = "variables_test",
+ size = "small",
+ srcs = ["variables_test.py"],
+ deps = [
+ ":variables",
+ ],
+)
+
+py_library(
+ name = "losses",
+ srcs = ["losses.py"],
+)
+
+py_test(
+ name = "losses_test",
+ size = "small",
+ srcs = ["losses_test.py"],
+ deps = [
+ ":losses",
+ ],
+)
+
+py_library(
+ name = "ops",
+ srcs = ["ops.py"],
+ deps = [
+ ":losses",
+ ":scopes",
+ ":variables",
+ ],
+)
+
+py_test(
+ name = "ops_test",
+ size = "small",
+ srcs = ["ops_test.py"],
+ deps = [
+ ":ops",
+ ":variables",
+ ],
+)
+
+py_library(
+ name = "inception",
+ srcs = ["inception_model.py"],
+ deps = [
+ ":ops",
+ ":scopes",
+ ],
+)
+
+py_test(
+ name = "inception_test",
+ size = "medium",
+ srcs = ["inception_test.py"],
+ deps = [
+ ":inception",
+ ],
+)
+
+py_library(
+ name = "slim",
+ srcs = ["slim.py"],
+ deps = [
+ ":inception",
+ ":losses",
+ ":ops",
+ ":scopes",
+ ":variables",
+ ],
+)
+
+py_test(
+ name = "collections_test",
+ size = "small",
+ srcs = ["collections_test.py"],
+ deps = [
+ ":slim",
+ ],
+)
diff --git a/models/research/inception/inception/slim/README.md b/models/research/inception/inception/slim/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..36d8b7eb19ae47d8810ed97abe203aa34be50a75
--- /dev/null
+++ b/models/research/inception/inception/slim/README.md
@@ -0,0 +1,621 @@
+# TensorFlow-Slim
+
+TF-Slim is a lightweight library for defining, training and evaluating models in
+TensorFlow. It enables defining complex networks quickly and concisely while
+keeping a model's architecture transparent and its hyperparameters explicit.
+
+[TOC]
+
+## Teaser
+
+As a demonstration of the simplicity of using TF-Slim, compare the simplicity of
+the code necessary for defining the entire [VGG](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) network using TF-Slim to
+the lengthy and verbose nature of defining just the first three layers (out of
+16) using native tensorflow:
+
+```python{.good}
+# VGG16 in TF-Slim.
+def vgg16(inputs):
+ with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
+ net = slim.ops.repeat_op(2, inputs, slim.ops.conv2d, 64, [3, 3], scope='conv1')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool1')
+ net = slim.ops.repeat_op(2, net, slim.ops.conv2d, 128, [3, 3], scope='conv2')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool2')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool3')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv4')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool4')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv5')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool5')
+ net = slim.ops.flatten(net, scope='flatten5')
+ net = slim.ops.fc(net, 4096, scope='fc6')
+ net = slim.ops.dropout(net, 0.5, scope='dropout6')
+ net = slim.ops.fc(net, 4096, scope='fc7')
+ net = slim.ops.dropout(net, 0.5, scope='dropout7')
+ net = slim.ops.fc(net, 1000, activation=None, scope='fc8')
+ return net
+```
+
+```python{.bad}
+# Layers 1-3 (out of 16) of VGG16 in native tensorflow.
+def vgg16(inputs):
+ with tf.name_scope('conv1_1') as scope:
+ kernel = tf.Variable(tf.truncated_normal([3, 3, 3, 64], dtype=tf.float32, stddev=1e-1), name='weights')
+ conv = tf.nn.conv2d(inputs, kernel, [1, 1, 1, 1], padding='SAME')
+ biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32), trainable=True, name='biases')
+ bias = tf.nn.bias_add(conv, biases)
+ conv1 = tf.nn.relu(bias, name=scope)
+ with tf.name_scope('conv1_2') as scope:
+ kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 64], dtype=tf.float32, stddev=1e-1), name='weights')
+ conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
+ biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32), trainable=True, name='biases')
+ bias = tf.nn.bias_add(conv, biases)
+ conv1 = tf.nn.relu(bias, name=scope)
+ with tf.name_scope('pool1')
+ pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID', name='pool1')
+```
+
+## Why TF-Slim?
+
+TF-Slim offers several advantages over just the built-in tensorflow libraries:
+
+* Allows one to define models much more compactly by eliminating boilerplate
+ code. This is accomplished through the use of [argument scoping](./scopes.py)
+ and numerous high level [operations](./ops.py). These tools increase
+ readability and maintainability, reduce the likelihood of an error from
+ copy-and-pasting hyperparameter values and simplifies hyperparameter tuning.
+* Makes developing models simple by providing commonly used [loss functions](./losses.py)
+* Provides a concise [definition](./inception_model.py) of [Inception v3](http://arxiv.org/abs/1512.00567) network architecture ready to be used
+ out-of-the-box or subsumed into new models.
+
+Additionally TF-Slim was designed with several principles in mind:
+
+* The various modules of TF-Slim (scopes, variables, ops, losses) are
+ independent. This flexibility allows users to pick and choose components of
+ TF-Slim completely à la carte.
+* TF-Slim is written using a Functional Programming style. That means it's
+ super-lightweight and can be used right alongside any of TensorFlow's native
+ operations.
+* Makes re-using network architectures easy. This allows users to build new
+ networks on top of existing ones as well as fine-tuning pre-trained models
+ on new tasks.
+
+## What are the various components of TF-Slim?
+
+TF-Slim is composed of several parts which were designed to exist independently.
+These include:
+
+* [scopes.py](./scopes.py): provides a new scope named `arg_scope` that allows
+ a user to define default arguments for specific operations within that
+ scope.
+* [variables.py](./variables.py): provides convenience wrappers for variable
+ creation and manipulation.
+* [ops.py](./ops.py): provides high level operations for building models using
+ tensorflow.
+* [losses.py](./losses.py): contains commonly used loss functions.
+
+## Defining Models
+
+Models can be succinctly defined using TF-Slim by combining its variables,
+operations and scopes. Each of these elements are defined below.
+
+### Variables
+
+Creating [`Variables`](https://www.tensorflow.org/how_tos/variables/index.html)
+in native tensorflow requires either a predefined value or an initialization
+mechanism (random, normally distributed). Furthermore, if a variable needs to be
+created on a specific device, such as a GPU, the specification must be [made
+explicit](https://www.tensorflow.org/how_tos/using_gpu/index.html). To alleviate
+the code required for variable creation, TF-Slim provides a set of thin wrapper
+functions in [variables.py](./variables.py) which allow callers to easily define
+variables.
+
+For example, to create a `weight` variable, initialize it using a truncated
+normal distribution, regularize it with an `l2_loss` and place it on the `CPU`,
+one need only declare the following:
+
+```python
+weights = variables.variable('weights',
+ shape=[10, 10, 3 , 3],
+ initializer=tf.truncated_normal_initializer(stddev=0.1),
+ regularizer=lambda t: losses.l2_loss(t, weight=0.05),
+ device='/cpu:0')
+```
+
+In addition to the functionality provided by `tf.Variable`, `slim.variables`
+keeps track of the variables created by `slim.ops` to define a model, which
+allows one to distinguish variables that belong to the model versus other
+variables.
+
+```python
+# Get all the variables defined by the model.
+model_variables = slim.variables.get_variables()
+
+# Get all the variables with the same given name, i.e. 'weights', 'biases'.
+weights = slim.variables.get_variables_by_name('weights')
+biases = slim.variables.get_variables_by_name('biases')
+
+# Get all the variables in VARIABLES_TO_RESTORE collection.
+variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
+
+
+weights = variables.variable('weights',
+ shape=[10, 10, 3 , 3],
+ initializer=tf.truncated_normal_initializer(stddev=0.1),
+ regularizer=lambda t: losses.l2_loss(t, weight=0.05),
+ device='/cpu:0')
+```
+
+### Operations (Layers)
+
+While the set of TensorFlow operations is quite extensive, builders of neural
+networks typically think of models in terms of "layers". A layer, such as a
+Convolutional Layer, a Fully Connected Layer or a BatchNorm Layer are more
+abstract than a single TensorFlow operation and typically involve many such
+operations. For example, a Convolutional Layer in a neural network is built
+using several steps:
+
+1. Creating the weight variables
+2. Creating the bias variables
+3. Convolving the weights with the input from the previous layer
+4. Adding the biases to the result of the convolution.
+
+In python code this can be rather laborious:
+
+```python
+input = ...
+with tf.name_scope('conv1_1') as scope:
+ kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32,
+ stddev=1e-1), name='weights')
+ conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
+ biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32),
+ trainable=True, name='biases')
+ bias = tf.nn.bias_add(conv, biases)
+ conv1 = tf.nn.relu(bias, name=scope)
+```
+
+To alleviate the need to duplicate this code repeatedly, TF-Slim provides a
+number of convenient operations defined at the (more abstract) level of neural
+network layers. For example, compare the code above to an invocation of the
+TF-Slim code:
+
+```python
+input = ...
+net = slim.ops.conv2d(input, [3, 3], 128, scope='conv1_1')
+```
+
+TF-Slim provides numerous operations used in building neural networks which
+roughly correspond to such layers. These include:
+
+Layer | TF-Slim Op
+--------------------- | ------------------------
+Convolutional Layer | [ops.conv2d](./ops.py)
+Fully Connected Layer | [ops.fc](./ops.py)
+BatchNorm layer | [ops.batch_norm](./ops.py)
+Max Pooling Layer | [ops.max_pool](./ops.py)
+Avg Pooling Layer | [ops.avg_pool](./ops.py)
+Dropout Layer | [ops.dropout](./ops.py)
+
+[ops.py](./ops.py) also includes operations that are not really "layers" per se,
+but are often used to manipulate hidden unit representations during inference:
+
+Operation | TF-Slim Op
+--------- | ---------------------
+Flatten | [ops.flatten](./ops.py)
+
+TF-Slim also provides a meta-operation called `repeat_op` that allows one to
+repeatedly perform the same operation. Consider the following snippet from the
+[VGG](https://www.robots.ox.ac.uk/~vgg/research/very_deep/) network whose layers
+perform several convolutions in a row between pooling layers:
+
+```python
+net = ...
+net = slim.ops.conv2d(net, 256, [3, 3], scope='conv3_1')
+net = slim.ops.conv2d(net, 256, [3, 3], scope='conv3_2')
+net = slim.ops.conv2d(net, 256, [3, 3], scope='conv3_3')
+net = slim.ops.max_pool(net, [2, 2], scope='pool3')
+```
+
+This clear duplication of code can be removed via a standard loop:
+
+```python
+net = ...
+for i in range(3):
+ net = slim.ops.conv2d(net, 256, [3, 3], scope='conv3_' % (i+1))
+net = slim.ops.max_pool(net, [2, 2], scope='pool3')
+```
+
+While this does reduce the amount of duplication, it can be made even cleaner by
+using the `RepeatOp`:
+
+```python
+net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
+net = slim.ops.max_pool(net, [2, 2], scope='pool2')
+```
+
+Notice that the RepeatOp not only applies the same argument in-line, it also is
+smart enough to unroll the scopes such that the scopes assigned to each
+subsequent call of `ops.conv2d` is appended with an underscore and iteration
+number. More concretely, the scopes in the example above would be 'conv3_1',
+'conv3_2' and 'conv3_3'.
+
+### Scopes
+
+In addition to the types of scope mechanisms in TensorFlow ([name_scope](https://www.tensorflow.org/api_docs/python/framework.html#name_scope),
+[variable_scope](https://www.tensorflow.org/api_docs/python/state_ops.html#variable_scope),
+TF-Slim adds a new scoping mechanism called "argument scope" or [arg_scope](./scopes.py). This new scope allows a user to specify one or more operations and
+a set of arguments which will be passed to each of the operations defined in the
+`arg_scope`. This functionality is best illustrated by example. Consider the
+following code snippet:
+
+```python
+net = slim.ops.conv2d(inputs, 64, [11, 11], 4, padding='SAME', stddev=0.01, weight_decay=0.0005, scope='conv1')
+net = slim.ops.conv2d(net, 128, [11, 11], padding='VALID', stddev=0.01, weight_decay=0.0005, scope='conv2')
+net = slim.ops.conv2d(net, 256, [11, 11], padding='SAME', stddev=0.01, weight_decay=0.0005, scope='conv3')
+```
+
+It should be clear that these three Convolution layers share many of the same
+hyperparameters. Two have the same padding, all three have the same weight_decay
+and standard deviation of its weights. Not only do the duplicated values make
+the code more difficult to read, it also adds the addition burder to the writer
+of needing to doublecheck that all of the values are identical in each step. One
+solution would be to specify default values using variables:
+
+```python
+padding='SAME'
+stddev=0.01
+weight_decay=0.0005
+net = slim.ops.conv2d(inputs, 64, [11, 11], 4, padding=padding, stddev=stddev, weight_decay=weight_decay, scope='conv1')
+net = slim.ops.conv2d(net, 128, [11, 11], padding='VALID', stddev=stddev, weight_decay=weight_decay, scope='conv2')
+net = slim.ops.conv2d(net, 256, [11, 11], padding=padding, stddev=stddev, weight_decay=weight_decay, scope='conv3')
+
+```
+
+This solution ensures that all three convolutions share the exact same variable
+values but doesn't reduce the code clutter. By using an `arg_scope`, we can both
+ensure that each layer uses the same values and simplify the code:
+
+```python
+ with slim.arg_scope([slim.ops.conv2d], padding='SAME', stddev=0.01, weight_decay=0.0005):
+ net = slim.ops.conv2d(inputs, 64, [11, 11], scope='conv1')
+ net = slim.ops.conv2d(net, 128, [11, 11], padding='VALID', scope='conv2')
+ net = slim.ops.conv2d(net, 256, [11, 11], scope='conv3')
+```
+
+As the example illustrates, the use of arg_scope makes the code cleaner, simpler
+and easier to maintain. Notice that while argument values are specifed in the
+arg_scope, they can be overwritten locally. In particular, while the padding
+argument has been set to 'SAME', the second convolution overrides it with the
+value of 'VALID'.
+
+One can also nest `arg_scope`s and use multiple operations in the same scope.
+For example:
+
+```python
+with arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
+ with arg_scope([slim.ops.conv2d], padding='SAME'), slim.arg_scope([slim.ops.fc], bias=1.0):
+ net = slim.ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
+ net = slim.ops.conv2d(net, 256, [5, 5], stddev=0.03, scope='conv2')
+ net = slim.ops.flatten(net)
+ net = slim.ops.fc(net, 1000, activation=None, scope='fc')
+```
+
+In this example, the first `arg_scope` applies the same `stddev` and
+`weight_decay` arguments to the `conv2d` and `fc` ops in its scope. In the
+second `arg_scope`, additional default arguments to `conv2d` only are specified.
+
+In addition to `arg_scope`, TF-Slim provides several decorators that wrap the
+use of tensorflow arg scopes. These include `@AddArgScope`, `@AddNameScope`,
+`@AddVariableScope`, `@AddOpScope` and `@AddVariableOpScope`. To illustrate
+their use, consider the following example.
+
+```python
+def MyNewOp(inputs):
+ varA = ...
+ varB = ...
+ outputs = tf.multiply(varA, inputs) + varB
+ return outputs
+
+```
+
+In this example, the user has created a new op which creates two variables. To
+ensure that these variables exist within a certain variable scope (to avoid
+collisions with variables with the same name), in standard TF, the op must be
+called within a variable scope:
+
+```python
+inputs = ...
+with tf.variable_scope('layer1'):
+ outputs = MyNewOp(inputs)
+```
+
+As an alternative, one can use TF-Slim's decorators to decorate the function and
+simplify the call:
+
+```python
+@AddVariableScope
+def MyNewOp(inputs):
+ ...
+ return outputs
+
+
+inputs = ...
+outputs = MyNewOp('layer1')
+```
+
+The `@AddVariableScope` decorater simply applies the `tf.variable_scope` scoping
+to the called function taking "layer1" as its argument. This allows the code to
+be written more concisely.
+
+### Losses
+
+The loss function defines a quantity that we want to minimize. For
+classification problems, this is typically the cross entropy between the true
+(one-hot) distribution and the predicted probability distribution across
+classes. For regression problems, this is often the sum-of-squares differences
+between the predicted and true values.
+
+Certain models, such as multi-task learning models, require the use of multiple
+loss functions simultaneously. In other words, the loss function ultimatey being
+minimized is the sum of various other loss functions. For example, consider a
+model that predicts both the type of scene in an image as well as the depth from
+the camera of each pixel. This model's loss function would be the sum of the
+classification loss and depth prediction loss.
+
+TF-Slim provides an easy-to-use mechanism for defining and keeping track of loss
+functions via the [losses.py](./losses.py) module. Consider the simple case
+where we want to train the VGG network:
+
+```python
+# Load the images and labels.
+images, labels = ...
+
+# Create the model.
+predictions = ...
+
+# Define the loss functions and get the total loss.
+loss = losses.cross_entropy_loss(predictions, labels)
+```
+
+In this example, we start by creating the model (using TF-Slim's VGG
+implementation), and add the standard classification loss. Now, lets turn to the
+case where we have a multi-task model that produces multiple outputs:
+
+```python
+# Load the images and labels.
+images, scene_labels, depth_labels = ...
+
+# Create the model.
+scene_predictions, depth_predictions = CreateMultiTaskModel(images)
+
+# Define the loss functions and get the total loss.
+classification_loss = slim.losses.cross_entropy_loss(scene_predictions, scene_labels)
+sum_of_squares_loss = slim.losses.l2loss(depth_predictions - depth_labels)
+
+# The following two lines have the same effect:
+total_loss1 = classification_loss + sum_of_squares_loss
+total_loss2 = tf.get_collection(slim.losses.LOSSES_COLLECTION)
+```
+
+In this example, we have two losses which we add by calling
+`losses.cross_entropy_loss` and `losses.l2loss`. We can obtain the
+total loss by adding them together (`total_loss1`) or by calling
+`losses.GetTotalLoss()`. How did this work? When you create a loss function via
+TF-Slim, TF-Slim adds the loss to a special TensorFlow collection of loss
+functions. This enables you to either manage the total loss manually, or allow
+TF-Slim to manage them for you.
+
+What if you want to let TF-Slim manage the losses for you but have a custom loss
+function? [losses.py](./losses.py) also has a function that adds this loss to
+TF-Slims collection. For example:
+
+```python
+# Load the images and labels.
+images, scene_labels, depth_labels, pose_labels = ...
+
+# Create the model.
+scene_predictions, depth_predictions, pose_predictions = CreateMultiTaskModel(images)
+
+# Define the loss functions and get the total loss.
+classification_loss = slim.losses.cross_entropy_loss(scene_predictions, scene_labels)
+sum_of_squares_loss = slim.losses.l2loss(depth_predictions - depth_labels)
+pose_loss = MyCustomLossFunction(pose_predictions, pose_labels)
+tf.add_to_collection(slim.losses.LOSSES_COLLECTION, pose_loss) # Letting TF-Slim know about the additional loss.
+
+# The following two lines have the same effect:
+total_loss1 = classification_loss + sum_of_squares_loss + pose_loss
+total_loss2 = losses.GetTotalLoss()
+```
+
+In this example, we can again either produce the total loss function manually or
+let TF-Slim know about the additional loss and let TF-Slim handle the losses.
+
+## Putting the Pieces Together
+
+By combining TF-Slim Variables, Operations and scopes, we can write a normally
+very complex network with very few lines of code. For example, the entire [VGG](https://www.robots.ox.ac.uk/~vgg/research/very_deep/) architecture can be
+defined with just the following snippet:
+
+```python
+with arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
+ net = slim.ops.repeat_op(2, inputs, slim.ops.conv2d, 64, [3, 3], scope='conv1')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool1')
+ net = slim.ops.repeat_op(2, net, slim.ops.conv2d, 128, [3, 3], scope='conv2')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool2')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool3')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv4')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool4')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv5')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool5')
+ net = slim.ops.flatten(net, scope='flatten5')
+ net = slim.ops.fc(net, 4096, scope='fc6')
+ net = slim.ops.dropout(net, 0.5, scope='dropout6')
+ net = slim.ops.fc(net, 4096, scope='fc7')
+ net = slim.ops.dropout(net, 0.5, scope='dropout7')
+ net = slim.ops.fc(net, 1000, activation=None, scope='fc8')
+return net
+```
+
+## Re-using previously defined network architectures and pre-trained models.
+
+### Brief Recap on Restoring Variables from a Checkpoint
+
+After a model has been trained, it can be restored using `tf.train.Saver()`
+which restores `Variables` from a given checkpoint. For many cases,
+`tf.train.Saver()` provides a simple mechanism to restore all or just a few
+variables.
+
+```python
+# Create some variables.
+v1 = tf.Variable(..., name="v1")
+v2 = tf.Variable(..., name="v2")
+...
+# Add ops to restore all the variables.
+restorer = tf.train.Saver()
+
+# Add ops to restore some variables.
+restorer = tf.train.Saver([v1, v2])
+
+# Later, launch the model, use the saver to restore variables from disk, and
+# do some work with the model.
+with tf.Session() as sess:
+ # Restore variables from disk.
+ restorer.restore(sess, "/tmp/model.ckpt")
+ print("Model restored.")
+ # Do some work with the model
+ ...
+```
+
+See [Restoring Variables](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#restoring-variables)
+and [Choosing which Variables to Save and Restore](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#choosing-which-variables-to-save-and-restore)
+sections of the [Variables](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html) page for
+more details.
+
+### Using slim.variables to Track which Variables need to be Restored
+
+It is often desirable to fine-tune a pre-trained model on an entirely new
+dataset or even a new task. In these situations, one must specify which layers
+of the model should be reused (and consequently loaded from a checkpoint) and
+which layers are new. Indicating which variables or layers should be restored is
+a process that quickly becomes cumbersome when done manually.
+
+To help keep track of which variables to restore, `slim.variables` provides a
+`restore` argument when creating each Variable. By default, all variables are
+marked as `restore=True`, which results in all variables defined by the model
+being restored.
+
+```python
+# Create some variables.
+v1 = slim.variables.variable(name="v1", ..., restore=False)
+v2 = slim.variables.variable(name="v2", ...) # By default restore=True
+...
+# Get list of variables to restore (which contains only 'v2')
+variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
+restorer = tf.train.Saver(variables_to_restore)
+with tf.Session() as sess:
+ # Restore variables from disk.
+ restorer.restore(sess, "/tmp/model.ckpt")
+ print("Model restored.")
+ # Do some work with the model
+ ...
+```
+
+Additionally, every layer in `slim.ops` that creates slim.variables (such as
+`slim.ops.conv2d`, `slim.ops.fc`, `slim.ops.batch_norm`) also has a `restore`
+argument which controls whether the variables created by that layer should be
+restored or not.
+
+```python
+# Create a small network.
+net = slim.ops.conv2d(images, 32, [7, 7], stride=2, scope='conv1')
+net = slim.ops.conv2d(net, 64, [3, 3], scope='conv2')
+net = slim.ops.conv2d(net, 128, [3, 3], scope='conv3')
+net = slim.ops.max_pool(net, [3, 3], stride=2, scope='pool3')
+net = slim.ops.flatten(net)
+net = slim.ops.fc(net, 10, scope='logits', restore=False)
+...
+
+# VARIABLES_TO_RESTORE would contain the 'weights' and 'bias' defined by 'conv1'
+# 'conv2' and 'conv3' but not the ones defined by 'logits'
+variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
+
+# Create a restorer that would restore only the needed variables.
+restorer = tf.train.Saver(variables_to_restore)
+
+# Create a saver that would save all the variables (including 'logits').
+saver = tf.train.Saver()
+with tf.Session() as sess:
+ # Restore variables from disk.
+ restorer.restore(sess, "/tmp/model.ckpt")
+ print("Model restored.")
+
+ # Do some work with the model
+ ...
+ saver.save(sess, "/tmp/new_model.ckpt")
+```
+
+Note: When restoring variables from a checkpoint, the `Saver` locates the
+variable names in a checkpoint file and maps them to variables in the current
+graph. Above, we created a saver by passing to it a list of variables. In this
+case, the names of the variables to locate in the checkpoint file were
+implicitly obtained from each provided variable's `var.op.name`.
+
+This works well when the variable names in the checkpoint file match those in
+the graph. However, sometimes, we want to restore a model from a checkpoint
+whose variables have different names those in the current graph. In this case,
+we must provide the `Saver` a dictionary that maps from each checkpoint variable
+name to each graph variable. Consider the following example where the checkpoint
+variables names are obtained via a simple function:
+
+```python
+# Assuming that 'conv1/weights' should be restored from 'vgg16/conv1/weights'
+def name_in_checkpoint(var):
+ return 'vgg16/' + var.op.name
+
+# Assuming that 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2'
+def name_in_checkpoint(var):
+ if "weights" in var.op.name:
+ return var.op.name.replace("weights", "params1")
+ if "bias" in var.op.name:
+ return var.op.name.replace("bias", "params2")
+
+variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
+variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
+restorer = tf.train.Saver(variables_to_restore)
+with tf.Session() as sess:
+ # Restore variables from disk.
+ restorer.restore(sess, "/tmp/model.ckpt")
+```
+
+### Reusing the VGG16 network defined in TF-Slim on a different task, i.e. PASCAL-VOC.
+
+Assuming one have already a pre-trained VGG16 model, one just need to replace
+the last layer `fc8` with a new layer `fc8_pascal` and use `restore=False`.
+
+```python
+def vgg16_pascal(inputs):
+ with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
+ net = slim.ops.repeat_op(2, inputs, slim.ops.conv2d, 64, [3, 3], scope='conv1')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool1')
+ net = slim.ops.repeat_op(2, net, slim.ops.conv2d, 128, [3, 3], scope='conv2')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool2')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool3')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv4')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool4')
+ net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv5')
+ net = slim.ops.max_pool(net, [2, 2], scope='pool5')
+ net = slim.ops.flatten(net, scope='flatten5')
+ net = slim.ops.fc(net, 4096, scope='fc6')
+ net = slim.ops.dropout(net, 0.5, scope='dropout6')
+ net = slim.ops.fc(net, 4096, scope='fc7')
+ net = slim.ops.dropout(net, 0.5, scope='dropout7')
+ # To reuse vgg16 on PASCAL-VOC, just change the last layer.
+ net = slim.ops.fc(net, 21, activation=None, scope='fc8_pascal', restore=False)
+ return net
+```
+
+## Authors
+
+Sergio Guadarrama and Nathan Silberman
diff --git a/models/research/inception/inception/slim/collections_test.py b/models/research/inception/inception/slim/collections_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a1f170edaaedae337df8e0b552a03dd82b263d4
--- /dev/null
+++ b/models/research/inception/inception/slim/collections_test.py
@@ -0,0 +1,181 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for inception."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from inception.slim import slim
+
+
+def get_variables(scope=None):
+ return slim.variables.get_variables(scope)
+
+
+def get_variables_by_name(name):
+ return slim.variables.get_variables_by_name(name)
+
+
+class CollectionsTest(tf.test.TestCase):
+
+ def testVariables(self):
+ batch_size = 5
+ height, width = 299, 299
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ with slim.arg_scope([slim.ops.conv2d],
+ batch_norm_params={'decay': 0.9997}):
+ slim.inception.inception_v3(inputs)
+ self.assertEqual(len(get_variables()), 388)
+ self.assertEqual(len(get_variables_by_name('weights')), 98)
+ self.assertEqual(len(get_variables_by_name('biases')), 2)
+ self.assertEqual(len(get_variables_by_name('beta')), 96)
+ self.assertEqual(len(get_variables_by_name('gamma')), 0)
+ self.assertEqual(len(get_variables_by_name('moving_mean')), 96)
+ self.assertEqual(len(get_variables_by_name('moving_variance')), 96)
+
+ def testVariablesWithoutBatchNorm(self):
+ batch_size = 5
+ height, width = 299, 299
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ with slim.arg_scope([slim.ops.conv2d],
+ batch_norm_params=None):
+ slim.inception.inception_v3(inputs)
+ self.assertEqual(len(get_variables()), 196)
+ self.assertEqual(len(get_variables_by_name('weights')), 98)
+ self.assertEqual(len(get_variables_by_name('biases')), 98)
+ self.assertEqual(len(get_variables_by_name('beta')), 0)
+ self.assertEqual(len(get_variables_by_name('gamma')), 0)
+ self.assertEqual(len(get_variables_by_name('moving_mean')), 0)
+ self.assertEqual(len(get_variables_by_name('moving_variance')), 0)
+
+ def testVariablesByLayer(self):
+ batch_size = 5
+ height, width = 299, 299
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ with slim.arg_scope([slim.ops.conv2d],
+ batch_norm_params={'decay': 0.9997}):
+ slim.inception.inception_v3(inputs)
+ self.assertEqual(len(get_variables()), 388)
+ self.assertEqual(len(get_variables('conv0')), 4)
+ self.assertEqual(len(get_variables('conv1')), 4)
+ self.assertEqual(len(get_variables('conv2')), 4)
+ self.assertEqual(len(get_variables('conv3')), 4)
+ self.assertEqual(len(get_variables('conv4')), 4)
+ self.assertEqual(len(get_variables('mixed_35x35x256a')), 28)
+ self.assertEqual(len(get_variables('mixed_35x35x288a')), 28)
+ self.assertEqual(len(get_variables('mixed_35x35x288b')), 28)
+ self.assertEqual(len(get_variables('mixed_17x17x768a')), 16)
+ self.assertEqual(len(get_variables('mixed_17x17x768b')), 40)
+ self.assertEqual(len(get_variables('mixed_17x17x768c')), 40)
+ self.assertEqual(len(get_variables('mixed_17x17x768d')), 40)
+ self.assertEqual(len(get_variables('mixed_17x17x768e')), 40)
+ self.assertEqual(len(get_variables('mixed_8x8x2048a')), 36)
+ self.assertEqual(len(get_variables('mixed_8x8x2048b')), 36)
+ self.assertEqual(len(get_variables('logits')), 2)
+ self.assertEqual(len(get_variables('aux_logits')), 10)
+
+ def testVariablesToRestore(self):
+ batch_size = 5
+ height, width = 299, 299
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ with slim.arg_scope([slim.ops.conv2d],
+ batch_norm_params={'decay': 0.9997}):
+ slim.inception.inception_v3(inputs)
+ variables_to_restore = tf.get_collection(
+ slim.variables.VARIABLES_TO_RESTORE)
+ self.assertEqual(len(variables_to_restore), 388)
+ self.assertListEqual(variables_to_restore, get_variables())
+
+ def testVariablesToRestoreWithoutLogits(self):
+ batch_size = 5
+ height, width = 299, 299
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ with slim.arg_scope([slim.ops.conv2d],
+ batch_norm_params={'decay': 0.9997}):
+ slim.inception.inception_v3(inputs, restore_logits=False)
+ variables_to_restore = tf.get_collection(
+ slim.variables.VARIABLES_TO_RESTORE)
+ self.assertEqual(len(variables_to_restore), 384)
+
+ def testRegularizationLosses(self):
+ batch_size = 5
+ height, width = 299, 299
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
+ slim.inception.inception_v3(inputs)
+ losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(losses), len(get_variables_by_name('weights')))
+
+ def testTotalLossWithoutRegularization(self):
+ batch_size = 5
+ height, width = 299, 299
+ num_classes = 1001
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ dense_labels = tf.random_uniform((batch_size, num_classes))
+ with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0):
+ logits, end_points = slim.inception.inception_v3(
+ inputs,
+ num_classes=num_classes)
+ # Cross entropy loss for the main softmax prediction.
+ slim.losses.cross_entropy_loss(logits,
+ dense_labels,
+ label_smoothing=0.1,
+ weight=1.0)
+ # Cross entropy loss for the auxiliary softmax head.
+ slim.losses.cross_entropy_loss(end_points['aux_logits'],
+ dense_labels,
+ label_smoothing=0.1,
+ weight=0.4,
+ scope='aux_loss')
+ losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
+ self.assertEqual(len(losses), 2)
+
+ def testTotalLossWithRegularization(self):
+ batch_size = 5
+ height, width = 299, 299
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ dense_labels = tf.random_uniform((batch_size, num_classes))
+ with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
+ logits, end_points = slim.inception.inception_v3(inputs, num_classes)
+ # Cross entropy loss for the main softmax prediction.
+ slim.losses.cross_entropy_loss(logits,
+ dense_labels,
+ label_smoothing=0.1,
+ weight=1.0)
+ # Cross entropy loss for the auxiliary softmax head.
+ slim.losses.cross_entropy_loss(end_points['aux_logits'],
+ dense_labels,
+ label_smoothing=0.1,
+ weight=0.4,
+ scope='aux_loss')
+ losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
+ self.assertEqual(len(losses), 2)
+ reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(reg_losses), 98)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/inception/inception/slim/inception_model.py b/models/research/inception/inception/slim/inception_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6136ab1ba68716f4f135110a4d5c518b732b23df
--- /dev/null
+++ b/models/research/inception/inception/slim/inception_model.py
@@ -0,0 +1,356 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Inception-v3 expressed in TensorFlow-Slim.
+
+ Usage:
+
+ # Parameters for BatchNorm.
+ batch_norm_params = {
+ # Decay for the batch_norm moving averages.
+ 'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
+ # epsilon to prevent 0s in variance.
+ 'epsilon': 0.001,
+ }
+ # Set weight_decay for weights in Conv and FC layers.
+ with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
+ with slim.arg_scope([slim.ops.conv2d],
+ stddev=0.1,
+ activation=tf.nn.relu,
+ batch_norm_params=batch_norm_params):
+ # Force all Variables to reside on the CPU.
+ with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
+ logits, endpoints = slim.inception.inception_v3(
+ images,
+ dropout_keep_prob=0.8,
+ num_classes=num_classes,
+ is_training=for_training,
+ restore_logits=restore_logits,
+ scope=scope)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from inception.slim import ops
+from inception.slim import scopes
+
+
+def inception_v3(inputs,
+ dropout_keep_prob=0.8,
+ num_classes=1000,
+ is_training=True,
+ restore_logits=True,
+ scope=''):
+ """Latest Inception from http://arxiv.org/abs/1512.00567.
+
+ "Rethinking the Inception Architecture for Computer Vision"
+
+ Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens,
+ Zbigniew Wojna
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, channels].
+ dropout_keep_prob: dropout keep_prob.
+ num_classes: number of predicted classes.
+ is_training: whether is training or not.
+ restore_logits: whether or not the logits layers should be restored.
+ Useful for fine-tuning a model with different num_classes.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ a list containing 'logits', 'aux_logits' Tensors.
+ """
+ # end_points will collect relevant activations for external use, for example
+ # summaries or losses.
+ end_points = {}
+ with tf.name_scope(scope, 'inception_v3', [inputs]):
+ with scopes.arg_scope([ops.conv2d, ops.fc, ops.batch_norm, ops.dropout],
+ is_training=is_training):
+ with scopes.arg_scope([ops.conv2d, ops.max_pool, ops.avg_pool],
+ stride=1, padding='VALID'):
+ # 299 x 299 x 3
+ end_points['conv0'] = ops.conv2d(inputs, 32, [3, 3], stride=2,
+ scope='conv0')
+ # 149 x 149 x 32
+ end_points['conv1'] = ops.conv2d(end_points['conv0'], 32, [3, 3],
+ scope='conv1')
+ # 147 x 147 x 32
+ end_points['conv2'] = ops.conv2d(end_points['conv1'], 64, [3, 3],
+ padding='SAME', scope='conv2')
+ # 147 x 147 x 64
+ end_points['pool1'] = ops.max_pool(end_points['conv2'], [3, 3],
+ stride=2, scope='pool1')
+ # 73 x 73 x 64
+ end_points['conv3'] = ops.conv2d(end_points['pool1'], 80, [1, 1],
+ scope='conv3')
+ # 73 x 73 x 80.
+ end_points['conv4'] = ops.conv2d(end_points['conv3'], 192, [3, 3],
+ scope='conv4')
+ # 71 x 71 x 192.
+ end_points['pool2'] = ops.max_pool(end_points['conv4'], [3, 3],
+ stride=2, scope='pool2')
+ # 35 x 35 x 192.
+ net = end_points['pool2']
+ # Inception blocks
+ with scopes.arg_scope([ops.conv2d, ops.max_pool, ops.avg_pool],
+ stride=1, padding='SAME'):
+ # mixed: 35 x 35 x 256.
+ with tf.variable_scope('mixed_35x35x256a'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 64, [1, 1])
+ with tf.variable_scope('branch5x5'):
+ branch5x5 = ops.conv2d(net, 48, [1, 1])
+ branch5x5 = ops.conv2d(branch5x5, 64, [5, 5])
+ with tf.variable_scope('branch3x3dbl'):
+ branch3x3dbl = ops.conv2d(net, 64, [1, 1])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 32, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch5x5, branch3x3dbl, branch_pool])
+ end_points['mixed_35x35x256a'] = net
+ # mixed_1: 35 x 35 x 288.
+ with tf.variable_scope('mixed_35x35x288a'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 64, [1, 1])
+ with tf.variable_scope('branch5x5'):
+ branch5x5 = ops.conv2d(net, 48, [1, 1])
+ branch5x5 = ops.conv2d(branch5x5, 64, [5, 5])
+ with tf.variable_scope('branch3x3dbl'):
+ branch3x3dbl = ops.conv2d(net, 64, [1, 1])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch5x5, branch3x3dbl, branch_pool])
+ end_points['mixed_35x35x288a'] = net
+ # mixed_2: 35 x 35 x 288.
+ with tf.variable_scope('mixed_35x35x288b'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 64, [1, 1])
+ with tf.variable_scope('branch5x5'):
+ branch5x5 = ops.conv2d(net, 48, [1, 1])
+ branch5x5 = ops.conv2d(branch5x5, 64, [5, 5])
+ with tf.variable_scope('branch3x3dbl'):
+ branch3x3dbl = ops.conv2d(net, 64, [1, 1])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch5x5, branch3x3dbl, branch_pool])
+ end_points['mixed_35x35x288b'] = net
+ # mixed_3: 17 x 17 x 768.
+ with tf.variable_scope('mixed_17x17x768a'):
+ with tf.variable_scope('branch3x3'):
+ branch3x3 = ops.conv2d(net, 384, [3, 3], stride=2, padding='VALID')
+ with tf.variable_scope('branch3x3dbl'):
+ branch3x3dbl = ops.conv2d(net, 64, [1, 1])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3],
+ stride=2, padding='VALID')
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
+ net = tf.concat(axis=3, values=[branch3x3, branch3x3dbl, branch_pool])
+ end_points['mixed_17x17x768a'] = net
+ # mixed4: 17 x 17 x 768.
+ with tf.variable_scope('mixed_17x17x768b'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 192, [1, 1])
+ with tf.variable_scope('branch7x7'):
+ branch7x7 = ops.conv2d(net, 128, [1, 1])
+ branch7x7 = ops.conv2d(branch7x7, 128, [1, 7])
+ branch7x7 = ops.conv2d(branch7x7, 192, [7, 1])
+ with tf.variable_scope('branch7x7dbl'):
+ branch7x7dbl = ops.conv2d(net, 128, [1, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 128, [7, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 128, [1, 7])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 128, [7, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch7x7, branch7x7dbl, branch_pool])
+ end_points['mixed_17x17x768b'] = net
+ # mixed_5: 17 x 17 x 768.
+ with tf.variable_scope('mixed_17x17x768c'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 192, [1, 1])
+ with tf.variable_scope('branch7x7'):
+ branch7x7 = ops.conv2d(net, 160, [1, 1])
+ branch7x7 = ops.conv2d(branch7x7, 160, [1, 7])
+ branch7x7 = ops.conv2d(branch7x7, 192, [7, 1])
+ with tf.variable_scope('branch7x7dbl'):
+ branch7x7dbl = ops.conv2d(net, 160, [1, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [7, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [1, 7])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [7, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch7x7, branch7x7dbl, branch_pool])
+ end_points['mixed_17x17x768c'] = net
+ # mixed_6: 17 x 17 x 768.
+ with tf.variable_scope('mixed_17x17x768d'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 192, [1, 1])
+ with tf.variable_scope('branch7x7'):
+ branch7x7 = ops.conv2d(net, 160, [1, 1])
+ branch7x7 = ops.conv2d(branch7x7, 160, [1, 7])
+ branch7x7 = ops.conv2d(branch7x7, 192, [7, 1])
+ with tf.variable_scope('branch7x7dbl'):
+ branch7x7dbl = ops.conv2d(net, 160, [1, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [7, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [1, 7])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [7, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch7x7, branch7x7dbl, branch_pool])
+ end_points['mixed_17x17x768d'] = net
+ # mixed_7: 17 x 17 x 768.
+ with tf.variable_scope('mixed_17x17x768e'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 192, [1, 1])
+ with tf.variable_scope('branch7x7'):
+ branch7x7 = ops.conv2d(net, 192, [1, 1])
+ branch7x7 = ops.conv2d(branch7x7, 192, [1, 7])
+ branch7x7 = ops.conv2d(branch7x7, 192, [7, 1])
+ with tf.variable_scope('branch7x7dbl'):
+ branch7x7dbl = ops.conv2d(net, 192, [1, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [7, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [7, 1])
+ branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch7x7, branch7x7dbl, branch_pool])
+ end_points['mixed_17x17x768e'] = net
+ # Auxiliary Head logits
+ aux_logits = tf.identity(end_points['mixed_17x17x768e'])
+ with tf.variable_scope('aux_logits'):
+ aux_logits = ops.avg_pool(aux_logits, [5, 5], stride=3,
+ padding='VALID')
+ aux_logits = ops.conv2d(aux_logits, 128, [1, 1], scope='proj')
+ # Shape of feature map before the final layer.
+ shape = aux_logits.get_shape()
+ aux_logits = ops.conv2d(aux_logits, 768, shape[1:3], stddev=0.01,
+ padding='VALID')
+ aux_logits = ops.flatten(aux_logits)
+ aux_logits = ops.fc(aux_logits, num_classes, activation=None,
+ stddev=0.001, restore=restore_logits)
+ end_points['aux_logits'] = aux_logits
+ # mixed_8: 8 x 8 x 1280.
+ # Note that the scope below is not changed to not void previous
+ # checkpoints.
+ # (TODO) Fix the scope when appropriate.
+ with tf.variable_scope('mixed_17x17x1280a'):
+ with tf.variable_scope('branch3x3'):
+ branch3x3 = ops.conv2d(net, 192, [1, 1])
+ branch3x3 = ops.conv2d(branch3x3, 320, [3, 3], stride=2,
+ padding='VALID')
+ with tf.variable_scope('branch7x7x3'):
+ branch7x7x3 = ops.conv2d(net, 192, [1, 1])
+ branch7x7x3 = ops.conv2d(branch7x7x3, 192, [1, 7])
+ branch7x7x3 = ops.conv2d(branch7x7x3, 192, [7, 1])
+ branch7x7x3 = ops.conv2d(branch7x7x3, 192, [3, 3],
+ stride=2, padding='VALID')
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
+ net = tf.concat(axis=3, values=[branch3x3, branch7x7x3, branch_pool])
+ end_points['mixed_17x17x1280a'] = net
+ # mixed_9: 8 x 8 x 2048.
+ with tf.variable_scope('mixed_8x8x2048a'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 320, [1, 1])
+ with tf.variable_scope('branch3x3'):
+ branch3x3 = ops.conv2d(net, 384, [1, 1])
+ branch3x3 = tf.concat(axis=3, values=[ops.conv2d(branch3x3, 384, [1, 3]),
+ ops.conv2d(branch3x3, 384, [3, 1])])
+ with tf.variable_scope('branch3x3dbl'):
+ branch3x3dbl = ops.conv2d(net, 448, [1, 1])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
+ branch3x3dbl = tf.concat(axis=3, values=[ops.conv2d(branch3x3dbl, 384, [1, 3]),
+ ops.conv2d(branch3x3dbl, 384, [3, 1])])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch3x3, branch3x3dbl, branch_pool])
+ end_points['mixed_8x8x2048a'] = net
+ # mixed_10: 8 x 8 x 2048.
+ with tf.variable_scope('mixed_8x8x2048b'):
+ with tf.variable_scope('branch1x1'):
+ branch1x1 = ops.conv2d(net, 320, [1, 1])
+ with tf.variable_scope('branch3x3'):
+ branch3x3 = ops.conv2d(net, 384, [1, 1])
+ branch3x3 = tf.concat(axis=3, values=[ops.conv2d(branch3x3, 384, [1, 3]),
+ ops.conv2d(branch3x3, 384, [3, 1])])
+ with tf.variable_scope('branch3x3dbl'):
+ branch3x3dbl = ops.conv2d(net, 448, [1, 1])
+ branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
+ branch3x3dbl = tf.concat(axis=3, values=[ops.conv2d(branch3x3dbl, 384, [1, 3]),
+ ops.conv2d(branch3x3dbl, 384, [3, 1])])
+ with tf.variable_scope('branch_pool'):
+ branch_pool = ops.avg_pool(net, [3, 3])
+ branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
+ net = tf.concat(axis=3, values=[branch1x1, branch3x3, branch3x3dbl, branch_pool])
+ end_points['mixed_8x8x2048b'] = net
+ # Final pooling and prediction
+ with tf.variable_scope('logits'):
+ shape = net.get_shape()
+ net = ops.avg_pool(net, shape[1:3], padding='VALID', scope='pool')
+ # 1 x 1 x 2048
+ net = ops.dropout(net, dropout_keep_prob, scope='dropout')
+ net = ops.flatten(net, scope='flatten')
+ # 2048
+ logits = ops.fc(net, num_classes, activation=None, scope='logits',
+ restore=restore_logits)
+ # 1000
+ end_points['logits'] = logits
+ end_points['predictions'] = tf.nn.softmax(logits, name='predictions')
+ return logits, end_points
+
+
+def inception_v3_parameters(weight_decay=0.00004, stddev=0.1,
+ batch_norm_decay=0.9997, batch_norm_epsilon=0.001):
+ """Yields the scope with the default parameters for inception_v3.
+
+ Args:
+ weight_decay: the weight decay for weights variables.
+ stddev: standard deviation of the truncated guassian weight distribution.
+ batch_norm_decay: decay for the moving average of batch_norm momentums.
+ batch_norm_epsilon: small float added to variance to avoid dividing by zero.
+
+ Yields:
+ a arg_scope with the parameters needed for inception_v3.
+ """
+ # Set weight_decay for weights in Conv and FC layers.
+ with scopes.arg_scope([ops.conv2d, ops.fc],
+ weight_decay=weight_decay):
+ # Set stddev, activation and parameters for batch_norm.
+ with scopes.arg_scope([ops.conv2d],
+ stddev=stddev,
+ activation=tf.nn.relu,
+ batch_norm_params={
+ 'decay': batch_norm_decay,
+ 'epsilon': batch_norm_epsilon}) as arg_scope:
+ yield arg_scope
diff --git a/models/research/inception/inception/slim/inception_test.py b/models/research/inception/inception/slim/inception_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..231dea298f4b761aa90224df1c263873bc890ac5
--- /dev/null
+++ b/models/research/inception/inception/slim/inception_test.py
@@ -0,0 +1,134 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.inception."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from inception.slim import inception_model as inception
+
+
+class InceptionTest(tf.test.TestCase):
+
+ def testBuildLogits(self):
+ batch_size = 5
+ height, width = 299, 299
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = inception.inception_v3(inputs, num_classes)
+ self.assertTrue(logits.op.name.startswith('logits'))
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+
+ def testBuildEndPoints(self):
+ batch_size = 5
+ height, width = 299, 299
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ _, end_points = inception.inception_v3(inputs, num_classes)
+ self.assertTrue('logits' in end_points)
+ logits = end_points['logits']
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+ self.assertTrue('aux_logits' in end_points)
+ aux_logits = end_points['aux_logits']
+ self.assertListEqual(aux_logits.get_shape().as_list(),
+ [batch_size, num_classes])
+ pre_pool = end_points['mixed_8x8x2048b']
+ self.assertListEqual(pre_pool.get_shape().as_list(),
+ [batch_size, 8, 8, 2048])
+
+ def testVariablesSetDevice(self):
+ batch_size = 5
+ height, width = 299, 299
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ # Force all Variables to reside on the device.
+ with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
+ inception.inception_v3(inputs, num_classes)
+ with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
+ inception.inception_v3(inputs, num_classes)
+ for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
+ self.assertDeviceEqual(v.device, '/cpu:0')
+ for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
+ self.assertDeviceEqual(v.device, '/gpu:0')
+
+ def testHalfSizeImages(self):
+ batch_size = 5
+ height, width = 150, 150
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, end_points = inception.inception_v3(inputs, num_classes)
+ self.assertTrue(logits.op.name.startswith('logits'))
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+ pre_pool = end_points['mixed_8x8x2048b']
+ self.assertListEqual(pre_pool.get_shape().as_list(),
+ [batch_size, 3, 3, 2048])
+
+ def testUnknowBatchSize(self):
+ batch_size = 1
+ height, width = 299, 299
+ num_classes = 1000
+ with self.test_session() as sess:
+ inputs = tf.placeholder(tf.float32, (None, height, width, 3))
+ logits, _ = inception.inception_v3(inputs, num_classes)
+ self.assertTrue(logits.op.name.startswith('logits'))
+ self.assertListEqual(logits.get_shape().as_list(),
+ [None, num_classes])
+ images = tf.random_uniform((batch_size, height, width, 3))
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(logits, {inputs: images.eval()})
+ self.assertEquals(output.shape, (batch_size, num_classes))
+
+ def testEvaluation(self):
+ batch_size = 2
+ height, width = 299, 299
+ num_classes = 1000
+ with self.test_session() as sess:
+ eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = inception.inception_v3(eval_inputs, num_classes,
+ is_training=False)
+ predictions = tf.argmax(logits, 1)
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(predictions)
+ self.assertEquals(output.shape, (batch_size,))
+
+ def testTrainEvalWithReuse(self):
+ train_batch_size = 5
+ eval_batch_size = 2
+ height, width = 150, 150
+ num_classes = 1000
+ with self.test_session() as sess:
+ train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
+ inception.inception_v3(train_inputs, num_classes)
+ tf.get_variable_scope().reuse_variables()
+ eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
+ logits, _ = inception.inception_v3(eval_inputs, num_classes,
+ is_training=False)
+ predictions = tf.argmax(logits, 1)
+ sess.run(tf.global_variables_initializer())
+ output = sess.run(predictions)
+ self.assertEquals(output.shape, (eval_batch_size,))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/inception/inception/slim/losses.py b/models/research/inception/inception/slim/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..78298d092fab3afc264e427fb060602c27ea97b0
--- /dev/null
+++ b/models/research/inception/inception/slim/losses.py
@@ -0,0 +1,174 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains convenience wrappers for various Neural Network TensorFlow losses.
+
+ All the losses defined here add themselves to the LOSSES_COLLECTION
+ collection.
+
+ l1_loss: Define a L1 Loss, useful for regularization, i.e. lasso.
+ l2_loss: Define a L2 Loss, useful for regularization, i.e. weight decay.
+ cross_entropy_loss: Define a cross entropy loss using
+ softmax_cross_entropy_with_logits. Useful for classification.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+# In order to gather all losses in a network, the user should use this
+# key for get_collection, i.e:
+# losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
+LOSSES_COLLECTION = '_losses'
+
+
+def l1_regularizer(weight=1.0, scope=None):
+ """Define a L1 regularizer.
+
+ Args:
+ weight: scale the loss by this factor.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ a regularizer function.
+ """
+ def regularizer(tensor):
+ with tf.name_scope(scope, 'L1Regularizer', [tensor]):
+ l1_weight = tf.convert_to_tensor(weight,
+ dtype=tensor.dtype.base_dtype,
+ name='weight')
+ return tf.multiply(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
+ return regularizer
+
+
+def l2_regularizer(weight=1.0, scope=None):
+ """Define a L2 regularizer.
+
+ Args:
+ weight: scale the loss by this factor.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ a regularizer function.
+ """
+ def regularizer(tensor):
+ with tf.name_scope(scope, 'L2Regularizer', [tensor]):
+ l2_weight = tf.convert_to_tensor(weight,
+ dtype=tensor.dtype.base_dtype,
+ name='weight')
+ return tf.multiply(l2_weight, tf.nn.l2_loss(tensor), name='value')
+ return regularizer
+
+
+def l1_l2_regularizer(weight_l1=1.0, weight_l2=1.0, scope=None):
+ """Define a L1L2 regularizer.
+
+ Args:
+ weight_l1: scale the L1 loss by this factor.
+ weight_l2: scale the L2 loss by this factor.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ a regularizer function.
+ """
+ def regularizer(tensor):
+ with tf.name_scope(scope, 'L1L2Regularizer', [tensor]):
+ weight_l1_t = tf.convert_to_tensor(weight_l1,
+ dtype=tensor.dtype.base_dtype,
+ name='weight_l1')
+ weight_l2_t = tf.convert_to_tensor(weight_l2,
+ dtype=tensor.dtype.base_dtype,
+ name='weight_l2')
+ reg_l1 = tf.multiply(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
+ name='value_l1')
+ reg_l2 = tf.multiply(weight_l2_t, tf.nn.l2_loss(tensor),
+ name='value_l2')
+ return tf.add(reg_l1, reg_l2, name='value')
+ return regularizer
+
+
+def l1_loss(tensor, weight=1.0, scope=None):
+ """Define a L1Loss, useful for regularize, i.e. lasso.
+
+ Args:
+ tensor: tensor to regularize.
+ weight: scale the loss by this factor.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ the L1 loss op.
+ """
+ with tf.name_scope(scope, 'L1Loss', [tensor]):
+ weight = tf.convert_to_tensor(weight,
+ dtype=tensor.dtype.base_dtype,
+ name='loss_weight')
+ loss = tf.multiply(weight, tf.reduce_sum(tf.abs(tensor)), name='value')
+ tf.add_to_collection(LOSSES_COLLECTION, loss)
+ return loss
+
+
+def l2_loss(tensor, weight=1.0, scope=None):
+ """Define a L2Loss, useful for regularize, i.e. weight decay.
+
+ Args:
+ tensor: tensor to regularize.
+ weight: an optional weight to modulate the loss.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ the L2 loss op.
+ """
+ with tf.name_scope(scope, 'L2Loss', [tensor]):
+ weight = tf.convert_to_tensor(weight,
+ dtype=tensor.dtype.base_dtype,
+ name='loss_weight')
+ loss = tf.multiply(weight, tf.nn.l2_loss(tensor), name='value')
+ tf.add_to_collection(LOSSES_COLLECTION, loss)
+ return loss
+
+
+def cross_entropy_loss(logits, one_hot_labels, label_smoothing=0,
+ weight=1.0, scope=None):
+ """Define a Cross Entropy loss using softmax_cross_entropy_with_logits.
+
+ It can scale the loss by weight factor, and smooth the labels.
+
+ Args:
+ logits: [batch_size, num_classes] logits outputs of the network .
+ one_hot_labels: [batch_size, num_classes] target one_hot_encoded labels.
+ label_smoothing: if greater than 0 then smooth the labels.
+ weight: scale the loss by this factor.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ A tensor with the softmax_cross_entropy loss.
+ """
+ logits.get_shape().assert_is_compatible_with(one_hot_labels.get_shape())
+ with tf.name_scope(scope, 'CrossEntropyLoss', [logits, one_hot_labels]):
+ num_classes = one_hot_labels.get_shape()[-1].value
+ one_hot_labels = tf.cast(one_hot_labels, logits.dtype)
+ if label_smoothing > 0:
+ smooth_positives = 1.0 - label_smoothing
+ smooth_negatives = label_smoothing / num_classes
+ one_hot_labels = one_hot_labels * smooth_positives + smooth_negatives
+ cross_entropy = tf.contrib.nn.deprecated_flipped_softmax_cross_entropy_with_logits(
+ logits, one_hot_labels, name='xentropy')
+
+ weight = tf.convert_to_tensor(weight,
+ dtype=logits.dtype.base_dtype,
+ name='loss_weight')
+ loss = tf.multiply(weight, tf.reduce_mean(cross_entropy), name='value')
+ tf.add_to_collection(LOSSES_COLLECTION, loss)
+ return loss
diff --git a/models/research/inception/inception/slim/losses_test.py b/models/research/inception/inception/slim/losses_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e267f6520779f63be0becf41ceccc7de494e14f7
--- /dev/null
+++ b/models/research/inception/inception/slim/losses_test.py
@@ -0,0 +1,177 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.losses."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from inception.slim import losses
+
+
+class LossesTest(tf.test.TestCase):
+
+ def testL1Loss(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ weights = tf.constant(1.0, shape=shape)
+ wd = 0.01
+ loss = losses.l1_loss(weights, wd)
+ self.assertEquals(loss.op.name, 'L1Loss/value')
+ self.assertAlmostEqual(loss.eval(), num_elem * wd, 5)
+
+ def testL2Loss(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ weights = tf.constant(1.0, shape=shape)
+ wd = 0.01
+ loss = losses.l2_loss(weights, wd)
+ self.assertEquals(loss.op.name, 'L2Loss/value')
+ self.assertAlmostEqual(loss.eval(), num_elem * wd / 2, 5)
+
+
+class RegularizersTest(tf.test.TestCase):
+
+ def testL1Regularizer(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ loss = losses.l1_regularizer()(tensor)
+ self.assertEquals(loss.op.name, 'L1Regularizer/value')
+ self.assertAlmostEqual(loss.eval(), num_elem, 5)
+
+ def testL1RegularizerWithScope(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ loss = losses.l1_regularizer(scope='L1')(tensor)
+ self.assertEquals(loss.op.name, 'L1/value')
+ self.assertAlmostEqual(loss.eval(), num_elem, 5)
+
+ def testL1RegularizerWithWeight(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ weight = 0.01
+ loss = losses.l1_regularizer(weight)(tensor)
+ self.assertEquals(loss.op.name, 'L1Regularizer/value')
+ self.assertAlmostEqual(loss.eval(), num_elem * weight, 5)
+
+ def testL2Regularizer(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ loss = losses.l2_regularizer()(tensor)
+ self.assertEquals(loss.op.name, 'L2Regularizer/value')
+ self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
+
+ def testL2RegularizerWithScope(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ loss = losses.l2_regularizer(scope='L2')(tensor)
+ self.assertEquals(loss.op.name, 'L2/value')
+ self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
+
+ def testL2RegularizerWithWeight(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ weight = 0.01
+ loss = losses.l2_regularizer(weight)(tensor)
+ self.assertEquals(loss.op.name, 'L2Regularizer/value')
+ self.assertAlmostEqual(loss.eval(), num_elem * weight / 2, 5)
+
+ def testL1L2Regularizer(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ loss = losses.l1_l2_regularizer()(tensor)
+ self.assertEquals(loss.op.name, 'L1L2Regularizer/value')
+ self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
+
+ def testL1L2RegularizerWithScope(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ loss = losses.l1_l2_regularizer(scope='L1L2')(tensor)
+ self.assertEquals(loss.op.name, 'L1L2/value')
+ self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
+
+ def testL1L2RegularizerWithWeights(self):
+ with self.test_session():
+ shape = [5, 5, 5]
+ num_elem = 5 * 5 * 5
+ tensor = tf.constant(1.0, shape=shape)
+ weight_l1 = 0.01
+ weight_l2 = 0.05
+ loss = losses.l1_l2_regularizer(weight_l1, weight_l2)(tensor)
+ self.assertEquals(loss.op.name, 'L1L2Regularizer/value')
+ self.assertAlmostEqual(loss.eval(),
+ num_elem * weight_l1 + num_elem * weight_l2 / 2, 5)
+
+
+class CrossEntropyLossTest(tf.test.TestCase):
+
+ def testCrossEntropyLossAllCorrect(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]])
+ loss = losses.cross_entropy_loss(logits, labels)
+ self.assertEquals(loss.op.name, 'CrossEntropyLoss/value')
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testCrossEntropyLossAllWrong(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ loss = losses.cross_entropy_loss(logits, labels)
+ self.assertEquals(loss.op.name, 'CrossEntropyLoss/value')
+ self.assertAlmostEqual(loss.eval(), 10.0, 3)
+
+ def testCrossEntropyLossAllWrongWithWeight(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ loss = losses.cross_entropy_loss(logits, labels, weight=0.5)
+ self.assertEquals(loss.op.name, 'CrossEntropyLoss/value')
+ self.assertAlmostEqual(loss.eval(), 5.0, 3)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/inception/inception/slim/ops.py b/models/research/inception/inception/slim/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..54fda4eb81f3a138d9bb2748c21164b88570ede9
--- /dev/null
+++ b/models/research/inception/inception/slim/ops.py
@@ -0,0 +1,473 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains convenience wrappers for typical Neural Network TensorFlow layers.
+
+ Additionally it maintains a collection with update_ops that need to be
+ updated after the ops have been computed, for example to update moving means
+ and moving variances of batch_norm.
+
+ Ops that have different behavior during training or eval have an is_training
+ parameter. Additionally Ops that contain variables.variable have a trainable
+ parameter, which control if the ops variables are trainable or not.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from tensorflow.python.training import moving_averages
+
+from inception.slim import losses
+from inception.slim import scopes
+from inception.slim import variables
+
+# Used to keep the update ops done by batch_norm.
+UPDATE_OPS_COLLECTION = '_update_ops_'
+
+
+@scopes.add_arg_scope
+def batch_norm(inputs,
+ decay=0.999,
+ center=True,
+ scale=False,
+ epsilon=0.001,
+ moving_vars='moving_vars',
+ activation=None,
+ is_training=True,
+ trainable=True,
+ restore=True,
+ scope=None,
+ reuse=None):
+ """Adds a Batch Normalization layer.
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, channels]
+ or [batch_size, channels].
+ decay: decay for the moving average.
+ center: If True, subtract beta. If False, beta is not created and ignored.
+ scale: If True, multiply by gamma. If False, gamma is
+ not used. When the next layer is linear (also e.g. ReLU), this can be
+ disabled since the scaling can be done by the next layer.
+ epsilon: small float added to variance to avoid dividing by zero.
+ moving_vars: collection to store the moving_mean and moving_variance.
+ activation: activation function.
+ is_training: whether or not the model is in training mode.
+ trainable: whether or not the variables should be trainable or not.
+ restore: whether or not the variables should be marked for restore.
+ scope: Optional scope for variable_scope.
+ reuse: whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+
+ Returns:
+ a tensor representing the output of the operation.
+
+ """
+ inputs_shape = inputs.get_shape()
+ with tf.variable_scope(scope, 'BatchNorm', [inputs], reuse=reuse):
+ axis = list(range(len(inputs_shape) - 1))
+ params_shape = inputs_shape[-1:]
+ # Allocate parameters for the beta and gamma of the normalization.
+ beta, gamma = None, None
+ if center:
+ beta = variables.variable('beta',
+ params_shape,
+ initializer=tf.zeros_initializer(),
+ trainable=trainable,
+ restore=restore)
+ if scale:
+ gamma = variables.variable('gamma',
+ params_shape,
+ initializer=tf.ones_initializer(),
+ trainable=trainable,
+ restore=restore)
+ # Create moving_mean and moving_variance add them to
+ # GraphKeys.MOVING_AVERAGE_VARIABLES collections.
+ moving_collections = [moving_vars, tf.GraphKeys.MOVING_AVERAGE_VARIABLES]
+ moving_mean = variables.variable('moving_mean',
+ params_shape,
+ initializer=tf.zeros_initializer(),
+ trainable=False,
+ restore=restore,
+ collections=moving_collections)
+ moving_variance = variables.variable('moving_variance',
+ params_shape,
+ initializer=tf.ones_initializer(),
+ trainable=False,
+ restore=restore,
+ collections=moving_collections)
+ if is_training:
+ # Calculate the moments based on the individual batch.
+ mean, variance = tf.nn.moments(inputs, axis)
+
+ update_moving_mean = moving_averages.assign_moving_average(
+ moving_mean, mean, decay)
+ tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
+ update_moving_variance = moving_averages.assign_moving_average(
+ moving_variance, variance, decay)
+ tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
+ else:
+ # Just use the moving_mean and moving_variance.
+ mean = moving_mean
+ variance = moving_variance
+ # Normalize the activations.
+ outputs = tf.nn.batch_normalization(
+ inputs, mean, variance, beta, gamma, epsilon)
+ outputs.set_shape(inputs.get_shape())
+ if activation:
+ outputs = activation(outputs)
+ return outputs
+
+
+def _two_element_tuple(int_or_tuple):
+ """Converts `int_or_tuple` to height, width.
+
+ Several of the functions that follow accept arguments as either
+ a tuple of 2 integers or a single integer. A single integer
+ indicates that the 2 values of the tuple are the same.
+
+ This functions normalizes the input value by always returning a tuple.
+
+ Args:
+ int_or_tuple: A list of 2 ints, a single int or a tf.TensorShape.
+
+ Returns:
+ A tuple with 2 values.
+
+ Raises:
+ ValueError: If `int_or_tuple` it not well formed.
+ """
+ if isinstance(int_or_tuple, (list, tuple)):
+ if len(int_or_tuple) != 2:
+ raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple)
+ return int(int_or_tuple[0]), int(int_or_tuple[1])
+ if isinstance(int_or_tuple, int):
+ return int(int_or_tuple), int(int_or_tuple)
+ if isinstance(int_or_tuple, tf.TensorShape):
+ if len(int_or_tuple) == 2:
+ return int_or_tuple[0], int_or_tuple[1]
+ raise ValueError('Must be an int, a list with 2 elements or a TensorShape of '
+ 'length 2')
+
+
+@scopes.add_arg_scope
+def conv2d(inputs,
+ num_filters_out,
+ kernel_size,
+ stride=1,
+ padding='SAME',
+ activation=tf.nn.relu,
+ stddev=0.01,
+ bias=0.0,
+ weight_decay=0,
+ batch_norm_params=None,
+ is_training=True,
+ trainable=True,
+ restore=True,
+ scope=None,
+ reuse=None):
+ """Adds a 2D convolution followed by an optional batch_norm layer.
+
+ conv2d creates a variable called 'weights', representing the convolutional
+ kernel, that is convolved with the input. If `batch_norm_params` is None, a
+ second variable called 'biases' is added to the result of the convolution
+ operation.
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, channels].
+ num_filters_out: the number of output filters.
+ kernel_size: a list of length 2: [kernel_height, kernel_width] of
+ of the filters. Can be an int if both values are the same.
+ stride: a list of length 2: [stride_height, stride_width].
+ Can be an int if both strides are the same. Note that presently
+ both strides must have the same value.
+ padding: one of 'VALID' or 'SAME'.
+ activation: activation function.
+ stddev: standard deviation of the truncated guassian weight distribution.
+ bias: the initial value of the biases.
+ weight_decay: the weight decay.
+ batch_norm_params: parameters for the batch_norm. If is None don't use it.
+ is_training: whether or not the model is in training mode.
+ trainable: whether or not the variables should be trainable or not.
+ restore: whether or not the variables should be marked for restore.
+ scope: Optional scope for variable_scope.
+ reuse: whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ Returns:
+ a tensor representing the output of the operation.
+
+ """
+ with tf.variable_scope(scope, 'Conv', [inputs], reuse=reuse):
+ kernel_h, kernel_w = _two_element_tuple(kernel_size)
+ stride_h, stride_w = _two_element_tuple(stride)
+ num_filters_in = inputs.get_shape()[-1]
+ weights_shape = [kernel_h, kernel_w,
+ num_filters_in, num_filters_out]
+ weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
+ l2_regularizer = None
+ if weight_decay and weight_decay > 0:
+ l2_regularizer = losses.l2_regularizer(weight_decay)
+ weights = variables.variable('weights',
+ shape=weights_shape,
+ initializer=weights_initializer,
+ regularizer=l2_regularizer,
+ trainable=trainable,
+ restore=restore)
+ conv = tf.nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1],
+ padding=padding)
+ if batch_norm_params is not None:
+ with scopes.arg_scope([batch_norm], is_training=is_training,
+ trainable=trainable, restore=restore):
+ outputs = batch_norm(conv, **batch_norm_params)
+ else:
+ bias_shape = [num_filters_out,]
+ bias_initializer = tf.constant_initializer(bias)
+ biases = variables.variable('biases',
+ shape=bias_shape,
+ initializer=bias_initializer,
+ trainable=trainable,
+ restore=restore)
+ outputs = tf.nn.bias_add(conv, biases)
+ if activation:
+ outputs = activation(outputs)
+ return outputs
+
+
+@scopes.add_arg_scope
+def fc(inputs,
+ num_units_out,
+ activation=tf.nn.relu,
+ stddev=0.01,
+ bias=0.0,
+ weight_decay=0,
+ batch_norm_params=None,
+ is_training=True,
+ trainable=True,
+ restore=True,
+ scope=None,
+ reuse=None):
+ """Adds a fully connected layer followed by an optional batch_norm layer.
+
+ FC creates a variable called 'weights', representing the fully connected
+ weight matrix, that is multiplied by the input. If `batch_norm` is None, a
+ second variable called 'biases' is added to the result of the initial
+ vector-matrix multiplication.
+
+ Args:
+ inputs: a [B x N] tensor where B is the batch size and N is the number of
+ input units in the layer.
+ num_units_out: the number of output units in the layer.
+ activation: activation function.
+ stddev: the standard deviation for the weights.
+ bias: the initial value of the biases.
+ weight_decay: the weight decay.
+ batch_norm_params: parameters for the batch_norm. If is None don't use it.
+ is_training: whether or not the model is in training mode.
+ trainable: whether or not the variables should be trainable or not.
+ restore: whether or not the variables should be marked for restore.
+ scope: Optional scope for variable_scope.
+ reuse: whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+
+ Returns:
+ the tensor variable representing the result of the series of operations.
+ """
+ with tf.variable_scope(scope, 'FC', [inputs], reuse=reuse):
+ num_units_in = inputs.get_shape()[1]
+ weights_shape = [num_units_in, num_units_out]
+ weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
+ l2_regularizer = None
+ if weight_decay and weight_decay > 0:
+ l2_regularizer = losses.l2_regularizer(weight_decay)
+ weights = variables.variable('weights',
+ shape=weights_shape,
+ initializer=weights_initializer,
+ regularizer=l2_regularizer,
+ trainable=trainable,
+ restore=restore)
+ if batch_norm_params is not None:
+ outputs = tf.matmul(inputs, weights)
+ with scopes.arg_scope([batch_norm], is_training=is_training,
+ trainable=trainable, restore=restore):
+ outputs = batch_norm(outputs, **batch_norm_params)
+ else:
+ bias_shape = [num_units_out,]
+ bias_initializer = tf.constant_initializer(bias)
+ biases = variables.variable('biases',
+ shape=bias_shape,
+ initializer=bias_initializer,
+ trainable=trainable,
+ restore=restore)
+ outputs = tf.nn.xw_plus_b(inputs, weights, biases)
+ if activation:
+ outputs = activation(outputs)
+ return outputs
+
+
+def one_hot_encoding(labels, num_classes, scope=None):
+ """Transform numeric labels into onehot_labels.
+
+ Args:
+ labels: [batch_size] target labels.
+ num_classes: total number of classes.
+ scope: Optional scope for name_scope.
+ Returns:
+ one hot encoding of the labels.
+ """
+ with tf.name_scope(scope, 'OneHotEncoding', [labels]):
+ batch_size = labels.get_shape()[0]
+ indices = tf.expand_dims(tf.range(0, batch_size), 1)
+ labels = tf.cast(tf.expand_dims(labels, 1), indices.dtype)
+ concated = tf.concat(axis=1, values=[indices, labels])
+ onehot_labels = tf.sparse_to_dense(
+ concated, tf.stack([batch_size, num_classes]), 1.0, 0.0)
+ onehot_labels.set_shape([batch_size, num_classes])
+ return onehot_labels
+
+
+@scopes.add_arg_scope
+def max_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
+ """Adds a Max Pooling layer.
+
+ It is assumed by the wrapper that the pooling is only done per image and not
+ in depth or batch.
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, depth].
+ kernel_size: a list of length 2: [kernel_height, kernel_width] of the
+ pooling kernel over which the op is computed. Can be an int if both
+ values are the same.
+ stride: a list of length 2: [stride_height, stride_width].
+ Can be an int if both strides are the same. Note that presently
+ both strides must have the same value.
+ padding: the padding method, either 'VALID' or 'SAME'.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ a tensor representing the results of the pooling operation.
+ Raises:
+ ValueError: if 'kernel_size' is not a 2-D list
+ """
+ with tf.name_scope(scope, 'MaxPool', [inputs]):
+ kernel_h, kernel_w = _two_element_tuple(kernel_size)
+ stride_h, stride_w = _two_element_tuple(stride)
+ return tf.nn.max_pool(inputs,
+ ksize=[1, kernel_h, kernel_w, 1],
+ strides=[1, stride_h, stride_w, 1],
+ padding=padding)
+
+
+@scopes.add_arg_scope
+def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
+ """Adds a Avg Pooling layer.
+
+ It is assumed by the wrapper that the pooling is only done per image and not
+ in depth or batch.
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, depth].
+ kernel_size: a list of length 2: [kernel_height, kernel_width] of the
+ pooling kernel over which the op is computed. Can be an int if both
+ values are the same.
+ stride: a list of length 2: [stride_height, stride_width].
+ Can be an int if both strides are the same. Note that presently
+ both strides must have the same value.
+ padding: the padding method, either 'VALID' or 'SAME'.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ a tensor representing the results of the pooling operation.
+ """
+ with tf.name_scope(scope, 'AvgPool', [inputs]):
+ kernel_h, kernel_w = _two_element_tuple(kernel_size)
+ stride_h, stride_w = _two_element_tuple(stride)
+ return tf.nn.avg_pool(inputs,
+ ksize=[1, kernel_h, kernel_w, 1],
+ strides=[1, stride_h, stride_w, 1],
+ padding=padding)
+
+
+@scopes.add_arg_scope
+def dropout(inputs, keep_prob=0.5, is_training=True, scope=None):
+ """Returns a dropout layer applied to the input.
+
+ Args:
+ inputs: the tensor to pass to the Dropout layer.
+ keep_prob: the probability of keeping each input unit.
+ is_training: whether or not the model is in training mode. If so, dropout is
+ applied and values scaled. Otherwise, inputs is returned.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ a tensor representing the output of the operation.
+ """
+ if is_training and keep_prob > 0:
+ with tf.name_scope(scope, 'Dropout', [inputs]):
+ return tf.nn.dropout(inputs, keep_prob)
+ else:
+ return inputs
+
+
+def flatten(inputs, scope=None):
+ """Flattens the input while maintaining the batch_size.
+
+ Assumes that the first dimension represents the batch.
+
+ Args:
+ inputs: a tensor of size [batch_size, ...].
+ scope: Optional scope for name_scope.
+
+ Returns:
+ a flattened tensor with shape [batch_size, k].
+ Raises:
+ ValueError: if inputs.shape is wrong.
+ """
+ if len(inputs.get_shape()) < 2:
+ raise ValueError('Inputs must be have a least 2 dimensions')
+ dims = inputs.get_shape()[1:]
+ k = dims.num_elements()
+ with tf.name_scope(scope, 'Flatten', [inputs]):
+ return tf.reshape(inputs, [-1, k])
+
+
+def repeat_op(repetitions, inputs, op, *args, **kwargs):
+ """Build a sequential Tower starting from inputs by using an op repeatedly.
+
+ It creates new scopes for each operation by increasing the counter.
+ Example: given repeat_op(3, _, ops.conv2d, 64, [3, 3], scope='conv1')
+ it will repeat the given op under the following variable_scopes:
+ conv1/Conv
+ conv1/Conv_1
+ conv1/Conv_2
+
+ Args:
+ repetitions: number or repetitions.
+ inputs: a tensor of size [batch_size, height, width, channels].
+ op: an operation.
+ *args: args for the op.
+ **kwargs: kwargs for the op.
+
+ Returns:
+ a tensor result of applying the operation op, num times.
+ Raises:
+ ValueError: if the op is unknown or wrong.
+ """
+ scope = kwargs.pop('scope', None)
+ with tf.variable_scope(scope, 'RepeatOp', [inputs]):
+ tower = inputs
+ for _ in range(repetitions):
+ tower = op(tower, *args, **kwargs)
+ return tower
diff --git a/models/research/inception/inception/slim/ops_test.py b/models/research/inception/inception/slim/ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..13dc5d9aacf6e283540a406d419a67d2d7215161
--- /dev/null
+++ b/models/research/inception/inception/slim/ops_test.py
@@ -0,0 +1,687 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import numpy as np
+import tensorflow as tf
+
+from inception.slim import ops
+from inception.slim import scopes
+from inception.slim import variables
+
+
+class ConvTest(tf.test.TestCase):
+
+ def testCreateConv(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, [3, 3])
+ self.assertEquals(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+ def testCreateSquareConv(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, 3)
+ self.assertEquals(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+ def testCreateConvWithTensorShape(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, images.get_shape()[1:3])
+ self.assertEquals(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+ def testCreateFullyConv(self):
+ height, width = 6, 6
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 32), seed=1)
+ output = ops.conv2d(images, 64, images.get_shape()[1:3], padding='VALID')
+ self.assertEquals(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 64])
+
+ def testCreateVerticalConv(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, [3, 1])
+ self.assertEquals(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, height, width, 32])
+
+ def testCreateHorizontalConv(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, [1, 3])
+ self.assertEquals(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, height, width, 32])
+
+ def testCreateConvWithStride(self):
+ height, width = 6, 6
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, [3, 3], stride=2)
+ self.assertEquals(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, height/2, width/2, 32])
+
+ def testCreateConvCreatesWeightsAndBiasesVars(self):
+ height, width = 3, 3
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ with self.test_session():
+ self.assertFalse(variables.get_variables('conv1/weights'))
+ self.assertFalse(variables.get_variables('conv1/biases'))
+ ops.conv2d(images, 32, [3, 3], scope='conv1')
+ self.assertTrue(variables.get_variables('conv1/weights'))
+ self.assertTrue(variables.get_variables('conv1/biases'))
+
+ def testCreateConvWithScope(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, [3, 3], scope='conv1')
+ self.assertEquals(output.op.name, 'conv1/Relu')
+
+ def testCreateConvWithoutActivation(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, [3, 3], activation=None)
+ self.assertEquals(output.op.name, 'Conv/BiasAdd')
+
+ def testCreateConvValid(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.conv2d(images, 32, [3, 3], padding='VALID')
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 32])
+
+ def testCreateConvWithWD(self):
+ height, width = 3, 3
+ with self.test_session() as sess:
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.conv2d(images, 32, [3, 3], weight_decay=0.01)
+ wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
+ self.assertEquals(wd.op.name,
+ 'Conv/weights/Regularizer/L2Regularizer/value')
+ sess.run(tf.global_variables_initializer())
+ self.assertTrue(sess.run(wd) <= 0.01)
+
+ def testCreateConvWithoutWD(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.conv2d(images, 32, [3, 3], weight_decay=0)
+ self.assertEquals(
+ tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
+
+ def testReuseVars(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.conv2d(images, 32, [3, 3], scope='conv1')
+ self.assertEquals(len(variables.get_variables()), 2)
+ ops.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
+ self.assertEquals(len(variables.get_variables()), 2)
+
+ def testNonReuseVars(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.conv2d(images, 32, [3, 3])
+ self.assertEquals(len(variables.get_variables()), 2)
+ ops.conv2d(images, 32, [3, 3])
+ self.assertEquals(len(variables.get_variables()), 4)
+
+ def testReuseConvWithWD(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
+ self.assertEquals(len(variables.get_variables()), 2)
+ self.assertEquals(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+ ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1',
+ reuse=True)
+ self.assertEquals(len(variables.get_variables()), 2)
+ self.assertEquals(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+
+ def testConvWithBatchNorm(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 32), seed=1)
+ with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
+ net = ops.conv2d(images, 32, [3, 3])
+ net = ops.conv2d(net, 32, [3, 3])
+ self.assertEquals(len(variables.get_variables()), 8)
+ self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
+ self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 3)
+
+ def testReuseConvWithBatchNorm(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 32), seed=1)
+ with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
+ net = ops.conv2d(images, 32, [3, 3], scope='Conv')
+ net = ops.conv2d(net, 32, [3, 3], scope='Conv', reuse=True)
+ self.assertEquals(len(variables.get_variables()), 4)
+ self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
+ self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 0)
+
+
+class FCTest(tf.test.TestCase):
+
+ def testCreateFC(self):
+ height, width = 3, 3
+ with self.test_session():
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ output = ops.fc(inputs, 32)
+ self.assertEquals(output.op.name, 'FC/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, 32])
+
+ def testCreateFCWithScope(self):
+ height, width = 3, 3
+ with self.test_session():
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ output = ops.fc(inputs, 32, scope='fc1')
+ self.assertEquals(output.op.name, 'fc1/Relu')
+
+ def testCreateFcCreatesWeightsAndBiasesVars(self):
+ height, width = 3, 3
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ with self.test_session():
+ self.assertFalse(variables.get_variables('fc1/weights'))
+ self.assertFalse(variables.get_variables('fc1/biases'))
+ ops.fc(inputs, 32, scope='fc1')
+ self.assertTrue(variables.get_variables('fc1/weights'))
+ self.assertTrue(variables.get_variables('fc1/biases'))
+
+ def testReuseVars(self):
+ height, width = 3, 3
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ with self.test_session():
+ ops.fc(inputs, 32, scope='fc1')
+ self.assertEquals(len(variables.get_variables('fc1')), 2)
+ ops.fc(inputs, 32, scope='fc1', reuse=True)
+ self.assertEquals(len(variables.get_variables('fc1')), 2)
+
+ def testNonReuseVars(self):
+ height, width = 3, 3
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ with self.test_session():
+ ops.fc(inputs, 32)
+ self.assertEquals(len(variables.get_variables('FC')), 2)
+ ops.fc(inputs, 32)
+ self.assertEquals(len(variables.get_variables('FC')), 4)
+
+ def testCreateFCWithoutActivation(self):
+ height, width = 3, 3
+ with self.test_session():
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ output = ops.fc(inputs, 32, activation=None)
+ self.assertEquals(output.op.name, 'FC/xw_plus_b')
+
+ def testCreateFCWithWD(self):
+ height, width = 3, 3
+ with self.test_session() as sess:
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ ops.fc(inputs, 32, weight_decay=0.01)
+ wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
+ self.assertEquals(wd.op.name,
+ 'FC/weights/Regularizer/L2Regularizer/value')
+ sess.run(tf.global_variables_initializer())
+ self.assertTrue(sess.run(wd) <= 0.01)
+
+ def testCreateFCWithoutWD(self):
+ height, width = 3, 3
+ with self.test_session():
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ ops.fc(inputs, 32, weight_decay=0)
+ self.assertEquals(
+ tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
+
+ def testReuseFCWithWD(self):
+ height, width = 3, 3
+ with self.test_session():
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+ ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
+ self.assertEquals(len(variables.get_variables()), 2)
+ self.assertEquals(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+ ops.fc(inputs, 32, weight_decay=0.01, scope='fc', reuse=True)
+ self.assertEquals(len(variables.get_variables()), 2)
+ self.assertEquals(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+
+ def testFCWithBatchNorm(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height * width * 3), seed=1)
+ with scopes.arg_scope([ops.fc], batch_norm_params={}):
+ net = ops.fc(images, 27)
+ net = ops.fc(net, 27)
+ self.assertEquals(len(variables.get_variables()), 8)
+ self.assertEquals(len(variables.get_variables('FC/BatchNorm')), 3)
+ self.assertEquals(len(variables.get_variables('FC_1/BatchNorm')), 3)
+
+ def testReuseFCWithBatchNorm(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height * width * 3), seed=1)
+ with scopes.arg_scope([ops.fc], batch_norm_params={'decay': 0.9}):
+ net = ops.fc(images, 27, scope='fc1')
+ net = ops.fc(net, 27, scope='fc1', reuse=True)
+ self.assertEquals(len(variables.get_variables()), 4)
+ self.assertEquals(len(variables.get_variables('fc1/BatchNorm')), 3)
+
+
+class MaxPoolTest(tf.test.TestCase):
+
+ def testCreateMaxPool(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.max_pool(images, [3, 3])
+ self.assertEquals(output.op.name, 'MaxPool/MaxPool')
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
+ def testCreateSquareMaxPool(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.max_pool(images, 3)
+ self.assertEquals(output.op.name, 'MaxPool/MaxPool')
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
+ def testCreateMaxPoolWithScope(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.max_pool(images, [3, 3], scope='pool1')
+ self.assertEquals(output.op.name, 'pool1/MaxPool')
+
+ def testCreateMaxPoolSAME(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.max_pool(images, [3, 3], padding='SAME')
+ self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
+
+ def testCreateMaxPoolStrideSAME(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.max_pool(images, [3, 3], stride=1, padding='SAME')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
+
+ def testGlobalMaxPool(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.max_pool(images, images.get_shape()[1:3], stride=1)
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
+
+class AvgPoolTest(tf.test.TestCase):
+
+ def testCreateAvgPool(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.avg_pool(images, [3, 3])
+ self.assertEquals(output.op.name, 'AvgPool/AvgPool')
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
+ def testCreateSquareAvgPool(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.avg_pool(images, 3)
+ self.assertEquals(output.op.name, 'AvgPool/AvgPool')
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
+ def testCreateAvgPoolWithScope(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.avg_pool(images, [3, 3], scope='pool1')
+ self.assertEquals(output.op.name, 'pool1/AvgPool')
+
+ def testCreateAvgPoolSAME(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.avg_pool(images, [3, 3], padding='SAME')
+ self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
+
+ def testCreateAvgPoolStrideSAME(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.avg_pool(images, [3, 3], stride=1, padding='SAME')
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
+
+ def testGlobalAvgPool(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.avg_pool(images, images.get_shape()[1:3], stride=1)
+ self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
+
+class OneHotEncodingTest(tf.test.TestCase):
+
+ def testOneHotEncodingCreate(self):
+ with self.test_session():
+ labels = tf.constant([0, 1, 2])
+ output = ops.one_hot_encoding(labels, num_classes=3)
+ self.assertEquals(output.op.name, 'OneHotEncoding/SparseToDense')
+ self.assertListEqual(output.get_shape().as_list(), [3, 3])
+
+ def testOneHotEncoding(self):
+ with self.test_session():
+ labels = tf.constant([0, 1, 2])
+ one_hot_labels = tf.constant([[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]])
+ output = ops.one_hot_encoding(labels, num_classes=3)
+ self.assertAllClose(output.eval(), one_hot_labels.eval())
+
+
+class DropoutTest(tf.test.TestCase):
+
+ def testCreateDropout(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.dropout(images)
+ self.assertEquals(output.op.name, 'Dropout/dropout/mul')
+ output.get_shape().assert_is_compatible_with(images.get_shape())
+
+ def testCreateDropoutNoTraining(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
+ output = ops.dropout(images, is_training=False)
+ self.assertEquals(output, images)
+
+
+class FlattenTest(tf.test.TestCase):
+
+ def testFlatten4D(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
+ output = ops.flatten(images)
+ self.assertEquals(output.get_shape().num_elements(),
+ images.get_shape().num_elements())
+ self.assertEqual(output.get_shape()[0], images.get_shape()[0])
+
+ def testFlatten3D(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width), seed=1, name='images')
+ output = ops.flatten(images)
+ self.assertEquals(output.get_shape().num_elements(),
+ images.get_shape().num_elements())
+ self.assertEqual(output.get_shape()[0], images.get_shape()[0])
+
+ def testFlattenBatchSize(self):
+ height, width = 3, 3
+ with self.test_session() as sess:
+ images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
+ inputs = tf.placeholder(tf.int32, (None, height, width, 3))
+ output = ops.flatten(inputs)
+ self.assertEquals(output.get_shape().as_list(),
+ [None, height * width * 3])
+ output = sess.run(output, {inputs: images.eval()})
+ self.assertEquals(output.size,
+ images.get_shape().num_elements())
+ self.assertEqual(output.shape[0], images.get_shape()[0])
+
+
+class BatchNormTest(tf.test.TestCase):
+
+ def testCreateOp(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ output = ops.batch_norm(images)
+ self.assertTrue(output.op.name.startswith('BatchNorm/batchnorm'))
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
+
+ def testCreateVariables(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.batch_norm(images)
+ beta = variables.get_variables_by_name('beta')[0]
+ self.assertEquals(beta.op.name, 'BatchNorm/beta')
+ gamma = variables.get_variables_by_name('gamma')
+ self.assertEquals(gamma, [])
+ moving_mean = tf.moving_average_variables()[0]
+ moving_variance = tf.moving_average_variables()[1]
+ self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
+ self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
+
+ def testCreateVariablesWithScale(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.batch_norm(images, scale=True)
+ beta = variables.get_variables_by_name('beta')[0]
+ gamma = variables.get_variables_by_name('gamma')[0]
+ self.assertEquals(beta.op.name, 'BatchNorm/beta')
+ self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
+ moving_mean = tf.moving_average_variables()[0]
+ moving_variance = tf.moving_average_variables()[1]
+ self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
+ self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
+
+ def testCreateVariablesWithoutCenterWithScale(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.batch_norm(images, center=False, scale=True)
+ beta = variables.get_variables_by_name('beta')
+ self.assertEquals(beta, [])
+ gamma = variables.get_variables_by_name('gamma')[0]
+ self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
+ moving_mean = tf.moving_average_variables()[0]
+ moving_variance = tf.moving_average_variables()[1]
+ self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
+ self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
+
+ def testCreateVariablesWithoutCenterWithoutScale(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.batch_norm(images, center=False, scale=False)
+ beta = variables.get_variables_by_name('beta')
+ self.assertEquals(beta, [])
+ gamma = variables.get_variables_by_name('gamma')
+ self.assertEquals(gamma, [])
+ moving_mean = tf.moving_average_variables()[0]
+ moving_variance = tf.moving_average_variables()[1]
+ self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
+ self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
+
+ def testMovingAverageVariables(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.batch_norm(images, scale=True)
+ moving_mean = tf.moving_average_variables()[0]
+ moving_variance = tf.moving_average_variables()[1]
+ self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
+ self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
+
+ def testUpdateOps(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.batch_norm(images)
+ update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
+ update_moving_mean = update_ops[0]
+ update_moving_variance = update_ops[1]
+ self.assertEquals(update_moving_mean.op.name,
+ 'BatchNorm/AssignMovingAvg')
+ self.assertEquals(update_moving_variance.op.name,
+ 'BatchNorm/AssignMovingAvg_1')
+
+ def testReuseVariables(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.batch_norm(images, scale=True, scope='bn')
+ ops.batch_norm(images, scale=True, scope='bn', reuse=True)
+ beta = variables.get_variables_by_name('beta')
+ gamma = variables.get_variables_by_name('gamma')
+ self.assertEquals(len(beta), 1)
+ self.assertEquals(len(gamma), 1)
+ moving_vars = tf.get_collection('moving_vars')
+ self.assertEquals(len(moving_vars), 2)
+
+ def testReuseUpdateOps(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ ops.batch_norm(images, scope='bn')
+ self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 2)
+ ops.batch_norm(images, scope='bn', reuse=True)
+ self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 4)
+
+ def testCreateMovingVars(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = tf.random_uniform((5, height, width, 3), seed=1)
+ _ = ops.batch_norm(images, moving_vars='moving_vars')
+ moving_mean = tf.get_collection('moving_vars',
+ 'BatchNorm/moving_mean')
+ self.assertEquals(len(moving_mean), 1)
+ self.assertEquals(moving_mean[0].op.name, 'BatchNorm/moving_mean')
+ moving_variance = tf.get_collection('moving_vars',
+ 'BatchNorm/moving_variance')
+ self.assertEquals(len(moving_variance), 1)
+ self.assertEquals(moving_variance[0].op.name, 'BatchNorm/moving_variance')
+
+ def testComputeMovingVars(self):
+ height, width = 3, 3
+ with self.test_session() as sess:
+ image_shape = (10, height, width, 3)
+ image_values = np.random.rand(*image_shape)
+ expected_mean = np.mean(image_values, axis=(0, 1, 2))
+ expected_var = np.var(image_values, axis=(0, 1, 2))
+ images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
+ output = ops.batch_norm(images, decay=0.1)
+ update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
+ with tf.control_dependencies(update_ops):
+ output = tf.identity(output)
+ # Initialize all variables
+ sess.run(tf.global_variables_initializer())
+ moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
+ moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
+ mean, variance = sess.run([moving_mean, moving_variance])
+ # After initialization moving_mean == 0 and moving_variance == 1.
+ self.assertAllClose(mean, [0] * 3)
+ self.assertAllClose(variance, [1] * 3)
+ for _ in range(10):
+ sess.run([output])
+ mean = moving_mean.eval()
+ variance = moving_variance.eval()
+ # After 10 updates with decay 0.1 moving_mean == expected_mean and
+ # moving_variance == expected_var.
+ self.assertAllClose(mean, expected_mean)
+ self.assertAllClose(variance, expected_var)
+
+ def testEvalMovingVars(self):
+ height, width = 3, 3
+ with self.test_session() as sess:
+ image_shape = (10, height, width, 3)
+ image_values = np.random.rand(*image_shape)
+ expected_mean = np.mean(image_values, axis=(0, 1, 2))
+ expected_var = np.var(image_values, axis=(0, 1, 2))
+ images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
+ output = ops.batch_norm(images, decay=0.1, is_training=False)
+ update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
+ with tf.control_dependencies(update_ops):
+ output = tf.identity(output)
+ # Initialize all variables
+ sess.run(tf.global_variables_initializer())
+ moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
+ moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
+ mean, variance = sess.run([moving_mean, moving_variance])
+ # After initialization moving_mean == 0 and moving_variance == 1.
+ self.assertAllClose(mean, [0] * 3)
+ self.assertAllClose(variance, [1] * 3)
+ # Simulate assigment from saver restore.
+ init_assigns = [tf.assign(moving_mean, expected_mean),
+ tf.assign(moving_variance, expected_var)]
+ sess.run(init_assigns)
+ for _ in range(10):
+ sess.run([output], {images: np.random.rand(*image_shape)})
+ mean = moving_mean.eval()
+ variance = moving_variance.eval()
+ # Although we feed different images, the moving_mean and moving_variance
+ # shouldn't change.
+ self.assertAllClose(mean, expected_mean)
+ self.assertAllClose(variance, expected_var)
+
+ def testReuseVars(self):
+ height, width = 3, 3
+ with self.test_session() as sess:
+ image_shape = (10, height, width, 3)
+ image_values = np.random.rand(*image_shape)
+ expected_mean = np.mean(image_values, axis=(0, 1, 2))
+ expected_var = np.var(image_values, axis=(0, 1, 2))
+ images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
+ output = ops.batch_norm(images, decay=0.1, is_training=False)
+ update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
+ with tf.control_dependencies(update_ops):
+ output = tf.identity(output)
+ # Initialize all variables
+ sess.run(tf.global_variables_initializer())
+ moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
+ moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
+ mean, variance = sess.run([moving_mean, moving_variance])
+ # After initialization moving_mean == 0 and moving_variance == 1.
+ self.assertAllClose(mean, [0] * 3)
+ self.assertAllClose(variance, [1] * 3)
+ # Simulate assigment from saver restore.
+ init_assigns = [tf.assign(moving_mean, expected_mean),
+ tf.assign(moving_variance, expected_var)]
+ sess.run(init_assigns)
+ for _ in range(10):
+ sess.run([output], {images: np.random.rand(*image_shape)})
+ mean = moving_mean.eval()
+ variance = moving_variance.eval()
+ # Although we feed different images, the moving_mean and moving_variance
+ # shouldn't change.
+ self.assertAllClose(mean, expected_mean)
+ self.assertAllClose(variance, expected_var)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/inception/inception/slim/scopes.py b/models/research/inception/inception/slim/scopes.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2fb0a2efa7d30eaddb36fc30265f30cbaeb9ef
--- /dev/null
+++ b/models/research/inception/inception/slim/scopes.py
@@ -0,0 +1,170 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the new arg_scope used for TF-Slim ops.
+
+ Allows one to define models much more compactly by eliminating boilerplate
+ code. This is accomplished through the use of argument scoping (arg_scope).
+
+ Example of how to use scopes.arg_scope:
+
+ with scopes.arg_scope(ops.conv2d, padding='SAME',
+ stddev=0.01, weight_decay=0.0005):
+ net = ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
+ net = ops.conv2d(net, 256, [5, 5], scope='conv2')
+
+ The first call to conv2d will overwrite padding:
+ ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
+ stddev=0.01, weight_decay=0.0005, scope='conv1')
+
+ The second call to Conv will use predefined args:
+ ops.conv2d(inputs, 256, [5, 5], padding='SAME',
+ stddev=0.01, weight_decay=0.0005, scope='conv2')
+
+ Example of how to reuse an arg_scope:
+ with scopes.arg_scope(ops.conv2d, padding='SAME',
+ stddev=0.01, weight_decay=0.0005) as conv2d_arg_scope:
+ net = ops.conv2d(net, 256, [5, 5], scope='conv1')
+ ....
+
+ with scopes.arg_scope(conv2d_arg_scope):
+ net = ops.conv2d(net, 256, [5, 5], scope='conv2')
+
+ Example of how to use scopes.add_arg_scope:
+
+ @scopes.add_arg_scope
+ def conv2d(*args, **kwargs)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import functools
+
+from tensorflow.python.framework import ops
+
+_ARGSTACK_KEY = ("__arg_stack",)
+
+_DECORATED_OPS = set()
+
+
+def _get_arg_stack():
+ stack = ops.get_collection(_ARGSTACK_KEY)
+ if stack:
+ return stack[0]
+ else:
+ stack = [{}]
+ ops.add_to_collection(_ARGSTACK_KEY, stack)
+ return stack
+
+
+def _current_arg_scope():
+ stack = _get_arg_stack()
+ return stack[-1]
+
+
+def _add_op(op):
+ key_op = (op.__module__, op.__name__)
+ if key_op not in _DECORATED_OPS:
+ _DECORATED_OPS.add(key_op)
+
+
+@contextlib.contextmanager
+def arg_scope(list_ops_or_scope, **kwargs):
+ """Stores the default arguments for the given set of list_ops.
+
+ For usage, please see examples at top of the file.
+
+ Args:
+ list_ops_or_scope: List or tuple of operations to set argument scope for or
+ a dictionary containg the current scope. When list_ops_or_scope is a dict,
+ kwargs must be empty. When list_ops_or_scope is a list or tuple, then
+ every op in it need to be decorated with @add_arg_scope to work.
+ **kwargs: keyword=value that will define the defaults for each op in
+ list_ops. All the ops need to accept the given set of arguments.
+
+ Yields:
+ the current_scope, which is a dictionary of {op: {arg: value}}
+ Raises:
+ TypeError: if list_ops is not a list or a tuple.
+ ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
+ """
+ if isinstance(list_ops_or_scope, dict):
+ # Assumes that list_ops_or_scope is a scope that is being reused.
+ if kwargs:
+ raise ValueError("When attempting to re-use a scope by suppling a"
+ "dictionary, kwargs must be empty.")
+ current_scope = list_ops_or_scope.copy()
+ try:
+ _get_arg_stack().append(current_scope)
+ yield current_scope
+ finally:
+ _get_arg_stack().pop()
+ else:
+ # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
+ if not isinstance(list_ops_or_scope, (list, tuple)):
+ raise TypeError("list_ops_or_scope must either be a list/tuple or reused"
+ "scope (i.e. dict)")
+ try:
+ current_scope = _current_arg_scope().copy()
+ for op in list_ops_or_scope:
+ key_op = (op.__module__, op.__name__)
+ if not has_arg_scope(op):
+ raise ValueError("%s is not decorated with @add_arg_scope", key_op)
+ if key_op in current_scope:
+ current_kwargs = current_scope[key_op].copy()
+ current_kwargs.update(kwargs)
+ current_scope[key_op] = current_kwargs
+ else:
+ current_scope[key_op] = kwargs.copy()
+ _get_arg_stack().append(current_scope)
+ yield current_scope
+ finally:
+ _get_arg_stack().pop()
+
+
+def add_arg_scope(func):
+ """Decorates a function with args so it can be used within an arg_scope.
+
+ Args:
+ func: function to decorate.
+
+ Returns:
+ A tuple with the decorated function func_with_args().
+ """
+ @functools.wraps(func)
+ def func_with_args(*args, **kwargs):
+ current_scope = _current_arg_scope()
+ current_args = kwargs
+ key_func = (func.__module__, func.__name__)
+ if key_func in current_scope:
+ current_args = current_scope[key_func].copy()
+ current_args.update(kwargs)
+ return func(*args, **current_args)
+ _add_op(func)
+ return func_with_args
+
+
+def has_arg_scope(func):
+ """Checks whether a func has been decorated with @add_arg_scope or not.
+
+ Args:
+ func: function to check.
+
+ Returns:
+ a boolean.
+ """
+ key_op = (func.__module__, func.__name__)
+ return key_op in _DECORATED_OPS
diff --git a/models/research/inception/inception/slim/scopes_test.py b/models/research/inception/inception/slim/scopes_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd349399ed7300dde38ac9bcb9818abc9d0680b4
--- /dev/null
+++ b/models/research/inception/inception/slim/scopes_test.py
@@ -0,0 +1,162 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests slim.scopes."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+from inception.slim import scopes
+
+
+@scopes.add_arg_scope
+def func1(*args, **kwargs):
+ return (args, kwargs)
+
+
+@scopes.add_arg_scope
+def func2(*args, **kwargs):
+ return (args, kwargs)
+
+
+class ArgScopeTest(tf.test.TestCase):
+
+ def testEmptyArgScope(self):
+ with self.test_session():
+ self.assertEqual(scopes._current_arg_scope(), {})
+
+ def testCurrentArgScope(self):
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ key_op = (func1.__module__, func1.__name__)
+ current_scope = {key_op: func1_kwargs.copy()}
+ with self.test_session():
+ with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope:
+ self.assertDictEqual(scope, current_scope)
+
+ def testCurrentArgScopeNested(self):
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ func2_kwargs = {'b': 2, 'd': [2]}
+ key = lambda f: (f.__module__, f.__name__)
+ current_scope = {key(func1): func1_kwargs.copy(),
+ key(func2): func2_kwargs.copy()}
+ with self.test_session():
+ with scopes.arg_scope([func1], a=1, b=None, c=[1]):
+ with scopes.arg_scope([func2], b=2, d=[2]) as scope:
+ self.assertDictEqual(scope, current_scope)
+
+ def testReuseArgScope(self):
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ key_op = (func1.__module__, func1.__name__)
+ current_scope = {key_op: func1_kwargs.copy()}
+ with self.test_session():
+ with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope1:
+ pass
+ with scopes.arg_scope(scope1) as scope:
+ self.assertDictEqual(scope, current_scope)
+
+ def testReuseArgScopeNested(self):
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ func2_kwargs = {'b': 2, 'd': [2]}
+ key = lambda f: (f.__module__, f.__name__)
+ current_scope1 = {key(func1): func1_kwargs.copy()}
+ current_scope2 = {key(func1): func1_kwargs.copy(),
+ key(func2): func2_kwargs.copy()}
+ with self.test_session():
+ with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope1:
+ with scopes.arg_scope([func2], b=2, d=[2]) as scope2:
+ pass
+ with scopes.arg_scope(scope1):
+ self.assertDictEqual(scopes._current_arg_scope(), current_scope1)
+ with scopes.arg_scope(scope2):
+ self.assertDictEqual(scopes._current_arg_scope(), current_scope2)
+
+ def testSimpleArgScope(self):
+ func1_args = (0,)
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ with self.test_session():
+ with scopes.arg_scope([func1], a=1, b=None, c=[1]):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+
+ def testSimpleArgScopeWithTuple(self):
+ func1_args = (0,)
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ with self.test_session():
+ with scopes.arg_scope((func1,), a=1, b=None, c=[1]):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+
+ def testOverwriteArgScope(self):
+ func1_args = (0,)
+ func1_kwargs = {'a': 1, 'b': 2, 'c': [1]}
+ with scopes.arg_scope([func1], a=1, b=None, c=[1]):
+ args, kwargs = func1(0, b=2)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+
+ def testNestedArgScope(self):
+ func1_args = (0,)
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ with scopes.arg_scope([func1], a=1, b=None, c=[1]):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+ func1_kwargs['b'] = 2
+ with scopes.arg_scope([func1], b=2):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+
+ def testSharedArgScope(self):
+ func1_args = (0,)
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ with scopes.arg_scope([func1, func2], a=1, b=None, c=[1]):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+ args, kwargs = func2(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+
+ def testSharedArgScopeTuple(self):
+ func1_args = (0,)
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ with scopes.arg_scope((func1, func2), a=1, b=None, c=[1]):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+ args, kwargs = func2(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+
+ def testPartiallySharedArgScope(self):
+ func1_args = (0,)
+ func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+ func2_args = (1,)
+ func2_kwargs = {'a': 1, 'b': None, 'd': [2]}
+ with scopes.arg_scope([func1, func2], a=1, b=None):
+ with scopes.arg_scope([func1], c=[1]), scopes.arg_scope([func2], d=[2]):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, func1_args)
+ self.assertDictEqual(kwargs, func1_kwargs)
+ args, kwargs = func2(1)
+ self.assertTupleEqual(args, func2_args)
+ self.assertDictEqual(kwargs, func2_kwargs)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/inception/inception/slim/slim.py b/models/research/inception/inception/slim/slim.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7a5c0f8c52b66db899835480c331ffafdc386e2
--- /dev/null
+++ b/models/research/inception/inception/slim/slim.py
@@ -0,0 +1,24 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TF-Slim grouped API. Please see README.md for details and usage."""
+# pylint: disable=unused-import
+
+# Collapse tf-slim into a single namespace.
+from inception.slim import inception_model as inception
+from inception.slim import losses
+from inception.slim import ops
+from inception.slim import scopes
+from inception.slim import variables
+from inception.slim.scopes import arg_scope
diff --git a/models/research/inception/inception/slim/variables.py b/models/research/inception/inception/slim/variables.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d967b79e9563724b1114995a732cfd4dd486afd
--- /dev/null
+++ b/models/research/inception/inception/slim/variables.py
@@ -0,0 +1,289 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains convenience wrappers for creating variables in TF-Slim.
+
+The variables module is typically used for defining model variables from the
+ops routines (see slim.ops). Such variables are used for training, evaluation
+and inference of models.
+
+All the variables created through this module would be added to the
+MODEL_VARIABLES collection, if you create a model variable outside slim, it can
+be added with slim.variables.add_variable(external_variable, reuse).
+
+Usage:
+ weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
+ l2_regularizer = lambda t: losses.l2_loss(t, weight=0.0005)
+ weights = variables.variable('weights',
+ shape=[100, 100],
+ initializer=weights_initializer,
+ regularizer=l2_regularizer,
+ device='/cpu:0')
+
+ biases = variables.variable('biases',
+ shape=[100],
+ initializer=tf.zeros_initializer(),
+ device='/cpu:0')
+
+ # More complex example.
+
+ net = slim.ops.conv2d(input, 32, [3, 3], scope='conv1')
+ net = slim.ops.conv2d(net, 64, [3, 3], scope='conv2')
+ with slim.arg_scope([variables.variable], restore=False):
+ net = slim.ops.conv2d(net, 64, [3, 3], scope='conv3')
+
+ # Get all model variables from all the layers.
+ model_variables = slim.variables.get_variables()
+
+ # Get all model variables from a specific the layer, i.e 'conv1'.
+ conv1_variables = slim.variables.get_variables('conv1')
+
+ # Get all weights from all the layers.
+ weights = slim.variables.get_variables_by_name('weights')
+
+ # Get all bias from all the layers.
+ biases = slim.variables.get_variables_by_name('biases')
+
+ # Get all variables to restore.
+ # (i.e. only those created by 'conv1' and 'conv2')
+ variables_to_restore = slim.variables.get_variables_to_restore()
+
+************************************************
+* Initializing model variables from a checkpoint
+************************************************
+
+# Create some variables.
+v1 = slim.variables.variable(name="v1", ..., restore=False)
+v2 = slim.variables.variable(name="v2", ...) # By default restore=True
+...
+# The list of variables to restore should only contain 'v2'.
+variables_to_restore = slim.variables.get_variables_to_restore()
+restorer = tf.train.Saver(variables_to_restore)
+with tf.Session() as sess:
+ # Restore variables from disk.
+ restorer.restore(sess, "/tmp/model.ckpt")
+ print("Model restored.")
+ # Do some work with the model
+ ...
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from inception.slim import scopes
+
+# Collection containing all the variables created using slim.variables
+MODEL_VARIABLES = '_model_variables_'
+
+# Collection containing the slim.variables that are created with restore=True.
+VARIABLES_TO_RESTORE = '_variables_to_restore_'
+
+
+def add_variable(var, restore=True):
+ """Adds a variable to the MODEL_VARIABLES collection.
+
+ Optionally it will add the variable to the VARIABLES_TO_RESTORE collection.
+ Args:
+ var: a variable.
+ restore: whether the variable should be added to the
+ VARIABLES_TO_RESTORE collection.
+
+ """
+ collections = [MODEL_VARIABLES]
+ if restore:
+ collections.append(VARIABLES_TO_RESTORE)
+ for collection in collections:
+ if var not in tf.get_collection(collection):
+ tf.add_to_collection(collection, var)
+
+
+def get_variables(scope=None, suffix=None):
+ """Gets the list of variables, filtered by scope and/or suffix.
+
+ Args:
+ scope: an optional scope for filtering the variables to return.
+ suffix: an optional suffix for filtering the variables to return.
+
+ Returns:
+ a copied list of variables with scope and suffix.
+ """
+ candidates = tf.get_collection(MODEL_VARIABLES, scope)[:]
+ if suffix is not None:
+ candidates = [var for var in candidates if var.op.name.endswith(suffix)]
+ return candidates
+
+
+def get_variables_to_restore():
+ """Gets the list of variables to restore.
+
+ Returns:
+ a copied list of variables.
+ """
+ return tf.get_collection(VARIABLES_TO_RESTORE)[:]
+
+
+def get_variables_by_name(given_name, scope=None):
+ """Gets the list of variables that were given that name.
+
+ Args:
+ given_name: name given to the variable without scope.
+ scope: an optional scope for filtering the variables to return.
+
+ Returns:
+ a copied list of variables with the given name and prefix.
+ """
+ return get_variables(scope=scope, suffix=given_name)
+
+
+def get_unique_variable(name):
+ """Gets the variable uniquely identified by that name.
+
+ Args:
+ name: a name that uniquely identifies the variable.
+
+ Returns:
+ a tensorflow variable.
+
+ Raises:
+ ValueError: if no variable uniquely identified by the name exists.
+ """
+ candidates = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name)
+ if not candidates:
+ raise ValueError('Couldnt find variable %s' % name)
+
+ for candidate in candidates:
+ if candidate.op.name == name:
+ return candidate
+ raise ValueError('Variable %s does not uniquely identify a variable', name)
+
+
+class VariableDeviceChooser(object):
+ """Slim device chooser for variables.
+
+ When using a parameter server it will assign them in a round-robin fashion.
+ When not using a parameter server it allows GPU:0 placement otherwise CPU:0.
+ """
+
+ def __init__(self,
+ num_parameter_servers=0,
+ ps_device='/job:ps',
+ placement='CPU:0'):
+ """Initialize VariableDeviceChooser.
+
+ Args:
+ num_parameter_servers: number of parameter servers.
+ ps_device: string representing the parameter server device.
+ placement: string representing the placement of the variable either CPU:0
+ or GPU:0. When using parameter servers forced to CPU:0.
+ """
+ self._num_ps = num_parameter_servers
+ self._ps_device = ps_device
+ self._placement = placement if num_parameter_servers == 0 else 'CPU:0'
+ self._next_task_id = 0
+
+ def __call__(self, op):
+ device_string = ''
+ if self._num_ps > 0:
+ task_id = self._next_task_id
+ self._next_task_id = (self._next_task_id + 1) % self._num_ps
+ device_string = '%s/task:%d' % (self._ps_device, task_id)
+ device_string += '/%s' % self._placement
+ return device_string
+
+
+# TODO(sguada) Remove once get_variable is able to colocate op.devices.
+def variable_device(device, name):
+ """Fix the variable device to colocate its ops."""
+ if callable(device):
+ var_name = tf.get_variable_scope().name + '/' + name
+ var_def = tf.NodeDef(name=var_name, op='Variable')
+ device = device(var_def)
+ if device is None:
+ device = ''
+ return device
+
+
+@scopes.add_arg_scope
+def global_step(device=''):
+ """Returns the global step variable.
+
+ Args:
+ device: Optional device to place the variable. It can be an string or a
+ function that is called to get the device for the variable.
+
+ Returns:
+ the tensor representing the global step variable.
+ """
+ global_step_ref = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
+ if global_step_ref:
+ return global_step_ref[0]
+ else:
+ collections = [
+ VARIABLES_TO_RESTORE,
+ tf.GraphKeys.GLOBAL_VARIABLES,
+ tf.GraphKeys.GLOBAL_STEP,
+ ]
+ # Get the device for the variable.
+ with tf.device(variable_device(device, 'global_step')):
+ return tf.get_variable('global_step', shape=[], dtype=tf.int64,
+ initializer=tf.zeros_initializer(),
+ trainable=False, collections=collections)
+
+
+@scopes.add_arg_scope
+def variable(name, shape=None, dtype=tf.float32, initializer=None,
+ regularizer=None, trainable=True, collections=None, device='',
+ restore=True):
+ """Gets an existing variable with these parameters or creates a new one.
+
+ It also add itself to a group with its name.
+
+ Args:
+ name: the name of the new or existing variable.
+ shape: shape of the new or existing variable.
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
+ initializer: initializer for the variable if one is created.
+ regularizer: a (Tensor -> Tensor or None) function; the result of
+ applying it on a newly created variable will be added to the collection
+ GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ collections: A list of collection names to which the Variable will be added.
+ Note that the variable is always also added to the tf.GraphKeys.GLOBAL_VARIABLES
+ and MODEL_VARIABLES collections.
+ device: Optional device to place the variable. It can be an string or a
+ function that is called to get the device for the variable.
+ restore: whether the variable should be added to the
+ VARIABLES_TO_RESTORE collection.
+
+ Returns:
+ The created or existing variable.
+ """
+ collections = list(collections or [])
+
+ # Make sure variables are added to tf.GraphKeys.GLOBAL_VARIABLES and MODEL_VARIABLES
+ collections += [tf.GraphKeys.GLOBAL_VARIABLES, MODEL_VARIABLES]
+ # Add to VARIABLES_TO_RESTORE if necessary
+ if restore:
+ collections.append(VARIABLES_TO_RESTORE)
+ # Remove duplicates
+ collections = set(collections)
+ # Get the device for the variable.
+ with tf.device(variable_device(device, name)):
+ return tf.get_variable(name, shape=shape, dtype=dtype,
+ initializer=initializer, regularizer=regularizer,
+ trainable=trainable, collections=collections)
diff --git a/models/research/inception/inception/slim/variables_test.py b/models/research/inception/inception/slim/variables_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8c1944dfeb0fba7ad99f104b0c366c41d737c63
--- /dev/null
+++ b/models/research/inception/inception/slim/variables_test.py
@@ -0,0 +1,392 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.variables."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from inception.slim import scopes
+from inception.slim import variables
+
+
+class VariablesTest(tf.test.TestCase):
+
+ def testCreateVariable(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [5])
+ self.assertEquals(a.op.name, 'A/a')
+ self.assertListEqual(a.get_shape().as_list(), [5])
+
+ def testGetVariables(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [5])
+ with tf.variable_scope('B'):
+ b = variables.variable('a', [5])
+ self.assertEquals([a, b], variables.get_variables())
+ self.assertEquals([a], variables.get_variables('A'))
+ self.assertEquals([b], variables.get_variables('B'))
+
+ def testGetVariablesSuffix(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [5])
+ with tf.variable_scope('A'):
+ b = variables.variable('b', [5])
+ self.assertEquals([a], variables.get_variables(suffix='a'))
+ self.assertEquals([b], variables.get_variables(suffix='b'))
+
+ def testGetVariableWithSingleVar(self):
+ with self.test_session():
+ with tf.variable_scope('parent'):
+ a = variables.variable('child', [5])
+ self.assertEquals(a, variables.get_unique_variable('parent/child'))
+
+ def testGetVariableWithDistractors(self):
+ with self.test_session():
+ with tf.variable_scope('parent'):
+ a = variables.variable('child', [5])
+ with tf.variable_scope('child'):
+ variables.variable('grandchild1', [7])
+ variables.variable('grandchild2', [9])
+ self.assertEquals(a, variables.get_unique_variable('parent/child'))
+
+ def testGetVariableThrowsExceptionWithNoMatch(self):
+ var_name = 'cant_find_me'
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ variables.get_unique_variable(var_name)
+
+ def testGetThrowsExceptionWithChildrenButNoMatch(self):
+ var_name = 'parent/child'
+ with self.test_session():
+ with tf.variable_scope(var_name):
+ variables.variable('grandchild1', [7])
+ variables.variable('grandchild2', [9])
+ with self.assertRaises(ValueError):
+ variables.get_unique_variable(var_name)
+
+ def testGetVariablesToRestore(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [5])
+ with tf.variable_scope('B'):
+ b = variables.variable('a', [5])
+ self.assertEquals([a, b], variables.get_variables_to_restore())
+
+ def testNoneGetVariablesToRestore(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [5], restore=False)
+ with tf.variable_scope('B'):
+ b = variables.variable('a', [5], restore=False)
+ self.assertEquals([], variables.get_variables_to_restore())
+ self.assertEquals([a, b], variables.get_variables())
+
+ def testGetMixedVariablesToRestore(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [5])
+ b = variables.variable('b', [5], restore=False)
+ with tf.variable_scope('B'):
+ c = variables.variable('c', [5])
+ d = variables.variable('d', [5], restore=False)
+ self.assertEquals([a, b, c, d], variables.get_variables())
+ self.assertEquals([a, c], variables.get_variables_to_restore())
+
+ def testReuseVariable(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [])
+ with tf.variable_scope('A', reuse=True):
+ b = variables.variable('a', [])
+ self.assertEquals(a, b)
+ self.assertListEqual([a], variables.get_variables())
+
+ def testVariableWithDevice(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [], device='cpu:0')
+ b = variables.variable('b', [], device='cpu:1')
+ self.assertDeviceEqual(a.device, 'cpu:0')
+ self.assertDeviceEqual(b.device, 'cpu:1')
+
+ def testVariableWithDeviceFromScope(self):
+ with self.test_session():
+ with tf.device('/cpu:0'):
+ a = variables.variable('a', [])
+ b = variables.variable('b', [], device='cpu:1')
+ self.assertDeviceEqual(a.device, 'cpu:0')
+ self.assertDeviceEqual(b.device, 'cpu:1')
+
+ def testVariableWithDeviceFunction(self):
+ class DevFn(object):
+
+ def __init__(self):
+ self.counter = -1
+
+ def __call__(self, op):
+ self.counter += 1
+ return 'cpu:%d' % self.counter
+
+ with self.test_session():
+ with scopes.arg_scope([variables.variable], device=DevFn()):
+ a = variables.variable('a', [])
+ b = variables.variable('b', [])
+ c = variables.variable('c', [], device='cpu:12')
+ d = variables.variable('d', [])
+ with tf.device('cpu:99'):
+ e_init = tf.constant(12)
+ e = variables.variable('e', initializer=e_init)
+ self.assertDeviceEqual(a.device, 'cpu:0')
+ self.assertDeviceEqual(a.initial_value.device, 'cpu:0')
+ self.assertDeviceEqual(b.device, 'cpu:1')
+ self.assertDeviceEqual(b.initial_value.device, 'cpu:1')
+ self.assertDeviceEqual(c.device, 'cpu:12')
+ self.assertDeviceEqual(c.initial_value.device, 'cpu:12')
+ self.assertDeviceEqual(d.device, 'cpu:2')
+ self.assertDeviceEqual(d.initial_value.device, 'cpu:2')
+ self.assertDeviceEqual(e.device, 'cpu:3')
+ self.assertDeviceEqual(e.initial_value.device, 'cpu:99')
+
+ def testVariableWithReplicaDeviceSetter(self):
+ with self.test_session():
+ with tf.device(tf.train.replica_device_setter(ps_tasks=2)):
+ a = variables.variable('a', [])
+ b = variables.variable('b', [])
+ c = variables.variable('c', [], device='cpu:12')
+ d = variables.variable('d', [])
+ with tf.device('cpu:99'):
+ e_init = tf.constant(12)
+ e = variables.variable('e', initializer=e_init)
+ # The values below highlight how the replica_device_setter puts initial
+ # values on the worker job, and how it merges explicit devices.
+ self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
+ self.assertDeviceEqual(a.initial_value.device, '/job:worker/cpu:0')
+ self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
+ self.assertDeviceEqual(b.initial_value.device, '/job:worker/cpu:0')
+ self.assertDeviceEqual(c.device, '/job:ps/task:0/cpu:12')
+ self.assertDeviceEqual(c.initial_value.device, '/job:worker/cpu:12')
+ self.assertDeviceEqual(d.device, '/job:ps/task:1/cpu:0')
+ self.assertDeviceEqual(d.initial_value.device, '/job:worker/cpu:0')
+ self.assertDeviceEqual(e.device, '/job:ps/task:0/cpu:0')
+ self.assertDeviceEqual(e.initial_value.device, '/job:worker/cpu:99')
+
+ def testVariableWithVariableDeviceChooser(self):
+
+ with tf.Graph().as_default():
+ device_fn = variables.VariableDeviceChooser(num_parameter_servers=2)
+ with scopes.arg_scope([variables.variable], device=device_fn):
+ a = variables.variable('a', [])
+ b = variables.variable('b', [])
+ c = variables.variable('c', [], device='cpu:12')
+ d = variables.variable('d', [])
+ with tf.device('cpu:99'):
+ e_init = tf.constant(12)
+ e = variables.variable('e', initializer=e_init)
+ # The values below highlight how the VariableDeviceChooser puts initial
+ # values on the same device as the variable job.
+ self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
+ self.assertDeviceEqual(a.initial_value.device, a.device)
+ self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
+ self.assertDeviceEqual(b.initial_value.device, b.device)
+ self.assertDeviceEqual(c.device, '/cpu:12')
+ self.assertDeviceEqual(c.initial_value.device, c.device)
+ self.assertDeviceEqual(d.device, '/job:ps/task:0/cpu:0')
+ self.assertDeviceEqual(d.initial_value.device, d.device)
+ self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
+ self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+
+ def testVariableGPUPlacement(self):
+
+ with tf.Graph().as_default():
+ device_fn = variables.VariableDeviceChooser(placement='gpu:0')
+ with scopes.arg_scope([variables.variable], device=device_fn):
+ a = variables.variable('a', [])
+ b = variables.variable('b', [])
+ c = variables.variable('c', [], device='cpu:12')
+ d = variables.variable('d', [])
+ with tf.device('cpu:99'):
+ e_init = tf.constant(12)
+ e = variables.variable('e', initializer=e_init)
+ # The values below highlight how the VariableDeviceChooser puts initial
+ # values on the same device as the variable job.
+ self.assertDeviceEqual(a.device, '/gpu:0')
+ self.assertDeviceEqual(a.initial_value.device, a.device)
+ self.assertDeviceEqual(b.device, '/gpu:0')
+ self.assertDeviceEqual(b.initial_value.device, b.device)
+ self.assertDeviceEqual(c.device, '/cpu:12')
+ self.assertDeviceEqual(c.initial_value.device, c.device)
+ self.assertDeviceEqual(d.device, '/gpu:0')
+ self.assertDeviceEqual(d.initial_value.device, d.device)
+ self.assertDeviceEqual(e.device, '/gpu:0')
+ self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+
+ def testVariableCollection(self):
+ with self.test_session():
+ a = variables.variable('a', [], collections='A')
+ b = variables.variable('b', [], collections='B')
+ self.assertEquals(a, tf.get_collection('A')[0])
+ self.assertEquals(b, tf.get_collection('B')[0])
+
+ def testVariableCollections(self):
+ with self.test_session():
+ a = variables.variable('a', [], collections=['A', 'C'])
+ b = variables.variable('b', [], collections=['B', 'C'])
+ self.assertEquals(a, tf.get_collection('A')[0])
+ self.assertEquals(b, tf.get_collection('B')[0])
+
+ def testVariableCollectionsWithArgScope(self):
+ with self.test_session():
+ with scopes.arg_scope([variables.variable], collections='A'):
+ a = variables.variable('a', [])
+ b = variables.variable('b', [])
+ self.assertListEqual([a, b], tf.get_collection('A'))
+
+ def testVariableCollectionsWithArgScopeNested(self):
+ with self.test_session():
+ with scopes.arg_scope([variables.variable], collections='A'):
+ a = variables.variable('a', [])
+ with scopes.arg_scope([variables.variable], collections='B'):
+ b = variables.variable('b', [])
+ self.assertEquals(a, tf.get_collection('A')[0])
+ self.assertEquals(b, tf.get_collection('B')[0])
+
+ def testVariableCollectionsWithArgScopeNonNested(self):
+ with self.test_session():
+ with scopes.arg_scope([variables.variable], collections='A'):
+ a = variables.variable('a', [])
+ with scopes.arg_scope([variables.variable], collections='B'):
+ b = variables.variable('b', [])
+ variables.variable('c', [])
+ self.assertListEqual([a], tf.get_collection('A'))
+ self.assertListEqual([b], tf.get_collection('B'))
+
+ def testVariableRestoreWithArgScopeNested(self):
+ with self.test_session():
+ with scopes.arg_scope([variables.variable], restore=True):
+ a = variables.variable('a', [])
+ with scopes.arg_scope([variables.variable],
+ trainable=False,
+ collections=['A', 'B']):
+ b = variables.variable('b', [])
+ c = variables.variable('c', [])
+ self.assertListEqual([a, b, c], variables.get_variables_to_restore())
+ self.assertListEqual([a, c], tf.trainable_variables())
+ self.assertListEqual([b], tf.get_collection('A'))
+ self.assertListEqual([b], tf.get_collection('B'))
+
+
+class GetVariablesByNameTest(tf.test.TestCase):
+
+ def testGetVariableGivenNameScoped(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [5])
+ b = variables.variable('b', [5])
+ self.assertEquals([a], variables.get_variables_by_name('a'))
+ self.assertEquals([b], variables.get_variables_by_name('b'))
+
+ def testGetVariablesByNameReturnsByValueWithScope(self):
+ with self.test_session():
+ with tf.variable_scope('A'):
+ a = variables.variable('a', [5])
+ matched_variables = variables.get_variables_by_name('a')
+
+ # If variables.get_variables_by_name returns the list by reference, the
+ # following append should persist, and be returned, in subsequent calls
+ # to variables.get_variables_by_name('a').
+ matched_variables.append(4)
+
+ matched_variables = variables.get_variables_by_name('a')
+ self.assertEquals([a], matched_variables)
+
+ def testGetVariablesByNameReturnsByValueWithoutScope(self):
+ with self.test_session():
+ a = variables.variable('a', [5])
+ matched_variables = variables.get_variables_by_name('a')
+
+ # If variables.get_variables_by_name returns the list by reference, the
+ # following append should persist, and be returned, in subsequent calls
+ # to variables.get_variables_by_name('a').
+ matched_variables.append(4)
+
+ matched_variables = variables.get_variables_by_name('a')
+ self.assertEquals([a], matched_variables)
+
+
+class GlobalStepTest(tf.test.TestCase):
+
+ def testStable(self):
+ with tf.Graph().as_default():
+ gs = variables.global_step()
+ gs2 = variables.global_step()
+ self.assertTrue(gs is gs2)
+
+ def testDevice(self):
+ with tf.Graph().as_default():
+ with scopes.arg_scope([variables.global_step], device='/gpu:0'):
+ gs = variables.global_step()
+ self.assertDeviceEqual(gs.device, '/gpu:0')
+
+ def testDeviceFn(self):
+ class DevFn(object):
+
+ def __init__(self):
+ self.counter = -1
+
+ def __call__(self, op):
+ self.counter += 1
+ return '/cpu:%d' % self.counter
+
+ with tf.Graph().as_default():
+ with scopes.arg_scope([variables.global_step], device=DevFn()):
+ gs = variables.global_step()
+ gs2 = variables.global_step()
+ self.assertDeviceEqual(gs.device, '/cpu:0')
+ self.assertEquals(gs, gs2)
+ self.assertDeviceEqual(gs2.device, '/cpu:0')
+
+ def testReplicaDeviceSetter(self):
+ device_fn = tf.train.replica_device_setter(2)
+ with tf.Graph().as_default():
+ with scopes.arg_scope([variables.global_step], device=device_fn):
+ gs = variables.global_step()
+ gs2 = variables.global_step()
+ self.assertEquals(gs, gs2)
+ self.assertDeviceEqual(gs.device, '/job:ps/task:0')
+ self.assertDeviceEqual(gs.initial_value.device, '/job:ps/task:0')
+ self.assertDeviceEqual(gs2.device, '/job:ps/task:0')
+ self.assertDeviceEqual(gs2.initial_value.device, '/job:ps/task:0')
+
+ def testVariableWithVariableDeviceChooser(self):
+
+ with tf.Graph().as_default():
+ device_fn = variables.VariableDeviceChooser()
+ with scopes.arg_scope([variables.global_step], device=device_fn):
+ gs = variables.global_step()
+ gs2 = variables.global_step()
+ self.assertEquals(gs, gs2)
+ self.assertDeviceEqual(gs.device, 'cpu:0')
+ self.assertDeviceEqual(gs.initial_value.device, gs.device)
+ self.assertDeviceEqual(gs2.device, 'cpu:0')
+ self.assertDeviceEqual(gs2.initial_value.device, gs2.device)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/keypointnet/CONTRIBUTING.md b/models/research/keypointnet/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..939e5341e74dc2371c8b47f0e27b50581bed5f63
--- /dev/null
+++ b/models/research/keypointnet/CONTRIBUTING.md
@@ -0,0 +1,28 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution;
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+## Community Guidelines
+
+This project follows [Google's Open Source Community
+Guidelines](https://opensource.google.com/conduct/).
diff --git a/models/research/keypointnet/LICENSE b/models/research/keypointnet/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7
--- /dev/null
+++ b/models/research/keypointnet/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/models/research/keypointnet/README.md b/models/research/keypointnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8de88ca5a18816984302a9c20639364a7c8cde53
--- /dev/null
+++ b/models/research/keypointnet/README.md
@@ -0,0 +1,46 @@
+
+
+
+
+# KeypointNet
+This is an implementation of the keypoint network proposed in "Discovery of
+Latent 3D Keypoints via End-to-end Geometric Reasoning
+[[pdf](https://arxiv.org/pdf/1807.03146.pdf)]". Given a single 2D image of a
+known class, this network can predict a set of 3D keypoints that are consistent
+across viewing angles of the same object and across object instances. These
+keypoints and their detectors are discovered and learned automatically without
+keypoint location supervision [[demo](https://keypointnet.github.io)].
+
+## Datasets:
+ ShapeNet's rendering for
+ [Cars](https://storage.googleapis.com/discovery-3dkeypoints-data/cars_with_keypoints.zip),
+ [Planes](https://storage.googleapis.com/discovery-3dkeypoints-data/planes_with_keypoints.zip),
+ [Chairs](https://storage.googleapis.com/discovery-3dkeypoints-data/chairs_with_keypoints.zip).
+
+ Each set contains:
+1. tfrecords
+2. train.txt, a list of tfrecords used for training.
+2. dev.txt, a list of tfrecords used for validation.
+3. test.txt, a list of tfrecords used for testing.
+4. projection.txt, storing the global 4x4 camera projection matrix.
+5. job.txt, storing ShapeNet's object IDs in each tfrecord.
+
+## Training:
+ Run `main.py --model_dir=MODEL_DIR --dset=DSET`
+
+ where MODEL_DIR is a folder for storing model checkpoints: (see [tf.estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)), and DSET should point to the folder containing tfrecords (download above).
+
+## Inference:
+ Run `main.py --model_dir=MODEL_DIR --input=INPUT --predict`
+
+ where MODEL_DIR is the model checkpoint folder, and INPUT is a folder containing png or jpeg test images.
+ We trained the network using the total batch size of 256 (8 x 32 replicas). You may have to tune the learning rate if your batch size is different.
+
+## Code credit:
+ Supasorn Suwajanakorn
+
+## Contact:
+ supasorn@gmail.com, [snavely,tompson,mnorouzi]@google.com
+
+
+(This is not an officially supported Google product)
diff --git a/models/research/keypointnet/main.py b/models/research/keypointnet/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..04b30159404e01529c898ee75fb1ed78f705f539
--- /dev/null
+++ b/models/research/keypointnet/main.py
@@ -0,0 +1,697 @@
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""KeypointNet!!
+
+A reimplementation of 'Discovery of Latent 3D Keypoints via End-to-end
+Geometric Reasoning' keypoint network. Given a single 2D image of a known class,
+this network can predict a set of 3D keypoints that are consistent across
+viewing angles of the same object and across object instances. These keypoints
+and their detectors are discovered and learned automatically without
+keypoint location supervision.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from scipy import misc
+import sys
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+import utils
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_boolean("predict", False, "Running inference if true")
+tf.app.flags.DEFINE_string(
+ "input",
+ "",
+ "Input folder containing images")
+tf.app.flags.DEFINE_string("model_dir", None, "Estimator model_dir")
+tf.app.flags.DEFINE_string(
+ "dset",
+ "",
+ "Path to the directory containing the dataset.")
+tf.app.flags.DEFINE_integer("steps", 200000, "Training steps")
+tf.app.flags.DEFINE_integer("batch_size", 8, "Size of mini-batch.")
+tf.app.flags.DEFINE_string(
+ "hparams", "",
+ "A comma-separated list of `name=value` hyperparameter values. This flag "
+ "is used to override hyperparameter settings either when manually "
+ "selecting hyperparameters or when using Vizier.")
+tf.app.flags.DEFINE_integer(
+ "sync_replicas", -1,
+ "If > 0, use SyncReplicasOptimizer and use this many replicas per sync.")
+
+# Fixed input size 128 x 128.
+vw = vh = 128
+
+
+def create_input_fn(split, batch_size):
+ """Returns input_fn for tf.estimator.Estimator.
+
+ Reads tfrecords and construts input_fn for either training or eval. All
+ tfrecords not in test.txt or dev.txt will be assigned to training set.
+
+ Args:
+ split: A string indicating the split. Can be either 'train' or 'validation'.
+ batch_size: The batch size!
+
+ Returns:
+ input_fn for tf.estimator.Estimator.
+
+ Raises:
+ IOError: If test.txt or dev.txt are not found.
+ """
+
+ if (not os.path.exists(os.path.join(FLAGS.dset, "test.txt")) or
+ not os.path.exists(os.path.join(FLAGS.dset, "dev.txt"))):
+ raise IOError("test.txt or dev.txt not found")
+
+ with open(os.path.join(FLAGS.dset, "test.txt"), "r") as f:
+ testset = [x.strip() for x in f.readlines()]
+
+ with open(os.path.join(FLAGS.dset, "dev.txt"), "r") as f:
+ validset = [x.strip() for x in f.readlines()]
+
+ files = os.listdir(FLAGS.dset)
+ filenames = []
+ for f in files:
+ sp = os.path.splitext(f)
+ if sp[1] != ".tfrecord" or sp[0] in testset:
+ continue
+
+ if ((split == "validation" and sp[0] in validset) or
+ (split == "train" and sp[0] not in validset)):
+ filenames.append(os.path.join(FLAGS.dset, f))
+
+ def input_fn():
+ """input_fn for tf.estimator.Estimator."""
+
+ def parser(serialized_example):
+ """Parses a single tf.Example into image and label tensors."""
+ fs = tf.parse_single_example(
+ serialized_example,
+ features={
+ "img0": tf.FixedLenFeature([], tf.string),
+ "img1": tf.FixedLenFeature([], tf.string),
+ "mv0": tf.FixedLenFeature([16], tf.float32),
+ "mvi0": tf.FixedLenFeature([16], tf.float32),
+ "mv1": tf.FixedLenFeature([16], tf.float32),
+ "mvi1": tf.FixedLenFeature([16], tf.float32),
+ })
+
+ fs["img0"] = tf.div(tf.to_float(tf.image.decode_png(fs["img0"], 4)), 255)
+ fs["img1"] = tf.div(tf.to_float(tf.image.decode_png(fs["img1"], 4)), 255)
+
+ fs["img0"].set_shape([vh, vw, 4])
+ fs["img1"].set_shape([vh, vw, 4])
+
+ # fs["lr0"] = [fs["mv0"][0]]
+ # fs["lr1"] = [fs["mv1"][0]]
+
+ fs["lr0"] = tf.convert_to_tensor([fs["mv0"][0]])
+ fs["lr1"] = tf.convert_to_tensor([fs["mv1"][0]])
+
+ return fs
+
+ np.random.shuffle(filenames)
+ dataset = tf.data.TFRecordDataset(filenames)
+ dataset = dataset.map(parser, num_parallel_calls=4)
+ dataset = dataset.shuffle(400).repeat().batch(batch_size)
+ dataset = dataset.prefetch(buffer_size=256)
+
+ return dataset.make_one_shot_iterator().get_next(), None
+
+ return input_fn
+
+
+class Transformer(object):
+ """A utility for projecting 3D points to 2D coordinates and vice versa.
+
+ 3D points are represented in 4D-homogeneous world coordinates. The pixel
+ coordinates are represented in normalized device coordinates [-1, 1].
+ See https://learnopengl.com/Getting-started/Coordinate-Systems.
+ """
+
+ def __get_matrix(self, lines):
+ return np.array([[float(y) for y in x.strip().split(" ")] for x in lines])
+
+ def __read_projection_matrix(self, filename):
+ if not os.path.exists(filename):
+ filename = "/cns/vz-d/home/supasorn/datasets/cars/projection.txt"
+ with open(filename, "r") as f:
+ lines = f.readlines()
+ return self.__get_matrix(lines)
+
+ def __init__(self, w, h, dataset_dir):
+ self.w = w
+ self.h = h
+ p = self.__read_projection_matrix(dataset_dir + "projection.txt")
+
+ # transposed of inversed projection matrix.
+ self.pinv_t = tf.constant([[1.0 / p[0, 0], 0, 0,
+ 0], [0, 1.0 / p[1, 1], 0, 0], [0, 0, 1, 0],
+ [0, 0, 0, 1]])
+ self.f = p[0, 0]
+
+ def project(self, xyzw):
+ """Projects homogeneous 3D coordinates to normalized device coordinates."""
+
+ z = xyzw[:, :, 2:3] + 1e-8
+ return tf.concat([-self.f * xyzw[:, :, :2] / z, z], axis=2)
+
+ def unproject(self, xyz):
+ """Unprojects normalized device coordinates with depth to 3D coordinates."""
+
+ z = xyz[:, :, 2:]
+ xy = -xyz * z
+
+ def batch_matmul(a, b):
+ return tf.reshape(
+ tf.matmul(tf.reshape(a, [-1, a.shape[2].value]), b),
+ [-1, a.shape[1].value, a.shape[2].value])
+
+ return batch_matmul(
+ tf.concat([xy[:, :, :2], z, tf.ones_like(z)], axis=2), self.pinv_t)
+
+
+def meshgrid(h):
+ """Returns a meshgrid ranging from [-1, 1] in x, y axes."""
+
+ r = np.arange(0.5, h, 1) / (h / 2) - 1
+ ranx, rany = tf.meshgrid(r, -r)
+ return tf.to_float(ranx), tf.to_float(rany)
+
+
+def estimate_rotation(xyz0, xyz1, pconf, noise):
+ """Estimates the rotation between two sets of keypoints.
+
+ The rotation is estimated by first subtracting mean from each set of keypoints
+ and computing SVD of the covariance matrix.
+
+ Args:
+ xyz0: [batch, num_kp, 3] The first set of keypoints.
+ xyz1: [batch, num_kp, 3] The second set of keypoints.
+ pconf: [batch, num_kp] The weights used to compute the rotation estimate.
+ noise: A number indicating the noise added to the keypoints.
+
+ Returns:
+ [batch, 3, 3] A batch of transposed 3 x 3 rotation matrices.
+ """
+
+ xyz0 += tf.random_normal(tf.shape(xyz0), mean=0, stddev=noise)
+ xyz1 += tf.random_normal(tf.shape(xyz1), mean=0, stddev=noise)
+
+ pconf2 = tf.expand_dims(pconf, 2)
+ cen0 = tf.reduce_sum(xyz0 * pconf2, 1, keepdims=True)
+ cen1 = tf.reduce_sum(xyz1 * pconf2, 1, keepdims=True)
+
+ x = xyz0 - cen0
+ y = xyz1 - cen1
+
+ cov = tf.matmul(tf.matmul(x, tf.matrix_diag(pconf), transpose_a=True), y)
+ _, u, v = tf.svd(cov, full_matrices=True)
+
+ d = tf.matrix_determinant(tf.matmul(v, u, transpose_b=True))
+ ud = tf.concat(
+ [u[:, :, :-1], u[:, :, -1:] * tf.expand_dims(tf.expand_dims(d, 1), 1)],
+ axis=2)
+ return tf.matmul(ud, v, transpose_b=True)
+
+
+def relative_pose_loss(xyz0, xyz1, rot, pconf, noise):
+ """Computes the relative pose loss (chordal, angular).
+
+ Args:
+ xyz0: [batch, num_kp, 3] The first set of keypoints.
+ xyz1: [batch, num_kp, 3] The second set of keypoints.
+ rot: [batch, 4, 4] The ground-truth rotation matrices.
+ pconf: [batch, num_kp] The weights used to compute the rotation estimate.
+ noise: A number indicating the noise added to the keypoints.
+
+ Returns:
+ A tuple (chordal loss, angular loss).
+ """
+
+ r_transposed = estimate_rotation(xyz0, xyz1, pconf, noise)
+ rotation = rot[:, :3, :3]
+ frob_sqr = tf.reduce_sum(tf.square(r_transposed - rotation), axis=[1, 2])
+ frob = tf.sqrt(frob_sqr)
+
+ return tf.reduce_mean(frob_sqr), \
+ 2.0 * tf.reduce_mean(tf.asin(tf.minimum(1.0, frob / (2 * math.sqrt(2)))))
+
+
+def separation_loss(xyz, delta):
+ """Computes the separation loss.
+
+ Args:
+ xyz: [batch, num_kp, 3] Input keypoints.
+ delta: A separation threshold. Incur 0 cost if the distance >= delta.
+
+ Returns:
+ The seperation loss.
+ """
+
+ num_kp = tf.shape(xyz)[1]
+ t1 = tf.tile(xyz, [1, num_kp, 1])
+
+ t2 = tf.reshape(tf.tile(xyz, [1, 1, num_kp]), tf.shape(t1))
+ diffsq = tf.square(t1 - t2)
+
+ # -> [batch, num_kp ^ 2]
+ lensqr = tf.reduce_sum(diffsq, axis=2)
+
+ return (tf.reduce_sum(tf.maximum(-lensqr + delta, 0.0)) / tf.to_float(
+ num_kp * FLAGS.batch_size * 2))
+
+
+def consistency_loss(uv0, uv1, pconf):
+ """Computes multi-view consistency loss between two sets of keypoints.
+
+ Args:
+ uv0: [batch, num_kp, 2] The first set of keypoint 2D coordinates.
+ uv1: [batch, num_kp, 2] The second set of keypoint 2D coordinates.
+ pconf: [batch, num_kp] The weights used to compute the rotation estimate.
+
+ Returns:
+ The consistency loss.
+ """
+
+ # [batch, num_kp, 2]
+ wd = tf.square(uv0 - uv1) * tf.expand_dims(pconf, 2)
+ wd = tf.reduce_sum(wd, axis=[1, 2])
+ return tf.reduce_mean(wd)
+
+
+def variance_loss(probmap, ranx, rany, uv):
+ """Computes the variance loss as part of Sillhouette consistency.
+
+ Args:
+ probmap: [batch, num_kp, h, w] The distribution map of keypoint locations.
+ ranx: X-axis meshgrid.
+ rany: Y-axis meshgrid.
+ uv: [batch, num_kp, 2] Keypoint locations (in NDC).
+
+ Returns:
+ The variance loss.
+ """
+
+ ran = tf.stack([ranx, rany], axis=2)
+
+ sh = tf.shape(ran)
+ # [batch, num_kp, vh, vw, 2]
+ ran = tf.reshape(ran, [1, 1, sh[0], sh[1], 2])
+
+ sh = tf.shape(uv)
+ uv = tf.reshape(uv, [sh[0], sh[1], 1, 1, 2])
+
+ diff = tf.reduce_sum(tf.square(uv - ran), axis=4)
+ diff *= probmap
+
+ return tf.reduce_mean(tf.reduce_sum(diff, axis=[2, 3]))
+
+
+def dilated_cnn(images, num_filters, is_training):
+ """Constructs a base dilated convolutional network.
+
+ Args:
+ images: [batch, h, w, 3] Input RGB images.
+ num_filters: The number of filters for all layers.
+ is_training: True if this function is called during training.
+
+ Returns:
+ Output of this dilated CNN.
+ """
+
+ net = images
+
+ with slim.arg_scope(
+ [slim.conv2d, slim.fully_connected],
+ normalizer_fn=slim.batch_norm,
+ activation_fn=lambda x: tf.nn.leaky_relu(x, alpha=0.1),
+ normalizer_params={"is_training": is_training}):
+ for i, r in enumerate([1, 1, 2, 4, 8, 16, 1, 2, 4, 8, 16, 1]):
+ net = slim.conv2d(net, num_filters, [3, 3], rate=r, scope="dconv%d" % i)
+
+ return net
+
+
+def orientation_network(images, num_filters, is_training):
+ """Constructs a network that infers the orientation of an object.
+
+ Args:
+ images: [batch, h, w, 3] Input RGB images.
+ num_filters: The number of filters for all layers.
+ is_training: True if this function is called during training.
+
+ Returns:
+ Output of the orientation network.
+ """
+
+ with tf.variable_scope("OrientationNetwork"):
+ net = dilated_cnn(images, num_filters, is_training)
+
+ modules = 2
+ prob = slim.conv2d(net, 2, [3, 3], rate=1, activation_fn=None)
+ prob = tf.transpose(prob, [0, 3, 1, 2])
+
+ prob = tf.reshape(prob, [-1, modules, vh * vw])
+ prob = tf.nn.softmax(prob)
+ ranx, rany = meshgrid(vh)
+
+ prob = tf.reshape(prob, [-1, 2, vh, vw])
+
+ sx = tf.reduce_sum(prob * ranx, axis=[2, 3])
+ sy = tf.reduce_sum(prob * rany, axis=[2, 3]) # -> batch x modules
+
+ out_xy = tf.reshape(tf.stack([sx, sy], -1), [-1, modules, 2])
+
+ return out_xy
+
+
+def keypoint_network(rgba,
+ num_filters,
+ num_kp,
+ is_training,
+ lr_gt=None,
+ anneal=1):
+ """Constructs our main keypoint network that predicts 3D keypoints.
+
+ Args:
+ rgba: [batch, h, w, 4] Input RGB images with alpha channel.
+ num_filters: The number of filters for all layers.
+ num_kp: The number of keypoints.
+ is_training: True if this function is called during training.
+ lr_gt: The groundtruth orientation flag used at the beginning of training.
+ Then we linearly anneal in the prediction.
+ anneal: A number between [0, 1] where 1 means using the ground-truth
+ orientation and 0 means using our estimate.
+
+ Returns:
+ uv: [batch, num_kp, 2] 2D locations of keypoints.
+ z: [batch, num_kp] The depth of keypoints.
+ orient: [batch, 2, 2] Two 2D coordinates that correspond to [1, 0, 0] and
+ [-1, 0, 0] in object space.
+ sill: The Sillhouette loss.
+ variance: The variance loss.
+ prob_viz: A visualization of all predicted keypoints.
+ prob_vizs: A list of visualizations of each keypoint.
+
+ """
+
+ images = rgba[:, :, :, :3]
+
+ # [batch, 1]
+ orient = orientation_network(images, num_filters * 0.5, is_training)
+
+ # [batch, 1]
+ lr_estimated = tf.maximum(0.0, tf.sign(orient[:, 0, :1] - orient[:, 1, :1]))
+
+ if lr_gt is None:
+ lr = lr_estimated
+ else:
+ lr_gt = tf.maximum(0.0, tf.sign(lr_gt[:, :1]))
+ lr = tf.round(lr_gt * anneal + lr_estimated * (1 - anneal))
+
+ lrtiled = tf.tile(
+ tf.expand_dims(tf.expand_dims(lr, 1), 1),
+ [1, images.shape[1], images.shape[2], 1])
+
+ images = tf.concat([images, lrtiled], axis=3)
+
+ mask = rgba[:, :, :, 3]
+ mask = tf.cast(tf.greater(mask, tf.zeros_like(mask)), dtype=tf.float32)
+
+ net = dilated_cnn(images, num_filters, is_training)
+
+ # The probability distribution map.
+ prob = slim.conv2d(
+ net, num_kp, [3, 3], rate=1, scope="conv_xy", activation_fn=None)
+
+ # We added the fixed camera distance as a bias.
+ z = -30 + slim.conv2d(
+ net, num_kp, [3, 3], rate=1, scope="conv_z", activation_fn=None)
+
+ prob = tf.transpose(prob, [0, 3, 1, 2])
+ z = tf.transpose(z, [0, 3, 1, 2])
+
+ prob = tf.reshape(prob, [-1, num_kp, vh * vw])
+ prob = tf.nn.softmax(prob, name="softmax")
+
+ ranx, rany = meshgrid(vh)
+ prob = tf.reshape(prob, [-1, num_kp, vh, vw])
+
+ # These are for visualizing the distribution maps.
+ prob_viz = tf.expand_dims(tf.reduce_sum(prob, 1), 3)
+ prob_vizs = [tf.expand_dims(prob[:, i, :, :], 3) for i in range(num_kp)]
+
+ sx = tf.reduce_sum(prob * ranx, axis=[2, 3])
+ sy = tf.reduce_sum(prob * rany, axis=[2, 3]) # -> batch x num_kp
+
+ # [batch, num_kp]
+ sill = tf.reduce_sum(prob * tf.expand_dims(mask, 1), axis=[2, 3])
+ sill = tf.reduce_mean(-tf.log(sill + 1e-12))
+
+ z = tf.reduce_sum(prob * z, axis=[2, 3])
+ uv = tf.reshape(tf.stack([sx, sy], -1), [-1, num_kp, 2])
+
+ variance = variance_loss(prob, ranx, rany, uv)
+
+ return uv, z, orient, sill, variance, prob_viz, prob_vizs
+
+
+def model_fn(features, labels, mode, hparams):
+ """Returns model_fn for tf.estimator.Estimator."""
+
+ del labels
+
+ is_training = (mode == tf.estimator.ModeKeys.TRAIN)
+ t = Transformer(vw, vh, FLAGS.dset)
+
+ def func1(x):
+ return tf.transpose(tf.reshape(features[x], [-1, 4, 4]), [0, 2, 1])
+
+ mv = [func1("mv%d" % i) for i in range(2)]
+ mvi = [func1("mvi%d" % i) for i in range(2)]
+
+ uvz = [None] * 2
+ uvz_proj = [None] * 2 # uvz coordinates projected on to the other view.
+ viz = [None] * 2
+ vizs = [None] * 2
+
+ loss_sill = 0
+ loss_variance = 0
+ loss_con = 0
+ loss_sep = 0
+ loss_lr = 0
+
+ for i in range(2):
+ with tf.variable_scope("KeypointNetwork", reuse=i > 0):
+ # anneal: 1 = using ground-truth, 0 = using our estimate orientation.
+ anneal = tf.to_float(hparams.lr_anneal_end - tf.train.get_global_step())
+ anneal = tf.clip_by_value(
+ anneal / (hparams.lr_anneal_end - hparams.lr_anneal_start), 0.0, 1.0)
+
+ uv, z, orient, sill, variance, viz[i], vizs[i] = keypoint_network(
+ features["img%d" % i],
+ hparams.num_filters,
+ hparams.num_kp,
+ is_training,
+ lr_gt=features["lr%d" % i],
+ anneal=anneal)
+
+ # x-positive/negative axes (dominant direction).
+ xp_axis = tf.tile(
+ tf.constant([[[1.0, 0, 0, 1], [-1.0, 0, 0, 1]]]),
+ [tf.shape(orient)[0], 1, 1])
+
+ # [batch, 2, 4] = [batch, 2, 4] x [batch, 4, 4]
+ xp = tf.matmul(xp_axis, mv[i])
+
+ # [batch, 2, 3]
+ xp = t.project(xp)
+
+ loss_lr += tf.losses.mean_squared_error(orient[:, :, :2], xp[:, :, :2])
+ loss_variance += variance
+ loss_sill += sill
+
+ uv = tf.reshape(uv, [-1, hparams.num_kp, 2])
+ z = tf.reshape(z, [-1, hparams.num_kp, 1])
+
+ # [batch, num_kp, 3]
+ uvz[i] = tf.concat([uv, z], axis=2)
+
+ world_coords = tf.matmul(t.unproject(uvz[i]), mvi[i])
+
+ # [batch, num_kp, 3]
+ uvz_proj[i] = t.project(tf.matmul(world_coords, mv[1 - i]))
+
+ pconf = tf.ones(
+ [tf.shape(uv)[0], tf.shape(uv)[1]], dtype=tf.float32) / hparams.num_kp
+
+ for i in range(2):
+ loss_con += consistency_loss(uvz_proj[i][:, :, :2], uvz[1 - i][:, :, :2],
+ pconf)
+ loss_sep += separation_loss(
+ t.unproject(uvz[i])[:, :, :3], hparams.sep_delta)
+
+ chordal, angular = relative_pose_loss(
+ t.unproject(uvz[0])[:, :, :3],
+ t.unproject(uvz[1])[:, :, :3], tf.matmul(mvi[0], mv[1]), pconf,
+ hparams.noise)
+
+ loss = (
+ hparams.loss_pose * angular +
+ hparams.loss_con * loss_con +
+ hparams.loss_sep * loss_sep +
+ hparams.loss_sill * loss_sill +
+ hparams.loss_lr * loss_lr +
+ hparams.loss_variance * loss_variance
+ )
+
+ def touint8(img):
+ return tf.cast(img * 255.0, tf.uint8)
+
+ with tf.variable_scope("output"):
+ tf.summary.image("0_img0", touint8(features["img0"][:, :, :, :3]))
+ tf.summary.image("1_combined", viz[0])
+ for i in range(hparams.num_kp):
+ tf.summary.image("2_f%02d" % i, vizs[0][i])
+
+ with tf.variable_scope("stats"):
+ tf.summary.scalar("anneal", anneal)
+ tf.summary.scalar("closs", loss_con)
+ tf.summary.scalar("seploss", loss_sep)
+ tf.summary.scalar("angular", angular)
+ tf.summary.scalar("chordal", chordal)
+ tf.summary.scalar("lrloss", loss_lr)
+ tf.summary.scalar("sill", loss_sill)
+ tf.summary.scalar("vloss", loss_variance)
+
+ return {
+ "loss": loss,
+ "predictions": {
+ "img0": features["img0"],
+ "img1": features["img1"],
+ "uvz0": uvz[0],
+ "uvz1": uvz[1]
+ },
+ "eval_metric_ops": {
+ "closs": tf.metrics.mean(loss_con),
+ "angular_loss": tf.metrics.mean(angular),
+ "chordal_loss": tf.metrics.mean(chordal),
+ }
+ }
+
+
+def predict(input_folder, hparams):
+ """Predicts keypoints on all images in input_folder."""
+
+ cols = plt.cm.get_cmap("rainbow")(
+ np.linspace(0, 1.0, hparams.num_kp))[:, :4]
+
+ img = tf.placeholder(tf.float32, shape=(1, 128, 128, 4))
+
+ with tf.variable_scope("KeypointNetwork"):
+ ret = keypoint_network(
+ img, hparams.num_filters, hparams.num_kp, False)
+
+ uv = tf.reshape(ret[0], [-1, hparams.num_kp, 2])
+ z = tf.reshape(ret[1], [-1, hparams.num_kp, 1])
+ uvz = tf.concat([uv, z], axis=2)
+
+ sess = tf.Session()
+ saver = tf.train.Saver()
+ ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
+
+ print("loading model: ", ckpt.model_checkpoint_path)
+ saver.restore(sess, ckpt.model_checkpoint_path)
+
+ files = [x for x in os.listdir(input_folder)
+ if x[-3:] in ["jpg", "png"]]
+
+ output_folder = os.path.join(input_folder, "output")
+ if not os.path.exists(output_folder):
+ os.mkdir(output_folder)
+
+ for f in files:
+ orig = misc.imread(os.path.join(input_folder, f)).astype(float) / 255
+ if orig.shape[2] == 3:
+ orig = np.concatenate((orig, np.ones_like(orig[:, :, :1])), axis=2)
+
+ uv_ret = sess.run(uvz, feed_dict={img: np.expand_dims(orig, 0)})
+
+ utils.draw_ndc_points(orig, uv_ret.reshape(hparams.num_kp, 3), cols)
+ misc.imsave(os.path.join(output_folder, f), orig)
+
+
+def _default_hparams():
+ """Returns default or overridden user-specified hyperparameters."""
+
+ hparams = tf.contrib.training.HParams(
+ num_filters=64, # Number of filters.
+ num_kp=10, # Numer of keypoints.
+
+ loss_pose=0.2, # Pose Loss.
+ loss_con=1.0, # Multiview consistency Loss.
+ loss_sep=1.0, # Seperation Loss.
+ loss_sill=1.0, # Sillhouette Loss.
+ loss_lr=1.0, # Orientation Loss.
+ loss_variance=0.5, # Variance Loss (part of Sillhouette loss).
+
+ sep_delta=0.05, # Seperation threshold.
+ noise=0.1, # Noise added during estimating rotation.
+
+ learning_rate=1.0e-3,
+ lr_anneal_start=30000, # When to anneal in the orientation prediction.
+ lr_anneal_end=60000, # When to use the prediction completely.
+ )
+ if FLAGS.hparams:
+ hparams = hparams.parse(FLAGS.hparams)
+ return hparams
+
+
+def main(argv):
+ del argv
+
+ hparams = _default_hparams()
+
+ if FLAGS.predict:
+ predict(FLAGS.input, hparams)
+ else:
+ utils.train_and_eval(
+ model_dir=FLAGS.model_dir,
+ model_fn=model_fn,
+ input_fn=create_input_fn,
+ hparams=hparams,
+ steps=FLAGS.steps,
+ batch_size=FLAGS.batch_size,
+ save_checkpoints_secs=600,
+ eval_throttle_secs=1800,
+ eval_steps=5,
+ sync_replicas=FLAGS.sync_replicas,
+ )
+
+
+if __name__ == "__main__":
+ sys.excepthook = utils.colored_hook(
+ os.path.dirname(os.path.realpath(__file__)))
+ tf.app.run()
diff --git a/models/research/keypointnet/tools/gen_tfrecords.py b/models/research/keypointnet/tools/gen_tfrecords.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f973b7fe5f16951dbfa01edd2a759b96b4f79db
--- /dev/null
+++ b/models/research/keypointnet/tools/gen_tfrecords.py
@@ -0,0 +1,99 @@
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""An example script to generate a tfrecord file from a folder containing the
+renderings.
+
+Example usage:
+ python gen_tfrecords.py --input=FOLDER --output=output.tfrecord
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import os
+from scipy import misc
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+tf.app.flags.DEFINE_string("input", "", "Input folder containing images")
+tf.app.flags.DEFINE_string("output", "", "Output tfrecord.")
+
+
+def get_matrix(lines):
+ return np.array([[float(y) for y in x.strip().split(" ")] for x in lines])
+
+
+def read_model_view_matrices(filename):
+ with open(filename, "r") as f:
+ lines = f.readlines()
+ return get_matrix(lines[:4]), get_matrix(lines[4:])
+
+
+def bytes_feature(values):
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
+
+
+def generate():
+ with tf.python_io.TFRecordWriter(FLAGS.output) as tfrecord_writer:
+ with tf.Graph().as_default():
+ im0 = tf.placeholder(dtype=tf.uint8)
+ im1 = tf.placeholder(dtype=tf.uint8)
+ encoded0 = tf.image.encode_png(im0)
+ encoded1 = tf.image.encode_png(im1)
+
+ with tf.Session() as sess:
+ count = 0
+ indir = FLAGS.input + "/"
+ while tf.gfile.Exists(indir + "%06d.txt" % count):
+ print("saving %06d" % count)
+ image0 = misc.imread(indir + "%06d.png" % (count * 2))
+ image1 = misc.imread(indir + "%06d.png" % (count * 2 + 1))
+
+ mat0, mat1 = read_model_view_matrices(indir + "%06d.txt" % count)
+
+ mati0 = np.linalg.inv(mat0).flatten()
+ mati1 = np.linalg.inv(mat1).flatten()
+ mat0 = mat0.flatten()
+ mat1 = mat1.flatten()
+
+ st0, st1 = sess.run([encoded0, encoded1],
+ feed_dict={im0: image0, im1: image1})
+
+ example = tf.train.Example(features=tf.train.Features(feature={
+ 'img0': bytes_feature(st0),
+ 'img1': bytes_feature(st1),
+ 'mv0': tf.train.Feature(
+ float_list=tf.train.FloatList(value=mat0)),
+ 'mvi0': tf.train.Feature(
+ float_list=tf.train.FloatList(value=mati0)),
+ 'mv1': tf.train.Feature(
+ float_list=tf.train.FloatList(value=mat1)),
+ 'mvi1': tf.train.Feature(
+ float_list=tf.train.FloatList(value=mati1)),
+ }))
+
+ tfrecord_writer.write(example.SerializeToString())
+ count += 1
+
+
+def main(argv):
+ del argv
+ generate()
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/keypointnet/tools/render.py b/models/research/keypointnet/tools/render.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a8872675d83cc414d6348dbc7a56e924541b8d7
--- /dev/null
+++ b/models/research/keypointnet/tools/render.py
@@ -0,0 +1,310 @@
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Script to render object views from ShapeNet obj models.
+
+Example usage:
+ blender -b --python render.py -- -m model.obj -o output/ -s 128 -n 120 -fov 5
+
+"""
+from __future__ import print_function
+
+import argparse
+import itertools
+import json
+from math import pi
+import os
+import random
+import sys
+from mathutils import Vector
+import math
+import mathutils
+import time
+import copy
+
+import bpy
+
+sys.path.append(os.path.dirname(__file__))
+
+BG_LUMINANCE = 0
+
+
+def look_at(obj_camera, point):
+ loc_camera = obj_camera.location
+ direction = point - loc_camera
+ # point the cameras '-Z' and use its 'Y' as up
+ rot_quat = direction.to_track_quat('-Z', 'Y')
+
+ obj_camera.rotation_euler = rot_quat.to_euler()
+
+
+def roll_camera(obj_camera):
+ roll_rotate = mathutils.Euler(
+ (0, 0, random.random() * math.pi - math.pi * 0.5), 'XYZ')
+ obj_camera.rotation_euler = (obj_camera.rotation_euler.to_matrix() *
+ roll_rotate.to_matrix()).to_euler()
+
+
+def norm(x):
+ return math.sqrt(x[0] * x[0] + x[1] * x[1] + x[2] * x[2])
+
+
+def normalize(x):
+ n = norm(x)
+ x[0] /= n
+ x[1] /= n
+ x[2] /= n
+
+
+def random_top_sphere():
+ xyz = [random.normalvariate(0, 1) for x in range(3)]
+ normalize(xyz)
+
+ if xyz[2] < 0:
+ xyz[2] *= -1
+ return xyz
+
+
+def perturb_sphere(loc, size):
+ while True:
+ xyz = [random.normalvariate(0, 1) for x in range(3)]
+ normalize(xyz)
+
+ nloc = [loc[i] + xyz[i] * random.random() * size for i in range(3)]
+ normalize(nloc)
+
+ if nloc[2] >= 0:
+ return nloc
+
+
+def perturb(loc, size):
+ while True:
+ nloc = [loc[i] + random.random() * size * 2 - size for i in range(3)]
+ if nloc[2] >= 0:
+ return nloc
+
+ bpy.ops.object.mode_set()
+
+
+def delete_all_objects():
+ bpy.ops.object.select_by_type(type="MESH")
+ bpy.ops.object.delete(use_global=False)
+
+
+def set_scene(render_size, fov, alpha=False):
+ """Set up default scene properties."""
+ delete_all_objects()
+
+ cam = bpy.data.cameras["Camera"]
+ cam.angle = fov * pi / 180
+
+ light = bpy.data.objects["Lamp"]
+ light.location = (0, 0, 1)
+ look_at(light, Vector((0.0, 0, 0)))
+ bpy.data.lamps['Lamp'].type = "HEMI"
+ bpy.data.lamps['Lamp'].energy = 1
+ bpy.data.lamps['Lamp'].use_specular = False
+ bpy.data.lamps['Lamp'].use_diffuse = True
+
+ bpy.context.scene.world.horizon_color = (
+ BG_LUMINANCE, BG_LUMINANCE, BG_LUMINANCE)
+
+ bpy.context.scene.render.resolution_x = render_size
+ bpy.context.scene.render.resolution_y = render_size
+ bpy.context.scene.render.resolution_percentage = 100
+
+ bpy.context.scene.render.use_antialiasing = True
+ bpy.context.scene.render.antialiasing_samples = '5'
+
+
+def get_modelview_matrix():
+ cam = bpy.data.objects["Camera"]
+ bpy.context.scene.update()
+
+ # when apply to object with CV coordinate i.e. to_blender * obj
+ # this gives object in blender coordinate
+ to_blender = mathutils.Matrix(
+ ((1., 0., 0., 0.),
+ (0., 0., -1., 0.),
+ (0., 1., 0., 0.),
+ (0., 0., 0., 1.)))
+ return cam.matrix_world.inverted() * to_blender
+
+
+def print_matrix(f, mat):
+ for i in range(4):
+ for j in range(4):
+ f.write("%lf " % mat[i][j])
+ f.write("\n")
+
+
+def mul(loc, v):
+ return [loc[i] * v for i in range(3)]
+
+
+def merge_all():
+ bpy.ops.object.select_by_type(type="MESH")
+ bpy.context.scene.objects.active = bpy.context.selected_objects[0]
+ bpy.ops.object.join()
+ obj = bpy.context.scene.objects.active
+ bpy.ops.object.origin_set(type="ORIGIN_CENTER_OF_MASS")
+ return obj
+
+
+def insert_frame(obj, frame_number):
+ obj.keyframe_insert(data_path="location", frame=frame_number)
+ obj.keyframe_insert(data_path="rotation_euler", frame=frame_number)
+ obj.keyframe_insert(data_path="scale", frame=frame_number)
+
+
+def render(output_prefix):
+ bpy.context.scene.render.filepath = output_prefix
+ bpy.context.scene.render.image_settings.file_format = "PNG"
+ bpy.context.scene.render.alpha_mode = "TRANSPARENT"
+ bpy.context.scene.render.image_settings.color_mode = "RGBA"
+ bpy.ops.render.render(write_still=True, animation=True)
+
+
+def render_obj(
+ obj_fn, save_dir, n, perturb_size, rotate=False, roll=False, scale=1.0):
+
+ # Load object.
+ bpy.ops.import_scene.obj(filepath=obj_fn)
+ cur_obj = merge_all()
+
+ scale = 2.0 / max(cur_obj.dimensions) * scale
+ cur_obj.scale = (scale, scale, scale)
+ # Using the center of mass as the origin doesn't really work, because Blender
+ # assumes the object is a solid shell. This seems to generate better-looking
+ # rotations.
+
+ bpy.ops.object.origin_set(type='ORIGIN_GEOMETRY', center='BOUNDS')
+
+ # bpy.ops.mesh.primitive_cube_add(location=(0, 0, 1))
+ # cube = bpy.data.objects["Cube"]
+ # cube.scale = (0.2, 0.2, 0.2)
+
+ for polygon in cur_obj.data.polygons:
+ polygon.use_smooth = True
+
+ bpy.ops.object.select_all(action="DESELECT")
+
+ camera = bpy.data.objects["Camera"]
+
+ # os.system("mkdir " + save_dir)
+ for i in range(n):
+ fo = open(save_dir + "/%06d.txt" % i, "w")
+ d = 30
+ shift = 0.2
+ if rotate:
+ t = 1.0 * i / (n-1) * 2 * math.pi
+ loc = [math.sin(t), math.cos(t), 1]
+
+ normalize(loc)
+ camera.location = mul(loc, d)
+ look_at(camera, Vector((0.0, 0, 0)))
+
+ print_matrix(fo, get_modelview_matrix())
+ print_matrix(fo, get_modelview_matrix())
+
+ insert_frame(camera, 2 * i)
+ insert_frame(camera, 2 * i + 1)
+
+ else:
+ loc = random_top_sphere()
+
+ camera.location = mul(loc, d)
+ look_at(camera, Vector((0.0, 0, 0)))
+
+ if roll:
+ roll_camera(camera)
+ camera.location = perturb(mul(loc, d), shift)
+
+ print_matrix(fo, get_modelview_matrix())
+ insert_frame(camera, 2 * i)
+
+ if perturb_size > 0:
+ loc = perturb_sphere(loc, perturb_size)
+ else:
+ loc = random_top_sphere()
+
+ camera.location = mul(loc, d)
+ look_at(camera, Vector((0.0, 0, 0)))
+ if roll:
+ roll_camera(camera)
+ camera.location = perturb(mul(loc, d), shift)
+
+ print_matrix(fo, get_modelview_matrix())
+ insert_frame(camera, 2 * i + 1)
+
+ fo.close()
+
+ # Create a bunch of views of the object
+ bpy.context.scene.frame_start = 0
+ bpy.context.scene.frame_end = 2 * n - 1
+
+ stem = os.path.join(save_dir, '######')
+ render(stem)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-m', '--model', dest='model',
+ required=True,
+ help='Path to model obj file.')
+ parser.add_argument('-o', '--output_dir', dest='output_dir',
+ required=True,
+ help='Where to output files.')
+ parser.add_argument('-s', '--output_size', dest='output_size',
+ required=True,
+ help='Width and height of output in pixels, e.g. 32x32.')
+ parser.add_argument('-n', '--num_frames', dest='n', type=int,
+ required=True,
+ help='Number of frames to generate per clip.')
+
+ parser.add_argument('-scale', '--scale', dest='scale', type=float,
+ help='object scaling', default=1)
+
+ parser.add_argument('-perturb', '--perturb', dest='perturb', type=float,
+ help='sphere perturbation', default=0)
+
+ parser.add_argument('-rotate', '--rotate', dest='rotate', action='store_true',
+ help='render rotating test set')
+
+ parser.add_argument('-roll', '--roll', dest='roll', action='store_true',
+ help='add roll')
+
+ parser.add_argument(
+ '-fov', '--fov', dest='fov', type=float, required=True,
+ help='field of view')
+
+ if '--' not in sys.argv:
+ parser.print_help()
+ exit(1)
+
+ argv = sys.argv[sys.argv.index('--') + 1:]
+ args, _ = parser.parse_known_args(argv)
+
+ random.seed(args.model + str(time.time()) + str(os.getpid()))
+ # random.seed(0)
+
+ set_scene(int(args.output_size), args.fov)
+ render_obj(
+ args.model, args.output_dir, args.n, args.perturb, args.rotate,
+ args.roll, args.scale)
+ exit()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/models/research/keypointnet/utils.py b/models/research/keypointnet/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..148b7a3ed843638cff597be0c462b7e335df9857
--- /dev/null
+++ b/models/research/keypointnet/utils.py
@@ -0,0 +1,307 @@
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Utility functions for KeypointNet.
+
+These are helper / tensorflow related functions. The actual implementation and
+algorithm is in main.py.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+import os
+import re
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+import time
+import traceback
+
+
+class TrainingHook(tf.train.SessionRunHook):
+ """A utility for displaying training information such as the loss, percent
+ completed, estimated finish date and time."""
+
+ def __init__(self, steps):
+ self.steps = steps
+
+ self.last_time = time.time()
+ self.last_est = self.last_time
+
+ self.eta_interval = int(math.ceil(0.1 * self.steps))
+ self.current_interval = 0
+
+ def before_run(self, run_context):
+ graph = tf.get_default_graph()
+ return tf.train.SessionRunArgs(
+ {"loss": graph.get_collection("total_loss")[0]})
+
+ def after_run(self, run_context, run_values):
+ step = run_context.session.run(tf.train.get_global_step())
+ now = time.time()
+
+ if self.current_interval < self.eta_interval:
+ self.duration = now - self.last_est
+ self.current_interval += 1
+ if step % self.eta_interval == 0:
+ self.duration = now - self.last_est
+ self.last_est = now
+
+ eta_time = float(self.steps - step) / self.current_interval * \
+ self.duration
+ m, s = divmod(eta_time, 60)
+ h, m = divmod(m, 60)
+ eta = "%d:%02d:%02d" % (h, m, s)
+
+ print("%.2f%% (%d/%d): %.3e t %.3f @ %s (%s)" % (
+ step * 100.0 / self.steps,
+ step,
+ self.steps,
+ run_values.results["loss"],
+ now - self.last_time,
+ time.strftime("%a %d %H:%M:%S", time.localtime(time.time() + eta_time)),
+ eta))
+
+ self.last_time = now
+
+
+def standard_model_fn(
+ func, steps, run_config=None, sync_replicas=0, optimizer_fn=None):
+ """Creates model_fn for tf.Estimator.
+
+ Args:
+ func: A model_fn with prototype model_fn(features, labels, mode, hparams).
+ steps: Training steps.
+ run_config: tf.estimatorRunConfig (usually passed in from TF_CONFIG).
+ sync_replicas: The number of replicas used to compute gradient for
+ synchronous training.
+ optimizer_fn: The type of the optimizer. Default to Adam.
+
+ Returns:
+ model_fn for tf.estimator.Estimator.
+ """
+
+ def fn(features, labels, mode, params):
+ """Returns model_fn for tf.estimator.Estimator."""
+
+ is_training = (mode == tf.estimator.ModeKeys.TRAIN)
+ ret = func(features, labels, mode, params)
+
+ tf.add_to_collection("total_loss", ret["loss"])
+ train_op = None
+
+ training_hooks = []
+ if is_training:
+ training_hooks.append(TrainingHook(steps))
+
+ if optimizer_fn is None:
+ optimizer = tf.train.AdamOptimizer(params.learning_rate)
+ else:
+ optimizer = optimizer_fn
+
+ if run_config is not None and run_config.num_worker_replicas > 1:
+ sr = sync_replicas
+ if sr <= 0:
+ sr = run_config.num_worker_replicas
+
+ optimizer = tf.train.SyncReplicasOptimizer(
+ optimizer,
+ replicas_to_aggregate=sr,
+ total_num_replicas=run_config.num_worker_replicas)
+
+ training_hooks.append(
+ optimizer.make_session_run_hook(
+ run_config.is_chief, num_tokens=run_config.num_worker_replicas))
+
+ optimizer = tf.contrib.estimator.clip_gradients_by_norm(optimizer, 5)
+ train_op = slim.learning.create_train_op(ret["loss"], optimizer)
+
+ if "eval_metric_ops" not in ret:
+ ret["eval_metric_ops"] = {}
+
+ return tf.estimator.EstimatorSpec(
+ mode=mode,
+ predictions=ret["predictions"],
+ loss=ret["loss"],
+ train_op=train_op,
+ eval_metric_ops=ret["eval_metric_ops"],
+ training_hooks=training_hooks)
+ return fn
+
+
+def train_and_eval(
+ model_dir,
+ steps,
+ batch_size,
+ model_fn,
+ input_fn,
+ hparams,
+ keep_checkpoint_every_n_hours=0.5,
+ save_checkpoints_secs=180,
+ save_summary_steps=50,
+ eval_steps=20,
+ eval_start_delay_secs=10,
+ eval_throttle_secs=300,
+ sync_replicas=0):
+ """Trains and evaluates our model. Supports local and distributed training.
+
+ Args:
+ model_dir: The output directory for trained parameters, checkpoints, etc.
+ steps: Training steps.
+ batch_size: Batch size.
+ model_fn: A func with prototype model_fn(features, labels, mode, hparams).
+ input_fn: A input function for the tf.estimator.Estimator.
+ hparams: tf.HParams containing a set of hyperparameters.
+ keep_checkpoint_every_n_hours: Number of hours between each checkpoint
+ to be saved.
+ save_checkpoints_secs: Save checkpoints every this many seconds.
+ save_summary_steps: Save summaries every this many steps.
+ eval_steps: Number of steps to evaluate model.
+ eval_start_delay_secs: Start evaluating after waiting for this many seconds.
+ eval_throttle_secs: Do not re-evaluate unless the last evaluation was
+ started at least this many seconds ago
+ sync_replicas: Number of synchronous replicas for distributed training.
+
+ Returns:
+ None
+ """
+
+ run_config = tf.estimator.RunConfig(
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+ save_checkpoints_secs=save_checkpoints_secs,
+ save_summary_steps=save_summary_steps)
+
+ estimator = tf.estimator.Estimator(
+ model_dir=model_dir,
+ model_fn=standard_model_fn(
+ model_fn,
+ steps,
+ run_config,
+ sync_replicas=sync_replicas),
+ params=hparams, config=run_config)
+
+ train_spec = tf.estimator.TrainSpec(
+ input_fn=input_fn(split="train", batch_size=batch_size),
+ max_steps=steps)
+
+ eval_spec = tf.estimator.EvalSpec(
+ input_fn=input_fn(split="validation", batch_size=batch_size),
+ steps=eval_steps,
+ start_delay_secs=eval_start_delay_secs,
+ throttle_secs=eval_throttle_secs)
+
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+
+
+def draw_circle(rgb, u, v, col, r):
+ """Draws a simple anti-aliasing circle in-place.
+
+ Args:
+ rgb: Input image to be modified.
+ u: Horizontal coordinate.
+ v: Vertical coordinate.
+ col: Color.
+ r: Radius.
+ """
+
+ ir = int(math.ceil(r))
+ for i in range(-ir-1, ir+2):
+ for j in range(-ir-1, ir+2):
+ nu = int(round(u + i))
+ nv = int(round(v + j))
+ if nu < 0 or nu >= rgb.shape[1] or nv < 0 or nv >= rgb.shape[0]:
+ continue
+
+ du = abs(nu - u)
+ dv = abs(nv - v)
+
+ # need sqrt to keep scale
+ t = math.sqrt(du * du + dv * dv) - math.sqrt(r * r)
+ if t < 0:
+ rgb[nv, nu, :] = col
+ else:
+ t = 1 - t
+ if t > 0:
+ # t = t ** 0.3
+ rgb[nv, nu, :] = col * t + rgb[nv, nu, :] * (1-t)
+
+
+def draw_ndc_points(rgb, xy, cols):
+ """Draws keypoints onto an input image.
+
+ Args:
+ rgb: Input image to be modified.
+ xy: [n x 2] matrix of 2D locations.
+ cols: A list of colors for the keypoints.
+ """
+
+ vh, vw = rgb.shape[0], rgb.shape[1]
+
+ for j in range(len(cols)):
+ x, y = xy[j, :2]
+ x = (min(max(x, -1), 1) * vw / 2 + vw / 2) - 0.5
+ y = vh - 0.5 - (min(max(y, -1), 1) * vh / 2 + vh / 2)
+
+ x = int(round(x))
+ y = int(round(y))
+ if x < 0 or y < 0 or x >= vw or y >= vh:
+ continue
+
+ rad = 1.5
+ rad *= rgb.shape[0] / 128.0
+ draw_circle(rgb, x, y, np.array([0.0, 0.0, 0.0, 1.0]), rad * 1.5)
+ draw_circle(rgb, x, y, cols[j], rad)
+
+
+def colored_hook(home_dir):
+ """Colorizes python's error message.
+
+ Args:
+ home_dir: directory where code resides (to highlight your own files).
+ Returns:
+ The traceback hook.
+ """
+
+ def hook(type_, value, tb):
+ def colorize(text, color, own=0):
+ """Returns colorized text."""
+ endcolor = "\x1b[0m"
+ codes = {
+ "green": "\x1b[0;32m",
+ "green_own": "\x1b[1;32;40m",
+ "red": "\x1b[0;31m",
+ "red_own": "\x1b[1;31m",
+ "yellow": "\x1b[0;33m",
+ "yellow_own": "\x1b[1;33m",
+ "black": "\x1b[0;90m",
+ "black_own": "\x1b[1;90m",
+ "cyan": "\033[1;36m",
+ }
+ return codes[color + ("_own" if own else "")] + text + endcolor
+
+ for filename, line_num, func, text in traceback.extract_tb(tb):
+ basename = os.path.basename(filename)
+ own = (home_dir in filename) or ("/" not in filename)
+
+ print(colorize("\"" + basename + '"', "green", own) + " in " + func)
+ print("%s: %s" % (
+ colorize("%5d" % line_num, "red", own),
+ colorize(text, "yellow", own)))
+ print(" %s" % colorize(filename, "black", own))
+
+ print(colorize("%s: %s" % (type_.__name__, value), "cyan"))
+ return hook
diff --git a/models/research/learned_optimizer/.gitignore b/models/research/learned_optimizer/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/learned_optimizer/BUILD b/models/research/learned_optimizer/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..629c9a06b51d10eb7cab69ed0d9dd0bfa52fd2f0
--- /dev/null
+++ b/models/research/learned_optimizer/BUILD
@@ -0,0 +1,33 @@
+# Learning to Optimize Learning (LOL)
+
+package(default_visibility = ["//visibility:public"])
+
+# Libraries
+# =========
+
+py_library(
+ name = "metaopt",
+ srcs = ["metaopt.py"],
+ deps = [
+ "//learned_optimizer/problems:datasets",
+ "//learned_optimizer/problems:problem_generator",
+ ],
+)
+
+# Binaries
+# ========
+py_binary(
+ name = "metarun",
+ srcs = ["metarun.py"],
+ deps = [
+ ":metaopt",
+ "//learned_optimizer/optimizer:coordinatewise_rnn",
+ "//learned_optimizer/optimizer:global_learning_rate",
+ "//learned_optimizer/optimizer:hierarchical_rnn",
+ "//learned_optimizer/optimizer:learning_rate_schedule",
+ "//learned_optimizer/optimizer:trainable_adam",
+ "//learned_optimizer/problems:problem_sets",
+ "//learned_optimizer/problems:problem_spec",
+ ],
+)
+
diff --git a/models/research/learned_optimizer/README.md b/models/research/learned_optimizer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6a32514f053f97bc64dc87c4ec972c8223a83fe2
--- /dev/null
+++ b/models/research/learned_optimizer/README.md
@@ -0,0 +1,47 @@
+
+
+
+
+# Learned Optimizer
+
+Code for [Learned Optimizers that Scale and Generalize](https://arxiv.org/abs/1703.04813).
+
+## Requirements
+
+* Bazel ([install](https://bazel.build/versions/master/docs/install.html))
+* TensorFlow >= v1.3
+* Python 2.7.x
+
+## Training a Learned Optimizer
+
+## Code Overview
+In the top-level directory, ```metaopt.py``` contains the code to train and test a learned optimizer. ```metarun.py``` packages the actual training procedure into a
+single file, defining and exposing many flags to tune the procedure, from selecting the optimizer type and problem set to more fine-grained hyperparameter settings.
+There is no testing binary; testing can be done ad-hoc via ```metaopt.test_optimizer``` by passing an optimizer object and a directory with a checkpoint.
+
+The ```optimizer``` directory contains a base ```trainable_optimizer.py``` class and a number of extensions, including the ```hierarchical_rnn``` optimizer used in
+the paper, a ```coordinatewise_rnn``` optimizer that more closely matches previous work, and a number of simpler optimizers to demonstrate the basic mechanics of
+a learnable optimizer.
+
+The ```problems``` directory contains the code to build the problems that were used in the meta-training set.
+
+### Binaries
+```metarun.py```: meta-training of a learned optimizer
+
+### Command-Line Flags
+The flags most relevant to meta-training are defined in ```metarun.py```. The default values will meta-train a HierarchicalRNN optimizer with the hyperparameter
+settings used in the paper.
+
+### Using a Learned Optimizer as a Black Box
+The ```trainable_optimizer``` inherits from ```tf.train.Optimizer```, so a properly instantiated version can be used to train any model in any APIs that accept
+this class. There are just 2 caveats:
+
+1. If using the Hierarchical RNN optimizer, the apply_gradients return type must be changed (see comments inline for what exactly must be removed)
+
+2. Care must be taken to restore the variables from the optimizer without overriding them. Optimizer variables should be loaded manually using a pretrained checkpoint
+and a ```tf.train.Saver``` with only the optimizer variables. Then, when constructing the session, ensure that any automatic variable initialization does not
+re-initialize the loaded optimizer variables.
+
+## Contact for Issues
+
+* Olga Wichrowska (@olganw), Niru Maheswaranathan (@nirum)
diff --git a/models/research/learned_optimizer/metaopt.py b/models/research/learned_optimizer/metaopt.py
new file mode 100644
index 0000000000000000000000000000000000000000..62c06272d3096ed63296744792c8742826380536
--- /dev/null
+++ b/models/research/learned_optimizer/metaopt.py
@@ -0,0 +1,639 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Helper utilities for training and testing optimizers."""
+
+from collections import defaultdict
+import random
+import sys
+import time
+
+import numpy as np
+from six.moves import xrange
+import tensorflow as tf
+
+from learned_optimizer.optimizer import trainable_optimizer
+from learned_optimizer.optimizer import utils
+from learned_optimizer.problems import datasets
+from learned_optimizer.problems import problem_generator
+
+tf.app.flags.DEFINE_integer("ps_tasks", 0,
+ """Number of tasks in the ps job.
+ If 0 no ps job is used.""")
+tf.app.flags.DEFINE_float("nan_l2_reg", 1e-2,
+ """Strength of l2-reg when NaNs are encountered.""")
+tf.app.flags.DEFINE_float("l2_reg", 0.,
+ """Lambda value for parameter regularization.""")
+# Default is 0.9
+tf.app.flags.DEFINE_float("rms_decay", 0.9,
+ """Decay value for the RMSProp metaoptimizer.""")
+# Default is 1e-10
+tf.app.flags.DEFINE_float("rms_epsilon", 1e-20,
+ """Epsilon value for the RMSProp metaoptimizer.""")
+tf.app.flags.DEFINE_boolean("set_profiling", False,
+ """Enable memory usage and computation time """
+ """tracing for tensorflow nodes (available in """
+ """TensorBoard).""")
+tf.app.flags.DEFINE_boolean("reset_rnn_params", True,
+ """Reset the parameters of the optimizer
+ from one meta-iteration to the next.""")
+
+FLAGS = tf.app.flags.FLAGS
+OPTIMIZER_SCOPE = "LOL"
+OPT_SUM_COLLECTION = "LOL_summaries"
+
+
+def sigmoid_weights(n, slope=0.1, offset=5):
+ """Generates a sigmoid, scaled to sum to 1.
+
+ This function is used to generate weights that serve to mask out
+ the early objective values of an optimization problem such that
+ initial variation in the objective is phased out (hence the sigmoid
+ starts at zero and ramps up to the maximum value, and the total
+ weight is normalized to sum to one)
+
+ Args:
+ n: the number of samples
+ slope: slope of the sigmoid (Default: 0.1)
+ offset: threshold of the sigmoid (Default: 5)
+
+ Returns:
+ No
+ """
+ x = np.arange(n)
+ y = 1. / (1. + np.exp(-slope * (x-offset)))
+ y_normalized = y / np.sum(y)
+ return y_normalized
+
+
+def sample_numiter(scale, min_steps=50):
+ """Samples a number of iterations from an exponential distribution.
+
+ Args:
+ scale: parameter for the exponential distribution
+ min_steps: minimum number of steps to run (additive)
+
+ Returns:
+ num_steps: An integer equal to a rounded sample from the exponential
+ distribution + the value of min_steps.
+ """
+ return int(np.round(np.random.exponential(scale=scale)) + min_steps)
+
+
+def train_optimizer(logdir,
+ optimizer_spec,
+ problems_and_data,
+ num_problems,
+ num_meta_iterations,
+ num_unroll_func,
+ num_partial_unroll_itrs_func,
+ learning_rate=1e-4,
+ gradient_clip=5.,
+ is_chief=False,
+ select_random_problems=True,
+ callbacks=None,
+ obj_train_max_multiplier=-1,
+ out=sys.stdout):
+ """Trains the meta-parameters of this optimizer.
+
+ Args:
+ logdir: a directory filepath for storing model checkpoints (must exist)
+ optimizer_spec: specification for an Optimizer (see utils.Spec)
+ problems_and_data: a list of tuples containing three elements: a problem
+ specification (see utils.Spec), a dataset (see datasets.Dataset), and
+ a batch_size (int) for generating a problem and corresponding dataset. If
+ the problem doesn't have data, set dataset to None.
+ num_problems: the number of problems to sample during meta-training
+ num_meta_iterations: the number of iterations (steps) to run the
+ meta-optimizer for on each subproblem.
+ num_unroll_func: called once per meta iteration and returns the number of
+ unrolls to do for that meta iteration.
+ num_partial_unroll_itrs_func: called once per unroll and returns the number
+ of iterations to do for that unroll.
+ learning_rate: learning rate of the RMSProp meta-optimizer (Default: 1e-4)
+ gradient_clip: value to clip gradients at (Default: 5.0)
+ is_chief: whether this is the chief task (Default: False)
+ select_random_problems: whether to select training problems randomly
+ (Default: True)
+ callbacks: a list of callback functions that is run after every random
+ problem draw
+ obj_train_max_multiplier: the maximum increase in the objective value over
+ a single training run. Ignored if < 0.
+ out: where to write output to, e.g. a file handle (Default: sys.stdout)
+
+ Raises:
+ ValueError: If one of the subproblems has a negative objective value.
+ """
+
+ if select_random_problems:
+ # iterate over random draws of problem / dataset pairs
+ sampler = (random.choice(problems_and_data) for _ in range(num_problems))
+ else:
+ # iterate over a random shuffle of problems, looping if necessary
+ num_repeats = (num_problems / len(problems_and_data)) + 1
+ random.shuffle(problems_and_data)
+ sampler = (problems_and_data * num_repeats)[:num_problems]
+
+ for problem_itr, (problem_spec, dataset, batch_size) in enumerate(sampler):
+
+ # timer used to time how long it takes to initialize a problem
+ problem_start_time = time.time()
+
+ # if dataset is None, use the EMPTY_DATASET
+ if dataset is None:
+ dataset = datasets.EMPTY_DATASET
+ batch_size = dataset.size
+
+ # build a new graph for this problem
+ graph = tf.Graph()
+ real_device_setter = tf.train.replica_device_setter(FLAGS.ps_tasks)
+
+ def custom_device_setter(op):
+ # Places the local variables onto the workers.
+ if trainable_optimizer.is_local_state_variable(op):
+ return "/job:worker"
+ else:
+ return real_device_setter(op)
+
+ if real_device_setter:
+ device_setter = custom_device_setter
+ else:
+ device_setter = None
+
+ with graph.as_default(), graph.device(device_setter):
+
+ # initialize a problem
+ problem = problem_spec.build()
+
+ # build the optimizer
+ opt = optimizer_spec.build()
+
+ # get the meta-objective for training the optimizer
+ train_output = opt.train(problem, dataset)
+
+ state_keys = opt.state_keys
+ for key, val in zip(state_keys, train_output.output_state[0]):
+ finite_val = utils.make_finite(val, replacement=tf.zeros_like(val))
+ tf.summary.histogram("State/{}".format(key), finite_val,
+ collections=[OPT_SUM_COLLECTION])
+
+ tf.summary.scalar("MetaObjective", train_output.metaobj,
+ collections=[OPT_SUM_COLLECTION])
+
+ # Per-problem meta-objective
+ tf.summary.scalar(problem_spec.callable.__name__ + "_MetaObjective",
+ train_output.metaobj,
+ collections=[OPT_SUM_COLLECTION])
+
+ # create the meta-train_op
+ global_step = tf.Variable(0, name="global_step", trainable=False)
+ meta_parameters = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
+ scope=OPTIMIZER_SCOPE)
+ # parameter regularization
+ reg_l2 = FLAGS.l2_reg * sum([tf.reduce_sum(param ** 2)
+ for param in meta_parameters])
+
+ # compute the meta-gradients
+ meta_opt = tf.train.RMSPropOptimizer(learning_rate, decay=FLAGS.rms_decay,
+ use_locking=True,
+ epsilon=FLAGS.rms_epsilon)
+ grads_and_vars = meta_opt.compute_gradients(train_output.metaobj + reg_l2,
+ meta_parameters)
+
+ # clip the gradients
+ clipped_grads_and_vars = []
+ for grad, var in grads_and_vars:
+ clipped_grad = tf.clip_by_value(
+ utils.make_finite(grad, replacement=tf.zeros_like(var)),
+ -gradient_clip, gradient_clip)
+ clipped_grads_and_vars.append((clipped_grad, var))
+
+ # histogram summary of grads and vars
+ for grad, var in grads_and_vars:
+ tf.summary.histogram(
+ var.name + "_rawgrad",
+ utils.make_finite(
+ grad, replacement=tf.zeros_like(grad)),
+ collections=[OPT_SUM_COLLECTION])
+ for grad, var in clipped_grads_and_vars:
+ tf.summary.histogram(var.name + "_var", var,
+ collections=[OPT_SUM_COLLECTION])
+ tf.summary.histogram(var.name + "_grad", grad,
+ collections=[OPT_SUM_COLLECTION])
+
+ # builds the train and summary operations
+ train_op = meta_opt.apply_gradients(clipped_grads_and_vars,
+ global_step=global_step)
+
+ # only grab summaries defined for LOL, not inside the problem
+ summary_op = tf.summary.merge_all(key=OPT_SUM_COLLECTION)
+
+ # make sure the state gets propagated after the gradients and summaries
+ # were computed.
+ with tf.control_dependencies([train_op, summary_op]):
+ propagate_loop_state_ops = []
+ for dest, src in zip(
+ train_output.init_loop_vars, train_output.output_loop_vars):
+ propagate_loop_state_ops.append(dest.assign(src))
+ propagate_loop_state_op = tf.group(*propagate_loop_state_ops)
+
+ # create the supervisor
+ sv = tf.train.Supervisor(
+ graph=graph,
+ is_chief=is_chief,
+ logdir=logdir,
+ summary_op=None,
+ save_model_secs=0, # we save checkpoints manually
+ global_step=global_step,
+ )
+
+ with sv.managed_session() as sess:
+
+ init_time = time.time() - problem_start_time
+ out.write("--------- Problem #{} ---------\n".format(problem_itr))
+ out.write("{callable.__name__}{args}{kwargs}\n".format(
+ **problem_spec.__dict__))
+ out.write("Took {} seconds to initialize.\n".format(init_time))
+ out.flush()
+
+ # For profiling summaries
+ if FLAGS.set_profiling:
+ summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)
+
+ # used to store information during training
+ metadata = defaultdict(list)
+
+ for k in range(num_meta_iterations):
+
+ if sv.should_stop():
+ break
+
+ problem.init_fn(sess)
+
+ # set run options (for profiling)
+ full_trace_opt = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
+ run_options = full_trace_opt if FLAGS.set_profiling else None
+ run_metadata = tf.RunMetadata() if FLAGS.set_profiling else None
+
+ num_unrolls = num_unroll_func()
+ partial_unroll_iters = [
+ num_partial_unroll_itrs_func() for _ in xrange(num_unrolls)
+ ]
+ total_num_iter = sum(partial_unroll_iters)
+
+ objective_weights = [np.ones(num) / float(num)
+ for num in partial_unroll_iters]
+ db = dataset.batch_indices(total_num_iter, batch_size)
+ dataset_batches = []
+ last_index = 0
+ for num in partial_unroll_iters:
+ dataset_batches.append(db[last_index:last_index + num])
+ last_index += num
+
+ train_start_time = time.time()
+
+ unroll_itr = 0
+ additional_log_info = ""
+
+ for unroll_itr in range(num_unrolls):
+ first_unroll = unroll_itr == 0
+ if FLAGS.reset_rnn_params:
+ reset_state = first_unroll and k == 0
+ else:
+ reset_state = first_unroll
+
+ feed = {
+ train_output.obj_weights: objective_weights[unroll_itr],
+ train_output.batches: dataset_batches[unroll_itr],
+ train_output.first_unroll: first_unroll,
+ train_output.reset_state: reset_state,
+ }
+
+ # run the train and summary ops
+ # when a "save_diagnostics" flag is turned on
+ fetches_list = [
+ train_output.metaobj,
+ train_output.problem_objectives,
+ train_output.initial_obj,
+ summary_op,
+ clipped_grads_and_vars,
+ train_op
+ ]
+ if unroll_itr + 1 < num_unrolls:
+ fetches_list += [propagate_loop_state_op]
+
+ fetched = sess.run(fetches_list, feed_dict=feed,
+ options=run_options, run_metadata=run_metadata)
+ meta_obj = fetched[0]
+ sub_obj = fetched[1]
+ init_obj = fetched[2]
+ summ = fetched[3]
+ meta_grads_and_params = fetched[4]
+
+ # assert that the subproblem objectives are non-negative
+ # (this is so that we can rescale the objective by the initial value
+ # and not worry about rescaling by a negative value)
+ if np.any(sub_obj < 0):
+ raise ValueError(
+ "Training problem objectives must be nonnegative.")
+ # If the objective has increased more than we want, exit this
+ # training run and start over on another meta iteration.
+ if obj_train_max_multiplier > 0 and (
+ sub_obj[-1] > (init_obj +
+ abs(init_obj) * (obj_train_max_multiplier - 1))):
+ msg = " Broke early at {} out of {} unrolls. ".format(
+ unroll_itr + 1, num_unrolls)
+ additional_log_info += msg
+ break
+
+ # only the chief task is allowed to write the summary
+ if is_chief:
+ sv.summary_computed(sess, summ)
+
+ metadata["subproblem_objs"].append(sub_obj)
+ # store training metadata to pass to the callback
+ metadata["meta_objs"].append(meta_obj)
+ metadata["meta_grads_and_params"].append(meta_grads_and_params)
+
+ optimization_time = time.time() - train_start_time
+
+ if FLAGS.set_profiling:
+ summary_name = "%02d_iter%04d_%02d" % (FLAGS.task, problem_itr, k)
+ summary_writer.add_run_metadata(run_metadata, summary_name)
+
+ metadata["global_step"].append(sess.run(global_step))
+ metadata["runtimes"].append(optimization_time)
+
+ # write a diagnostic message to the output
+ args = (k, meta_obj, optimization_time,
+ sum(partial_unroll_iters[:unroll_itr+1]))
+ out.write(" [{:02}] {}, {} seconds, {} iters ".format(*args))
+ out.write("(unrolled {} steps)".format(
+ ", ".join([str(s) for s in partial_unroll_iters[:unroll_itr+1]])))
+ out.write("{}\n".format(additional_log_info))
+ out.flush()
+
+ if FLAGS.set_profiling:
+ summary_writer.close()
+
+ # force a checkpoint save before we load a new problem
+ # only the chief task has the save_path and can write the checkpoint
+ if is_chief:
+ sv.saver.save(sess, sv.save_path, global_step=global_step)
+
+ # run the callbacks on the chief
+ if is_chief and callbacks is not None:
+ for callback in callbacks:
+ if hasattr(callback, "__call__"):
+ problem_name = problem_spec.callable.__name__
+ callback(problem_name, problem_itr, logdir, metadata)
+
+
+def test_optimizer(optimizer,
+ problem,
+ num_iter,
+ dataset=datasets.EMPTY_DATASET,
+ batch_size=None,
+ seed=None,
+ graph=None,
+ logdir=None,
+ record_every=None):
+ """Tests an optimization algorithm on a given problem.
+
+ Args:
+ optimizer: Either a tf.train.Optimizer instance, or an Optimizer instance
+ inheriting from trainable_optimizer.py
+ problem: A Problem instance that defines an optimization problem to solve
+ num_iter: The number of iterations of the optimizer to run
+ dataset: The dataset to train the problem against
+ batch_size: The number of samples per batch. If None (default), the
+ batch size is set to the full batch (dataset.size)
+ seed: A random seed used for drawing the initial parameters, or a list of
+ numpy arrays used to explicitly initialize the parameters.
+ graph: The tensorflow graph to execute (if None, uses the default graph)
+ logdir: A directory containing model checkpoints. If given, then the
+ parameters of the optimizer are loaded from the latest checkpoint
+ in this folder.
+ record_every: if an integer, stores the parameters, objective, and gradient
+ every recored_every iterations. If None, nothing is stored
+
+ Returns:
+ objective_values: A list of the objective values during optimization
+ parameters: The parameters obtained after training
+ records: A dictionary containing lists of the parameters and gradients
+ during optimization saved every record_every iterations (empty if
+ record_every is set to None)
+ """
+
+ if dataset is None:
+ dataset = datasets.EMPTY_DATASET
+ batch_size = dataset.size
+ else:
+ # default batch size is the entire dataset
+ batch_size = dataset.size if batch_size is None else batch_size
+
+ graph = tf.get_default_graph() if graph is None else graph
+ with graph.as_default():
+
+ # define the parameters of the optimization problem
+ if isinstance(seed, (list, tuple)):
+ # seed is a list of arrays
+ params = problem_generator.init_fixed_variables(seed)
+ else:
+ # seed is an int or None
+ params = problem.init_variables(seed)
+
+ data_placeholder = tf.placeholder(tf.float32)
+ labels_placeholder = tf.placeholder(tf.int32)
+
+ # get the problem objective and gradient(s)
+ obj = problem.objective(params, data_placeholder, labels_placeholder)
+ gradients = problem.gradients(obj, params)
+
+ vars_to_preinitialize = params
+
+ with tf.Session(graph=graph) as sess:
+ # initialize the parameter scope variables; necessary for apply_gradients
+ sess.run(tf.variables_initializer(vars_to_preinitialize))
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ # create the train operation and training variables
+ try:
+ train_op, real_params = optimizer.apply_gradients(zip(gradients, params))
+ obj = problem.objective(real_params, data_placeholder, labels_placeholder)
+ except TypeError:
+ # If all goes well, this exception should only be thrown when we are using
+ # a non-hrnn optimizer.
+ train_op = optimizer.apply_gradients(zip(gradients, params))
+
+ vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
+ scope=OPTIMIZER_SCOPE)
+ vars_to_initialize = list(
+ set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
+ set(vars_to_restore) - set(vars_to_preinitialize))
+ # load or initialize optimizer variables
+ if logdir is not None:
+ restorer = tf.Saver(var_list=vars_to_restore)
+ ckpt = tf.train.latest_checkpoint(logdir)
+ restorer.restore(sess, ckpt)
+ else:
+ sess.run(tf.variables_initializer(vars_to_restore))
+ # initialize all the other variables
+ sess.run(tf.variables_initializer(vars_to_initialize))
+
+ problem.init_fn(sess)
+
+ # generate the minibatch indices
+ batch_inds = dataset.batch_indices(num_iter, batch_size)
+
+ # run the train operation for n iterations and save the objectives
+ records = defaultdict(list)
+ objective_values = []
+ for itr, batch in enumerate(batch_inds):
+
+ # data to feed in
+ feed = {data_placeholder: dataset.data[batch],
+ labels_placeholder: dataset.labels[batch]}
+ full_feed = {data_placeholder: dataset.data,
+ labels_placeholder: dataset.labels}
+
+ # record stuff
+ if record_every is not None and (itr % record_every) == 0:
+ def grad_value(g):
+ if isinstance(g, tf.IndexedSlices):
+ return g.values
+ else:
+ return g
+
+ records_fetch = {}
+ for p in params:
+ for key in optimizer.get_slot_names():
+ v = optimizer.get_slot(p, key)
+ records_fetch[p.name + "_" + key] = v
+ gav_fetch = [(grad_value(g), v) for g, v in zip(gradients, params)]
+
+ _, gav_eval, records_eval = sess.run(
+ (obj, gav_fetch, records_fetch), feed_dict=feed)
+ full_obj_eval = sess.run([obj], feed_dict=full_feed)
+
+ records["objective"].append(full_obj_eval)
+ records["grad_norm"].append([np.linalg.norm(g.ravel())
+ for g, _ in gav_eval])
+ records["param_norm"].append([np.linalg.norm(v.ravel())
+ for _, v in gav_eval])
+ records["grad"].append([g for g, _ in gav_eval])
+ records["param"].append([v for _, v in gav_eval])
+ records["iter"].append(itr)
+
+ for k, v in records_eval.iteritems():
+ records[k].append(v)
+
+ # run the optimization train operation
+ objective_values.append(sess.run([train_op, obj], feed_dict=feed)[1])
+
+ # final parameters
+ parameters = [sess.run(p) for p in params]
+ coord.request_stop()
+ coord.join(threads)
+
+ return objective_values, parameters, records
+
+
+def run_wall_clock_test(optimizer,
+ problem,
+ num_steps,
+ dataset=datasets.EMPTY_DATASET,
+ seed=None,
+ logdir=None,
+ batch_size=None):
+ """Runs optimization with the given parameters and return average iter time.
+
+ Args:
+ optimizer: The tf.train.Optimizer instance
+ problem: The problem to optimize (a problem_generator.Problem)
+ num_steps: The number of steps to run optimization for
+ dataset: The dataset to train the problem against
+ seed: The seed used for drawing the initial parameters, or a list of
+ numpy arrays used to explicitly initialize the parameters
+ logdir: A directory containing model checkpoints. If given, then the
+ parameters of the optimizer are loaded from the latest checkpoint
+ in this folder.
+ batch_size: The number of samples per batch.
+
+ Returns:
+ The average time in seconds for a single optimization iteration.
+ """
+ if dataset is None:
+ dataset = datasets.EMPTY_DATASET
+ batch_size = dataset.size
+ else:
+ # default batch size is the entire dataset
+ batch_size = dataset.size if batch_size is None else batch_size
+
+ # define the parameters of the optimization problem
+ if isinstance(seed, (list, tuple)):
+ # seed is a list of arrays
+ params = problem_generator.init_fixed_variables(seed)
+ else:
+ # seed is an int or None
+ params = problem.init_variables(seed)
+
+ data_placeholder = tf.placeholder(tf.float32)
+ labels_placeholder = tf.placeholder(tf.int32)
+
+ obj = problem.objective(params, data_placeholder, labels_placeholder)
+ gradients = problem.gradients(obj, params)
+ vars_to_preinitialize = params
+
+ with tf.Session(graph=tf.get_default_graph()) as sess:
+ # initialize the parameter scope variables; necessary for apply_gradients
+ sess.run(tf.variables_initializer(vars_to_preinitialize))
+ train_op = optimizer.apply_gradients(zip(gradients, params))
+ if isinstance(train_op, tuple) or isinstance(train_op, list):
+ # LOL apply_gradients returns a tuple. Regular optimizers do not.
+ train_op = train_op[0]
+ vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
+ scope=OPTIMIZER_SCOPE)
+ vars_to_initialize = list(
+ set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
+ set(vars_to_restore) - set(vars_to_preinitialize))
+ # load or initialize optimizer variables
+ if logdir is not None:
+ restorer = tf.Saver(var_list=vars_to_restore)
+ ckpt = tf.train.latest_checkpoint(logdir)
+ restorer.restore(sess, ckpt)
+ else:
+ sess.run(tf.variables_initializer(vars_to_restore))
+ # initialize all the other variables
+ sess.run(tf.variables_initializer(vars_to_initialize))
+
+ problem.init_fn(sess)
+
+ # generate the minibatch indices
+ batch_inds = dataset.batch_indices(num_steps, batch_size)
+
+ avg_iter_time = []
+ for batch in batch_inds:
+ # data to feed in
+ feed = {data_placeholder: dataset.data[batch],
+ labels_placeholder: dataset.labels[batch]}
+
+ # run the optimization train operation
+ start = time.time()
+ sess.run([train_op], feed_dict=feed)
+ avg_iter_time.append(time.time() - start)
+
+ return np.median(np.array(avg_iter_time))
diff --git a/models/research/learned_optimizer/metarun.py b/models/research/learned_optimizer/metarun.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a29623c7fd1381cef590c4e8440d8749585b72
--- /dev/null
+++ b/models/research/learned_optimizer/metarun.py
@@ -0,0 +1,394 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Scripts for meta-optimization."""
+
+from __future__ import print_function
+
+import os
+
+import tensorflow as tf
+
+import metaopt
+from learned_optimizer.optimizer import coordinatewise_rnn
+from learned_optimizer.optimizer import global_learning_rate
+from learned_optimizer.optimizer import hierarchical_rnn
+from learned_optimizer.optimizer import learning_rate_schedule
+from learned_optimizer.optimizer import trainable_adam
+from learned_optimizer.problems import problem_sets as ps
+from learned_optimizer.problems import problem_spec
+
+tf.app.flags.DEFINE_string("train_dir", "/tmp/lol/",
+ """Directory to store parameters and results.""")
+
+tf.app.flags.DEFINE_integer("task", 0,
+ """Task id of the replica running the training.""")
+tf.app.flags.DEFINE_integer("worker_tasks", 1,
+ """Number of tasks in the worker job.""")
+
+tf.app.flags.DEFINE_integer("num_problems", 1000,
+ """Number of sub-problems to run.""")
+tf.app.flags.DEFINE_integer("num_meta_iterations", 5,
+ """Number of meta-iterations to optimize.""")
+tf.app.flags.DEFINE_integer("num_unroll_scale", 40,
+ """The scale parameter of the exponential
+ distribution from which the number of partial
+ unrolls is drawn""")
+tf.app.flags.DEFINE_integer("min_num_unrolls", 1,
+ """The minimum number of unrolls per problem.""")
+tf.app.flags.DEFINE_integer("num_partial_unroll_itr_scale", 200,
+ """The scale parameter of the exponential
+ distribution from which the number of iterations
+ per unroll is drawn.""")
+tf.app.flags.DEFINE_integer("min_num_itr_partial_unroll", 50,
+ """The minimum number of iterations for one
+ unroll.""")
+
+tf.app.flags.DEFINE_string("optimizer", "HierarchicalRNN",
+ """Which meta-optimizer to train.""")
+
+# CoordinatewiseRNN-specific flags
+tf.app.flags.DEFINE_integer("cell_size", 20,
+ """Size of the RNN hidden state in each layer.""")
+tf.app.flags.DEFINE_integer("num_cells", 2,
+ """Number of RNN layers.""")
+tf.app.flags.DEFINE_string("cell_cls", "GRUCell",
+ """Type of RNN cell to use.""")
+
+# Metaoptimization parameters
+tf.app.flags.DEFINE_float("meta_learning_rate", 1e-6,
+ """The learning rate for the meta-optimizer.""")
+tf.app.flags.DEFINE_float("gradient_clip_level", 1e4,
+ """The level to clip gradients to.""")
+
+# Training set selection
+tf.app.flags.DEFINE_boolean("include_quadratic_problems", False,
+ """Include non-noisy quadratic problems.""")
+tf.app.flags.DEFINE_boolean("include_noisy_quadratic_problems", True,
+ """Include noisy quadratic problems.""")
+tf.app.flags.DEFINE_boolean("include_large_quadratic_problems", True,
+ """Include very large quadratic problems.""")
+tf.app.flags.DEFINE_boolean("include_bowl_problems", True,
+ """Include 2D bowl problems.""")
+tf.app.flags.DEFINE_boolean("include_softmax_2_class_problems", True,
+ """Include 2-class logistic regression problems.""")
+tf.app.flags.DEFINE_boolean("include_noisy_softmax_2_class_problems", True,
+ """Include noisy 2-class logistic regression
+ problems.""")
+tf.app.flags.DEFINE_boolean("include_optimization_test_problems", True,
+ """Include non-noisy versions of classic
+ optimization test problems, e.g. Rosenbrock.""")
+tf.app.flags.DEFINE_boolean("include_noisy_optimization_test_problems", True,
+ """Include gradient-noise versions of classic
+ optimization test problems, e.g. Rosenbrock""")
+tf.app.flags.DEFINE_boolean("include_fully_connected_random_2_class_problems",
+ True, """Include MLP problems for 2 classes.""")
+tf.app.flags.DEFINE_boolean("include_matmul_problems", True,
+ """Include matrix multiplication problems.""")
+tf.app.flags.DEFINE_boolean("include_log_objective_problems", True,
+ """Include problems where the objective is the log
+ objective of another problem, e.g. Bowl.""")
+tf.app.flags.DEFINE_boolean("include_rescale_problems", True,
+ """Include problems where the parameters are scaled
+ version of the original parameters.""")
+tf.app.flags.DEFINE_boolean("include_norm_problems", True,
+ """Include problems where the objective is the
+ N-norm of another problem, e.g. Quadratic.""")
+tf.app.flags.DEFINE_boolean("include_sum_problems", True,
+ """Include problems where the objective is the sum
+ of the objectives of the subproblems that make
+ up the problem parameters. Per-problem tensors
+ are still independent of each other.""")
+tf.app.flags.DEFINE_boolean("include_sparse_gradient_problems", True,
+ """Include problems where the gradient is set to 0
+ with some high probability.""")
+tf.app.flags.DEFINE_boolean("include_sparse_softmax_problems", False,
+ """Include sparse softmax problems.""")
+tf.app.flags.DEFINE_boolean("include_one_hot_sparse_softmax_problems", False,
+ """Include one-hot sparse softmax problems.""")
+tf.app.flags.DEFINE_boolean("include_noisy_bowl_problems", True,
+ """Include noisy bowl problems.""")
+tf.app.flags.DEFINE_boolean("include_noisy_norm_problems", True,
+ """Include noisy norm problems.""")
+tf.app.flags.DEFINE_boolean("include_noisy_sum_problems", True,
+ """Include noisy sum problems.""")
+tf.app.flags.DEFINE_boolean("include_sum_of_quadratics_problems", False,
+ """Include sum of quadratics problems.""")
+tf.app.flags.DEFINE_boolean("include_projection_quadratic_problems", False,
+ """Include projection quadratic problems.""")
+tf.app.flags.DEFINE_boolean("include_outward_snake_problems", False,
+ """Include outward snake problems.""")
+tf.app.flags.DEFINE_boolean("include_dependency_chain_problems", False,
+ """Include dependency chain problems.""")
+tf.app.flags.DEFINE_boolean("include_min_max_well_problems", False,
+ """Include min-max well problems.""")
+
+# Optimizer parameters: initialization and scale values
+tf.app.flags.DEFINE_float("min_lr", 1e-6,
+ """The minimum initial learning rate.""")
+tf.app.flags.DEFINE_float("max_lr", 1e-2,
+ """The maximum initial learning rate.""")
+
+# Optimizer parameters: small features.
+tf.app.flags.DEFINE_boolean("zero_init_lr_weights", True,
+ """Whether to initialize the learning rate weights
+ to 0 rather than the scaled random initialization
+ used for other RNN variables.""")
+tf.app.flags.DEFINE_boolean("use_relative_lr", True,
+ """Whether to use the relative learning rate as an
+ input during training. Can only be used if
+ learnable_decay is also True.""")
+tf.app.flags.DEFINE_boolean("use_extreme_indicator", False,
+ """Whether to use the extreme indicator for learning
+ rates as an input during training. Can only be
+ used if learnable_decay is also True.""")
+tf.app.flags.DEFINE_boolean("use_log_means_squared", True,
+ """Whether to track the log of the mean squared
+ grads instead of the means squared grads.""")
+tf.app.flags.DEFINE_boolean("use_problem_lr_mean", True,
+ """Whether to use the mean over all learning rates
+ in the problem when calculating the relative
+ learning rate.""")
+
+# Optimizer parameters: major features
+tf.app.flags.DEFINE_boolean("learnable_decay", True,
+ """Whether to learn weights that dynamically
+ modulate the input scale via RMS decay.""")
+tf.app.flags.DEFINE_boolean("dynamic_output_scale", True,
+ """Whether to learn weights that dynamically
+ modulate the output scale.""")
+tf.app.flags.DEFINE_boolean("use_log_objective", True,
+ """Whether to use the log of the scaled objective
+ rather than just the scaled obj for training.""")
+tf.app.flags.DEFINE_boolean("use_attention", False,
+ """Whether to learn where to attend.""")
+tf.app.flags.DEFINE_boolean("use_second_derivatives", True,
+ """Whether to use second derivatives.""")
+tf.app.flags.DEFINE_integer("num_gradient_scales", 4,
+ """How many different timescales to keep for
+ gradient history. If > 1, also learns a scale
+ factor for gradient history.""")
+tf.app.flags.DEFINE_float("max_log_lr", 33,
+ """The maximum log learning rate allowed.""")
+tf.app.flags.DEFINE_float("objective_training_max_multiplier", -1,
+ """How much the objective can grow before training on
+ this problem / param pair is terminated. Sets a max
+ on the objective value when multiplied by the
+ initial objective. If <= 0, not used.""")
+tf.app.flags.DEFINE_boolean("use_gradient_shortcut", True,
+ """Whether to add a learned affine projection of the
+ gradient to the update delta in addition to the
+ gradient function computed by the RNN.""")
+tf.app.flags.DEFINE_boolean("use_lr_shortcut", False,
+ """Whether to add the difference between the current
+ learning rate and the desired learning rate to
+ the RNN input.""")
+tf.app.flags.DEFINE_boolean("use_grad_products", True,
+ """Whether to use gradient products in the input to
+ the RNN. Only applicable when num_gradient_scales
+ > 1.""")
+tf.app.flags.DEFINE_boolean("use_multiple_scale_decays", False,
+ """Whether to use many-timescale scale decays.""")
+tf.app.flags.DEFINE_boolean("use_numerator_epsilon", False,
+ """Whether to use epsilon in the numerator of the
+ log objective.""")
+tf.app.flags.DEFINE_boolean("learnable_inp_decay", True,
+ """Whether to learn input decay weight and bias.""")
+tf.app.flags.DEFINE_boolean("learnable_rnn_init", True,
+ """Whether to learn RNN state initialization.""")
+
+FLAGS = tf.app.flags.FLAGS
+
+# The Size of the RNN hidden state in each layer:
+# [PerParam, PerTensor, Global]. The length of this list must be 1, 2, or 3.
+# If less than 3, the Global and/or PerTensor RNNs will not be created.
+
+HRNN_CELL_SIZES = [10, 20, 20]
+
+
+
+def register_optimizers():
+ opts = {}
+ opts["CoordinatewiseRNN"] = coordinatewise_rnn.CoordinatewiseRNN
+ opts["GlobalLearningRate"] = global_learning_rate.GlobalLearningRate
+ opts["HierarchicalRNN"] = hierarchical_rnn.HierarchicalRNN
+ opts["LearningRateSchedule"] = learning_rate_schedule.LearningRateSchedule
+ opts["TrainableAdam"] = trainable_adam.TrainableAdam
+ return opts
+
+
+def main(unused_argv):
+ """Runs the main script."""
+
+ opts = register_optimizers()
+
+ # Choose a set of problems to optimize. By default this includes quadratics,
+ # 2-dimensional bowls, 2-class softmax problems, and non-noisy optimization
+ # test problems (e.g. Rosenbrock, Beale)
+ problems_and_data = []
+
+ if FLAGS.include_sparse_softmax_problems:
+ problems_and_data.extend(ps.sparse_softmax_2_class_sparse_problems())
+
+ if FLAGS.include_one_hot_sparse_softmax_problems:
+ problems_and_data.extend(
+ ps.one_hot_sparse_softmax_2_class_sparse_problems())
+
+ if FLAGS.include_quadratic_problems:
+ problems_and_data.extend(ps.quadratic_problems())
+
+ if FLAGS.include_noisy_quadratic_problems:
+ problems_and_data.extend(ps.quadratic_problems_noisy())
+
+ if FLAGS.include_large_quadratic_problems:
+ problems_and_data.extend(ps.quadratic_problems_large())
+
+ if FLAGS.include_bowl_problems:
+ problems_and_data.extend(ps.bowl_problems())
+
+ if FLAGS.include_noisy_bowl_problems:
+ problems_and_data.extend(ps.bowl_problems_noisy())
+
+ if FLAGS.include_softmax_2_class_problems:
+ problems_and_data.extend(ps.softmax_2_class_problems())
+
+ if FLAGS.include_noisy_softmax_2_class_problems:
+ problems_and_data.extend(ps.softmax_2_class_problems_noisy())
+
+ if FLAGS.include_optimization_test_problems:
+ problems_and_data.extend(ps.optimization_test_problems())
+
+ if FLAGS.include_noisy_optimization_test_problems:
+ problems_and_data.extend(ps.optimization_test_problems_noisy())
+
+ if FLAGS.include_fully_connected_random_2_class_problems:
+ problems_and_data.extend(ps.fully_connected_random_2_class_problems())
+
+ if FLAGS.include_matmul_problems:
+ problems_and_data.extend(ps.matmul_problems())
+
+ if FLAGS.include_log_objective_problems:
+ problems_and_data.extend(ps.log_objective_problems())
+
+ if FLAGS.include_rescale_problems:
+ problems_and_data.extend(ps.rescale_problems())
+
+ if FLAGS.include_norm_problems:
+ problems_and_data.extend(ps.norm_problems())
+
+ if FLAGS.include_noisy_norm_problems:
+ problems_and_data.extend(ps.norm_problems_noisy())
+
+ if FLAGS.include_sum_problems:
+ problems_and_data.extend(ps.sum_problems())
+
+ if FLAGS.include_noisy_sum_problems:
+ problems_and_data.extend(ps.sum_problems_noisy())
+
+ if FLAGS.include_sparse_gradient_problems:
+ problems_and_data.extend(ps.sparse_gradient_problems())
+ if FLAGS.include_fully_connected_random_2_class_problems:
+ problems_and_data.extend(ps.sparse_gradient_problems_mlp())
+
+ if FLAGS.include_min_max_well_problems:
+ problems_and_data.extend(ps.min_max_well_problems())
+
+ if FLAGS.include_sum_of_quadratics_problems:
+ problems_and_data.extend(ps.sum_of_quadratics_problems())
+
+ if FLAGS.include_projection_quadratic_problems:
+ problems_and_data.extend(ps.projection_quadratic_problems())
+
+ if FLAGS.include_outward_snake_problems:
+ problems_and_data.extend(ps.outward_snake_problems())
+
+ if FLAGS.include_dependency_chain_problems:
+ problems_and_data.extend(ps.dependency_chain_problems())
+
+ # log directory
+ logdir = os.path.join(FLAGS.train_dir,
+ "{}_{}_{}_{}".format(FLAGS.optimizer,
+ FLAGS.cell_cls,
+ FLAGS.cell_size,
+ FLAGS.num_cells))
+
+ # get the optimizer class and arguments
+ optimizer_cls = opts[FLAGS.optimizer]
+
+ assert len(HRNN_CELL_SIZES) in [1, 2, 3]
+ optimizer_args = (HRNN_CELL_SIZES,)
+
+ optimizer_kwargs = {
+ "init_lr_range": (FLAGS.min_lr, FLAGS.max_lr),
+ "learnable_decay": FLAGS.learnable_decay,
+ "dynamic_output_scale": FLAGS.dynamic_output_scale,
+ "cell_cls": getattr(tf.contrib.rnn, FLAGS.cell_cls),
+ "use_attention": FLAGS.use_attention,
+ "use_log_objective": FLAGS.use_log_objective,
+ "num_gradient_scales": FLAGS.num_gradient_scales,
+ "zero_init_lr_weights": FLAGS.zero_init_lr_weights,
+ "use_log_means_squared": FLAGS.use_log_means_squared,
+ "use_relative_lr": FLAGS.use_relative_lr,
+ "use_extreme_indicator": FLAGS.use_extreme_indicator,
+ "max_log_lr": FLAGS.max_log_lr,
+ "obj_train_max_multiplier": FLAGS.objective_training_max_multiplier,
+ "use_problem_lr_mean": FLAGS.use_problem_lr_mean,
+ "use_gradient_shortcut": FLAGS.use_gradient_shortcut,
+ "use_second_derivatives": FLAGS.use_second_derivatives,
+ "use_lr_shortcut": FLAGS.use_lr_shortcut,
+ "use_grad_products": FLAGS.use_grad_products,
+ "use_multiple_scale_decays": FLAGS.use_multiple_scale_decays,
+ "use_numerator_epsilon": FLAGS.use_numerator_epsilon,
+ "learnable_inp_decay": FLAGS.learnable_inp_decay,
+ "learnable_rnn_init": FLAGS.learnable_rnn_init,
+ }
+ optimizer_spec = problem_spec.Spec(
+ optimizer_cls, optimizer_args, optimizer_kwargs)
+
+ # make log directory
+ tf.gfile.MakeDirs(logdir)
+
+ is_chief = FLAGS.task == 0
+ # if this is a distributed run, make the chief run through problems in order
+ select_random_problems = FLAGS.worker_tasks == 1 or not is_chief
+
+ def num_unrolls():
+ return metaopt.sample_numiter(FLAGS.num_unroll_scale, FLAGS.min_num_unrolls)
+
+ def num_partial_unroll_itrs():
+ return metaopt.sample_numiter(FLAGS.num_partial_unroll_itr_scale,
+ FLAGS.min_num_itr_partial_unroll)
+
+ # run it
+ metaopt.train_optimizer(
+ logdir,
+ optimizer_spec,
+ problems_and_data,
+ FLAGS.num_problems,
+ FLAGS.num_meta_iterations,
+ num_unrolls,
+ num_partial_unroll_itrs,
+ learning_rate=FLAGS.meta_learning_rate,
+ gradient_clip=FLAGS.gradient_clip_level,
+ is_chief=is_chief,
+ select_random_problems=select_random_problems,
+ obj_train_max_multiplier=FLAGS.objective_training_max_multiplier,
+ callbacks=[])
+
+ return 0
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/learned_optimizer/optimizer/BUILD b/models/research/learned_optimizer/optimizer/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..8953e7592ace416b786be2a6fa59f4c537c82644
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/BUILD
@@ -0,0 +1,69 @@
+package(default_visibility = ["//visibility:public"])
+
+# Libraries
+# =========
+py_library(
+ name = "coordinatewise_rnn",
+ srcs = ["coordinatewise_rnn.py"],
+ deps = [
+ ":trainable_optimizer",
+ ":utils",
+ ],
+)
+
+py_library(
+ name = "global_learning_rate",
+ srcs = ["global_learning_rate.py"],
+ deps = [
+ ":trainable_optimizer",
+ ],
+)
+
+py_library(
+ name = "hierarchical_rnn",
+ srcs = ["hierarchical_rnn.py"],
+ deps = [
+ ":rnn_cells",
+ ":trainable_optimizer",
+ ":utils",
+ ],
+)
+
+py_library(
+ name = "learning_rate_schedule",
+ srcs = ["learning_rate_schedule.py"],
+ deps = [
+ ":trainable_optimizer",
+ ],
+)
+
+py_library(
+ name = "rnn_cells",
+ srcs = ["rnn_cells.py"],
+ deps = [
+ ":utils",
+ ],
+)
+
+py_library(
+ name = "trainable_adam",
+ srcs = ["trainable_adam.py"],
+ deps = [
+ ":trainable_optimizer",
+ ":utils",
+ ],
+)
+
+py_library(
+ name = "trainable_optimizer",
+ srcs = ["trainable_optimizer.py"],
+ deps = [
+ ],
+)
+
+py_library(
+ name = "utils",
+ srcs = ["utils.py"],
+ deps = [
+ ],
+)
diff --git a/models/research/learned_optimizer/optimizer/coordinatewise_rnn.py b/models/research/learned_optimizer/optimizer/coordinatewise_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d699504b7a3d86643bea6b295d20b2434131a99
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/coordinatewise_rnn.py
@@ -0,0 +1,316 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Collection of trainable optimizers for meta-optimization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+import tensorflow as tf
+
+from learned_optimizer.optimizer import utils
+from learned_optimizer.optimizer import trainable_optimizer as opt
+
+
+# Default was 1e-3
+tf.app.flags.DEFINE_float("crnn_rnn_readout_scale", 0.5,
+ """The initialization scale for the RNN readouts.""")
+tf.app.flags.DEFINE_float("crnn_default_decay_var_init", 2.2,
+ """The default initializer value for any decay/
+ momentum style variables and constants.
+ sigmoid(2.2) ~ 0.9, sigmoid(-2.2) ~ 0.01.""")
+
+FLAGS = tf.flags.FLAGS
+
+
+class CoordinatewiseRNN(opt.TrainableOptimizer):
+ """RNN that operates on each coordinate of the problem independently."""
+
+ def __init__(self,
+ cell_sizes,
+ cell_cls,
+ init_lr_range=(1., 1.),
+ dynamic_output_scale=True,
+ learnable_decay=True,
+ zero_init_lr_weights=False,
+ **kwargs):
+ """Initializes the RNN per-parameter optimizer.
+
+ Args:
+ cell_sizes: List of hidden state sizes for each RNN cell in the network
+ cell_cls: tf.contrib.rnn class for specifying the RNN cell type
+ init_lr_range: the range in which to initialize the learning rates.
+ dynamic_output_scale: whether to learn weights that dynamically modulate
+ the output scale (default: True)
+ learnable_decay: whether to learn weights that dynamically modulate the
+ input scale via RMS style decay (default: True)
+ zero_init_lr_weights: whether to initialize the lr weights to zero
+ **kwargs: args passed to TrainableOptimizer's constructor
+
+ Raises:
+ ValueError: If the init lr range is not of length 2.
+ ValueError: If the init lr range is not a valid range (min > max).
+ """
+ if len(init_lr_range) != 2:
+ raise ValueError(
+ "Initial LR range must be len 2, was {}".format(len(init_lr_range)))
+ if init_lr_range[0] > init_lr_range[1]:
+ raise ValueError("Initial LR range min is greater than max.")
+ self.init_lr_range = init_lr_range
+
+ self.zero_init_lr_weights = zero_init_lr_weights
+ self.reuse_vars = False
+
+ # create the RNN cell
+ with tf.variable_scope(opt.OPTIMIZER_SCOPE):
+ self.component_cells = [cell_cls(sz) for sz in cell_sizes]
+ self.cell = tf.contrib.rnn.MultiRNNCell(self.component_cells)
+
+ # random normal initialization scaled by the output size
+ scale_factor = FLAGS.crnn_rnn_readout_scale / math.sqrt(cell_sizes[-1])
+ scaled_init = tf.random_normal_initializer(0., scale_factor)
+
+ # weights for projecting the hidden state to a parameter update
+ self.update_weights = tf.get_variable("update_weights",
+ shape=(cell_sizes[-1], 1),
+ initializer=scaled_init)
+
+ self._initialize_decay(learnable_decay, (cell_sizes[-1], 1), scaled_init)
+
+ self._initialize_lr(dynamic_output_scale, (cell_sizes[-1], 1),
+ scaled_init)
+
+ state_size = sum([sum(state_size) for state_size in self.cell.state_size])
+ self._init_vector = tf.get_variable(
+ "init_vector", shape=[1, state_size],
+ initializer=tf.random_uniform_initializer(-1., 1.))
+
+ state_keys = ["rms", "rnn", "learning_rate", "decay"]
+ super(CoordinatewiseRNN, self).__init__("cRNN", state_keys, **kwargs)
+
+ def _initialize_decay(
+ self, learnable_decay, weights_tensor_shape, scaled_init):
+ """Initializes the decay weights and bias variables or tensors.
+
+ Args:
+ learnable_decay: Whether to use learnable decay.
+ weights_tensor_shape: The shape the weight tensor should take.
+ scaled_init: The scaled initialization for the weights tensor.
+ """
+ if learnable_decay:
+
+ # weights for projecting the hidden state to the RMS decay term
+ self.decay_weights = tf.get_variable("decay_weights",
+ shape=weights_tensor_shape,
+ initializer=scaled_init)
+ self.decay_bias = tf.get_variable(
+ "decay_bias", shape=(1,),
+ initializer=tf.constant_initializer(
+ FLAGS.crnn_default_decay_var_init))
+ else:
+ self.decay_weights = tf.zeros_like(self.update_weights)
+ self.decay_bias = tf.constant(FLAGS.crnn_default_decay_var_init)
+
+ def _initialize_lr(
+ self, dynamic_output_scale, weights_tensor_shape, scaled_init):
+ """Initializes the learning rate weights and bias variables or tensors.
+
+ Args:
+ dynamic_output_scale: Whether to use a dynamic output scale.
+ weights_tensor_shape: The shape the weight tensor should take.
+ scaled_init: The scaled initialization for the weights tensor.
+ """
+ if dynamic_output_scale:
+ zero_init = tf.constant_initializer(0.)
+ wt_init = zero_init if self.zero_init_lr_weights else scaled_init
+ self.lr_weights = tf.get_variable("learning_rate_weights",
+ shape=weights_tensor_shape,
+ initializer=wt_init)
+ self.lr_bias = tf.get_variable("learning_rate_bias", shape=(1,),
+ initializer=zero_init)
+ else:
+ self.lr_weights = tf.zeros_like(self.update_weights)
+ self.lr_bias = tf.zeros([1, 1])
+
+ def _initialize_state(self, var):
+ """Return a dictionary mapping names of state variables to their values."""
+ vectorized_shape = [var.get_shape().num_elements(), 1]
+
+ min_lr = self.init_lr_range[0]
+ max_lr = self.init_lr_range[1]
+ if min_lr == max_lr:
+ init_lr = tf.constant(min_lr, shape=vectorized_shape)
+ else:
+ actual_vals = tf.random_uniform(vectorized_shape,
+ np.log(min_lr),
+ np.log(max_lr))
+ init_lr = tf.exp(actual_vals)
+
+ ones = tf.ones(vectorized_shape)
+ rnn_init = ones * self._init_vector
+
+ return {
+ "rms": tf.ones(vectorized_shape),
+ "learning_rate": init_lr,
+ "rnn": rnn_init,
+ "decay": tf.ones(vectorized_shape),
+ }
+
+ def _compute_update(self, param, grad, state):
+ """Update parameters given the gradient and state.
+
+ Args:
+ param: tensor of parameters
+ grad: tensor of gradients with the same shape as param
+ state: a dictionary containing any state for the optimizer
+
+ Returns:
+ updated_param: updated parameters
+ updated_state: updated state variables in a dictionary
+ """
+
+ with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope:
+
+ if self.reuse_vars:
+ scope.reuse_variables()
+ else:
+ self.reuse_vars = True
+
+ param_shape = tf.shape(param)
+
+ (grad_values, decay_state, rms_state, rnn_state, learning_rate_state,
+ grad_indices) = self._extract_gradients_and_internal_state(
+ grad, state, param_shape)
+
+ # Vectorize and scale the gradients.
+ grad_scaled, rms = utils.rms_scaling(grad_values, decay_state, rms_state)
+
+ # Apply the RNN update.
+ rnn_state_tuples = self._unpack_rnn_state_into_tuples(rnn_state)
+ rnn_output, rnn_state_tuples = self.cell(grad_scaled, rnn_state_tuples)
+ rnn_state = self._pack_tuples_into_rnn_state(rnn_state_tuples)
+
+ # Compute the update direction (a linear projection of the RNN output).
+ delta = utils.project(rnn_output, self.update_weights)
+
+ # The updated decay is an affine projection of the hidden state
+ decay = utils.project(rnn_output, self.decay_weights,
+ bias=self.decay_bias, activation=tf.nn.sigmoid)
+
+ # Compute the change in learning rate (an affine projection of the RNN
+ # state, passed through a 2x sigmoid, so the change is bounded).
+ learning_rate_change = 2. * utils.project(rnn_output, self.lr_weights,
+ bias=self.lr_bias,
+ activation=tf.nn.sigmoid)
+
+ # Update the learning rate.
+ new_learning_rate = learning_rate_change * learning_rate_state
+
+ # Apply the update to the parameters.
+ update = tf.reshape(new_learning_rate * delta, tf.shape(grad_values))
+
+ if isinstance(grad, tf.IndexedSlices):
+ update = utils.stack_tensor(update, grad_indices, param,
+ param_shape[:1])
+ rms = utils.update_slices(rms, grad_indices, state["rms"], param_shape)
+ new_learning_rate = utils.update_slices(new_learning_rate, grad_indices,
+ state["learning_rate"],
+ param_shape)
+ rnn_state = utils.update_slices(rnn_state, grad_indices, state["rnn"],
+ param_shape)
+ decay = utils.update_slices(decay, grad_indices, state["decay"],
+ param_shape)
+
+ new_param = param - update
+
+ # Collect the update and new state.
+ new_state = {
+ "rms": rms,
+ "learning_rate": new_learning_rate,
+ "rnn": rnn_state,
+ "decay": decay,
+ }
+
+ return new_param, new_state
+
+ def _extract_gradients_and_internal_state(self, grad, state, param_shape):
+ """Extracts the gradients and relevant internal state.
+
+ If the gradient is sparse, extracts the appropriate slices from the state.
+
+ Args:
+ grad: The current gradient.
+ state: The current state.
+ param_shape: The shape of the parameter (used if gradient is sparse).
+
+ Returns:
+ grad_values: The gradient value tensor.
+ decay_state: The current decay state.
+ rms_state: The current rms state.
+ rnn_state: The current state of the internal rnns.
+ learning_rate_state: The current learning rate state.
+ grad_indices: The indices for the gradient tensor, if sparse.
+ None otherwise.
+ """
+ if isinstance(grad, tf.IndexedSlices):
+ grad_indices, grad_values = utils.accumulate_sparse_gradients(grad)
+ decay_state = utils.slice_tensor(state["decay"], grad_indices,
+ param_shape)
+ rms_state = utils.slice_tensor(state["rms"], grad_indices, param_shape)
+ rnn_state = utils.slice_tensor(state["rnn"], grad_indices, param_shape)
+ learning_rate_state = utils.slice_tensor(state["learning_rate"],
+ grad_indices, param_shape)
+ decay_state.set_shape([None, 1])
+ rms_state.set_shape([None, 1])
+ else:
+ grad_values = grad
+ grad_indices = None
+
+ decay_state = state["decay"]
+ rms_state = state["rms"]
+ rnn_state = state["rnn"]
+ learning_rate_state = state["learning_rate"]
+ return (grad_values, decay_state, rms_state, rnn_state, learning_rate_state,
+ grad_indices)
+
+ def _unpack_rnn_state_into_tuples(self, rnn_state):
+ """Creates state tuples from the rnn state vector."""
+ rnn_state_tuples = []
+ cur_state_pos = 0
+ for cell in self.component_cells:
+ total_state_size = sum(cell.state_size)
+ cur_state = tf.slice(rnn_state, [0, cur_state_pos],
+ [-1, total_state_size])
+ cur_state_tuple = tf.split(value=cur_state, num_or_size_splits=2,
+ axis=1)
+ rnn_state_tuples.append(cur_state_tuple)
+ cur_state_pos += total_state_size
+ return rnn_state_tuples
+
+ def _pack_tuples_into_rnn_state(self, rnn_state_tuples):
+ """Creates a single state vector concatenated along column axis."""
+ rnn_state = None
+ for new_state_tuple in rnn_state_tuples:
+ new_c, new_h = new_state_tuple
+ if rnn_state is None:
+ rnn_state = tf.concat([new_c, new_h], axis=1)
+ else:
+ rnn_state = tf.concat([rnn_state, tf.concat([new_c, new_h], 1)], axis=1)
+ return rnn_state
+
diff --git a/models/research/learned_optimizer/optimizer/global_learning_rate.py b/models/research/learned_optimizer/optimizer/global_learning_rate.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcf102fff054e9fe9e92d4379538f6394314fe1c
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/global_learning_rate.py
@@ -0,0 +1,40 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A trainable optimizer that learns a single global learning rate."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from learned_optimizer.optimizer import trainable_optimizer
+
+
+class GlobalLearningRate(trainable_optimizer.TrainableOptimizer):
+ """Optimizes for a single global learning rate."""
+
+ def __init__(self, initial_rate=1e-3, **kwargs):
+ """Initializes the global learning rate."""
+ with tf.variable_scope(trainable_optimizer.OPTIMIZER_SCOPE):
+ initializer = tf.constant_initializer(initial_rate)
+ self.learning_rate = tf.get_variable("global_learning_rate", shape=(),
+ initializer=initializer)
+ super(GlobalLearningRate, self).__init__("GLR", [], **kwargs)
+
+ def _compute_update(self, param, grad, state):
+ return param - tf.scalar_mul(self.learning_rate, grad), state
+
diff --git a/models/research/learned_optimizer/optimizer/hierarchical_rnn.py b/models/research/learned_optimizer/optimizer/hierarchical_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..953b72b5d04724a11a0e95385bbe0c6a0d91289d
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/hierarchical_rnn.py
@@ -0,0 +1,792 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Collection of trainable optimizers for meta-optimization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import state_ops
+from learned_optimizer.optimizer import rnn_cells
+from learned_optimizer.optimizer import trainable_optimizer as opt
+from learned_optimizer.optimizer import utils
+
+# Default was 0.1
+tf.app.flags.DEFINE_float("biasgrucell_scale", 0.5,
+ """The scale for the internal BiasGRUCell vars.""")
+# Default was 0
+tf.app.flags.DEFINE_float("biasgrucell_gate_bias_init", 2.2,
+ """The bias for the internal BiasGRUCell reset and
+ update gate variables.""")
+# Default was 1e-3
+tf.app.flags.DEFINE_float("hrnn_rnn_readout_scale", 0.5,
+ """The initialization scale for the RNN readouts.""")
+tf.app.flags.DEFINE_float("hrnn_default_decay_var_init", 2.2,
+ """The default initializer value for any decay/
+ momentum style variables and constants.
+ sigmoid(2.2) ~ 0.9, sigmoid(-2.2) ~ 0.01.""")
+# Default was 2.2
+tf.app.flags.DEFINE_float("scale_decay_bias_init", 3.2,
+ """The initialization for the scale decay bias. This
+ is the initial bias for the timescale for the
+ exponential avg of the mean square gradients.""")
+tf.app.flags.DEFINE_float("learning_rate_momentum_logit_init", 3.2,
+ """Initialization for the learning rate momentum.""")
+# Default was 0.1
+tf.app.flags.DEFINE_float("hrnn_affine_scale", 0.5,
+ """The initialization scale for the weight matrix of
+ the bias variables in layer0 and 1 of the hrnn.""")
+
+FLAGS = tf.flags.FLAGS
+
+
+class HierarchicalRNN(opt.TrainableOptimizer):
+ """3 level hierarchical RNN.
+
+ Optionally uses second order gradient information and has decoupled evaluation
+ and update locations.
+ """
+
+ def __init__(self, level_sizes, init_lr_range=(1e-6, 1e-2),
+ learnable_decay=True, dynamic_output_scale=True,
+ use_attention=False, use_log_objective=True,
+ num_gradient_scales=4, zero_init_lr_weights=True,
+ use_log_means_squared=True, use_relative_lr=True,
+ use_extreme_indicator=False, max_log_lr=33,
+ obj_train_max_multiplier=-1, use_problem_lr_mean=False,
+ use_gradient_shortcut=False, use_lr_shortcut=False,
+ use_grad_products=False, use_multiple_scale_decays=False,
+ learnable_inp_decay=True, learnable_rnn_init=True,
+ random_seed=None, **kwargs):
+ """Initializes the RNN per-parameter optimizer.
+
+ The hierarchy consists of up to three levels:
+ Level 0: per parameter RNN
+ Level 1: per tensor RNN
+ Level 2: global RNN
+
+ Args:
+ level_sizes: list or tuple with 1, 2, or 3 integers, the number of units
+ in each RNN in the hierarchy (level0, level1, level2).
+ length 1: only coordinatewise rnn's will be used
+ length 2: coordinatewise and tensor-level rnn's will be used
+ length 3: a single global-level rnn will be used in addition to
+ coordinatewise and tensor-level
+ init_lr_range: the range in which to initialize the learning rates
+ learnable_decay: whether to learn weights that dynamically modulate the
+ input scale via RMS style decay
+ dynamic_output_scale: whether to learn weights that dynamically modulate
+ the output scale
+ use_attention: whether to use attention to train the optimizer
+ use_log_objective: whether to train on the log of the objective
+ num_gradient_scales: the number of scales to use for gradient history
+ zero_init_lr_weights: whether to initialize the lr weights to zero
+ use_log_means_squared: whether to track the log of the means_squared,
+ used as a measure of signal vs. noise in gradient.
+ use_relative_lr: whether to use the relative learning rate as an
+ input during training (requires learnable_decay=True)
+ use_extreme_indicator: whether to use the extreme indicator for learning
+ rates as an input during training (requires learnable_decay=True)
+ max_log_lr: the maximum log learning rate allowed during train or test
+ obj_train_max_multiplier: max objective increase during a training run
+ use_problem_lr_mean: whether to use the mean over all learning rates in
+ the problem when calculating the relative learning rate as opposed to
+ the per-tensor mean
+ use_gradient_shortcut: Whether to add a learned affine projection of the
+ gradient to the update delta in addition to the gradient function
+ computed by the RNN
+ use_lr_shortcut: Whether to add as input the difference between the log lr
+ and the desired log lr (1e-3)
+ use_grad_products: Whether to use gradient products in the rnn input.
+ Only applicable if num_gradient_scales > 1
+ use_multiple_scale_decays: Whether to use multiple scales for the scale
+ decay, as with input decay
+ learnable_inp_decay: Whether to learn the input decay weights and bias.
+ learnable_rnn_init: Whether to learn the RNN state initialization.
+ random_seed: Random seed for random variable initializers. (Default: None)
+ **kwargs: args passed to TrainableOptimizer's constructor
+
+ Raises:
+ ValueError: If level_sizes is not a length 1, 2, or 3 list.
+ ValueError: If there are any non-integer sizes in level_sizes.
+ ValueError: If the init lr range is not of length 2.
+ ValueError: If the init lr range is not a valid range (min > max).
+ """
+ if len(level_sizes) not in [1, 2, 3]:
+ raise ValueError("HierarchicalRNN only supports 1, 2, or 3 levels in the "
+ "hierarchy, but {} were requested.".format(
+ len(level_sizes)))
+ if any(not isinstance(level, int) for level in level_sizes):
+ raise ValueError("Level sizes must be integer values, were {}".format(
+ level_sizes))
+ if len(init_lr_range) != 2:
+ raise ValueError(
+ "Initial LR range must be len 2, was {}".format(len(init_lr_range)))
+ if init_lr_range[0] > init_lr_range[1]:
+ raise ValueError("Initial LR range min is greater than max.")
+
+ self.learnable_decay = learnable_decay
+ self.dynamic_output_scale = dynamic_output_scale
+ self.use_attention = use_attention
+ self.use_log_objective = use_log_objective
+ self.num_gradient_scales = num_gradient_scales
+ self.zero_init_lr_weights = zero_init_lr_weights
+ self.use_log_means_squared = use_log_means_squared
+ self.use_relative_lr = use_relative_lr
+ self.use_extreme_indicator = use_extreme_indicator
+ self.max_log_lr = max_log_lr
+ self.use_problem_lr_mean = use_problem_lr_mean
+ self.use_gradient_shortcut = use_gradient_shortcut
+ self.use_lr_shortcut = use_lr_shortcut
+ self.use_grad_products = use_grad_products
+ self.use_multiple_scale_decays = use_multiple_scale_decays
+ self.learnable_inp_decay = learnable_inp_decay
+ self.learnable_rnn_init = learnable_rnn_init
+
+ self.random_seed = random_seed
+
+ self.num_layers = len(level_sizes)
+ self.init_lr_range = init_lr_range
+
+ self.reuse_vars = None
+ self.reuse_global_state = None
+ self.cells = []
+ self.init_vectors = []
+
+ with tf.variable_scope(opt.OPTIMIZER_SCOPE):
+
+ self._initialize_rnn_cells(level_sizes)
+
+ # get the cell size for the per-parameter RNN (Level 0)
+ cell_size = level_sizes[0]
+
+ # Random normal initialization scaled by the output size. This is the
+ # scale for the RNN *readouts*. RNN internal weight scale is set in the
+ # BiasGRUCell call.
+ scale_factor = FLAGS.hrnn_rnn_readout_scale / math.sqrt(cell_size)
+ scaled_init = tf.random_normal_initializer(0., scale_factor,
+ seed=self.random_seed)
+
+ # weights for projecting the hidden state to a parameter update
+ self.update_weights = tf.get_variable("update_weights",
+ shape=(cell_size, 1),
+ initializer=scaled_init)
+
+ if self.use_attention:
+ # weights for projecting the hidden state to the location at which the
+ # gradient is attended
+ self.attention_weights = tf.get_variable(
+ "attention_weights",
+ initializer=self.update_weights.initialized_value())
+
+ # weights for projecting the hidden state to the RMS decay term
+ self._initialize_scale_decay((cell_size, 1), scaled_init)
+ self._initialize_input_decay((cell_size, 1), scaled_init)
+
+ self._initialize_lr((cell_size, 1), scaled_init)
+
+ state_keys = ["parameter", "layer", "scl_decay", "inp_decay", "true_param"]
+
+ if self.dynamic_output_scale:
+ state_keys.append("log_learning_rate")
+
+ for i in range(self.num_gradient_scales):
+ state_keys.append("grad_accum{}".format(i + 1))
+ state_keys.append("ms{}".format(i + 1))
+
+ super(HierarchicalRNN, self).__init__(
+ "hRNN", state_keys, use_attention=use_attention,
+ use_log_objective=use_log_objective,
+ obj_train_max_multiplier=obj_train_max_multiplier, **kwargs)
+
+ def _initialize_rnn_cells(self, level_sizes):
+ """Initializes the RNN cells to use in the hierarchical RNN."""
+
+ # RNN Cell layers (0 -> lowest, 1 -> middle, 2 -> global)
+ for level in range(self.num_layers):
+ scope = "Level{}_RNN".format(level)
+ with tf.variable_scope(scope):
+ hcell = rnn_cells.BiasGRUCell(
+ level_sizes[level],
+ scale=FLAGS.biasgrucell_scale,
+ gate_bias_init=FLAGS.biasgrucell_gate_bias_init,
+ random_seed=self.random_seed)
+ self.cells.append(hcell)
+ if self.learnable_rnn_init:
+ self.init_vectors.append(tf.Variable(
+ tf.random_uniform([1, hcell.state_size], -1., 1.,
+ seed=self.random_seed),
+ name="init_vector"))
+ else:
+ self.init_vectors.append(
+ tf.random_uniform([1, hcell.state_size], -1., 1.,
+ seed=self.random_seed))
+
+ def _initialize_scale_decay(self, weights_tensor_shape, scaled_init):
+ """Initializes the scale decay weights and bias variables or tensors.
+
+ Args:
+ weights_tensor_shape: The shape the weight tensor should take.
+ scaled_init: The scaled initialization for the weights tensor.
+ """
+ if self.learnable_decay:
+ self.scl_decay_weights = tf.get_variable("scl_decay_weights",
+ shape=weights_tensor_shape,
+ initializer=scaled_init)
+ scl_decay_bias_init = tf.constant_initializer(
+ FLAGS.scale_decay_bias_init)
+ self.scl_decay_bias = tf.get_variable("scl_decay_bias",
+ shape=(1,),
+ initializer=scl_decay_bias_init)
+ else:
+ self.scl_decay_weights = tf.zeros_like(self.update_weights)
+ self.scl_decay_bias = tf.log(0.93 / (1. - 0.93))
+
+ def _initialize_input_decay(self, weights_tensor_shape, scaled_init):
+ """Initializes the input scale decay weights and bias variables or tensors.
+
+ Args:
+ weights_tensor_shape: The shape the weight tensor should take.
+ scaled_init: The scaled initialization for the weights tensor.
+ """
+ if (self.learnable_decay and self.num_gradient_scales > 1 and
+ self.learnable_inp_decay):
+ self.inp_decay_weights = tf.get_variable("inp_decay_weights",
+ shape=weights_tensor_shape,
+ initializer=scaled_init)
+ inp_decay_bias_init = tf.constant_initializer(
+ FLAGS.hrnn_default_decay_var_init)
+ self.inp_decay_bias = tf.get_variable("inp_decay_bias",
+ shape=(1,),
+ initializer=inp_decay_bias_init)
+ else:
+ self.inp_decay_weights = tf.zeros_like(self.update_weights)
+ self.inp_decay_bias = tf.log(0.89 / (1. - 0.89))
+
+ def _initialize_lr(self, weights_tensor_shape, scaled_init):
+ """Initializes the learning rate weights and bias variables or tensors.
+
+ Args:
+ weights_tensor_shape: The shape the weight tensor should take.
+ scaled_init: The scaled initialization for the weights tensor.
+ """
+ if self.dynamic_output_scale:
+ zero_init = tf.constant_initializer(0.)
+ wt_init = zero_init if self.zero_init_lr_weights else scaled_init
+ self.lr_weights = tf.get_variable("learning_rate_weights",
+ shape=weights_tensor_shape,
+ initializer=wt_init)
+ self.lr_bias = tf.get_variable("learning_rate_bias", shape=(1,),
+ initializer=zero_init)
+ else:
+ self.lr_weights = tf.zeros_like(self.update_weights)
+ self.lr_bias = tf.zeros([1, 1])
+
+ def _initialize_state(self, var):
+ """Return a dictionary mapping names of state variables to their values."""
+ var_vectorized = tf.reshape(var, [-1, 1])
+ ndim = var_vectorized.get_shape().as_list()[0]
+
+ state = {
+ # parameter init tensor is [var_ndim x layer0_cell_size]
+ "parameter": tf.ones([ndim, 1]) * self.init_vectors[0],
+ "scl_decay": tf.zeros_like(var_vectorized),
+ "inp_decay": tf.zeros_like(var_vectorized),
+ "true_param": var,
+ }
+
+ if self.num_layers > 1:
+ # layer init tensor is [1 x layer1_cell_size]
+ state["layer"] = tf.ones([1, 1]) * self.init_vectors[1]
+
+ if self.dynamic_output_scale:
+ min_lr = self.init_lr_range[0]
+ max_lr = self.init_lr_range[1]
+ if min_lr == max_lr:
+ log_init_lr = tf.log(min_lr * tf.ones_like(var_vectorized))
+ else:
+ # Use a random offset to increase the likelihood that the average of the
+ # LRs for this variable is different from the LRs for other variables.
+ actual_vals = tf.random_uniform(var_vectorized.get_shape().as_list(),
+ np.log(min_lr) / 2.,
+ np.log(max_lr) / 2.,
+ seed=self.random_seed)
+ offset = tf.random_uniform((), np.log(min_lr) / 2., np.log(max_lr) / 2.,
+ seed=self.random_seed)
+ log_init_lr = actual_vals + offset
+ # Clip the log learning rate to the flag at the top end, and to
+ # (log(min int32) - 1) at the bottom
+ clipped = tf.clip_by_value(log_init_lr, -33, self.max_log_lr)
+ state["log_learning_rate"] = clipped
+
+ for i in range(self.num_gradient_scales):
+ state["grad_accum{}".format(i + 1)] = tf.zeros_like(var_vectorized)
+ state["ms{}".format(i + 1)] = tf.zeros_like(var_vectorized)
+
+ return state
+
+ def _initialize_global_state(self):
+ if self.num_layers < 3:
+ return []
+ rnn_global_init = tf.ones([1, 1]) * self.init_vectors[2]
+ return [rnn_global_init]
+
+ def _compute_updates(self, params, grads, states, global_state):
+ # Store the updated parameters and states.
+ updated_params = []
+ updated_attention = []
+ updated_states = []
+
+ with tf.variable_scope(opt.OPTIMIZER_SCOPE):
+
+ mean_log_lr = self._compute_mean_log_lr(states)
+
+ # Iterate over the layers.
+ for param, grad_unflat, state in zip(params, grads, states):
+
+ with tf.variable_scope("PerTensor", reuse=self.reuse_vars):
+ self.reuse_vars = True
+ grad = tf.reshape(grad_unflat, [-1, 1])
+
+ # Create the RNN input. We will optionally extend it with additional
+ # features such as curvature and gradient signal vs. noise.
+ (grads_scaled, mean_squared_gradients,
+ grads_accum) = self._compute_scaled_and_ms_grads(grad, state)
+ rnn_input = [g for g in grads_scaled]
+
+ self._extend_rnn_input(rnn_input, state, grads_scaled,
+ mean_squared_gradients, mean_log_lr)
+
+ # Concatenate any features we've collected.
+ rnn_input_tensor = tf.concat(rnn_input, 1)
+
+ layer_state, new_param_state = self._update_rnn_cells(
+ state, global_state, rnn_input_tensor,
+ len(rnn_input) != len(grads_scaled))
+
+ (scl_decay, inp_decay, new_log_lr, update_step, lr_attend,
+ attention_delta) = self._compute_rnn_state_projections(
+ state, new_param_state, grads_scaled)
+
+ # Apply updates and store state variables.
+ if self.use_attention:
+ truth = state["true_param"]
+ updated_param = truth - update_step
+ attention_step = tf.reshape(lr_attend * attention_delta,
+ truth.get_shape())
+ updated_attention.append(truth - attention_step)
+ else:
+ updated_param = param - update_step
+ updated_attention.append(updated_param)
+ updated_params.append(updated_param)
+
+ # Collect the new state.
+ new_state = {
+ "parameter": new_param_state,
+ "scl_decay": scl_decay,
+ "inp_decay": inp_decay,
+ "true_param": updated_param,
+ }
+ if layer_state is not None:
+ new_state["layer"] = layer_state
+
+ if self.dynamic_output_scale:
+ new_state["log_learning_rate"] = new_log_lr
+
+ for i in range(self.num_gradient_scales):
+ new_state["grad_accum{}".format(i + 1)] = grads_accum[i]
+ new_state["ms{}".format(i + 1)] = mean_squared_gradients[i]
+ updated_states.append(new_state)
+
+ updated_global_state = self._compute_updated_global_state([layer_state],
+ global_state)
+
+ return (updated_params, updated_states, [updated_global_state],
+ updated_attention)
+
+ def _compute_mean_log_lr(self, states):
+ """Computes the mean log learning rate across all variables."""
+ if self.use_problem_lr_mean and self.use_relative_lr:
+
+ sum_log_lr = 0.
+ count_log_lr = 0.
+ for state in states:
+ sum_log_lr += tf.reduce_sum(state["log_learning_rate"])
+ # Note: get_shape().num_elements()=num elements in the original tensor.
+ count_log_lr += state["log_learning_rate"].get_shape().num_elements()
+ return sum_log_lr / count_log_lr
+
+ def _compute_scaled_and_ms_grads(self, grad, state):
+ """Computes the scaled gradient and the mean squared gradients.
+
+ Gradients are also accumulated across different timescales if appropriate.
+
+ Args:
+ grad: The gradient tensor for this layer.
+ state: The optimizer state for this layer.
+
+ Returns:
+ The scaled gradients, mean squared gradients, and accumulated gradients.
+ """
+ input_decays = [state["inp_decay"]]
+ scale_decays = [state["scl_decay"]]
+ if self.use_multiple_scale_decays and self.num_gradient_scales > 1:
+ for i in range(self.num_gradient_scales - 1):
+ scale_decays.append(tf.sqrt(scale_decays[i]))
+
+ for i in range(self.num_gradient_scales - 1):
+ # Each accumulator on twice the timescale of the one before.
+ input_decays.append(tf.sqrt(input_decays[i]))
+ grads_accum = []
+ grads_scaled = []
+ mean_squared_gradients = []
+
+ # populate the scaled gradients and associated mean_squared values
+ if self.num_gradient_scales > 0:
+ for i, decay in enumerate(input_decays):
+ if self.num_gradient_scales == 1:
+ # We don't accumulate if no scales, just take the current gradient.
+ grad_accum = grad
+ else:
+ # The state vars are 1-indexed.
+ old_accum = state["grad_accum{}".format(i + 1)]
+ grad_accum = grad * (1. - decay) + old_accum * decay
+
+ grads_accum.append(grad_accum)
+
+ sd = scale_decays[i if self.use_multiple_scale_decays else 0]
+ grad_scaled, ms = utils.rms_scaling(grad_accum, sd,
+ state["ms{}".format(i + 1)],
+ update_ms=True)
+ grads_scaled.append(grad_scaled)
+ mean_squared_gradients.append(ms)
+
+ return grads_scaled, mean_squared_gradients, grads_accum
+
+ def _extend_rnn_input(self, rnn_input, state, grads_scaled,
+ mean_squared_gradients, mean_log_lr):
+ """Computes additional rnn inputs and adds them to the rnn_input list."""
+ if self.num_gradient_scales > 1 and self.use_grad_products:
+ # This gives a measure of curvature relative to input averaging
+ # lengthscale and to the learning rate
+ grad_products = [a * b for a, b in
+ zip(grads_scaled[:-1], grads_scaled[1:])]
+ rnn_input.extend([g for g in grad_products])
+
+ if self.use_log_means_squared:
+ log_means_squared = [tf.log(ms + 1e-16)
+ for ms in mean_squared_gradients]
+
+ avg = tf.reduce_mean(log_means_squared, axis=0)
+ # This gives a measure of the signal vs. noise contribution to the
+ # gradient, at the current averaging lengthscale. If all the noise
+ # is averaged out, and if updates are small, these will be 0.
+ mean_log_means_squared = [m - avg for m in log_means_squared]
+
+ rnn_input.extend([m for m in mean_log_means_squared])
+
+ if self.use_relative_lr or self.use_extreme_indicator:
+ if not self.dynamic_output_scale:
+ raise Exception("Relative LR and Extreme Indicator features "
+ "require dynamic_output_scale to be set to True.")
+ log_lr_vec = tf.reshape(state["log_learning_rate"], [-1, 1])
+ if self.use_relative_lr:
+ if self.use_problem_lr_mean:
+ # Learning rate of this dimension vs. rest of target problem.
+ relative_lr = log_lr_vec - mean_log_lr
+ else:
+ # Learning rate of this dimension vs. rest of tensor.
+ relative_lr = log_lr_vec - tf.reduce_mean(log_lr_vec)
+ rnn_input.append(relative_lr)
+ if self.use_extreme_indicator:
+ # Indicator of extremely large or extremely small learning rate.
+ extreme_indicator = (tf.nn.relu(log_lr_vec - tf.log(1.)) -
+ tf.nn.relu(tf.log(1e-6) - log_lr_vec))
+ rnn_input.append(extreme_indicator)
+
+ if self.use_lr_shortcut:
+ log_lr_vec = tf.reshape(state["log_learning_rate"], [-1, 1])
+ rnn_input.append(log_lr_vec - tf.log(1e-3))
+
+ def _update_rnn_cells(self, state, global_state, rnn_input_tensor,
+ use_additional_features):
+ """Updates the component RNN cells with the given state and tensor.
+
+ Args:
+ state: The current state of the optimizer.
+ global_state: The current global RNN state.
+ rnn_input_tensor: The input tensor to the RNN.
+ use_additional_features: Whether the rnn input tensor contains additional
+ features beyond the scaled gradients (affects whether the rnn input
+ tensor is used as input to the RNN.)
+
+ Returns:
+ layer_state: The new state of the per-tensor RNN.
+ new_param_state: The new state of the per-parameter RNN.
+ """
+ # lowest level (per parameter)
+ # input -> gradient for this parameter
+ # bias -> output from the layer RNN
+ with tf.variable_scope("Layer0_RNN"):
+ total_bias = None
+ if self.num_layers > 1:
+ sz = 3 * self.cells[0].state_size # size of the concatenated bias
+ param_bias = utils.affine([state["layer"]], sz,
+ scope="Param/Affine",
+ scale=FLAGS.hrnn_affine_scale,
+ random_seed=self.random_seed)
+ total_bias = param_bias
+ if self.num_layers == 3:
+ global_bias = utils.affine(global_state, sz,
+ scope="Global/Affine",
+ scale=FLAGS.hrnn_affine_scale,
+ random_seed=self.random_seed)
+ total_bias += global_bias
+
+ new_param_state, _ = self.cells[0](
+ rnn_input_tensor, state["parameter"], bias=total_bias)
+
+ if self.num_layers > 1:
+ # middle level (per layer)
+ # input -> average hidden state from each parameter in this layer
+ # bias -> output from the RNN at the global level
+ with tf.variable_scope("Layer1_RNN"):
+ if not use_additional_features:
+ # Restore old behavior and only add the mean of the new params.
+ layer_input = tf.reduce_mean(new_param_state, 0, keep_dims=True)
+ else:
+ layer_input = tf.reduce_mean(
+ tf.concat((new_param_state, rnn_input_tensor), 1), 0,
+ keep_dims=True)
+ if self.num_layers == 3:
+ sz = 3 * self.cells[1].state_size
+ layer_bias = utils.affine(global_state, sz,
+ scale=FLAGS.hrnn_affine_scale,
+ random_seed=self.random_seed)
+ layer_state, _ = self.cells[1](
+ layer_input, state["layer"], bias=layer_bias)
+ else:
+ layer_state, _ = self.cells[1](layer_input, state["layer"])
+ else:
+ layer_state = None
+
+ return layer_state, new_param_state
+
+ def _compute_rnn_state_projections(self, state, new_param_state,
+ grads_scaled):
+ """Computes the RNN state-based updates to parameters and update steps."""
+ # Compute the update direction (a linear projection of the RNN output).
+ update_weights = self.update_weights
+
+ update_delta = utils.project(new_param_state, update_weights)
+ if self.use_gradient_shortcut:
+ # Include an affine projection of just the direction of the gradient
+ # so that RNN hidden states are freed up to store more complex
+ # functions of the gradient and other parameters.
+ grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1)
+ update_delta += utils.affine(grads_scaled_tensor, 1,
+ scope="GradsToDelta",
+ include_bias=False,
+ vec_mean=1. / len(grads_scaled),
+ random_seed=self.random_seed)
+ if self.dynamic_output_scale:
+ denom = tf.sqrt(tf.reduce_mean(update_delta ** 2) + 1e-16)
+
+ update_delta /= denom
+
+ if self.use_attention:
+ attention_weights = self.attention_weights
+ attention_delta = utils.project(new_param_state,
+ attention_weights)
+ if self.use_gradient_shortcut:
+ attention_delta += utils.affine(grads_scaled_tensor, 1,
+ scope="GradsToAttnDelta",
+ include_bias=False,
+ vec_mean=1. / len(grads_scaled),
+ random_seed=self.random_seed)
+ if self.dynamic_output_scale:
+ attention_delta /= tf.sqrt(
+ tf.reduce_mean(attention_delta ** 2) + 1e-16)
+ else:
+ attention_delta = None
+
+ # The updated decay is an affine projection of the hidden state.
+ scl_decay = utils.project(new_param_state, self.scl_decay_weights,
+ bias=self.scl_decay_bias,
+ activation=tf.nn.sigmoid)
+ # This is only used if learnable_decay and num_gradient_scales > 1
+ inp_decay = utils.project(new_param_state, self.inp_decay_weights,
+ bias=self.inp_decay_bias,
+ activation=tf.nn.sigmoid)
+
+ # Also update the learning rate.
+ lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate(
+ state, new_param_state)
+
+ update_step = tf.reshape(lr_param * update_delta,
+ state["true_param"].get_shape())
+
+ return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend,
+ attention_delta)
+
+ def _compute_new_learning_rate(self, state, new_param_state):
+ if self.dynamic_output_scale:
+ # Compute the change in learning rate (an affine projection of the
+ # RNN state, passed through a sigmoid or log depending on flags).
+ # Update the learning rate, w/ momentum.
+ lr_change = utils.project(new_param_state, self.lr_weights,
+ bias=self.lr_bias)
+ step_log_lr = state["log_learning_rate"] + lr_change
+
+ # Clip the log learning rate to the flag at the top end, and to
+ # (log(min int32) - 1) at the bottom
+
+ # Check out this hack: we want to be able to compute the gradient
+ # of the downstream result w.r.t lr weights and bias, even if the
+ # value of step_log_lr is outside the clip range. So we clip,
+ # subtract off step_log_lr, and wrap all that in a stop_gradient so
+ # TF never tries to take the gradient of the clip... or the
+ # subtraction. Then we add BACK step_log_lr so that downstream still
+ # receives the clipped value. But the GRADIENT of step_log_lr will
+ # be the gradient of the unclipped value, which we added back in
+ # after stop_gradients.
+ step_log_lr += tf.stop_gradient(
+ tf.clip_by_value(step_log_lr, -33, self.max_log_lr)
+ - step_log_lr)
+
+ lr_momentum_logit = tf.get_variable(
+ "learning_rate_momentum_logit",
+ initializer=FLAGS.learning_rate_momentum_logit_init)
+ lrm = tf.nn.sigmoid(lr_momentum_logit)
+ new_log_lr = (lrm * state["log_learning_rate"] +
+ (1. - lrm) * step_log_lr)
+ param_stepsize_offset = tf.get_variable("param_stepsize_offset",
+ initializer=-1.)
+ lr_param = tf.exp(step_log_lr + param_stepsize_offset)
+ lr_attend = tf.exp(step_log_lr) if self.use_attention else lr_param
+ else:
+ # Dynamic output scale is off, LR param is always 1.
+ lr_param = 2. * utils.project(new_param_state, self.lr_weights,
+ bias=self.lr_bias,
+ activation=tf.nn.sigmoid)
+ new_log_lr = None
+ lr_attend = lr_param
+
+ return lr_param, lr_attend, new_log_lr
+
+ def _compute_updated_global_state(self, layer_states, global_state):
+ """Computes the new global state gives the layers states and old state.
+
+ Args:
+ layer_states: The current layer states.
+ global_state: The old global state.
+
+ Returns:
+ The updated global state.
+ """
+ updated_global_state = []
+ if self.num_layers == 3:
+ # highest (global) layer
+ # input -> average hidden state from each layer-specific RNN
+ # bias -> None
+ with tf.variable_scope("Layer2_RNN", reuse=self.reuse_global_state):
+ self.reuse_global_state = True
+ global_input = tf.reduce_mean(tf.concat(layer_states, 0), 0,
+ keep_dims=True)
+ updated_global_state, _ = self.cells[2](global_input, global_state[0])
+ return updated_global_state
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Overwrites the tf.train.Optimizer interface for applying gradients."""
+
+ # Pull out the variables.
+ grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
+ for g, v in grads_and_vars:
+ if not isinstance(g, (tf.Tensor, tf.IndexedSlices, type(None))):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ if not isinstance(v, tf.Variable):
+ raise TypeError(
+ "Variable must be a tf.Variable: %s" % v)
+ if g is not None:
+ self._assert_valid_dtypes([g, v])
+ var_list = [v for g, v in grads_and_vars if g is not None]
+ if not var_list:
+ raise ValueError("No gradients provided for any variable: %s" %
+ (grads_and_vars,))
+
+ # Create slots for the variables.
+ with tf.control_dependencies(None):
+ self._create_slots(var_list)
+
+ # Store update ops in this list.
+ with tf.op_scope([], name, self._name) as name:
+
+ # Prepare the global state.
+ with tf.variable_scope(self._name, reuse=self.reuse_global_state):
+ gs = self._initialize_global_state()
+ if gs:
+ global_state = [tf.get_variable("global_state", initializer=gs[0])]
+ else:
+ global_state = []
+
+ # Get the states for each variable in the list.
+ states = [{key: self.get_slot(var, key) for key in self.get_slot_names()}
+ for var in var_list]
+
+ # Compute updated values.
+ grads, params = zip(*grads_and_vars)
+ args = (params, grads, states, global_state)
+ updates = self._compute_updates(*args)
+ new_params, new_states, new_global_state, new_attention = updates
+ # Assign op for new global state.
+ update_ops = [tf.assign(gs, ngs)
+ for gs, ngs in zip(global_state, new_global_state)]
+
+ # Create the assign ops for the params and state variables.
+ args = (params, states, new_params, new_attention, new_states)
+ for var, state, new_var, new_var_attend, new_state in zip(*args):
+ # Assign updates to the state variables.
+ state_assign_ops = [tf.assign(state_var, new_state[key])
+ for key, state_var in state.items()]
+
+ # Update the parameter.
+ with tf.control_dependencies(state_assign_ops):
+ if self.use_attention:
+ # Assign to the attended location, rather than the actual location
+ # so that the gradients are computed where attention is.
+ param_update_op = var.assign(new_var_attend)
+ else:
+ param_update_op = var.assign(new_var)
+
+ with tf.name_scope("update_" + var.op.name): #, tf.colocate_with(var):
+ update_ops.append(param_update_op)
+
+ real_params = [self.get_slot(var, "true_param") for var in var_list]
+
+ if global_step is None:
+ # NOTE: if using the optimizer in a non-test-optimizer setting (e.g.
+ # on Inception), remove the real_params return value. Otherwise
+ # the code will throw an error.
+ return self._finish(update_ops, name), real_params
+ else:
+ with tf.control_dependencies([self._finish(update_ops, "update")]):
+ return state_ops.assign_add(global_step, 1, name=name).op, real_params
diff --git a/models/research/learned_optimizer/optimizer/learning_rate_schedule.py b/models/research/learned_optimizer/optimizer/learning_rate_schedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..53db8addd3d152bfa02630ec6e37f0cc1776abc8
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/learning_rate_schedule.py
@@ -0,0 +1,60 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A trainable optimizer that learns a learning rate schedule."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from learned_optimizer.optimizer import trainable_optimizer
+
+
+class LearningRateSchedule(trainable_optimizer.TrainableOptimizer):
+ """Learns a learning rate schedule over a fixed number of iterations."""
+
+ def __init__(self, initial_rate=0.0, n_steps=1000, **kwargs):
+ """Initializes the learning rates."""
+ self.max_index = tf.constant(n_steps-1, dtype=tf.int32)
+
+ with tf.variable_scope(trainable_optimizer.OPTIMIZER_SCOPE):
+ initializer = tf.constant_initializer(initial_rate)
+ self.learning_rates = tf.get_variable("learning_rates",
+ shape=([n_steps,]),
+ initializer=initializer)
+
+ super(LearningRateSchedule, self).__init__("LRS", ["itr"], **kwargs)
+
+ def _initialize_state(self, var):
+ """Return a dictionary mapping names of state variables to their values."""
+ return {
+ "itr": tf.constant(0, dtype=tf.int32),
+ }
+
+ def _compute_update(self, param, grad, state):
+ """Compute updates of parameters."""
+
+ # get the learning rate at the current index, if the index
+ # is greater than the number of available learning rates,
+ # use the last one
+ index = tf.minimum(state["itr"], self.max_index)
+ learning_rate = tf.gather(self.learning_rates, index)
+
+ # update the parameters: parameter - learning_rate * gradient
+ updated_param = param - tf.scalar_mul(learning_rate, grad)
+
+ return updated_param, {"itr": state["itr"] + 1}
diff --git a/models/research/learned_optimizer/optimizer/rnn_cells.py b/models/research/learned_optimizer/optimizer/rnn_cells.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d68de04ca5318bb0f264d4f4647ddbc6fbe08e0
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/rnn_cells.py
@@ -0,0 +1,68 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Custom RNN cells for hierarchical RNNs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from learned_optimizer.optimizer import utils
+
+
+class BiasGRUCell(tf.contrib.rnn.RNNCell):
+ """GRU cell (cf. http://arxiv.org/abs/1406.1078) with an additional bias."""
+
+ def __init__(self, num_units, activation=tf.tanh, scale=0.1,
+ gate_bias_init=0., random_seed=None):
+ self._num_units = num_units
+ self._activation = activation
+ self._scale = scale
+ self._gate_bias_init = gate_bias_init
+ self._random_seed = random_seed
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, bias=None):
+ # Split the injected bias vector into a bias for the r, u, and c updates.
+ if bias is None:
+ bias = tf.zeros((1, 3))
+
+ r_bias, u_bias, c_bias = tf.split(bias, 3, 1)
+
+ with tf.variable_scope(type(self).__name__): # "BiasGRUCell"
+ with tf.variable_scope("gates"): # Reset gate and update gate.
+ proj = utils.affine([inputs, state], 2 * self._num_units,
+ scale=self._scale, bias_init=self._gate_bias_init,
+ random_seed=self._random_seed)
+ r_lin, u_lin = tf.split(proj, 2, 1)
+ r, u = tf.nn.sigmoid(r_lin + r_bias), tf.nn.sigmoid(u_lin + u_bias)
+
+ with tf.variable_scope("candidate"):
+ proj = utils.affine([inputs, r * state], self._num_units,
+ scale=self._scale, random_seed=self._random_seed)
+ c = self._activation(proj + c_bias)
+
+ new_h = u * state + (1 - u) * c
+
+ return new_h, new_h
diff --git a/models/research/learned_optimizer/optimizer/trainable_adam.py b/models/research/learned_optimizer/optimizer/trainable_adam.py
new file mode 100644
index 0000000000000000000000000000000000000000..638217f1b723da8633dc7a82623392eaaf190829
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/trainable_adam.py
@@ -0,0 +1,210 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A trainable ADAM optimizer that learns its internal variables."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from learned_optimizer.optimizer import trainable_optimizer as opt
+from learned_optimizer.optimizer import utils
+
+
+class TrainableAdam(opt.TrainableOptimizer):
+ """Adam optimizer with learnable scalar parameters.
+
+ See Kingma et. al., 2014 for algorithm (http://arxiv.org/abs/1412.6980).
+ """
+
+ def __init__(self,
+ learning_rate=1e-3,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8,
+ **kwargs):
+ """Initializes the TrainableAdam optimizer with the given initial values.
+
+ Args:
+ learning_rate: The learning rate (default: 1e-3).
+ beta1: The exponential decay rate for the 1st moment estimates.
+ beta2: The exponential decay rate for the 2nd moment estimates.
+ epsilon: A small constant for numerical stability.
+ **kwargs: Any additional keyword arguments for TrainableOptimizer.
+
+ Raises:
+ ValueError: if the learning rate or epsilon is not positive
+ ValueError: if beta1 or beta2 is not in (0, 1).
+ """
+ if learning_rate <= 0:
+ raise ValueError("Learning rate must be positive.")
+ if epsilon <= 0:
+ raise ValueError("Epsilon must be positive.")
+ if not 0 < beta1 < 1 or not 0 < beta2 < 1:
+ raise ValueError("Beta values must be between 0 and 1, exclusive.")
+
+ self._reuse_vars = False
+
+ with tf.variable_scope(opt.OPTIMIZER_SCOPE):
+ def inv_sigmoid(x):
+ return np.log(x / (1.0 - x))
+
+ self.log_learning_rate = tf.get_variable(
+ "log_learning_rate",
+ shape=[],
+ initializer=tf.constant_initializer(np.log(learning_rate)))
+ self.beta1_logit = tf.get_variable(
+ "beta1_logit",
+ shape=[],
+ initializer=tf.constant_initializer(inv_sigmoid(beta1)))
+ self.beta2_logit = tf.get_variable(
+ "beta2_logit",
+ shape=[],
+ initializer=tf.constant_initializer(inv_sigmoid(beta2)))
+ self.log_epsilon = tf.get_variable(
+ "log_epsilon",
+ shape=[],
+ initializer=tf.constant_initializer(np.log(epsilon)))
+
+ # Key names are derived from Algorithm 1 described in
+ # https://arxiv.org/pdf/1412.6980.pdf
+ state_keys = ["m", "v", "t"]
+ super(TrainableAdam, self).__init__("Adam", state_keys, **kwargs)
+
+ def _initialize_state(self, var):
+ """Returns a dictionary mapping names of state variables to their values."""
+ vectorized_shape = var.get_shape().num_elements(), 1
+
+ return {key: tf.zeros(vectorized_shape) for key in self.state_keys}
+
+ def _compute_update(self, param, grad, state):
+ """Calculates the new internal state and parameters.
+
+ If the gradient is sparse, updates the appropriate slices in the internal
+ state and stacks the update tensor.
+
+ Args:
+ param: A tensor of parameters.
+ grad: A tensor of gradients with the same shape as param.
+ state: A dictionary containing any state for the optimizer.
+
+ Returns:
+ updated_param: The updated parameters.
+ updated_state: The updated state variables in a dictionary.
+ """
+
+ with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope:
+
+ if self._reuse_vars:
+ scope.reuse_variables()
+ else:
+ self._reuse_vars = True
+
+ (grad_values, first_moment, second_moment, timestep, grad_indices
+ ) = self._extract_gradients_and_internal_state(
+ grad, state, tf.shape(param))
+
+ beta1 = tf.nn.sigmoid(self.beta1_logit)
+ beta2 = tf.nn.sigmoid(self.beta2_logit)
+ epsilon = tf.exp(self.log_epsilon) + 1e-10
+ learning_rate = tf.exp(self.log_learning_rate)
+
+ old_grad_shape = tf.shape(grad_values)
+ grad_values = tf.reshape(grad_values, [-1, 1])
+
+ new_timestep = timestep + 1
+ new_first_moment = self._update_adam_estimate(
+ first_moment, grad_values, beta1)
+ new_second_moment = self._debias_adam_estimate(
+ second_moment, tf.square(grad_values), beta2)
+
+ debiased_first_moment = self._debias_adam_estimate(
+ new_first_moment, beta1, new_timestep)
+ debiased_second_moment = self._debias_adam_estimate(
+ new_second_moment, beta2, new_timestep)
+
+ # Propagating through the square root of 0 is very bad for stability.
+ update = (learning_rate * debiased_first_moment /
+ (tf.sqrt(debiased_second_moment + 1e-10) + epsilon))
+
+ update = tf.reshape(update, old_grad_shape)
+
+ if grad_indices is not None:
+ param_shape = tf.shape(param)
+ update = utils.stack_tensor(
+ update, grad_indices, param, param_shape[:1])
+ new_first_moment = utils.update_slices(
+ new_first_moment, grad_indices, state["m"], param_shape)
+ new_second_moment = utils.update_slices(
+ new_second_moment, grad_indices, state["v"], param_shape)
+ new_timestep = utils.update_slices(
+ new_timestep, grad_indices, state["t"], param_shape)
+
+ new_param = param - update
+
+ # collect the update and new state
+ new_state = {
+ "m": new_first_moment,
+ "v": new_second_moment,
+ "t": new_timestep
+ }
+
+ return new_param, new_state
+
+ def _update_adam_estimate(self, estimate, value, beta):
+ """Returns a beta-weighted average of estimate and value."""
+ return (beta * estimate) + ((1 - beta) * value)
+
+ def _debias_adam_estimate(self, estimate, beta, t_step):
+ """Returns a debiased estimate based on beta and the timestep."""
+ return estimate / (1 - tf.pow(beta, t_step))
+
+ def _extract_gradients_and_internal_state(self, grad, state, param_shape):
+ """Extracts the gradients and relevant internal state.
+
+ If the gradient is sparse, extracts the appropriate slices from the state.
+
+ Args:
+ grad: The current gradient.
+ state: The current state.
+ param_shape: The shape of the parameter (used if gradient is sparse).
+
+ Returns:
+ grad_values: The gradient value tensor.
+ first_moment: The first moment tensor (internal state).
+ second_moment: The second moment tensor (internal state).
+ timestep: The current timestep (internal state).
+ grad_indices: The indices for the gradient tensor, if sparse.
+ None otherwise.
+ """
+ grad_values = grad
+ grad_indices = None
+ first_moment = state["m"]
+ second_moment = state["v"]
+ timestep = state["t"]
+
+ if isinstance(grad, tf.IndexedSlices):
+ grad_indices, grad_values = utils.accumulate_sparse_gradients(grad)
+ first_moment = utils.slice_tensor(
+ first_moment, grad_indices, param_shape)
+ second_moment = utils.slice_tensor(
+ second_moment, grad_indices, param_shape)
+ timestep = utils.slice_tensor(timestep, grad_indices, param_shape)
+
+ return grad_values, first_moment, second_moment, timestep, grad_indices
+
diff --git a/models/research/learned_optimizer/optimizer/trainable_optimizer.py b/models/research/learned_optimizer/optimizer/trainable_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..955112a9dd1d3b0af5ae2f5f0fe8eff65d2dbfc7
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/trainable_optimizer.py
@@ -0,0 +1,574 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A base class definition for trainable optimizers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import itertools
+
+import tensorflow as tf
+
+from tensorflow.python.framework import tensor_shape
+
+OPTIMIZER_SCOPE = "LOL"
+_LOCAL_VARIABLE_PREFIX = "local_state_"
+_LOCAL_STATE_VARIABLE_COLLECTION = "local_state_collection"
+EPSILON = 1e-6
+
+
+class TrainableOptimizer(tf.train.Optimizer):
+ """Base class for trainable optimizers.
+
+ A trainable optimizer is an optimizer that has parameters that can themselves
+ be learned (meta-optimized).
+
+ Subclasses must implement:
+ _compute_update(self, param, grad, state)
+ """
+
+ def __init__(self, name, state_keys, use_attention=False,
+ use_log_objective=False, obj_train_max_multiplier=-1,
+ use_second_derivatives=True, use_numerator_epsilon=False,
+ **kwargs):
+ """Initializes the optimizer with the given name and settings.
+
+ Args:
+ name: The name string for this optimizer.
+ state_keys: The names of any required state variables (list)
+ use_attention: Whether this optimizer uses attention (Default: True)
+ use_log_objective: Whether this optimizer uses the logarithm of the
+ objective when computing the loss (Default: False)
+ obj_train_max_multiplier: The maximum multiplier for the increase in the
+ objective before meta-training is stopped. If <= 0, meta-training is
+ not stopped early. (Default: -1)
+ use_second_derivatives: Whether this optimizer uses second derivatives in
+ meta-training. This should be set to False if some second derivatives
+ in the meta-training problem set are not defined in Tensorflow.
+ (Default: True)
+ use_numerator_epsilon: Whether to use epsilon in the numerator when
+ scaling the problem objective during meta-training. (Default: False)
+ **kwargs: Any additional keyword arguments.
+ """
+ self.use_second_derivatives = use_second_derivatives
+ self.state_keys = sorted(state_keys)
+ self.use_attention = use_attention
+ self.use_log_objective = use_log_objective
+ self.obj_train_max_multiplier = obj_train_max_multiplier
+ self.use_numerator_epsilon = use_numerator_epsilon
+
+ use_locking = False
+ super(TrainableOptimizer, self).__init__(use_locking, name)
+
+ def _create_slots(self, var_list):
+ """Creates all slots needed by the variables.
+
+ Args:
+ var_list: A list of `Variable` objects.
+ """
+ for var in var_list:
+ init_states = self._initialize_state(var)
+ for slot_name in sorted(init_states):
+ slot_var_name = "{}_{}".format(self.get_name(), slot_name)
+ value = init_states[slot_name]
+ self._get_or_make_slot(var, value, slot_name, slot_var_name)
+
+ def _initialize_state(self, var):
+ """Initializes any state required for this variable.
+
+ Args:
+ var: a tensor containing parameters to be optimized
+
+ Returns:
+ state: a dictionary mapping state keys to initial state values (tensors)
+ """
+ return {}
+
+ def _initialize_global_state(self):
+ """Initializes any global state values."""
+ return []
+
+ def _apply_common(self, grad, var):
+ """Applies the optimizer updates to the variables.
+
+ Note: this should only get called via _apply_dense or _apply_sparse when
+ using the optimizer via optimizer.minimize or optimizer.apply_gradients.
+ During meta-training, the optimizer.train function should be used to
+ construct an optimization path that is differentiable.
+
+ Args:
+ grad: A tensor representing the gradient.
+ var: A tf.Variable with the same shape as grad.
+
+ Returns:
+ update_op: A tensorflow op that assigns new values to the variable, and
+ also defines dependencies that update the state variables for the
+ optimizer.
+ """
+ state = {key: self.get_slot(var, key) for key in self.get_slot_names()}
+ new_var, new_state = self._compute_update(var, grad, state)
+ state_assign_ops = [tf.assign(state_var, new_state[key])
+ for key, state_var in state.items()]
+ with tf.control_dependencies(state_assign_ops):
+ update_op = var.assign(new_var)
+
+ return update_op
+
+ def _apply_dense(self, grad, var):
+ """Adds ops to apply dense gradients to 'var'."""
+ return self._apply_common(grad, var)
+
+ def _apply_sparse(self, grad, var):
+ """Adds ops to apply sparse gradients to 'var'."""
+ return self._apply_common(grad, var)
+
+ def _compute_update(self, param, grad, state):
+ """Computes the update step for optimization.
+
+ Args:
+ param: A tensor of parameters to optimize.
+ grad: The gradient tensor of the objective with respect to the parameters.
+ (It has the same shape as param.)
+ state: A dictionary containing any extra state required by the optimizer.
+
+ Returns:
+ updated_params: The updated parameters.
+ updated_state: The dictionary of updated state variable(s).
+ """
+ raise NotImplementedError
+
+ def _compute_updates(self, params, grads, states, global_state):
+ """Maps the compute update functions for each parameter.
+
+ This function can be overriden by a subclass if the subclass wants to
+ combine information across the different parameters in the list.
+
+ Args:
+ params: A list of parameter tensors.
+ grads: A list of gradients corresponding to each parameter.
+ states: A list of state variables corresponding to each parameter.
+ global_state: A list of global state variables for the problem.
+
+ Returns:
+ new_params: The updated parameters.
+ new_states: The updated states.
+ new_global_state: The updated global state.
+ attention_params: A list of attention parameters. This is the same as
+ new_params if the optimizer does not use attention.
+ """
+ # Zip up the arguments to _compute_update.
+ args = zip(params, grads, states)
+
+ # Call compute_update on each set of parameter/gradient/state args.
+ new_params, new_states = zip(*list(
+ itertools.starmap(self._compute_update, args)))
+
+ # Global state is unused in the basic case, just pass it through.
+ return list(new_params), list(new_states), global_state, list(new_params)
+
+ def train(self, problem, dataset):
+ """Creates graph operations to train the optimizer.
+
+ Args:
+ problem: A problem_generator.Problem instance to train on.
+ dataset: A datasets.Dataset tuple to use when training.
+
+ Returns:
+ meta_objective: A tensorflow operation for computing the meta-objective
+ obj_weights: A tensor placeholder for feeding in the objective weights
+ obj_values: The subproblem objective values during optimization
+ batches: The batch indexes tensor for overriding with feed_dict
+ first_unroll: A placeholder signifying if this is a first unroll
+ (this will propagate the gradients slightly differently).
+ reset_state: A placeholder signifying that the rnn state should be reset.
+ output_state: The final state of the optimizer
+ init_loop_vars_to_override: Local variables that can be assigned to
+ propagate the optimizer and problem state for unrolling
+ final_loop_vals: Final values of the loop variables that can be
+ assigned to init_loop_vars_to_override.
+ """
+
+ # Placeholder for the objective weights
+ obj_weights = tf.placeholder(tf.float32)
+ num_iter = tf.shape(obj_weights)[0]
+
+ # Unpack the dataset and generate the minibatches for training
+ data, labels = dataset
+ # Convert the ndarrays to tensors so we can pass them back in via feed_dict
+ data = tf.constant(data)
+ labels = tf.constant(labels)
+ batches = tf.placeholder(tf.int32)
+ first_unroll = tf.placeholder_with_default(False, [])
+ reset_state = tf.placeholder_with_default(False, [])
+
+ training_output = collections.namedtuple("TrainingOutput",
+ ["metaobj",
+ "obj_weights",
+ "problem_objectives",
+ "initial_obj",
+ "batches",
+ "first_unroll",
+ "reset_state",
+ "output_state",
+ "init_loop_vars",
+ "output_loop_vars"])
+
+ def loop_body(itr, obj_accum, params, attend_params, flattened_states,
+ global_state, all_obj, unused_init_obj, data,
+ labels, batches):
+ """Body of the meta-training while loop for optimizing a sub-problem.
+
+ Args:
+ itr: The current meta-training iteration.
+ obj_accum: The accumulated objective over all training steps so far.
+ params: The parameters of the sub-problem.
+ attend_params: The parameters of the sub-problems at the attended
+ location.
+ flattened_states: The states of the trainable optimizer, sorted and
+ flattened into a list (since a while loop can't handle nested lists
+ or dictionaries).
+ global_state: The global state of the optimizer.
+ all_obj: The list of all objective values in the training process.
+ unused_init_obj: The initial objective (unused here, but needed in the
+ variable list because it's used in a stopping condition in the
+ loop_cond.)
+ data: The data for this problem.
+ labels: The labels corresponding to the data.
+ batches: The batch indexes needed for shuffled minibatch creation.
+
+ Returns:
+ itr: The updated meta-training iteration.
+ obj_accum: The updated accumulated objective.
+ params: The new parameters of the sub-problem.
+ attend_params: The new parameters of the sub-problems at the attended
+ location.
+ flattened_states: The new states of the trainable optimizer.
+ global_state: The updated global state.
+ all_obj: The updates list of all objective values.
+ unused_init_obj: The initial objective.
+ data: The data for this problem.
+ labels: The labels corresponding to the data.
+ batches: The batch indexes needed for shuffled minibatch creation.
+ """
+ batch_indices = tf.gather(batches, itr)
+ batch_data = tf.gather(data, batch_indices)
+ batch_labels = tf.gather(labels, batch_indices)
+
+ # Compute the objective over the entire dataset (full batch).
+ obj = problem.objective(params, data, labels)
+
+ # Compute the gradients on just the current batch
+ if self.use_attention:
+ current_obj = problem.objective(attend_params, batch_data, batch_labels)
+ grads = problem.gradients(current_obj, attend_params)
+ else:
+ current_obj = problem.objective(params, batch_data, batch_labels)
+ grads = problem.gradients(current_obj, params)
+
+ if not self.use_second_derivatives:
+ new_grads = []
+ for grad in grads:
+ if isinstance(grad, tf.IndexedSlices):
+ new_grads.append(
+ tf.IndexedSlices(tf.stop_gradient(grad.values), grad.indices))
+ else:
+ new_grads.append(tf.stop_gradient(grad))
+ grads = new_grads
+
+ # store the objective value for the entire problem at each iteration
+ all_obj = tf.concat([all_obj, tf.reshape(obj, (1,))], 0)
+
+ # accumulate the weighted objective for the entire dataset
+ acc = tf.gather(obj_weights, itr) * obj
+
+ obj_accum = tf.add(obj_accum, acc)
+ # Set the shape to keep the shape invariant for obj_accum. Without this,
+ # the graph builder thinks the tensor shape is unknown on the 2nd iter.
+ obj_accum.set_shape([])
+
+ # convert flattened_states to dictionaries
+ dict_states = [dict(zip(self.state_keys, flat_state))
+ for flat_state in flattened_states]
+
+ # compute the new parameters and states
+ args = (params, grads, dict_states, global_state)
+ updates = self._compute_updates(*args)
+ new_params, new_states, new_global_state, new_attend_params = updates
+
+ # flatten the states
+ new_flattened_states = map(flatten_and_sort, new_states)
+
+ return [itr + 1, obj_accum, new_params, new_attend_params,
+ new_flattened_states, new_global_state, all_obj, unused_init_obj,
+ data, labels, batches]
+
+ def loop_cond(itr, obj_accum, unused_params, unused_attend_params,
+ unused_flattened_states, unused_global_state, all_obj,
+ init_obj, *args):
+ """Termination conditions of the sub-problem optimization loop."""
+ del args # unused
+
+ cond1 = tf.less(itr, num_iter) # We've run < num_iter times
+ cond2 = tf.is_finite(obj_accum) # The objective is still finite
+
+ if self.obj_train_max_multiplier > 0:
+ current_obj = tf.gather(all_obj, itr)
+ # Account for negative init_obj too
+ max_diff = (self.obj_train_max_multiplier - 1) * tf.abs(init_obj)
+ max_obj = init_obj + max_diff
+ # The objective is a reasonable multiplier of the original objective
+ cond3 = tf.less(current_obj, max_obj)
+
+ return tf.logical_and(tf.logical_and(cond1, cond2), cond3,
+ name="training_loop_cond")
+ else:
+ return tf.logical_and(cond1, cond2, name="training_loop_cond")
+
+ init = self._initialize_training_loop_parameters(
+ problem, data, labels, batches, first_unroll, reset_state)
+ loop_vars, invariants, initial_obj, init_loop_vars_to_override = init
+
+ loop_output = tf.while_loop(loop_cond, loop_body, loop_vars,
+ swap_memory=True, shape_invariants=invariants)
+ meta_obj, problem_objectives = loop_output[1], loop_output[6]
+
+ # The meta objective is normalized by the initial objective at the start of
+ # the series of partial unrolls.
+ scaled_meta_objective = self.scale_objective(
+ meta_obj, problem_objectives, initial_obj)
+
+ final_loop_vals = (
+ [initial_obj] + loop_output[2] + loop_output[3] + loop_output[5])
+ final_loop_vals.extend(itertools.chain(*loop_output[4]))
+
+ return training_output(scaled_meta_objective,
+ obj_weights,
+ problem_objectives,
+ initial_obj,
+ batches,
+ first_unroll,
+ reset_state,
+ loop_output[4],
+ init_loop_vars_to_override,
+ final_loop_vals)
+
+ def _initialize_training_loop_parameters(
+ self, problem, data, labels, batches, first_unroll, reset_state):
+ """Initializes the vars and params needed for the training process.
+
+ Args:
+ problem: The problem being optimized.
+ data: The data for the problem.
+ labels: The corresponding labels for the data.
+ batches: The indexes needed to create shuffled batches of the data.
+ first_unroll: Whether this is the first unroll in a partial unrolling.
+ reset_state: Whether RNN state variables should be reset.
+
+ Returns:
+ loop_vars: The while loop variables for training.
+ invariants: The corresponding variable shapes (required by while loop).
+ initial_obj: The initial objective (used later for scaling).
+ init_loop_vars_to_override: The loop vars that can be overridden when
+ performing training via partial unrolls.
+ """
+ # Extract these separately so we don't have to make inter-variable
+ # dependencies.
+ initial_tensors = problem.init_tensors()
+
+ return_initial_tensor_values = first_unroll
+ initial_params_vars, initial_params = local_state_variables(
+ initial_tensors, return_initial_tensor_values)
+ initial_attend_params_vars, initial_attend_params = local_state_variables(
+ initial_tensors, return_initial_tensor_values)
+ # Recalculate the initial objective for the list on each partial unroll with
+ # the new initial_params. initial_obj holds the value from the very first
+ # unroll.
+ initial_obj_init = problem.objective(initial_params, data, labels)
+ return_initial_obj_init = first_unroll
+ [initial_obj_var], [initial_obj] = local_state_variables(
+ [initial_obj_init], return_initial_obj_init)
+
+ # Initialize the loop variables.
+ initial_itr = tf.constant(0, dtype=tf.int32)
+ initial_meta_obj = tf.constant(0, dtype=tf.float32)
+ # N.B. the use of initial_obj_init here rather than initial_obj
+ initial_problem_objectives = tf.reshape(initial_obj_init, (1,))
+
+ # Initialize the extra state.
+ initial_state_vars = []
+ initial_state = []
+ state_shapes = []
+ return_initial_state_values = reset_state
+ for param in initial_tensors:
+ param_state_vars, param_state = local_state_variables(
+ flatten_and_sort(self._initialize_state(param)),
+ return_initial_state_values)
+
+ initial_state_vars.append(param_state_vars)
+ initial_state.append(param_state)
+ state_shapes.append([f.get_shape() for f in param_state])
+
+ # Initialize any global (problem-level) state.
+ initial_global_state_vars, initial_global_state = local_state_variables(
+ self._initialize_global_state(), return_initial_state_values)
+
+ global_shapes = []
+ for item in initial_global_state:
+ global_shapes.append(item.get_shape())
+
+ # build the list of loop variables:
+ loop_vars = [
+ initial_itr,
+ initial_meta_obj,
+ initial_params, # Local variables.
+ initial_attend_params, # Local variables.
+ initial_state, # Local variables.
+ initial_global_state, # Local variables.
+ initial_problem_objectives,
+ initial_obj, # Local variable.
+ data,
+ labels,
+ batches,
+ ]
+
+ invariants = [
+ initial_itr.get_shape(),
+ initial_meta_obj.get_shape(),
+ [t.get_shape() for t in initial_params],
+ [t.get_shape() for t in initial_attend_params],
+ state_shapes,
+ global_shapes,
+ tensor_shape.TensorShape([None]), # The problem objectives list grows
+ initial_obj.get_shape(),
+ tensor_shape.unknown_shape(), # Placeholder shapes are unknown
+ tensor_shape.unknown_shape(),
+ tensor_shape.unknown_shape(),
+ ]
+
+ # Initialize local variables that we will override with final tensors at the
+ # next iter.
+ init_loop_vars_to_override = (
+ [initial_obj_var] + initial_params_vars + initial_attend_params_vars +
+ initial_global_state_vars)
+ init_loop_vars_to_override.extend(itertools.chain(*initial_state_vars))
+
+ return loop_vars, invariants, initial_obj, init_loop_vars_to_override
+
+ def scale_objective(self, total_obj, all_objs, initial_obj,
+ obj_scale_eps=1e-6):
+ """Normalizes the objective based on the initial objective value.
+
+ Args:
+ total_obj: The total accumulated objective over the training run.
+ all_objs: A list of all the individual objectives over the training run.
+ initial_obj: The initial objective value.
+ obj_scale_eps: The epsilon value to use in computations for stability.
+
+ Returns:
+ The scaled objective as a single value.
+ """
+ if self.use_log_objective:
+ if self.use_numerator_epsilon:
+ scaled_problem_obj = ((all_objs + obj_scale_eps) /
+ (initial_obj + obj_scale_eps))
+ log_scaled_problem_obj = tf.log(scaled_problem_obj)
+ else:
+ scaled_problem_obj = all_objs / (initial_obj + obj_scale_eps)
+ log_scaled_problem_obj = tf.log(scaled_problem_obj + obj_scale_eps)
+ return tf.reduce_mean(log_scaled_problem_obj)
+ else:
+ return total_obj / (initial_obj + obj_scale_eps)
+
+
+def local_state_variables(init_values, return_init_values):
+ """Create local variables initialized from init_values.
+
+ This will create local variables from a list of init_values. Each variable
+ will be named based on the value's shape and dtype.
+
+ As a convenience, a boolean tensor allows you to return value from
+ the created local variable or from the original init value.
+
+ Args:
+ init_values: iterable of tensors
+ return_init_values: boolean tensor
+
+ Returns:
+ local_vars: list of the created local variables.
+ vals: if return_init_values is true, then this returns the values of
+ init_values. Otherwise it returns the values of the local_vars.
+ """
+ if not init_values:
+ return [], []
+
+ # This generates a harmless warning when saving the metagraph.
+ variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION)
+ if not variable_use_count:
+ variable_use_count.append(collections.defaultdict(int))
+ variable_use_count = variable_use_count[0]
+
+ local_vars = []
+ with tf.variable_scope(OPTIMIZER_SCOPE):
+ # We can't use the init_value as an initializer as init_value may
+ # itself depend on some problem variables. This would produce
+ # inter-variable initialization order dependence which TensorFlow
+ # sucks at making easy.
+ for init_value in init_values:
+ name = create_local_state_variable_name(init_value)
+ unique_name = name + "_" + str(variable_use_count[name])
+ variable_use_count[name] += 1
+ # The overarching idea here is to be able to reuse variables between
+ # different sessions on the same TensorFlow master without errors. By
+ # uniquifying based on the type and name we mirror the checks made inside
+ # TensorFlow, while still allowing some memory reuse. Ultimately this is a
+ # hack due to the broken Session.reset().
+ local_vars.append(
+ tf.get_local_variable(
+ unique_name,
+ initializer=tf.zeros(
+ init_value.get_shape(), dtype=init_value.dtype)))
+
+ # It makes things a lot simpler if we use the init_value the first
+ # iteration, instead of the variable itself. It allows us to propagate
+ # gradients through it as well as simplifying initialization. The variable
+ # ends up assigned to after the first iteration.
+ vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars)
+ if len(init_values) == 1:
+ # tf.cond extracts elements from singleton lists.
+ vals = [vals]
+ return local_vars, vals
+
+
+def create_local_state_variable_name(tensor):
+ """Create a name of the variable based on its type and shape."""
+ if not tensor.get_shape().is_fully_defined():
+ raise ValueError("Need a fully specified shape to create a local variable.")
+
+ return (_LOCAL_VARIABLE_PREFIX + "_".join(
+ map(str, tensor.get_shape().as_list())) + "_" + tensor.dtype.name)
+
+
+def is_local_state_variable(op):
+ """Returns if this op is a local state variable created for training."""
+ return op.node_def.op in ["Variable", "VariableV2"] and op.name.startswith(
+ OPTIMIZER_SCOPE + "/" + _LOCAL_VARIABLE_PREFIX)
+
+
+def flatten_and_sort(dictionary):
+ """Flattens a dictionary into a list of values sorted by the keys."""
+ return [dictionary[k] for k in sorted(dictionary.keys())]
diff --git a/models/research/learned_optimizer/optimizer/utils.py b/models/research/learned_optimizer/optimizer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..58744f4cb7919a84ecc8702ff1236e4c0a03f218
--- /dev/null
+++ b/models/research/learned_optimizer/optimizer/utils.py
@@ -0,0 +1,278 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities and helper functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+
+def make_finite(t, replacement):
+ """Replaces non-finite tensor values with the replacement value."""
+ return tf.where(tf.is_finite(t), t, replacement)
+
+
+def asinh(x):
+ """Computes the inverse hyperbolic sine function (in tensorflow)."""
+ return tf.log(x + tf.sqrt(1. + x ** 2))
+
+
+def affine(inputs, output_size, scope="Affine", scale=0.1, vec_mean=0.,
+ include_bias=True, bias_init=0., random_seed=None):
+ """Computes an affine function of the inputs.
+
+ Creates or recalls tensorflow variables "Matrix" and "Bias"
+ to generate an affine operation on the input.
+
+ If the inputs are a list of tensors, they are concatenated together.
+
+ Initial weights for the matrix are drawn from a Gaussian with zero
+ mean and standard deviation that is the given scale divided by the
+ square root of the input dimension. Initial weights for the bias are
+ set to zero.
+
+ Args:
+ inputs: List of tensors with shape (batch_size, input_size)
+ output_size: Size (dimension) of the output
+ scope: Variable scope for these parameters (default: "Affine")
+ scale: Initial weight scale for the matrix parameters (default: 0.1),
+ this constant is divided by the sqrt of the input size to get the
+ std. deviation of the initial weights
+ vec_mean: The mean for the random initializer
+ include_bias: Whether to include the bias term
+ bias_init: The initializer bias (default 0.)
+ random_seed: Random seed for random initializers. (Default: None)
+
+ Returns:
+ output: Tensor with shape (batch_size, output_size)
+ """
+
+ # Concatenate the input arguments.
+ x = tf.concat(inputs, 1)
+
+ with tf.variable_scope(scope):
+ input_size = x.get_shape().as_list()[1]
+
+ sigma = scale / np.sqrt(input_size)
+ rand_init = tf.random_normal_initializer(mean=vec_mean, stddev=sigma,
+ seed=random_seed)
+
+ matrix = tf.get_variable("Matrix", [input_size, output_size],
+ dtype=tf.float32, initializer=rand_init)
+
+ if include_bias:
+ bias = tf.get_variable("Bias", [output_size], dtype=tf.float32,
+ initializer=tf.constant_initializer(bias_init,
+ tf.float32))
+ else:
+ bias = 0.
+ output = tf.matmul(x, matrix) + bias
+
+ return output
+
+
+def project(inputs, weights, bias=0., activation=tf.identity):
+ """Computes an affine or linear projection of the inputs.
+
+ Projects the inputs onto the given weight vector and (optionally)
+ adds a bias and passes the result through an activation function.
+
+ Args:
+ inputs: matrix of inputs with shape [batch_size, dim]
+ weights: weight matrix with shape [dim, output_dim]
+ bias: bias vector with shape [output_dim] (default: 0)
+ activation: nonlinear activation function (default: tf.identity)
+
+ Returns:
+ outputs: an op which computes activation(inputs @ weights + bias)
+ """
+ return activation(tf.matmul(inputs, weights) + bias)
+
+
+def new_mean_squared(grad_vec, decay, ms):
+ """Calculates the new accumulated mean squared of the gradient.
+
+ Args:
+ grad_vec: the vector for the current gradient
+ decay: the decay term
+ ms: the previous mean_squared value
+
+ Returns:
+ the new mean_squared value
+ """
+ decay_size = decay.get_shape().num_elements()
+ decay_check_ops = [
+ tf.assert_less_equal(decay, 1., summarize=decay_size),
+ tf.assert_greater_equal(decay, 0., summarize=decay_size)]
+
+ with tf.control_dependencies(decay_check_ops):
+ grad_squared = tf.square(grad_vec)
+
+ # If the previous mean_squared is the 0 vector, don't use the decay and just
+ # return the full grad_squared. This should only happen on the first timestep.
+ decay = tf.cond(tf.reduce_all(tf.equal(ms, 0.)),
+ lambda: tf.zeros_like(decay, dtype=tf.float32), lambda: decay)
+
+ # Update the running average of squared gradients.
+ epsilon = 1e-12
+ return (1. - decay) * (grad_squared + epsilon) + decay * ms
+
+
+def rms_scaling(gradient, decay, ms, update_ms=True):
+ """Vectorizes and scales a tensor of gradients.
+
+ Args:
+ gradient: the current gradient
+ decay: the current decay value.
+ ms: the previous mean squared value
+ update_ms: Whether to update the mean squared value (default: True)
+
+ Returns:
+ The scaled gradient and the new ms value if update_ms is True,
+ the old ms value otherwise.
+ """
+
+ # Vectorize the gradients and compute the squared gradients.
+ grad_vec = tf.reshape(gradient, [-1, 1])
+
+ if update_ms:
+ ms = new_mean_squared(grad_vec, decay, ms)
+
+ # Scale the current gradients by the RMS, squashed by the asinh function.
+ scaled_gradient = asinh(grad_vec / tf.sqrt(ms + 1e-16))
+
+ return scaled_gradient, ms
+
+
+def accumulate_sparse_gradients(grad):
+ """Accumulates repeated indices of a sparse gradient update.
+
+ Args:
+ grad: a tf.IndexedSlices gradient
+
+ Returns:
+ grad_indices: unique indices
+ grad_values: gradient values corresponding to the indices
+ """
+
+ grad_indices, grad_segments = tf.unique(grad.indices)
+ grad_values = tf.unsorted_segment_sum(grad.values, grad_segments,
+ tf.shape(grad_indices)[0])
+ return grad_indices, grad_values
+
+
+def slice_tensor(dense_tensor, indices, head_dims):
+ """Extracts slices from a partially flattened dense tensor.
+
+ indices is assumed to index into the first dimension of head_dims.
+ dense_tensor is assumed to have a shape [D_0, D_1, ...] such that
+ prod(head_dims) == D_0. This function will extract slices along the
+ first_dimension of head_dims.
+
+ Example:
+
+ Consider a tensor with shape head_dims = [100, 2] and a dense_tensor with
+ shape [200, 3]. Note that the first dimension of dense_tensor equals the
+ product of head_dims. This function will reshape dense_tensor such that
+ its shape is now [100, 2, 3] (i.e. the first dimension became head-dims)
+ and then slice it along the first dimension. After slicing, the slices will
+ have their initial dimensions flattened just as they were in dense_tensor
+ (e.g. if there are 4 indices, the return value will have a shape of [4, 3]).
+
+ Args:
+ dense_tensor: a N-D dense tensor. Shape: [D_0, D_1, ...]
+ indices: a 1-D integer tensor. Shape: [K]
+ head_dims: True dimensions of the dense_tensor's first dimension.
+
+ Returns:
+ Extracted slices. Shape [K, D_1, ...]
+ """
+
+ tail_dims = tf.shape(dense_tensor)[1:]
+ dense_tensor = tf.reshape(dense_tensor,
+ tf.concat([head_dims, tail_dims], 0))
+
+ slices = tf.gather(dense_tensor, indices)
+ # NOTE(siege): This kills the shape annotation.
+ return tf.reshape(slices, tf.concat([[-1], tail_dims], 0))
+
+
+def stack_tensor(slices, indices, dense_tensor, head_dims):
+ """Reconsititutes a tensor from slices and corresponding indices.
+
+ This is an inverse operation to slice_tensor. Missing slices are set to 0.
+
+ Args:
+ slices: a tensor. Shape [K, D_1, ...]
+ indices: a 1-D integer tensor. Shape: [K]
+ dense_tensor: the original tensor the slices were taken
+ from. Shape: [D_0, D_1, ...]
+ head_dims: True dimensions of the dense_tensor's first dimension.
+
+ Returns:
+ Reconsituted tensor. Shape: [D_0, D_1, ...]
+ """
+ # NOTE(siege): This cast shouldn't be necessary.
+ indices = tf.cast(indices, tf.int32)
+
+ tail_dims = tf.shape(dense_tensor)[1:]
+ dense_shape = tf.concat([head_dims, tail_dims], 0)
+
+ slices = tf.reshape(slices, tf.concat([[-1], dense_shape[1:]], 0))
+ indices = tf.expand_dims(indices, -1)
+
+ return tf.reshape(tf.scatter_nd(indices, slices, dense_shape),
+ tf.shape(dense_tensor))
+
+
+def update_slices(slices, indices, dense_tensor, head_dims):
+ """Reconstitutes a tensor from slices and corresponding indices.
+
+ Like _stack_tensor, but instead of setting missing slices to 0, sets them to
+ what they were in the original tensor. The return value is reshaped to be
+ the same as dense_tensor.
+
+ Args:
+ slices: a tensor. Shape [K, D_1, ...]
+ indices: a 1-D integer tensor. Shape: [K]
+ dense_tensor: the original tensor the slices were taken
+ from. Shape: [D_0, D_1, ...]
+ head_dims: True dimensions of the dense_tensor's first dimension.
+
+ Returns:
+ Reconsituted tensor. Shape: [D_0, D_1, ...]
+ """
+ # NOTE(siege): This cast shouldn't be necessary.
+ indices = tf.cast(indices, tf.int32)
+
+ tail_dims = tf.shape(dense_tensor)[1:]
+ dense_shape = tf.concat([head_dims, tail_dims], 0)
+
+ update_mask_vals = tf.fill(tf.shape(indices), 1)
+ reshaped_indices = tf.expand_dims(indices, -1)
+ update_mask = tf.equal(
+ tf.scatter_nd(reshaped_indices, update_mask_vals, head_dims[:1]), 1)
+
+ reshaped_dense_slices = tf.reshape(
+ stack_tensor(slices, indices, dense_tensor, head_dims), dense_shape)
+ reshaped_dense_tensor = tf.reshape(dense_tensor, dense_shape)
+
+ return tf.reshape(
+ tf.where(update_mask, reshaped_dense_slices, reshaped_dense_tensor),
+ tf.shape(dense_tensor))
diff --git a/models/research/learned_optimizer/problems/BUILD b/models/research/learned_optimizer/problems/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..c704618821b36ca23f221f724888cde4e5d5a5ad
--- /dev/null
+++ b/models/research/learned_optimizer/problems/BUILD
@@ -0,0 +1,43 @@
+package(default_visibility = ["//visibility:public"])
+
+# Libraries
+# =====
+
+py_library(
+ name = "datasets",
+ srcs = ["datasets.py"],
+ deps = [
+ ],
+)
+
+py_library(
+ name = "model_adapter",
+ srcs = ["model_adapter.py"],
+ deps = [
+ ":problem_generator",
+ ],
+)
+
+py_library(
+ name = "problem_generator",
+ srcs = ["problem_generator.py"],
+ deps = [
+ ":problem_spec",
+ ],
+)
+
+py_library(
+ name = "problem_sets",
+ srcs = ["problem_sets.py"],
+ deps = [
+ ":datasets",
+ ":model_adapter",
+ ":problem_generator",
+ ],
+)
+
+py_library(
+ name = "problem_spec",
+ srcs = ["problem_spec.py"],
+ deps = [],
+)
diff --git a/models/research/learned_optimizer/problems/datasets.py b/models/research/learned_optimizer/problems/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..edf3df6532178b0e60ab93c78611d2313798e639
--- /dev/null
+++ b/models/research/learned_optimizer/problems/datasets.py
@@ -0,0 +1,218 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functions to generate or load datasets for supervised learning."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+
+import numpy as np
+from sklearn.datasets import make_classification
+
+MAX_SEED = 4294967295
+
+
+class Dataset(namedtuple("Dataset", "data labels")):
+ """Helper class for managing a supervised learning dataset.
+
+ Args:
+ data: an array of type float32 with N samples, each of which is the set
+ of features for that sample. (Shape (N, D_i), where N is the number of
+ samples and D_i is the number of features for that sample.)
+ labels: an array of type int32 or int64 with N elements, indicating the
+ class label for the corresponding set of features in data.
+ """
+ # Since this is an immutable object, we don't need to reserve slots.
+ __slots__ = ()
+
+ @property
+ def size(self):
+ """Dataset size (number of samples)."""
+ return len(self.data)
+
+ def batch_indices(self, num_batches, batch_size):
+ """Creates indices of shuffled minibatches.
+
+ Args:
+ num_batches: the number of batches to generate
+ batch_size: the size of each batch
+
+ Returns:
+ batch_indices: a list of minibatch indices, arranged so that the dataset
+ is randomly shuffled.
+
+ Raises:
+ ValueError: if the data and labels have different lengths
+ """
+ if len(self.data) != len(self.labels):
+ raise ValueError("Labels and data must have the same number of samples.")
+
+ batch_indices = []
+
+ # Follows logic in mnist.py to ensure we cover the entire dataset.
+ index_in_epoch = 0
+ dataset_size = len(self.data)
+ dataset_indices = np.arange(dataset_size)
+ np.random.shuffle(dataset_indices)
+
+ for _ in range(num_batches):
+ start = index_in_epoch
+ index_in_epoch += batch_size
+ if index_in_epoch > dataset_size:
+
+ # Finished epoch, reshuffle.
+ np.random.shuffle(dataset_indices)
+
+ # Start next epoch.
+ start = 0
+ index_in_epoch = batch_size
+
+ end = index_in_epoch
+ batch_indices.append(dataset_indices[start:end].tolist())
+
+ return batch_indices
+
+
+def noisy_parity_class(n_samples,
+ n_classes=2,
+ n_context_ids=5,
+ noise_prob=0.25,
+ random_seed=None):
+ """Returns a randomly generated sparse-to-sparse dataset.
+
+ The label is a parity class of a set of context classes.
+
+ Args:
+ n_samples: number of samples (data points)
+ n_classes: number of class labels (default: 2)
+ n_context_ids: how many classes to take the parity of (default: 5).
+ noise_prob: how often to corrupt the label (default: 0.25)
+ random_seed: seed used for drawing the random data (default: None)
+ Returns:
+ dataset: A Dataset namedtuple containing the generated data and labels
+ """
+ np.random.seed(random_seed)
+ x = np.random.randint(0, n_classes, [n_samples, n_context_ids])
+ noise = np.random.binomial(1, noise_prob, [n_samples])
+ y = (np.sum(x, 1) + noise) % n_classes
+ return Dataset(x.astype("float32"), y.astype("int32"))
+
+
+def random(n_features, n_samples, n_classes=2, sep=1.0, random_seed=None):
+ """Returns a randomly generated classification dataset.
+
+ Args:
+ n_features: number of features (dependent variables)
+ n_samples: number of samples (data points)
+ n_classes: number of class labels (default: 2)
+ sep: separation of the two classes, a higher value corresponds to
+ an easier classification problem (default: 1.0)
+ random_seed: seed used for drawing the random data (default: None)
+
+ Returns:
+ dataset: A Dataset namedtuple containing the generated data and labels
+ """
+ # Generate the problem data.
+ x, y = make_classification(n_samples=n_samples,
+ n_features=n_features,
+ n_informative=n_features,
+ n_redundant=0,
+ n_classes=n_classes,
+ class_sep=sep,
+ random_state=random_seed)
+
+ return Dataset(x.astype("float32"), y.astype("int32"))
+
+
+def random_binary(n_features, n_samples, random_seed=None):
+ """Returns a randomly generated dataset of binary values.
+
+ Args:
+ n_features: number of features (dependent variables)
+ n_samples: number of samples (data points)
+ random_seed: seed used for drawing the random data (default: None)
+
+ Returns:
+ dataset: A Dataset namedtuple containing the generated data and labels
+ """
+ random_seed = (np.random.randint(MAX_SEED) if random_seed is None
+ else random_seed)
+ np.random.seed(random_seed)
+
+ x = np.random.randint(2, size=(n_samples, n_features))
+ y = np.zeros((n_samples, 1))
+
+ return Dataset(x.astype("float32"), y.astype("int32"))
+
+
+def random_symmetric(n_features, n_samples, random_seed=None):
+ """Returns a randomly generated dataset of values and their negatives.
+
+ Args:
+ n_features: number of features (dependent variables)
+ n_samples: number of samples (data points)
+ random_seed: seed used for drawing the random data (default: None)
+
+ Returns:
+ dataset: A Dataset namedtuple containing the generated data and labels
+ """
+ random_seed = (np.random.randint(MAX_SEED) if random_seed is None
+ else random_seed)
+ np.random.seed(random_seed)
+
+ x1 = np.random.normal(size=(int(n_samples/2), n_features))
+ x = np.concatenate((x1, -x1), axis=0)
+ y = np.zeros((n_samples, 1))
+
+ return Dataset(x.astype("float32"), y.astype("int32"))
+
+
+def random_mlp(n_features, n_samples, random_seed=None, n_layers=6, width=20):
+ """Returns a generated output of an MLP with random weights.
+
+ Args:
+ n_features: number of features (dependent variables)
+ n_samples: number of samples (data points)
+ random_seed: seed used for drawing the random data (default: None)
+ n_layers: number of layers in random MLP
+ width: width of the layers in random MLP
+
+ Returns:
+ dataset: A Dataset namedtuple containing the generated data and labels
+ """
+ random_seed = (np.random.randint(MAX_SEED) if random_seed is None
+ else random_seed)
+ np.random.seed(random_seed)
+
+ x = np.random.normal(size=(n_samples, n_features))
+ y = x
+ n_in = n_features
+ scale_factor = np.sqrt(2.) / np.sqrt(n_features)
+ for _ in range(n_layers):
+ weights = np.random.normal(size=(n_in, width)) * scale_factor
+ y = np.dot(y, weights).clip(min=0)
+ n_in = width
+
+ y = y[:, 0]
+ y[y > 0] = 1
+
+ return Dataset(x.astype("float32"), y.astype("int32"))
+
+
+EMPTY_DATASET = Dataset(np.array([], dtype="float32"),
+ np.array([], dtype="int32"))
diff --git a/models/research/learned_optimizer/problems/model_adapter.py b/models/research/learned_optimizer/problems/model_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..8455992366dd46172e2a78471004779b1a4f091b
--- /dev/null
+++ b/models/research/learned_optimizer/problems/model_adapter.py
@@ -0,0 +1,190 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Implementation of the ModelAdapter class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import mock
+import tensorflow as tf
+
+from learned_optimizer.problems import problem_generator as pg
+
+
+class ModelAdapter(pg.Problem):
+ """Adapts Tensorflow models/graphs into a form suitable for meta-training.
+
+ This class adapts an existing TensorFlow graph into a form suitable for
+ meta-training a learned optimizer.
+ """
+
+ def __init__(self, make_loss_and_init_fn):
+ """Wraps a model in the Problem interface.
+
+ make_loss_and_init argument is a callable that returns a tuple of
+ two other callables as follows.
+
+ The first will construct most of the graph and return the problem loss. It
+ is essential that this graph contains the totality of the model's variables,
+ but none of its queues.
+
+ The second will return construct the model initialization graph given a list
+ of parameters and return a callable that is passed an instance of
+ tf.Session, and should initialize the models' parameters.
+
+ An argument value function would look like this:
+
+ ```python
+ def make_loss_and_init_fn():
+ inputs = queued_reader()
+
+ def make_loss():
+ return create_model_with_variables(inputs)
+
+ def make_init_fn(parameters):
+ saver = tf.Saver(parameters)
+ def init_fn(sess):
+ sess.restore(sess, ...)
+ return init_fn
+
+ return make_loss, make_init_fn
+ ```
+
+ Args:
+ make_loss_and_init_fn: a callable, as described aboce
+ """
+ make_loss_fn, make_init_fn = make_loss_and_init_fn()
+
+ self.make_loss_fn = make_loss_fn
+ self.parameters, self.constants = _get_variables(make_loss_fn)
+
+ if make_init_fn is not None:
+ init_fn = make_init_fn(self.parameters + self.constants)
+ else:
+ init_op = tf.initialize_variables(self.parameters + self.constants)
+ init_fn = lambda sess: sess.run(init_op)
+
+ tf.logging.info("ModelAdapter parameters: %s",
+ [op.name for op in self.parameters])
+ tf.logging.info("ModelAdapter constants: %s",
+ [op.name for op in self.constants])
+
+ super(ModelAdapter, self).__init__(
+ [], random_seed=None, noise_stdev=0.0, init_fn=init_fn)
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return self.parameters
+
+ def init_variables(self, seed=None):
+ """Returns a list of variables with the given shape."""
+ # NOTE(siege): This is awkward, as these are not set as trainable.
+ return self.parameters
+
+ def objective(self, parameters, data=None, labels=None):
+ """Computes the objective given a list of parameters.
+
+ Args:
+ parameters: The parameters to optimize (as a list of tensors)
+ data: An optional batch of data for calculating objectives
+ labels: An optional batch of corresponding labels
+
+ Returns:
+ A scalar tensor representing the objective value
+ """
+ # We need to set up a mapping based on the original parameter names, because
+ # the parameters passed can be arbitrary tensors.
+ parameter_mapping = {
+ old_p.name: p
+ for old_p, p in zip(self.parameters, parameters)
+ }
+
+ with tf.variable_scope(tf.get_variable_scope(), reuse=True):
+ return _make_with_custom_variables(self.make_loss_fn, parameter_mapping)
+
+
+def _get_variables(func):
+ """Calls func, returning any variables created.
+
+ The created variables are modified to not be trainable, and are placed into
+ the LOCAL_VARIABLES collection.
+
+ Args:
+ func: Function to be called.
+
+ Returns:
+ A tuple (variables, constants) where the first element is a list of
+ trainable variables and the second is the non-trainable variables.
+ """
+ variables = []
+ constants = []
+
+ # We need to create these variables like normal, so grab the original
+ # constructor before we mock it.
+ original_init = tf.Variable.__init__
+
+ def custom_init(self, *args, **kwargs):
+ trainable = kwargs["trainable"]
+ kwargs["trainable"] = False
+ # Making these variables local keeps them out of the optimizer's checkpoints
+ # somehow.
+ kwargs["collections"] = [tf.GraphKeys.LOCAL_VARIABLES]
+ original_init(self, *args, **kwargs)
+ if trainable:
+ variables.append(self)
+ else:
+ constants.append(self)
+
+ # This name-scope is just a nicety for TensorBoard.
+ with tf.name_scope("unused_graph"):
+ with mock.patch.object(tf.Variable, "__init__", custom_init):
+ func()
+
+ return variables, constants
+
+
+def _make_with_custom_variables(func, variable_mapping):
+ """Calls func and replaces the value of some variables created in it.
+
+ Args:
+ func: Function to be called.
+ variable_mapping: A mapping of variable name to the replacement tensor or
+ tf.Variable.
+
+ Returns:
+ The return value of func is returned.
+ """
+ original_value = tf.Variable.value
+
+ def custom_value(self):
+ if self.name in variable_mapping:
+ replacement = variable_mapping[self.name]
+ tf.logging.info("Replaced %s with %s" % (self.name, replacement))
+
+ # value() method needs to return a tensor, we need to call value on it.
+ # This has to be done manually like this otherwise we'll get an infinite
+ # loop.
+ if isinstance(replacement, tf.Variable):
+ replacement = original_value(replacement)
+
+ return replacement
+ else:
+ return original_value(self)
+
+ with mock.patch.object(tf.Variable, "value", custom_value):
+ with mock.patch.object(tf.Variable, "_AsTensor", custom_value):
+ return func()
diff --git a/models/research/learned_optimizer/problems/problem_generator.py b/models/research/learned_optimizer/problems/problem_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..abe1008faadbb04163bc27e0b991e3ec4ba9e6bc
--- /dev/null
+++ b/models/research/learned_optimizer/problems/problem_generator.py
@@ -0,0 +1,1016 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generates toy optimization problems.
+
+This module contains a base class, Problem, that defines a minimal interface
+for optimization problems, and a few specific problem types that subclass it.
+
+Test functions for optimization: http://www.sfu.ca/~ssurjano/optimization.html
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from learned_optimizer.problems import problem_spec as prob_spec
+
+tf.app.flags.DEFINE_float("l2_reg_scale", 1e-3,
+ """Scaling factor for parameter value regularization
+ in softmax classifier problems.""")
+FLAGS = tf.app.flags.FLAGS
+
+EPSILON = 1e-6
+MAX_SEED = 4294967295
+PARAMETER_SCOPE = "parameters"
+
+_Spec = prob_spec.Spec
+
+
+class Problem(object):
+ """Base class for optimization problems.
+
+ This defines an interface for optimization problems, including objective and
+ gradients functions and a feed_generator function that yields data to pass to
+ feed_dict in tensorflow.
+
+ Subclasses of Problem must (at the minimum) override the objective method,
+ which computes the objective/loss/cost to minimize, and specify the desired
+ shape of the parameters in a list in the param_shapes attribute.
+ """
+
+ def __init__(self, param_shapes, random_seed, noise_stdev, init_fn=None):
+ """Initializes a global random seed for the problem.
+
+ Args:
+ param_shapes: A list of tuples defining the expected shapes of the
+ parameters for this problem
+ random_seed: Either an integer (or None, in which case the seed is
+ randomly drawn)
+ noise_stdev: Strength (standard deviation) of added gradient noise
+ init_fn: A function taking a tf.Session object that is used to
+ initialize the problem's variables.
+
+ Raises:
+ ValueError: If the random_seed is not an integer and not None
+ """
+ if random_seed is not None and not isinstance(random_seed, int):
+ raise ValueError("random_seed must be an integer or None")
+
+ # Pick a random seed.
+ self.random_seed = (np.random.randint(MAX_SEED) if random_seed is None
+ else random_seed)
+
+ # Store the noise level.
+ self.noise_stdev = noise_stdev
+
+ # Set the random seed to ensure any random data in the problem is the same.
+ np.random.seed(self.random_seed)
+
+ # Store the parameter shapes.
+ self.param_shapes = param_shapes
+
+ if init_fn is not None:
+ self.init_fn = init_fn
+ else:
+ self.init_fn = lambda _: None
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_normal(shape, seed=seed) for shape in self.param_shapes]
+
+ def init_variables(self, seed=None):
+ """Returns a list of variables with the given shape."""
+ with tf.variable_scope(PARAMETER_SCOPE):
+ params = [tf.Variable(param) for param in self.init_tensors(seed)]
+ return params
+
+ def objective(self, parameters, data=None, labels=None):
+ """Computes the objective given a list of parameters.
+
+ Args:
+ parameters: The parameters to optimize (as a list of tensors)
+ data: An optional batch of data for calculating objectives
+ labels: An optional batch of corresponding labels
+
+ Returns:
+ A scalar tensor representing the objective value
+ """
+ raise NotImplementedError
+
+ def gradients(self, objective, parameters):
+ """Compute gradients of the objective with respect to the parameters.
+
+ Args:
+ objective: The objective op (e.g. output of self.objective())
+ parameters: A list of tensors (the parameters to optimize)
+
+ Returns:
+ A list of tensors representing the gradient for each parameter,
+ returned in the same order as the given list
+ """
+ grads = tf.gradients(objective, list(parameters))
+ noisy_grads = []
+
+ for grad in grads:
+ if isinstance(grad, tf.IndexedSlices):
+ noise = self.noise_stdev * tf.random_normal(tf.shape(grad.values))
+ new_grad = tf.IndexedSlices(grad.values + noise, grad.indices)
+ else:
+ new_grad = grad + self.noise_stdev * tf.random_normal(grad.get_shape())
+ noisy_grads.append(new_grad)
+
+ return noisy_grads
+
+
+class Quadratic(Problem):
+ """Optimizes a random quadratic function.
+
+ The objective is: f(x) = (1/2) ||Wx - y||_2^2
+ where W is a random Gaussian matrix and y is a random Gaussian vector.
+ """
+
+ def __init__(self, ndim, random_seed=None, noise_stdev=0.0):
+ """Initializes a random quadratic problem."""
+ param_shapes = [(ndim, 1)]
+ super(Quadratic, self).__init__(param_shapes, random_seed, noise_stdev)
+
+ # Generate a random problem instance.
+ self.w = np.random.randn(ndim, ndim).astype("float32")
+ self.y = np.random.randn(ndim, 1).astype("float32")
+
+ def objective(self, params, data=None, labels=None):
+ """Quadratic objective (see base class for details)."""
+ return tf.nn.l2_loss(tf.matmul(self.w, params[0]) - self.y)
+
+
+class SoftmaxClassifier(Problem):
+ """Helper functions for supervised softmax classification problems."""
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_normal(shape, seed=seed) * 1.2 / np.sqrt(shape[0])
+ for shape in self.param_shapes]
+
+ def inference(self, params, data):
+ """Computes logits given parameters and data.
+
+ Args:
+ params: List of parameter tensors or variables
+ data: Batch of features with samples along the first dimension
+
+ Returns:
+ logits: Un-normalized logits with shape (num_samples, num_classes)
+ """
+ raise NotImplementedError
+
+ def objective(self, params, data, labels):
+ """Computes the softmax cross entropy.
+
+ Args:
+ params: List of parameter tensors or variables
+ data: Batch of features with samples along the first dimension
+ labels: Vector of labels with the same number of samples as the data
+
+ Returns:
+ loss: Softmax cross entropy loss averaged over the samples in the batch
+
+ Raises:
+ ValueError: If the objective is to be computed over >2 classes, because
+ this operation is broken in tensorflow at the moment.
+ """
+ # Forward pass.
+ logits = self.inference(params, data)
+
+ # Compute the loss.
+ l2reg = [tf.reduce_sum(param ** 2) for param in params]
+ if int(logits.get_shape()[1]) == 2:
+ labels = tf.cast(labels, tf.float32)
+ losses = tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=labels, logits=logits[:, 0])
+ else:
+ raise ValueError("Unable to compute softmax cross entropy for more than"
+ " 2 classes.")
+
+ return tf.reduce_mean(losses) + tf.reduce_mean(l2reg) * FLAGS.l2_reg_scale
+
+ def argmax(self, logits):
+ """Samples the most likely class label given the logits.
+
+ Args:
+ logits: Un-normalized logits with shape (num_samples, num_classes)
+
+ Returns:
+ predictions: Predicted class labels, has shape (num_samples,)
+ """
+ return tf.cast(tf.argmax(tf.nn.softmax(logits), 1), tf.int32)
+
+ def accuracy(self, params, data, labels):
+ """Computes the accuracy (fraction of correct classifications).
+
+ Args:
+ params: List of parameter tensors or variables
+ data: Batch of features with samples along the first dimension
+ labels: Vector of labels with the same number of samples as the data
+
+ Returns:
+ accuracy: Fraction of correct classifications across the batch
+ """
+ predictions = self.argmax(self.inference(params, data))
+ return tf.contrib.metrics.accuracy(predictions, tf.cast(labels, tf.int32))
+
+
+class SoftmaxRegression(SoftmaxClassifier):
+ """Builds a softmax regression problem."""
+
+ def __init__(self, n_features, n_classes, activation=tf.identity,
+ random_seed=None, noise_stdev=0.0):
+ self.activation = activation
+ self.n_features = n_features
+ param_shapes = [(n_features, n_classes), (n_classes,)]
+ super(SoftmaxRegression, self).__init__(param_shapes,
+ random_seed,
+ noise_stdev)
+
+ def inference(self, params, data):
+ features = tf.reshape(data, (-1, self.n_features))
+ return tf.matmul(features, params[0]) + params[1]
+
+
+class SparseSoftmaxRegression(SoftmaxClassifier):
+ """Builds a sparse input softmax regression problem."""
+
+ def __init__(self,
+ n_features,
+ n_classes,
+ activation=tf.identity,
+ random_seed=None,
+ noise_stdev=0.0):
+ self.activation = activation
+ self.n_features = n_features
+ param_shapes = [(n_classes, n_features), (n_features, n_classes), (
+ n_classes,)]
+ super(SparseSoftmaxRegression, self).__init__(param_shapes, random_seed,
+ noise_stdev)
+
+ def inference(self, params, data):
+ all_embeddings, softmax_weights, softmax_bias = params
+ embeddings = tf.nn.embedding_lookup(all_embeddings, tf.cast(data, tf.int32))
+ embeddings = tf.reduce_sum(embeddings, 1)
+ return tf.matmul(embeddings, softmax_weights) + softmax_bias
+
+
+class OneHotSparseSoftmaxRegression(SoftmaxClassifier):
+ """Builds a sparse input softmax regression problem.
+
+ This is identical to SparseSoftmaxRegression, but without using embedding
+ ops.
+ """
+
+ def __init__(self,
+ n_features,
+ n_classes,
+ activation=tf.identity,
+ random_seed=None,
+ noise_stdev=0.0):
+ self.activation = activation
+ self.n_features = n_features
+ self.n_classes = n_classes
+ param_shapes = [(n_classes, n_features), (n_features, n_classes), (
+ n_classes,)]
+ super(OneHotSparseSoftmaxRegression, self).__init__(param_shapes,
+ random_seed,
+ noise_stdev)
+
+ def inference(self, params, data):
+ all_embeddings, softmax_weights, softmax_bias = params
+ num_ids = tf.shape(data)[1]
+ one_hot_embeddings = tf.one_hot(tf.cast(data, tf.int32), self.n_classes)
+ one_hot_embeddings = tf.reshape(one_hot_embeddings, [-1, self.n_classes])
+ embeddings = tf.matmul(one_hot_embeddings, all_embeddings)
+ embeddings = tf.reshape(embeddings, [-1, num_ids, self.n_features])
+ embeddings = tf.reduce_sum(embeddings, 1)
+ return tf.matmul(embeddings, softmax_weights) + softmax_bias
+
+
+class FullyConnected(SoftmaxClassifier):
+ """Builds a multi-layer perceptron classifier."""
+
+ def __init__(self, n_features, n_classes, hidden_sizes=(32, 64),
+ activation=tf.nn.sigmoid, random_seed=None, noise_stdev=0.0):
+ """Initializes an multi-layer perceptron classification problem."""
+ # Store the number of features and activation function.
+ self.n_features = n_features
+ self.activation = activation
+
+ # Define the network as a list of weight + bias shapes for each layer.
+ param_shapes = []
+ for ix, sz in enumerate(hidden_sizes + (n_classes,)):
+
+ # The previous layer"s size (n_features if input).
+ prev_size = n_features if ix == 0 else hidden_sizes[ix - 1]
+
+ # Weight shape for this layer.
+ param_shapes.append((prev_size, sz))
+
+ # Bias shape for this layer.
+ param_shapes.append((sz,))
+
+ super(FullyConnected, self).__init__(param_shapes, random_seed, noise_stdev)
+
+ def inference(self, params, data):
+ # Flatten the features into a vector.
+ features = tf.reshape(data, (-1, self.n_features))
+
+ # Pass the data through the network.
+ preactivations = tf.matmul(features, params[0]) + params[1]
+
+ for layer in range(2, len(self.param_shapes), 2):
+ net = self.activation(preactivations)
+ preactivations = tf.matmul(net, params[layer]) + params[layer + 1]
+
+ return preactivations
+
+ def accuracy(self, params, data, labels):
+ """Computes the accuracy (fraction of correct classifications).
+
+ Args:
+ params: List of parameter tensors or variables
+ data: Batch of features with samples along the first dimension
+ labels: Vector of labels with the same number of samples as the data
+
+ Returns:
+ accuracy: Fraction of correct classifications across the batch
+ """
+ predictions = self.argmax(self.activation(self.inference(params, data)))
+ return tf.contrib.metrics.accuracy(predictions, tf.cast(labels, tf.int32))
+
+
+class ConvNet(SoftmaxClassifier):
+ """Builds an N-layer convnet for image classification."""
+
+ def __init__(self,
+ image_shape,
+ n_classes,
+ filter_list,
+ activation=tf.nn.relu,
+ random_seed=None,
+ noise_stdev=0.0):
+ # Number of channels, number of pixels in x- and y- dimensions.
+ n_channels, px, py = image_shape
+
+ # Store the activation.
+ self.activation = activation
+
+ param_shapes = []
+ input_size = n_channels
+ for fltr in filter_list:
+ # Add conv2d filters.
+ param_shapes.append((fltr[0], fltr[1], input_size, fltr[2]))
+ input_size = fltr[2]
+
+ # Number of units in the final (dense) layer.
+ self.affine_size = input_size * px * py
+
+ param_shapes.append((self.affine_size, n_classes)) # affine weights
+ param_shapes.append((n_classes,)) # affine bias
+
+ super(ConvNet, self).__init__(param_shapes, random_seed, noise_stdev)
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_normal(shape, mean=0., stddev=0.01, seed=seed)
+ for shape in self.param_shapes]
+
+ def inference(self, params, data):
+
+ # Unpack.
+ w_conv_list = params[:-2]
+ output_w, output_b = params[-2:]
+
+ conv_input = data
+ for w_conv in w_conv_list:
+ layer = tf.nn.conv2d(conv_input, w_conv, strides=[1] * 4, padding="SAME")
+ output = self.activation(layer)
+ conv_input = output
+
+ # Flatten.
+ flattened = tf.reshape(conv_input, (-1, self.affine_size))
+
+ # Fully connected layer.
+ return tf.matmul(flattened, output_w) + output_b
+
+
+class Bowl(Problem):
+ """A 2D quadratic bowl."""
+
+ def __init__(self, condition_number, angle=0.0,
+ random_seed=None, noise_stdev=0.0):
+ assert condition_number > 0, "Condition number must be positive."
+
+ # Define parameter shapes.
+ param_shapes = [(2, 1)]
+ super(Bowl, self).__init__(param_shapes, random_seed, noise_stdev)
+
+ self.condition_number = condition_number
+ self.angle = angle
+ self._build_matrix(condition_number, angle)
+
+ def _build_matrix(self, condition_number, angle):
+ """Builds the Hessian matrix."""
+ hessian = np.array([[condition_number, 0.], [0., 1.]], dtype="float32")
+
+ # Build the rotation matrix.
+ rotation_matrix = np.array([
+ [np.cos(angle), -np.sin(angle)],
+ [np.sin(angle), np.cos(angle)]
+ ])
+
+ # The objective is 0.5 * || Ax ||_2^2
+ # where the data matrix (A) is: sqrt(Hessian).dot(rotation_matrix).
+ self.matrix = np.sqrt(hessian).dot(rotation_matrix)
+
+ def objective(self, params, data=None, labels=None):
+ mtx = tf.constant(self.matrix, dtype=tf.float32)
+ return tf.nn.l2_loss(tf.matmul(mtx, params[0]))
+
+ def surface(self, xlim=5, ylim=5, n=50):
+ xm, ym = _mesh(xlim, ylim, n)
+ pts = np.vstack([xm.ravel(), ym.ravel()])
+ zm = 0.5 * np.linalg.norm(self.matrix.dot(pts), axis=0) ** 2
+ return xm, ym, zm.reshape(n, n)
+
+
+class Problem2D(Problem):
+
+ def __init__(self, random_seed=None, noise_stdev=0.0):
+ param_shapes = [(2,)]
+ super(Problem2D, self).__init__(param_shapes, random_seed, noise_stdev)
+
+ def surface(self, n=50, xlim=5, ylim=5):
+ """Computes the objective surface over a 2d mesh."""
+
+ # Create a mesh over the given coordinate ranges.
+ xm, ym = _mesh(xlim, ylim, n)
+
+ with tf.Graph().as_default(), tf.Session() as sess:
+
+ # Ops to compute the objective at every (x, y) point.
+ x = tf.placeholder(tf.float32, shape=xm.shape)
+ y = tf.placeholder(tf.float32, shape=ym.shape)
+ obj = self.objective([[x, y]])
+
+ # Run the computation.
+ zm = sess.run(obj, feed_dict={x: xm, y: ym})
+
+ return xm, ym, zm
+
+
+class Rosenbrock(Problem2D):
+ """See https://en.wikipedia.org/wiki/Rosenbrock_function.
+
+ This function has a single global minima at [1, 1]
+ The objective value at this point is zero.
+ """
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_uniform(shape, minval=-5., maxval=10., seed=seed)
+ for shape in self.param_shapes]
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+ obj = (1 - x)**2 + 100 * (y - x**2)**2
+ return tf.squeeze(obj)
+
+
+def make_rosenbrock_loss_and_init(device=None):
+ """A variable-backed version of Rosenbrock problem.
+
+ See the Rosenbrock class for details.
+
+ Args:
+ device: Where to place the ops of this problem.
+
+ Returns:
+ A tuple of two callables, first of which creates the loss and the second
+ creates the parameter initializer function.
+ """
+ def make_rosenbrock_loss():
+ with tf.name_scope("optimizee"):
+ with tf.device(device):
+ x = tf.get_variable("x", [1])
+ y = tf.get_variable("y", [1])
+ c = tf.get_variable(
+ "c", [1],
+ initializer=tf.constant_initializer(100.0),
+ trainable=False)
+ obj = (1 - x)**2 + c * (y - x**2)**2
+ return tf.squeeze(obj)
+
+ def make_init_fn(parameters):
+ with tf.device(device):
+ init_op = tf.variables_initializer(parameters)
+ def init_fn(sess):
+ tf.logging.info("Initializing model parameters.")
+ sess.run(init_op)
+ return init_fn
+
+ return make_rosenbrock_loss, make_init_fn
+
+
+class Saddle(Problem2D):
+ """Loss surface around a saddle point."""
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+ obj = x ** 2 - y ** 2
+ return tf.squeeze(obj)
+
+
+class LogSumExp(Problem2D):
+ """2D function defined by the log of the sum of exponentials."""
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+ obj = tf.log(tf.exp(x + 3. * y - 0.1) +
+ tf.exp(x - 3. * y - 0.1) +
+ tf.exp(-x - 0.1) + 1.0)
+ return tf.squeeze(obj)
+
+
+class Ackley(Problem2D):
+ """Ackley's function (contains many local minima)."""
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_uniform(shape, minval=-32.768, maxval=32.768, seed=seed)
+ for shape in self.param_shapes]
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+ obj = (-20 * tf.exp(-0.2 * tf.sqrt(0.5 * (x ** 2 + y ** 2))) -
+ tf.exp(0.5 * (tf.cos(2 * np.pi * x) + tf.cos(2 * np.pi * y))) +
+ tf.exp(1.0) + 20.)
+ return tf.squeeze(obj)
+
+
+class Beale(Problem2D):
+ """Beale function (a multimodal function with sharp peaks)."""
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_uniform(shape, minval=-4.5, maxval=4.5, seed=seed)
+ for shape in self.param_shapes]
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+ obj = ((1.5 - x + x * y) ** 2 +
+ (2.25 - x + x * y ** 2) ** 2 +
+ (2.625 - x + x * y ** 3) ** 2)
+ return tf.squeeze(obj)
+
+
+class Booth(Problem2D):
+ """Booth's function (has a long valley along one dimension)."""
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_uniform(shape, minval=-10., maxval=10., seed=seed)
+ for shape in self.param_shapes]
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+ obj = (x + 2 * y - 7) ** 2 + (2 * x + y - 5) ** 2
+ return tf.squeeze(obj)
+
+
+class StyblinskiTang(Problem2D):
+ """Styblinski-Tang function (a bumpy function in two dimensions)."""
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_uniform(shape, minval=-5., maxval=5., seed=seed)
+ for shape in self.param_shapes]
+
+ def objective(self, params, data=None, labels=None):
+ params = tf.split(params[0], 2, axis=0)
+ obj = 0.5 * tf.reduce_sum([x ** 4 - 16 * x ** 2 + 5 * x
+ for x in params], 0) + 80.
+ return tf.squeeze(obj)
+
+
+class Matyas(Problem2D):
+ """Matyas function (a function with a single global minimum in a valley)."""
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_uniform(shape, minval=-10, maxval=10, seed=seed)
+ for shape in self.param_shapes]
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+ obj = 0.26 * (x ** 2 + y ** 2) - 0.48 * x * y
+ return tf.squeeze(obj)
+
+
+class Branin(Problem2D):
+ """Branin function (a function with three global minima)."""
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ x1 = tf.random_uniform((1,), minval=-5., maxval=10.,
+ seed=seed)
+ x2 = tf.random_uniform((1,), minval=0., maxval=15.,
+ seed=seed)
+ return [tf.concat([x1, x2], 0)]
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+
+ # Define some constants.
+ a = 1.
+ b = 5.1 / (4. * np.pi ** 2)
+ c = 5 / np.pi
+ r = 6.
+ s = 10.
+ t = 1 / (8. * np.pi)
+
+ # Evaluate the function.
+ obj = a * (y - b * x ** 2 + c * x - r) ** 2 + s * (1 - t) * tf.cos(x) + s
+ return tf.squeeze(obj)
+
+
+class Michalewicz(Problem2D):
+ """Michalewicz function (has steep ridges and valleys)."""
+
+ def init_tensors(self, seed=None):
+ """Returns a list of tensors with the given shape."""
+ return [tf.random_uniform(shape, minval=0., maxval=np.pi, seed=seed)
+ for shape in self.param_shapes]
+
+ def objective(self, params, data=None, labels=None):
+ x, y = tf.split(params[0], 2, axis=0)
+ m = 5 # Defines how steep the ridges are (larger m => steeper ridges).
+ obj = 2. - (tf.sin(x) * tf.sin(x ** 2 / np.pi) ** (2 * m) +
+ tf.sin(y) * tf.sin(2 * y ** 2 / np.pi) ** (2 * m))
+ return tf.squeeze(obj)
+
+
+class Rescale(Problem):
+ """Takes an existing problem, and rescales all the parameters."""
+
+ def __init__(self, problem_spec, scale=10., noise_stdev=0.0):
+ self.problem = problem_spec.build()
+ self.param_shapes = self.problem.param_shapes
+ self.scale = scale
+
+ super(Rescale, self).__init__(self.param_shapes, random_seed=None,
+ noise_stdev=noise_stdev)
+
+ def init_tensors(self, seed=None):
+ params_raw = self.problem.init_tensors(seed=seed)
+ params = [t * self.scale for t in params_raw]
+ return params
+
+ def objective(self, params, data=None, labels=None):
+ params_raw = [t/self.scale for t in params]
+
+ problem_obj = self.problem.objective(params_raw, data, labels)
+ return problem_obj
+
+
+class SumTask(Problem):
+ """Takes a list of problems and modifies the objective to be their sum."""
+
+ def __init__(self, problem_specs, noise_stdev=0.0):
+ self.problems = [ps.build() for ps in problem_specs]
+ self.param_shapes = []
+ for prob in self.problems:
+ self.param_shapes += prob.param_shapes
+
+ super(SumTask, self).__init__(self.param_shapes, random_seed=None,
+ noise_stdev=noise_stdev)
+
+ def init_tensors(self, seed=None):
+ tensors = []
+ for prob in self.problems:
+ tensors += prob.init_tensors(seed=seed)
+ return tensors
+
+ def objective(self, params, data=None, labels=None):
+ obj = 0.
+ index = 0
+ for prob in self.problems:
+ num_params = len(prob.param_shapes)
+ obj += prob.objective(params[index:index + num_params])
+ index += num_params
+ return obj
+
+
+class IsotropicQuadratic(Problem):
+ """An isotropic quadratic problem."""
+
+ def objective(self, params, data=None, labels=None):
+ return sum([tf.reduce_sum(param ** 2) for param in params])
+
+
+class Norm(Problem):
+ """Takes an existing problem and modifies the objective to be its N-norm."""
+
+ def __init__(self, ndim, random_seed=None, noise_stdev=0.0, norm_power=2.):
+ param_shapes = [(ndim, 1)]
+ super(Norm, self).__init__(param_shapes, random_seed, noise_stdev)
+
+ # Generate a random problem instance.
+ self.w = np.random.randn(ndim, ndim).astype("float32")
+ self.y = np.random.randn(ndim, 1).astype("float32")
+ self.norm_power = norm_power
+
+ def objective(self, params, data=None, labels=None):
+ diff = tf.matmul(self.w, params[0]) - self.y
+ exp = 1. / self.norm_power
+ loss = tf.reduce_sum((tf.abs(diff) + EPSILON) ** self.norm_power) ** exp
+ return loss
+
+
+class LogObjective(Problem):
+ """Takes an existing problem and modifies the objective to be its log."""
+
+ def __init__(self, problem_spec):
+ self.problem = problem_spec.build()
+ self.param_shapes = self.problem.param_shapes
+
+ super(LogObjective, self).__init__(self.param_shapes,
+ random_seed=None,
+ noise_stdev=0.0)
+
+ def objective(self, params, data=None, labels=None):
+ problem_obj = self.problem.objective(params, data, labels)
+ return tf.log(problem_obj + EPSILON) - tf.log(EPSILON)
+
+
+class SparseProblem(Problem):
+ """Takes a problem and sets gradients to 0 with the given probability."""
+
+ def __init__(self,
+ problem_spec,
+ zero_probability=0.99,
+ random_seed=None,
+ noise_stdev=0.0):
+ self.problem = problem_spec.build()
+ self.param_shapes = self.problem.param_shapes
+ self.zero_prob = zero_probability
+
+ super(SparseProblem, self).__init__(self.param_shapes,
+ random_seed=random_seed,
+ noise_stdev=noise_stdev)
+
+ def objective(self, parameters, data=None, labels=None):
+ return self.problem.objective(parameters, data, labels)
+
+ def gradients(self, objective, parameters):
+ grads = tf.gradients(objective, list(parameters))
+
+ new_grads = []
+ for grad in grads:
+ mask = tf.greater(self.zero_prob, tf.random_uniform(grad.get_shape()))
+ zero_grad = tf.zeros_like(grad, dtype=tf.float32)
+ noisy_grad = grad + self.noise_stdev * tf.random_normal(grad.get_shape())
+ new_grads.append(tf.where(mask, zero_grad, noisy_grad))
+ return new_grads
+
+
+class DependencyChain(Problem):
+ """A problem in which parameters must be optimized in order.
+
+ A sequence of parameters which all need to be brought to 0, but where each
+ parameter in the sequence can't be brought to 0 until the preceding one
+ has been. This should take a long time to optimize, with steady
+ (or accelerating) progress throughout the entire process.
+ """
+
+ def __init__(self, ndim, random_seed=None, noise_stdev=0.):
+ param_shapes = [(ndim + 1,)]
+ self.ndim = ndim
+ super(DependencyChain, self).__init__(
+ param_shapes, random_seed, noise_stdev)
+
+ def objective(self, params, data=None, labels=None):
+ terms = params[0][0]**2 + params[0][1:]**2 / (params[0][:-1]**2 + EPSILON)
+ return tf.reduce_sum(terms)
+
+
+class MinMaxWell(Problem):
+ """Problem with global min when both the min and max (absolute) params are 1.
+
+ The gradient for all but two parameters (the min and max) is zero. This
+ should therefore encourage the optimizer to behave sensible even when
+ parameters have zero gradients, as is common eg for some deep neural nets.
+ """
+
+ def __init__(self, ndim, random_seed=None, noise_stdev=0.):
+ param_shapes = [(ndim,)]
+ self.ndim = ndim
+ super(MinMaxWell, self).__init__(param_shapes, random_seed, noise_stdev)
+
+ def objective(self, params, data=None, labels=None):
+ params_sqr = params[0]**2
+ min_sqr = tf.reduce_min(params_sqr)
+ max_sqr = tf.reduce_max(params_sqr)
+ epsilon = 1e-12
+
+ return max_sqr + 1./min_sqr - 2. + epsilon
+
+
+class OutwardSnake(Problem):
+ """A winding path out to infinity.
+
+ Ideal step length stays constant along the entire path.
+ """
+
+ def __init__(self, ndim, random_seed=None, noise_stdev=0.):
+ param_shapes = [(ndim,)]
+ self.ndim = ndim
+ super(OutwardSnake, self).__init__(param_shapes, random_seed, noise_stdev)
+
+ def objective(self, params, data, labels=None):
+ radius = tf.sqrt(tf.reduce_sum(params[0]**2))
+ rad_loss = tf.reduce_sum(1. / (radius + 1e-6) * data[:, 0])
+
+ sin_dist = params[0][1:] - tf.cos(params[0][:-1]) * np.pi
+ sin_loss = tf.reduce_sum((sin_dist * data[:, 1:])**2)
+
+ return rad_loss + sin_loss
+
+
+class ProjectionQuadratic(Problem):
+ """Dataset consists of different directions to probe. Global min is at 0."""
+
+ def __init__(self, ndim, random_seed=None, noise_stdev=0.):
+ param_shapes = [(1, ndim)]
+ super(ProjectionQuadratic, self).__init__(
+ param_shapes, random_seed, noise_stdev)
+
+ def objective(self, params, data, labels=None):
+ return tf.reduce_sum((params[0] * data)**2)
+
+
+class SumOfQuadratics(Problem):
+
+ def __init__(self, ndim, random_seed=None, noise_stdev=0.):
+ param_shapes = [(1, ndim)]
+ super(SumOfQuadratics, self).__init__(
+ param_shapes, random_seed, noise_stdev)
+
+ def objective(self, params, data, labels=None):
+ epsilon = 1e-12
+ # Assume dataset is designed so that the global minimum is at params=0.
+ # Subtract loss at params=0, so that global minimum has objective value
+ # epsilon (added to avoid floating point issues).
+ return (tf.reduce_sum((params[0] - data)**2) - tf.reduce_sum(data**2) +
+ epsilon)
+
+
+class MatMulAlgorithm(Problem):
+ """A 6-th order polynomial optimization problem.
+
+ This problem is parametrized by n and k. A solution to this problem with
+ objective value exactly zero defines a matrix multiplication algorithm of
+ n x n matrices using k multiplications between matrices. When applied
+ recursively, such an algorithm has complexity O(n^(log_n(k))).
+
+ Given n, it is not known in general which values of k in [n^2, n^3] have a
+ solution. There is always a solution with k = n^3 (this is the naive
+ algorithm).
+
+ In the special case n = 2, it is known that there are solutions for k = {7, 8}
+ but not for k <= 6. For n = 3, it is known that there are exact solutions for
+ 23 <= k <= 27, and there are asymptotic solutions for k = {21, 22}, but the
+ other cases are unknown.
+
+ For a given n and k, if one solution exists then infinitely many solutions
+ exist due to permutation and scaling symmetries in the parameters.
+
+ This is a very hard problem for some values of n and k (e.g. n = 3, k = 21),
+ but very easy for other values (e.g. n = 2, k = 7).
+
+ For a given n and k, the specific formulation of this problem is as follows.
+ Let theta_a, theta_b, theta_c be parameter matrices with respective dimensions
+ [n**2, k], [n**2, k], [k, n**2]. Then for any matrices a, b with shape [n, n],
+ we can form the matrix c with shape [n, n] via the operation:
+ ((vec(a) * theta_a) .* (vec(b) * theta_b)) * theta_c = vec(c), (#)
+ where vec(x) is the operator that flattens a matrix with shape [n, n] into a
+ row vector with shape [1, n**2], * denotes matrix multiplication and .*
+ denotes elementwise multiplication.
+
+ This operation, parameterized by theta_a, theta_b, theta_c, is a matrix
+ multiplication algorithm iff c = a*b for all [n, n] matrices a and b. But
+ actually it suffices to verify all combinations of one-hot matrices a and b,
+ of which there are n**4 such combinations. This gives a batch of n**4 matrix
+ triplets (a, b, c) such that equation (#) must hold for each triplet. We solve
+ for theta_a, theta_b, theta_c by minimizing the sum of squares of errors
+ across this batch.
+
+ Finally, theta_c can be computed from theta_a and theta_b. Therefore it
+ suffices to learn theta_a and theta_b, from which theta_c and therefore the
+ objective value can be computed.
+ """
+
+ def __init__(self, n, k):
+ assert isinstance(n, int), "n must be an integer"
+ assert isinstance(k, int), "k must be an integer"
+ assert n >= 2, "Must have n >= 2"
+ assert k >= n**2 and k <= n**3, "Must have n**2 <= k <= n**3"
+
+ param_shapes = [(n**2, k), (n**2, k)] # theta_a, theta_b
+ super(MatMulAlgorithm, self).__init__(
+ param_shapes, random_seed=None, noise_stdev=0.0)
+
+ self.n = n
+ self.k = k
+
+ # Build a batch of all combinations of one-hot matrices a, b, and their
+ # respective products c. Correctness on this batch is a necessary and
+ # sufficient condition for the algorithm to be valid. The number of matrices
+ # in {a, b, c}_3d is n**4 and each matrix is n x n.
+ onehots = np.identity(n**2).reshape(n**2, n, n)
+ a_3d = np.repeat(onehots, n**2, axis=0)
+ b_3d = np.tile(onehots, [n**2, 1, 1])
+ c_3d = np.matmul(a_3d, b_3d)
+
+ # Convert the batch to 2D Tensors.
+ self.a = tf.constant(a_3d.reshape(n**4, n**2), tf.float32, name="a")
+ self.b = tf.constant(b_3d.reshape(n**4, n**2), tf.float32, name="b")
+ self.c = tf.constant(c_3d.reshape(n**4, n**2), tf.float32, name="c")
+
+ def init_tensors(self, seed=None):
+ # Initialize params such that the columns of theta_a and theta_b have L2
+ # norm 1.
+ def _param_initializer(shape, seed=None):
+ x = tf.random_normal(shape, dtype=tf.float32, seed=seed)
+ return tf.transpose(tf.nn.l2_normalize(tf.transpose(x), 1))
+
+ return [_param_initializer(shape, seed) for shape in self.param_shapes]
+
+ def objective(self, parameters, data=None, labels=None):
+ theta_a = parameters[0]
+ theta_b = parameters[1]
+
+ # Compute theta_c from theta_a and theta_b.
+ p = tf.matmul(self.a, theta_a) * tf.matmul(self.b, theta_b)
+ p_trans = tf.transpose(p, name="p_trans")
+ p_inv = tf.matmul(
+ tf.matrix_inverse(tf.matmul(p_trans, p)), p_trans, name="p_inv")
+ theta_c = tf.matmul(p_inv, self.c, name="theta_c")
+
+ # Compute the "predicted" value of c.
+ c_hat = tf.matmul(p, theta_c, name="c_hat")
+
+ # Compute the loss (sum of squared errors).
+ loss = tf.reduce_sum((c_hat - self.c)**2, name="loss")
+
+ return loss
+
+
+def matmul_problem_sequence(n, k_min, k_max):
+ """Helper to generate a sequence of matrix multiplication problems."""
+ return [(_Spec(MatMulAlgorithm, (n, k), {}), None, None)
+ for k in range(k_min, k_max + 1)]
+
+
+def init_fixed_variables(arrays):
+ with tf.variable_scope(PARAMETER_SCOPE):
+ params = [tf.Variable(arr.astype("float32")) for arr in arrays]
+ return params
+
+
+def _mesh(xlim, ylim, n):
+ """Creates a 2D meshgrid covering the given ranges.
+
+ Args:
+ xlim: int that defines the desired x-range (-xlim, xlim)
+ ylim: int that defines the desired y-range (-ylim, ylim)
+ n: number of points in each dimension of the mesh
+
+ Returns:
+ xm: 2D array of x-values in the mesh
+ ym: 2D array of y-values in the mesh
+ """
+ return np.meshgrid(np.linspace(-xlim, xlim, n),
+ np.linspace(-ylim, ylim, n))
diff --git a/models/research/learned_optimizer/problems/problem_sets.py b/models/research/learned_optimizer/problems/problem_sets.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaf9273b87ef69c6b3087330bdf46c8de7107a15
--- /dev/null
+++ b/models/research/learned_optimizer/problems/problem_sets.py
@@ -0,0 +1,561 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Groups of problems of different types for optimizer training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from learned_optimizer.problems import datasets
+from learned_optimizer.problems import model_adapter
+from learned_optimizer.problems import problem_generator as pg
+from learned_optimizer.problems import problem_spec
+
+_Spec = problem_spec.Spec
+
+
+def quadratic_problems():
+ return [
+ (_Spec(pg.Quadratic, (20,), {}), None, None),
+ (_Spec(pg.Quadratic, (25,), {}), None, None),
+ (_Spec(pg.Quadratic, (50,), {}), None, None),
+ (_Spec(pg.Quadratic, (100,), {}), None, None),
+ ]
+
+
+# Note: this group contains one non-noisy problem for historical reasons. The
+# original training set before the refactor included this set of quadratics.
+def quadratic_problems_noisy():
+ return [
+ (_Spec(pg.Quadratic, (20,), {"noise_stdev": 0.5}), None, None),
+ (_Spec(pg.Quadratic, (25,), {"noise_stdev": 0.0}), None, None),
+ (_Spec(pg.Quadratic, (50,), {"noise_stdev": 1.0}), None, None),
+ (_Spec(pg.Quadratic, (100,), {"noise_stdev": 2.0}), None, None),
+ ]
+
+
+def quadratic_problems_large():
+ return [
+ (_Spec(pg.Quadratic, (784,), {}), None, None),
+ (_Spec(pg.Quadratic, (1024,), {}), None, None),
+ (_Spec(pg.Quadratic, (2048,), {}), None, None),
+ ]
+
+
+def bowl_problems():
+ return [
+ (_Spec(pg.Bowl, (0.1,), {"noise_stdev": 0.0}), None, None),
+ (_Spec(pg.Bowl, (1.0,), {"noise_stdev": 0.0}), None, None),
+ (_Spec(pg.Bowl, (5.0,), {"noise_stdev": 0.0}), None, None),
+ (_Spec(pg.Bowl, (5.0,), {"noise_stdev": 0.0, "angle": np.pi / 4.}),
+ None, None),
+ ]
+
+
+def bowl_problems_noisy():
+ return [
+ (_Spec(pg.Bowl, (0.1,), {"noise_stdev": 0.1}), None, None),
+ (_Spec(pg.Bowl, (1.0,), {"noise_stdev": 0.1}), None, None),
+ (_Spec(pg.Bowl, (5.0,), {"noise_stdev": 0.1}), None, None),
+ (_Spec(pg.Bowl, (5.0,), {"noise_stdev": 0.1, "angle": np.pi / 4.}),
+ None, None),
+ ]
+
+
+def sparse_softmax_2_class_sparse_problems():
+ return [(_Spec(pg.SparseSoftmaxRegression, (5, 2), {"noise_stdev": 0.0}),
+ datasets.noisy_parity_class(5, random_seed=123), 23),]
+
+
+def one_hot_sparse_softmax_2_class_sparse_problems():
+ return [
+ (_Spec(pg.OneHotSparseSoftmaxRegression, (5, 2), {"noise_stdev": 0.0}),
+ datasets.noisy_parity_class(5, random_seed=123), 23),
+ ]
+
+
+def softmax_2_class_problems():
+ return [
+ (_Spec(pg.SoftmaxRegression, (10, 2), {}), datasets.random(
+ 10, 1000, random_seed=123, sep=2.0), 100),
+ (_Spec(pg.SoftmaxRegression, (100, 2), {}), datasets.random(
+ 100, 1000, random_seed=123), 50),
+ (_Spec(pg.SoftmaxRegression, (200, 2), {}), datasets.random(
+ 200, 1000, random_seed=123, sep=1.5), 20),
+ (_Spec(pg.SoftmaxRegression, (256, 2), {}), datasets.random(
+ 256, 1000, random_seed=123, sep=1.5), 100),
+ ]
+
+
+def softmax_2_class_problems_noisy():
+ return [
+ (_Spec(pg.SoftmaxRegression, (10, 2), {"noise_stdev": 0.5}),
+ datasets.random(10, 1000, random_seed=123, sep=2.0), 100),
+ (_Spec(pg.SoftmaxRegression, (100, 2), {"noise_stdev": 0.1}),
+ datasets.random(100, 1000, random_seed=123), 50),
+ (_Spec(pg.SoftmaxRegression, (200, 2), {"noise_stdev": 0.1}),
+ datasets.random(200, 1000, random_seed=123, sep=1.5), 20),
+ (_Spec(pg.SoftmaxRegression, (256, 2), {"noise_stdev": 0.5}),
+ datasets.random(256, 1000, random_seed=123, sep=1.5), 100),
+ ]
+
+
+def optimization_test_problems():
+ return [
+ (_Spec(pg.Ackley, (), {}), None, None),
+ (_Spec(pg.Beale, (), {}), None, None),
+ (_Spec(pg.Booth, (), {}), None, None),
+ (_Spec(pg.Branin, (), {}), None, None),
+ (_Spec(pg.LogSumExp, (), {}), None, None),
+ (_Spec(pg.Matyas, (), {}), None, None),
+ (_Spec(pg.Michalewicz, (), {}), None, None),
+ (_Spec(pg.Rosenbrock, (), {}), None, None),
+ (_Spec(pg.StyblinskiTang, (), {}), None, None),
+ ]
+
+
+def optimization_test_problems_noisy():
+ return [
+ (_Spec(pg.Ackley, (), {"noise_stdev": 1.}), None, None),
+ (_Spec(pg.Beale, (), {"noise_stdev": 1.}), None, None),
+ (_Spec(pg.Booth, (), {"noise_stdev": 1.}), None, None),
+ (_Spec(pg.Branin, (), {"noise_stdev": 1.}), None, None),
+ (_Spec(pg.LogSumExp, (), {"noise_stdev": 1.}), None, None),
+ (_Spec(pg.Matyas, (), {"noise_stdev": 1.}), None, None),
+ (_Spec(pg.Michalewicz, (), {"noise_stdev": 1.}), None, None),
+ (_Spec(pg.Rosenbrock, (), {"noise_stdev": 1.}), None, None),
+ (_Spec(pg.StyblinskiTang, (), {"noise_stdev": 1.}), None, None),
+ ]
+
+
+def fully_connected_random_2_class_problems():
+ return [
+ (_Spec(pg.FullyConnected, (8, 2),
+ {"hidden_sizes": (8, 5,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(8, 1000), 10),
+ (_Spec(pg.FullyConnected, (12, 2),
+ {"hidden_sizes": (8, 5, 3), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(12, 1000), 200),
+ (_Spec(pg.FullyConnected, (5, 2),
+ {"hidden_sizes": (4, 4, 4, 4,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(5, 1000), 100),
+ (_Spec(pg.FullyConnected, (11, 2),
+ {"hidden_sizes": (4, 5, 6,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(11, 1000), 64),
+ (_Spec(pg.FullyConnected, (9, 2),
+ {"hidden_sizes": (8,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(9, 1000), 128),
+ (_Spec(pg.FullyConnected, (7, 2),
+ {"hidden_sizes": (8, 5,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(7, 1000), 16),
+ (_Spec(pg.FullyConnected, (8, 2),
+ {"hidden_sizes": (32, 64,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(8, 1000), 10),
+ (_Spec(pg.FullyConnected, (12, 2),
+ {"hidden_sizes": (16, 8, 3), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(12, 1000), 200),
+ (_Spec(pg.FullyConnected, (5, 2),
+ {"hidden_sizes": (8, 8, 8, 8,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(5, 1000), 100),
+ (_Spec(pg.FullyConnected, (11, 2),
+ {"hidden_sizes": (10, 12, 12,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(11, 1000), 64),
+ (_Spec(pg.FullyConnected, (9, 2),
+ {"hidden_sizes": (32,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(9, 1000), 128),
+ (_Spec(pg.FullyConnected, (7, 2),
+ {"hidden_sizes": (32, 64,), "activation": tf.nn.sigmoid}),
+ datasets.random_mlp(7, 1000), 16),
+ ]
+
+
+def matmul_problems():
+ return sum([
+ pg.matmul_problem_sequence(2, 5, 8),
+ pg.matmul_problem_sequence(3, 19, 24)], [])
+
+
+def log_objective_problems():
+ return [
+ (_Spec(pg.LogObjective, [_Spec(pg.Quadratic, (20,), {})], {}),
+ None, None),
+ (_Spec(pg.LogObjective, [_Spec(pg.Quadratic, (50,), {})], {}),
+ None, None),
+ (_Spec(pg.LogObjective, [_Spec(pg.Quadratic, (100,), {})], {}),
+ None, None),
+ (_Spec(pg.LogObjective, [_Spec(pg.Bowl, (0.1,), {})], {}), None, None),
+ (_Spec(pg.LogObjective, [_Spec(pg.Bowl, (1.0,), {})], {}), None, None),
+ (_Spec(pg.LogObjective, [_Spec(pg.Bowl, (5.0,), {})], {}), None, None),
+ ]
+
+
+def sparse_gradient_problems():
+ return [
+ (_Spec(pg.SparseProblem, [_Spec(pg.Quadratic, (20,), {})], {}),
+ None, None),
+ (_Spec(pg.SparseProblem, [_Spec(pg.Quadratic, (50,), {})], {}),
+ None, None),
+ (_Spec(pg.SparseProblem, [_Spec(pg.Quadratic, (100,), {})], {}),
+ None, None),
+ (_Spec(pg.SparseProblem, [_Spec(pg.Bowl, (0.1,), {})], {}), None, None),
+ (_Spec(pg.SparseProblem, [_Spec(pg.Bowl, (1.0,), {})], {}), None, None),
+ (_Spec(pg.SparseProblem, [_Spec(pg.Bowl, (5.0,), {})], {}), None, None),
+ ]
+
+
+def sparse_gradient_problems_mlp():
+ return [
+ (_Spec(pg.SparseProblem, [
+ _Spec(pg.FullyConnected, (8, 2), {
+ "hidden_sizes": (8, 5,),
+ "activation": tf.nn.sigmoid
+ })
+ ], {}), datasets.random_mlp(8, 1000), 10),
+ (_Spec(pg.SparseProblem, [
+ _Spec(pg.FullyConnected, (12, 2), {
+ "hidden_sizes": (8, 5, 3),
+ "activation": tf.nn.sigmoid
+ })
+ ], {}), datasets.random_mlp(12, 1000), 200),
+ (_Spec(pg.SparseProblem, [
+ _Spec(pg.FullyConnected, (5, 2), {
+ "hidden_sizes": (4, 4, 4, 4,),
+ "activation": tf.nn.sigmoid
+ })
+ ], {}), datasets.random_mlp(5, 1000), 100),
+ ]
+
+
+def rescale_problems():
+ return [
+ (_Spec(pg.Rescale, [_Spec(pg.Norm, (18,), {"norm_power": 2.5})],
+ {"scale": 0.123}), None, None),
+ (_Spec(pg.Rescale, [_Spec(pg.Norm, (18,), {"norm_power": 1.5})],
+ {"scale": 8}), None, None),
+ (_Spec(pg.Rescale, [_Spec(pg.Norm, (18,), {"norm_power": 2.})],
+ {"scale": 50}), None, None),
+ (_Spec(pg.Rescale, [_Spec(pg.Norm, (18,), {"norm_power": 3.})],
+ {"scale": 200}), None, None),
+ (_Spec(pg.Rescale, [_Spec(pg.Norm, (18,), {"norm_power": 1.})],
+ {"scale": 1000}), None, None),
+ (_Spec(pg.Rescale, [_Spec(pg.Quadratic, (20,), {})], {"scale": 0.1}),
+ None, None),
+ (_Spec(pg.Rescale, [_Spec(pg.Quadratic, (25,), {})], {"scale": 10.}),
+ None, None),
+ (_Spec(pg.Rescale, [_Spec(pg.Quadratic, (50,), {})], {"scale": 350.}),
+ None, None),
+ (_Spec(pg.Rescale, [_Spec(pg.Quadratic, (100,), {})], {"scale": 132}),
+ None, None),
+ ]
+
+
+def norm_problems():
+ return [
+ # < 1 Norm causes NaN gradients early in training.
+ (_Spec(pg.Norm, (27,), {"norm_power": 1.}), None, None),
+ (_Spec(pg.Norm, (25,), {"norm_power": 2.}), None, None),
+ (_Spec(pg.Norm, (22,), {"norm_power": 3.}), None, None),
+ ]
+
+
+def norm_problems_noisy():
+ return [
+ # < 1 Norm causes NaN gradients early in training.
+ (_Spec(pg.Norm, (19,), {"noise_stdev": .1, "norm_power": 1.}),
+ None, None),
+ (_Spec(pg.Norm, (26,), {"noise_stdev": .1, "norm_power": 2.}),
+ None, None),
+ (_Spec(pg.Norm, (23,), {"noise_stdev": .1, "norm_power": 3.}),
+ None, None),
+ ]
+
+
+def sum_problems():
+ return [
+ (_Spec(pg.SumTask, [[
+ _Spec(pg.Quadratic, (11,), {}),
+ _Spec(pg.Quadratic, (3,), {}),
+ _Spec(pg.Quadratic, (9,), {}),
+ _Spec(pg.Quadratic, (7,), {}),
+ _Spec(pg.Quadratic, (5,), {}),
+ _Spec(pg.Quadratic, (13,), {}),
+ _Spec(pg.Quadratic, (12,), {})
+ ]], {}), None, None),
+ (_Spec(pg.SumTask, [[
+ _Spec(pg.Norm, (18,), {"norm_power": 3}),
+ _Spec(pg.Quadratic, (25,), {}),
+ _Spec(pg.Rosenbrock, (), {})
+ ]], {}), None, None),
+ (_Spec(pg.SumTask, [[
+ _Spec(pg.Rosenbrock, (), {}),
+ _Spec(pg.LogSumExp, (), {}),
+ _Spec(pg.Ackley, (), {}),
+ _Spec(pg.Beale, (), {}),
+ _Spec(pg.Booth, (), {}),
+ _Spec(pg.StyblinskiTang, (), {}),
+ _Spec(pg.Matyas, (), {}),
+ _Spec(pg.Branin, (), {}),
+ _Spec(pg.Michalewicz, (), {})
+ ]], {}), None, None),
+ (_Spec(pg.SumTask, [[
+ _Spec(pg.Rosenbrock, (), {}),
+ _Spec(pg.LogSumExp, (), {}),
+ _Spec(pg.Ackley, (), {}),
+ _Spec(pg.Beale, (), {}),
+ _Spec(pg.Booth, (), {}),
+ _Spec(pg.StyblinskiTang, (), {}),
+ _Spec(pg.Matyas, (), {}),
+ _Spec(pg.Branin, (), {}),
+ _Spec(pg.Michalewicz, (), {}),
+ _Spec(pg.Quadratic, (5,), {}),
+ _Spec(pg.Quadratic, (13,), {})
+ ]], {}), None, None),
+ (_Spec(pg.SumTask, [[
+ _Spec(pg.Quadratic, (11,), {}),
+ _Spec(pg.Quadratic, (3,), {})
+ ]], {}), None, None),
+ (_Spec(pg.SumTask, [[
+ _Spec(pg.Rosenbrock, (), {}),
+ _Spec(pg.LogSumExp, (), {}),
+ _Spec(pg.Ackley, (), {})
+ ]], {}), None, None),
+ ]
+
+
+def sum_problems_noisy():
+ return [
+ (_Spec(pg.SumTask, [[
+ _Spec(pg.Quadratic, (11,), {"noise_stdev": 0.1}),
+ _Spec(pg.Quadratic, (3,), {"noise_stdev": 0.1}),
+ _Spec(pg.Quadratic, (9,), {"noise_stdev": 0.1}),
+ _Spec(pg.Quadratic, (7,), {"noise_stdev": 0.1}),
+ _Spec(pg.Quadratic, (5,), {"noise_stdev": 0.1}),
+ _Spec(pg.Quadratic, (13,), {"noise_stdev": 0.1}),
+ _Spec(pg.Quadratic, (12,), {"noise_stdev": 0.1})
+ ]], {}), None, None),
+ (_Spec(pg.SumTask, [[
+ _Spec(pg.Rosenbrock, (), {}),
+ _Spec(pg.LogSumExp, (), {}),
+ _Spec(pg.Ackley, (), {}),
+ _Spec(pg.Beale, (), {}),
+ _Spec(pg.Booth, (), {}),
+ _Spec(pg.StyblinskiTang, (), {}),
+ _Spec(pg.Matyas, (), {}),
+ _Spec(pg.Branin, (), {}),
+ _Spec(pg.Michalewicz, (), {}),
+ _Spec(pg.Quadratic, (5,), {}),
+ _Spec(pg.Quadratic, (13,), {"noise_stdev": 0.5})
+ ]], {}), None, None),
+ ]
+
+
+def dependency_chain_problems():
+ return [
+ (_Spec(pg.DependencyChain, (20,), {}), datasets.random_binary(
+ 20, 1000), 100),
+ (_Spec(pg.DependencyChain, (12,), {}), datasets.random_binary(
+ 12, 200), 10),
+ (_Spec(pg.DependencyChain, (56,), {}), datasets.random_binary(
+ 56, 5000), 100),
+ (_Spec(pg.DependencyChain, (64,), {}), datasets.random_binary(
+ 64, 1000), 50),
+ (_Spec(pg.DependencyChain, (13,), {}), datasets.random_binary(
+ 13, 10000), 50),
+ (_Spec(pg.DependencyChain, (20,), {}), datasets.random_binary(
+ 20, 1000), 128),
+ (_Spec(pg.DependencyChain, (12,), {}), datasets.random_binary(
+ 12, 300), 16),
+ (_Spec(pg.DependencyChain, (56,), {}), datasets.random_binary(
+ 56, 5000), 128),
+ (_Spec(pg.DependencyChain, (64,), {}), datasets.random_binary(
+ 64, 1000), 64),
+ (_Spec(pg.DependencyChain, (13,), {}), datasets.random_binary(
+ 13, 10000), 32),
+ ]
+
+
+def outward_snake_problems():
+ return [
+ (_Spec(pg.OutwardSnake, (20,), {}), datasets.random_binary(
+ 20, 1000), 100),
+ (_Spec(pg.OutwardSnake, (12,), {}), datasets.random_binary(
+ 12, 200), 10),
+ (_Spec(pg.OutwardSnake, (56,), {}), datasets.random_binary(
+ 56, 5000), 100),
+ (_Spec(pg.OutwardSnake, (64,), {}), datasets.random_binary(
+ 64, 1000), 50),
+ (_Spec(pg.OutwardSnake, (13,), {}), datasets.random_binary(
+ 13, 10000), 50),
+ (_Spec(pg.OutwardSnake, (20,), {}), datasets.random_binary(
+ 20, 1000), 128),
+ (_Spec(pg.OutwardSnake, (12,), {}), datasets.random_binary(
+ 12, 300), 16),
+ (_Spec(pg.OutwardSnake, (56,), {}), datasets.random_binary(
+ 56, 5000), 128),
+ (_Spec(pg.OutwardSnake, (64,), {}), datasets.random_binary(
+ 64, 1000), 64),
+ (_Spec(pg.OutwardSnake, (13,), {}), datasets.random_binary(
+ 13, 10000), 32),
+ ]
+
+
+def min_max_well_problems():
+ return [
+ (_Spec(pg.MinMaxWell, (20,), {}), None, None),
+ (_Spec(pg.MinMaxWell, (12,), {}), None, None),
+ (_Spec(pg.MinMaxWell, (56,), {}), None, None),
+ (_Spec(pg.MinMaxWell, (64,), {}), None, None),
+ (_Spec(pg.MinMaxWell, (13,), {}), None, None),
+ ]
+
+
+def sum_of_quadratics_problems():
+ return [
+ (_Spec(pg.SumOfQuadratics, (20,), {}),
+ datasets.random_symmetric(20, 1000), 100),
+ (_Spec(pg.SumOfQuadratics, (12,), {}),
+ datasets.random_symmetric(12, 100), 10),
+ (_Spec(pg.SumOfQuadratics, (56,), {}),
+ datasets.random_symmetric(56, 5000), 100),
+ (_Spec(pg.SumOfQuadratics, (64,), {}),
+ datasets.random_symmetric(64, 1000), 50),
+ (_Spec(pg.SumOfQuadratics, (13,), {}),
+ datasets.random_symmetric(13, 10000), 50),
+ (_Spec(pg.SumOfQuadratics, (20,), {}),
+ datasets.random_symmetric(20, 1000), 128),
+ (_Spec(pg.SumOfQuadratics, (12,), {}),
+ datasets.random_symmetric(12, 100), 16),
+ (_Spec(pg.SumOfQuadratics, (56,), {}),
+ datasets.random_symmetric(56, 5000), 128),
+ (_Spec(pg.SumOfQuadratics, (64,), {}),
+ datasets.random_symmetric(64, 1000), 64),
+ (_Spec(pg.SumOfQuadratics, (13,), {}),
+ datasets.random_symmetric(13, 10000), 32),
+ ]
+
+
+def projection_quadratic_problems():
+ return [
+ (_Spec(pg.ProjectionQuadratic, (20,), {}),
+ datasets.random_symmetric(20, 1000), 100),
+ (_Spec(pg.ProjectionQuadratic, (12,), {}),
+ datasets.random_symmetric(12, 100), 10),
+ (_Spec(pg.ProjectionQuadratic, (56,), {}),
+ datasets.random_symmetric(56, 5000), 100),
+ (_Spec(pg.ProjectionQuadratic, (64,), {}),
+ datasets.random_symmetric(64, 1000), 50),
+ (_Spec(pg.ProjectionQuadratic, (13,), {}),
+ datasets.random_symmetric(13, 10000), 50),
+ (_Spec(pg.ProjectionQuadratic, (20,), {}),
+ datasets.random_symmetric(20, 1000), 128),
+ (_Spec(pg.ProjectionQuadratic, (12,), {}),
+ datasets.random_symmetric(12, 100), 16),
+ (_Spec(pg.ProjectionQuadratic, (56,), {}),
+ datasets.random_symmetric(56, 5000), 128),
+ (_Spec(pg.ProjectionQuadratic, (64,), {}),
+ datasets.random_symmetric(64, 1000), 64),
+ (_Spec(pg.ProjectionQuadratic, (13,), {}),
+ datasets.random_symmetric(13, 10000), 32),
+ ]
+
+
+def adapter_rosenbrock_local():
+ return [(_Spec(model_adapter.ModelAdapter,
+ (pg.make_rosenbrock_loss_and_init,), {}), None, None),]
+
+
+def adapter_rosenbrock_worker():
+ return [(_Spec(model_adapter.ModelAdapter,
+ (pg.make_rosenbrock_loss_and_init,),
+ {"device": "/job:worker"}), None, None),]
+
+
+def _test_problem_mlp_scaled_init_small():
+ return [
+ np.random.randn(10, 32) * np.sqrt(2./10),
+ np.random.randn(32,) * 0.1,
+ np.random.randn(32, 64) * np.sqrt(2./32.),
+ np.random.randn(64,) * 0.1,
+ np.random.randn(64, 2) * np.sqrt(2./64.),
+ np.random.randn(2,) * 0.1
+ ]
+
+
+def _test_problem_mlp_scaled_init_large():
+ return [
+ np.random.randn(20, 32) * np.sqrt(2./20),
+ np.random.randn(32,) * 0.1,
+ np.random.randn(32, 64) * np.sqrt(2./32.),
+ np.random.randn(64,) * 0.1,
+ np.random.randn(64, 10) * np.sqrt(2./64.),
+ np.random.randn(10,) * 0.1
+ ]
+
+
+def _test_problem_mlp_scaled_init_mnist():
+ return [
+ np.random.randn(784, 64) * np.sqrt(2./784.),
+ np.random.randn(64,) * 0.1,
+ np.random.randn(64, 10) * np.sqrt(2./ 64.),
+ np.random.randn(10,) * 0.1,
+ ]
+
+
+# Wrap this construction in a function to avoid UnparsedFlagAccessError
+def test_problems():
+ """Test problems for visualizations."""
+ # Unlike the training problem sets, these test problems are made up of
+ # length-5 tuples. The final items in the tuple are the name of the problem
+ # and the initialization random_seed for testing consistency.
+ tp = [
+ (_Spec(pg.Quadratic, (20,), {"random_seed": 1234}), None, None,
+ "quad_problem", 5678),
+ (_Spec(pg.Quadratic, (20,), {"noise_stdev": 1.0, "random_seed": 1234}),
+ None, None, "quad_problem_noise", 5678),
+ (_Spec(pg.Rosenbrock, (), {"random_seed": 1234}), None, None,
+ "rosenbrock", 5678),
+ (_Spec(pg.Rosenbrock, (), {"random_seed": 1234, "noise_stdev": 1.0}),
+ None, None, "rosenbrock_noise", 5678),
+ (_Spec(pg.SoftmaxRegression, (10, 2), {}), datasets.random(
+ 10, 10000, random_seed=1234), 100, "softmax", 5678),
+ (_Spec(pg.SoftmaxRegression, (10, 2), {"noise_stdev": 1.0}),
+ datasets.random(10, 10000, random_seed=1234), 100, "softmax_noise",
+ 5678),
+ (_Spec(pg.FullyConnected, (10, 2), {}), datasets.random(
+ 10, 10000, random_seed=1234), 100, "mlp_small",
+ _test_problem_mlp_scaled_init_small()),
+ (_Spec(pg.FullyConnected, (20, 10), {}), datasets.random(
+ 20, 10000, n_classes=10, random_seed=1234), 100, "mlp_large",
+ _test_problem_mlp_scaled_init_large()),
+ (_Spec(pg.FullyConnected, (784, 10),
+ {"hidden_sizes": (64,), "activation": tf.nn.sigmoid}),
+ datasets.mnist(), 64, "mlp_mnist_sigmoid",
+ _test_problem_mlp_scaled_init_mnist()),
+ (_Spec(pg.FullyConnected, (784, 10),
+ {"hidden_sizes": (64,), "activation": tf.nn.relu}),
+ datasets.mnist(), 64, "mlp_mnist_relu",
+ _test_problem_mlp_scaled_init_mnist()),
+ (_Spec(pg.ConvNet, ((1, 28, 28), 10, [(3, 3, 8), (5, 5, 8)]),
+ {"activation": tf.nn.sigmoid}), datasets.mnist(), 64,
+ "convnet_mnist_sigmoid", None),
+ (_Spec(pg.ConvNet, ((1, 28, 28), 10, [(3, 3, 8), (5, 5, 8)]),
+ {"activation": tf.nn.relu}), datasets.mnist(), 64,
+ "convnet_mnist_relu", None),
+ ]
+ return tp
diff --git a/models/research/learned_optimizer/problems/problem_spec.py b/models/research/learned_optimizer/problems/problem_spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30c47b277e5c8b3b8aba3b8d691a2af3a595ef6
--- /dev/null
+++ b/models/research/learned_optimizer/problems/problem_spec.py
@@ -0,0 +1,33 @@
+# Copyright 2017 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Wrapper around a training problem."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+
+
+class Spec(namedtuple("Spec", "callable args kwargs")):
+ """Syntactic sugar for keeping track of a function/class + args."""
+
+ # Since this is an immutable object, we don't need to reserve slots.
+ __slots__ = ()
+
+ def build(self):
+ """Returns the output of the callable."""
+ return self.callable(*self.args, **self.kwargs)
diff --git a/models/research/learning_to_remember_rare_events/README.md b/models/research/learning_to_remember_rare_events/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2eeadea784d4d22efc88c56e482c5d5374c90e24
--- /dev/null
+++ b/models/research/learning_to_remember_rare_events/README.md
@@ -0,0 +1,61 @@
+
+
+
+
+---
+
+Code for the Memory Module as described
+in "Learning to Remember Rare Events" by
+Lukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio
+published as a conference paper at ICLR 2017.
+
+Requirements:
+* TensorFlow (see tensorflow.org for how to install)
+* Some basic command-line utilities (git, unzip).
+
+Description:
+
+The general memory module is located in memory.py.
+Some code is provided to see the memory module in
+action on the standard Omniglot dataset.
+Download and setup the dataset using data_utils.py
+and then run the training script train.py
+(see example commands below).
+
+Note that the structure and parameters of the model
+are optimized for the data preparation as provided.
+
+Quick Start:
+
+First download and set-up Omniglot data by running
+
+```
+python data_utils.py
+```
+
+Then run the training script:
+
+```
+python train.py --memory_size=8192 \
+ --batch_size=16 --validation_length=50 \
+ --episode_width=5 --episode_length=30
+```
+
+The first validation batch may look like this (although it is noisy):
+```
+0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516, 3-shot: 0.604,
+ 4-shot: 0.656, 5-shot: 0.684
+```
+At step 500 you may see something like this:
+```
+0-shot: 0.036, 1-shot: 0.836, 2-shot: 0.900, 3-shot: 0.940,
+ 4-shot: 0.944, 5-shot: 0.916
+```
+At step 4000 you may see something like this:
+```
+0-shot: 0.044, 1-shot: 0.960, 2-shot: 1.000, 3-shot: 0.988,
+ 4-shot: 0.972, 5-shot: 0.992
+```
+
+Maintained by Ofir Nachum (ofirnachum) and
+Lukasz Kaiser (lukaszkaiser).
diff --git a/models/research/learning_to_remember_rare_events/data_utils.py b/models/research/learning_to_remember_rare_events/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..03d5dafb251d4e058a6780b447aabdcd1a84a1d4
--- /dev/null
+++ b/models/research/learning_to_remember_rare_events/data_utils.py
@@ -0,0 +1,243 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+"""Data loading and other utilities.
+
+Use this file to first copy over and pre-process the Omniglot dataset.
+Simply call
+ python data_utils.py
+"""
+
+import logging
+import os
+import subprocess
+from six.moves import cPickle as pickle
+
+import numpy as np
+from scipy.misc import imresize
+from scipy.misc import imrotate
+from scipy.ndimage import imread
+from six.moves import xrange
+import tensorflow as tf
+
+
+MAIN_DIR = ''
+REPO_LOCATION = 'https://github.com/brendenlake/omniglot.git'
+REPO_DIR = os.path.join(MAIN_DIR, 'omniglot')
+DATA_DIR = os.path.join(REPO_DIR, 'python')
+TRAIN_DIR = os.path.join(DATA_DIR, 'images_background')
+TEST_DIR = os.path.join(DATA_DIR, 'images_evaluation')
+DATA_FILE_FORMAT = os.path.join(MAIN_DIR, '%s_omni.pkl')
+
+TRAIN_ROTATIONS = True # augment training data with rotations
+TEST_ROTATIONS = False # augment testing data with rotations
+IMAGE_ORIGINAL_SIZE = 105
+IMAGE_NEW_SIZE = 28
+
+
+def get_data():
+ """Get data in form suitable for episodic training.
+
+ Returns:
+ Train and test data as dictionaries mapping
+ label to list of examples.
+ """
+ with tf.gfile.GFile(DATA_FILE_FORMAT % 'train', 'rb') as f:
+ processed_train_data = pickle.load(f)
+ with tf.gfile.GFile(DATA_FILE_FORMAT % 'test', 'rb') as f:
+ processed_test_data = pickle.load(f)
+
+ train_data = {}
+ test_data = {}
+
+ for data, processed_data in zip([train_data, test_data],
+ [processed_train_data, processed_test_data]):
+ for image, label in zip(processed_data['images'],
+ processed_data['labels']):
+ if label not in data:
+ data[label] = []
+ data[label].append(image.reshape([-1]).astype('float32'))
+
+ intersection = set(train_data.keys()) & set(test_data.keys())
+ assert not intersection, 'Train and test data intersect.'
+ ok_num_examples = [len(ll) == 20 for _, ll in train_data.items()]
+ assert all(ok_num_examples), 'Bad number of examples in train data.'
+ ok_num_examples = [len(ll) == 20 for _, ll in test_data.items()]
+ assert all(ok_num_examples), 'Bad number of examples in test data.'
+
+ logging.info('Number of labels in train data: %d.', len(train_data))
+ logging.info('Number of labels in test data: %d.', len(test_data))
+
+ return train_data, test_data
+
+
+def crawl_directory(directory, augment_with_rotations=False,
+ first_label=0):
+ """Crawls data directory and returns stuff."""
+ label_idx = first_label
+ images = []
+ labels = []
+ info = []
+
+ # traverse root directory
+ for root, _, files in os.walk(directory):
+ logging.info('Reading files from %s', root)
+ fileflag = 0
+ for file_name in files:
+ full_file_name = os.path.join(root, file_name)
+ img = imread(full_file_name, flatten=True)
+ for i, angle in enumerate([0, 90, 180, 270]):
+ if not augment_with_rotations and i > 0:
+ break
+
+ images.append(imrotate(img, angle))
+ labels.append(label_idx + i)
+ info.append(full_file_name)
+
+ fileflag = 1
+
+ if fileflag:
+ label_idx += 4 if augment_with_rotations else 1
+
+ return images, labels, info
+
+
+def resize_images(images, new_width, new_height):
+ """Resize images to new dimensions."""
+ resized_images = np.zeros([images.shape[0], new_width, new_height],
+ dtype=np.float32)
+
+ for i in range(images.shape[0]):
+ resized_images[i, :, :] = imresize(images[i, :, :],
+ [new_width, new_height],
+ interp='bilinear',
+ mode=None)
+ return resized_images
+
+
+def write_datafiles(directory, write_file,
+ resize=True, rotate=False,
+ new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
+ first_label=0):
+ """Load and preprocess images from a directory and write them to a file.
+
+ Args:
+ directory: Directory of alphabet sub-directories.
+ write_file: Filename to write to.
+ resize: Whether to resize the images.
+ rotate: Whether to augment the dataset with rotations.
+ new_width: New resize width.
+ new_height: New resize height.
+ first_label: Label to start with.
+
+ Returns:
+ Number of new labels created.
+ """
+
+ # these are the default sizes for Omniglot:
+ imgwidth = IMAGE_ORIGINAL_SIZE
+ imgheight = IMAGE_ORIGINAL_SIZE
+
+ logging.info('Reading the data.')
+ images, labels, info = crawl_directory(directory,
+ augment_with_rotations=rotate,
+ first_label=first_label)
+
+ images_np = np.zeros([len(images), imgwidth, imgheight], dtype=np.bool)
+ labels_np = np.zeros([len(labels)], dtype=np.uint32)
+ for i in xrange(len(images)):
+ images_np[i, :, :] = images[i]
+ labels_np[i] = labels[i]
+
+ if resize:
+ logging.info('Resizing images.')
+ resized_images = resize_images(images_np, new_width, new_height)
+
+ logging.info('Writing resized data in float32 format.')
+ data = {'images': resized_images,
+ 'labels': labels_np,
+ 'info': info}
+ with tf.gfile.GFile(write_file, 'w') as f:
+ pickle.dump(data, f)
+ else:
+ logging.info('Writing original sized data in boolean format.')
+ data = {'images': images_np,
+ 'labels': labels_np,
+ 'info': info}
+ with tf.gfile.GFile(write_file, 'w') as f:
+ pickle.dump(data, f)
+
+ return len(np.unique(labels_np))
+
+
+def maybe_download_data():
+ """Download Omniglot repo if it does not exist."""
+ if os.path.exists(REPO_DIR):
+ logging.info('It appears that Git repo already exists.')
+ else:
+ logging.info('It appears that Git repo does not exist.')
+ logging.info('Cloning now.')
+
+ subprocess.check_output('git clone %s' % REPO_LOCATION, shell=True)
+
+ if os.path.exists(TRAIN_DIR):
+ logging.info('It appears that train data has already been unzipped.')
+ else:
+ logging.info('It appears that train data has not been unzipped.')
+ logging.info('Unzipping now.')
+
+ subprocess.check_output('unzip %s.zip -d %s' % (TRAIN_DIR, DATA_DIR),
+ shell=True)
+
+ if os.path.exists(TEST_DIR):
+ logging.info('It appears that test data has already been unzipped.')
+ else:
+ logging.info('It appears that test data has not been unzipped.')
+ logging.info('Unzipping now.')
+
+ subprocess.check_output('unzip %s.zip -d %s' % (TEST_DIR, DATA_DIR),
+ shell=True)
+
+
+def preprocess_omniglot():
+ """Download and prepare raw Omniglot data.
+
+ Downloads the data from GitHub if it does not exist.
+ Then load the images, augment with rotations if desired.
+ Resize the images and write them to a pickle file.
+ """
+
+ maybe_download_data()
+
+ directory = TRAIN_DIR
+ write_file = DATA_FILE_FORMAT % 'train'
+ num_labels = write_datafiles(
+ directory, write_file, resize=True, rotate=TRAIN_ROTATIONS,
+ new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE)
+
+ directory = TEST_DIR
+ write_file = DATA_FILE_FORMAT % 'test'
+ write_datafiles(directory, write_file, resize=True, rotate=TEST_ROTATIONS,
+ new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
+ first_label=num_labels)
+
+
+def main(unused_argv):
+ logging.basicConfig(level=logging.INFO)
+ preprocess_omniglot()
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/learning_to_remember_rare_events/memory.py b/models/research/learning_to_remember_rare_events/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f40ff57f9434994f08b1ad97dc23142bb23daaa
--- /dev/null
+++ b/models/research/learning_to_remember_rare_events/memory.py
@@ -0,0 +1,392 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+"""Memory module for storing "nearest neighbors".
+
+Implements a key-value memory for generalized one-shot learning
+as described in the paper
+"Learning to Remember Rare Events"
+by Lukasz Kaiser, Ofir Nachum, Aurko Roy, Samy Bengio,
+published as a conference paper at ICLR 2017.
+"""
+
+import numpy as np
+from six.moves import xrange
+import tensorflow as tf
+
+
+class Memory(object):
+ """Memory module."""
+
+ def __init__(self, key_dim, memory_size, vocab_size,
+ choose_k=256, alpha=0.1, correct_in_top=1, age_noise=8.0,
+ var_cache_device='', nn_device=''):
+ self.key_dim = key_dim
+ self.memory_size = memory_size
+ self.vocab_size = vocab_size
+ self.choose_k = min(choose_k, memory_size)
+ self.alpha = alpha
+ self.correct_in_top = correct_in_top
+ self.age_noise = age_noise
+ self.var_cache_device = var_cache_device # Variables are cached here.
+ self.nn_device = nn_device # Device to perform nearest neighbour matmul.
+
+ caching_device = var_cache_device if var_cache_device else None
+ self.update_memory = tf.constant(True) # Can be fed "false" if needed.
+ self.mem_keys = tf.get_variable(
+ 'memkeys', [self.memory_size, self.key_dim], trainable=False,
+ initializer=tf.random_uniform_initializer(-0.0, 0.0),
+ caching_device=caching_device)
+ self.mem_vals = tf.get_variable(
+ 'memvals', [self.memory_size], dtype=tf.int32, trainable=False,
+ initializer=tf.constant_initializer(0, tf.int32),
+ caching_device=caching_device)
+ self.mem_age = tf.get_variable(
+ 'memage', [self.memory_size], dtype=tf.float32, trainable=False,
+ initializer=tf.constant_initializer(0.0), caching_device=caching_device)
+ self.recent_idx = tf.get_variable(
+ 'recent_idx', [self.vocab_size], dtype=tf.int32, trainable=False,
+ initializer=tf.constant_initializer(0, tf.int32))
+
+ # variable for projecting query vector into memory key
+ self.query_proj = tf.get_variable(
+ 'memory_query_proj', [self.key_dim, self.key_dim], dtype=tf.float32,
+ initializer=tf.truncated_normal_initializer(0, 0.01),
+ caching_device=caching_device)
+
+ def get(self):
+ return self.mem_keys, self.mem_vals, self.mem_age, self.recent_idx
+
+ def set(self, k, v, a, r=None):
+ return tf.group(
+ self.mem_keys.assign(k),
+ self.mem_vals.assign(v),
+ self.mem_age.assign(a),
+ (self.recent_idx.assign(r) if r is not None else tf.group()))
+
+ def clear(self):
+ return tf.variables_initializer([self.mem_keys, self.mem_vals, self.mem_age,
+ self.recent_idx])
+
+ def get_hint_pool_idxs(self, normalized_query):
+ """Get small set of idxs to compute nearest neighbor queries on.
+
+ This is an expensive look-up on the whole memory that is used to
+ avoid more expensive operations later on.
+
+ Args:
+ normalized_query: A Tensor of shape [None, key_dim].
+
+ Returns:
+ A Tensor of shape [None, choose_k] of indices in memory
+ that are closest to the queries.
+
+ """
+ # look up in large memory, no gradients
+ with tf.device(self.nn_device):
+ similarities = tf.matmul(tf.stop_gradient(normalized_query),
+ self.mem_keys, transpose_b=True, name='nn_mmul')
+ _, hint_pool_idxs = tf.nn.top_k(
+ tf.stop_gradient(similarities), k=self.choose_k, name='nn_topk')
+ return hint_pool_idxs
+
+ def make_update_op(self, upd_idxs, upd_keys, upd_vals,
+ batch_size, use_recent_idx, intended_output):
+ """Function that creates all the update ops."""
+ mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
+ dtype=tf.float32))
+ with tf.control_dependencies([mem_age_incr]):
+ mem_age_upd = tf.scatter_update(
+ self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))
+
+ mem_key_upd = tf.scatter_update(
+ self.mem_keys, upd_idxs, upd_keys)
+ mem_val_upd = tf.scatter_update(
+ self.mem_vals, upd_idxs, upd_vals)
+
+ if use_recent_idx:
+ recent_idx_upd = tf.scatter_update(
+ self.recent_idx, intended_output, upd_idxs)
+ else:
+ recent_idx_upd = tf.group()
+
+ return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
+
+ def query(self, query_vec, intended_output, use_recent_idx=True):
+ """Queries memory for nearest neighbor.
+
+ Args:
+ query_vec: A batch of vectors to query (embedding of input to model).
+ intended_output: The values that would be the correct output of the
+ memory.
+ use_recent_idx: Whether to always insert at least one instance of a
+ correct memory fetch.
+
+ Returns:
+ A tuple (result, mask, teacher_loss).
+ result: The result of the memory look up.
+ mask: The affinity of the query to the result.
+ teacher_loss: The loss for training the memory module.
+ """
+
+ batch_size = tf.shape(query_vec)[0]
+ output_given = intended_output is not None
+
+ # prepare query for memory lookup
+ query_vec = tf.matmul(query_vec, self.query_proj)
+ normalized_query = tf.nn.l2_normalize(query_vec, dim=1)
+
+ hint_pool_idxs = self.get_hint_pool_idxs(normalized_query)
+
+ if output_given and use_recent_idx: # add at least one correct memory
+ most_recent_hint_idx = tf.gather(self.recent_idx, intended_output)
+ hint_pool_idxs = tf.concat(
+ axis=1,
+ values=[hint_pool_idxs, tf.expand_dims(most_recent_hint_idx, 1)])
+ choose_k = tf.shape(hint_pool_idxs)[1]
+
+ with tf.device(self.var_cache_device):
+ # create small memory and look up with gradients
+ my_mem_keys = tf.stop_gradient(tf.gather(self.mem_keys, hint_pool_idxs,
+ name='my_mem_keys_gather'))
+ similarities = tf.matmul(tf.expand_dims(normalized_query, 1),
+ my_mem_keys, adjoint_b=True, name='batch_mmul')
+ hint_pool_sims = tf.squeeze(similarities, [1], name='hint_pool_sims')
+ hint_pool_mem_vals = tf.gather(self.mem_vals, hint_pool_idxs,
+ name='hint_pool_mem_vals')
+ # Calculate softmax mask on the top-k if requested.
+ # Softmax temperature. Say we have K elements at dist x and one at (x+a).
+ # Softmax of the last is e^tm(x+a)/Ke^tm*x + e^tm(x+a) = e^tm*a/K+e^tm*a.
+ # To make that 20% we'd need to have e^tm*a ~= 0.2K, so tm = log(0.2K)/a.
+ softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha)
+ mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp)
+
+ # prepare returned values
+ nearest_neighbor = tf.to_int32(
+ tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
+
+ no_teacher_idxs = tf.gather(
+ tf.reshape(hint_pool_idxs, [-1]),
+ nearest_neighbor + choose_k * tf.range(batch_size))
+
+ with tf.device(self.var_cache_device):
+ result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
+
+ if not output_given:
+ teacher_loss = None
+ return result, mask, teacher_loss
+
+ # prepare hints from the teacher on hint pool
+ teacher_hints = tf.to_float(
+ tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals))
+ teacher_hints = 1.0 - tf.minimum(1.0, teacher_hints)
+
+ teacher_vals, teacher_hint_idxs = tf.nn.top_k(
+ hint_pool_sims * teacher_hints, k=1)
+ neg_teacher_vals, _ = tf.nn.top_k(
+ hint_pool_sims * (1 - teacher_hints), k=1)
+
+ # bring back idxs to full memory
+ teacher_idxs = tf.gather(
+ tf.reshape(hint_pool_idxs, [-1]),
+ teacher_hint_idxs[:, 0] + choose_k * tf.range(batch_size))
+
+ # zero-out teacher_vals if there are no hints
+ teacher_vals *= (
+ 1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1))))
+
+ # we'll determine whether to do an update to memory based on whether
+ # memory was queried correctly
+ sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top])
+ incorrect_memory_lookup = tf.equal(0.0, tf.reduce_sum(sliced_hints, 1))
+
+ # loss based on triplet loss
+ teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha)
+ - self.alpha)
+
+ # prepare memory updates
+ update_keys = normalized_query
+ update_vals = intended_output
+
+ fetched_idxs = teacher_idxs # correctly fetched from memory
+ with tf.device(self.var_cache_device):
+ fetched_keys = tf.gather(self.mem_keys, fetched_idxs, name='fetched_keys')
+ fetched_vals = tf.gather(self.mem_vals, fetched_idxs, name='fetched_vals')
+
+ # do memory updates here
+ fetched_keys_upd = update_keys + fetched_keys # Momentum-like update
+ fetched_keys_upd = tf.nn.l2_normalize(fetched_keys_upd, dim=1)
+ # Randomize age a bit, e.g., to select different ones in parallel workers.
+ mem_age_with_noise = self.mem_age + tf.random_uniform(
+ [self.memory_size], - self.age_noise, self.age_noise)
+
+ _, oldest_idxs = tf.nn.top_k(mem_age_with_noise, k=batch_size, sorted=False)
+
+ with tf.control_dependencies([result]):
+ upd_idxs = tf.where(incorrect_memory_lookup,
+ oldest_idxs,
+ fetched_idxs)
+ # upd_idxs = tf.Print(upd_idxs, [upd_idxs], "UPD IDX", summarize=8)
+ upd_keys = tf.where(incorrect_memory_lookup,
+ update_keys,
+ fetched_keys_upd)
+ upd_vals = tf.where(incorrect_memory_lookup,
+ update_vals,
+ fetched_vals)
+
+ def make_update_op():
+ return self.make_update_op(upd_idxs, upd_keys, upd_vals,
+ batch_size, use_recent_idx, intended_output)
+
+ update_op = tf.cond(self.update_memory, make_update_op, tf.no_op)
+
+ with tf.control_dependencies([update_op]):
+ result = tf.identity(result)
+ mask = tf.identity(mask)
+ teacher_loss = tf.identity(teacher_loss)
+
+ return result, mask, tf.reduce_mean(teacher_loss)
+
+
+class LSHMemory(Memory):
+ """Memory employing locality sensitive hashing.
+
+ Note: Not fully tested.
+ """
+
+ def __init__(self, key_dim, memory_size, vocab_size,
+ choose_k=256, alpha=0.1, correct_in_top=1, age_noise=8.0,
+ var_cache_device='', nn_device='',
+ num_hashes=None, num_libraries=None):
+ super(LSHMemory, self).__init__(
+ key_dim, memory_size, vocab_size,
+ choose_k=choose_k, alpha=alpha, correct_in_top=1, age_noise=age_noise,
+ var_cache_device=var_cache_device, nn_device=nn_device)
+
+ self.num_libraries = num_libraries or int(self.choose_k ** 0.5)
+ self.num_per_hash_slot = max(1, self.choose_k // self.num_libraries)
+ self.num_hashes = (num_hashes or
+ int(np.log2(self.memory_size / self.num_per_hash_slot)))
+ self.num_hashes = min(max(self.num_hashes, 1), 20)
+ self.num_hash_slots = 2 ** self.num_hashes
+
+ # hashing vectors
+ self.hash_vecs = [
+ tf.get_variable(
+ 'hash_vecs%d' % i, [self.num_hashes, self.key_dim],
+ dtype=tf.float32, trainable=False,
+ initializer=tf.truncated_normal_initializer(0, 1))
+ for i in xrange(self.num_libraries)]
+
+ # map representing which hash slots map to which mem keys
+ self.hash_slots = [
+ tf.get_variable(
+ 'hash_slots%d' % i, [self.num_hash_slots, self.num_per_hash_slot],
+ dtype=tf.int32, trainable=False,
+ initializer=tf.random_uniform_initializer(maxval=self.memory_size,
+ dtype=tf.int32))
+ for i in xrange(self.num_libraries)]
+
+ def get(self): # not implemented
+ return self.mem_keys, self.mem_vals, self.mem_age, self.recent_idx
+
+ def set(self, k, v, a, r=None): # not implemented
+ return tf.group(
+ self.mem_keys.assign(k),
+ self.mem_vals.assign(v),
+ self.mem_age.assign(a),
+ (self.recent_idx.assign(r) if r is not None else tf.group()))
+
+ def clear(self):
+ return tf.variables_initializer([self.mem_keys, self.mem_vals, self.mem_age,
+ self.recent_idx] + self.hash_slots)
+
+ def get_hash_slots(self, query):
+ """Gets hashed-to buckets for batch of queries.
+
+ Args:
+ query: 2-d Tensor of query vectors.
+
+ Returns:
+ A list of hashed-to buckets for each hash function.
+ """
+
+ binary_hash = [
+ tf.less(tf.matmul(query, self.hash_vecs[i], transpose_b=True), 0)
+ for i in xrange(self.num_libraries)]
+ hash_slot_idxs = [
+ tf.reduce_sum(
+ tf.to_int32(binary_hash[i]) *
+ tf.constant([[2 ** i for i in xrange(self.num_hashes)]],
+ dtype=tf.int32), 1)
+ for i in xrange(self.num_libraries)]
+ return hash_slot_idxs
+
+ def get_hint_pool_idxs(self, normalized_query):
+ """Get small set of idxs to compute nearest neighbor queries on.
+
+ This is an expensive look-up on the whole memory that is used to
+ avoid more expensive operations later on.
+
+ Args:
+ normalized_query: A Tensor of shape [None, key_dim].
+
+ Returns:
+ A Tensor of shape [None, choose_k] of indices in memory
+ that are closest to the queries.
+
+ """
+ # get hash of query vecs
+ hash_slot_idxs = self.get_hash_slots(normalized_query)
+
+ # grab mem idxs in the hash slots
+ hint_pool_idxs = [
+ tf.maximum(tf.minimum(
+ tf.gather(self.hash_slots[i], idxs),
+ self.memory_size - 1), 0)
+ for i, idxs in enumerate(hash_slot_idxs)]
+
+ return tf.concat(axis=1, values=hint_pool_idxs)
+
+ def make_update_op(self, upd_idxs, upd_keys, upd_vals,
+ batch_size, use_recent_idx, intended_output):
+ """Function that creates all the update ops."""
+ base_update_op = super(LSHMemory, self).make_update_op(
+ upd_idxs, upd_keys, upd_vals,
+ batch_size, use_recent_idx, intended_output)
+
+ # compute hash slots to be updated
+ hash_slot_idxs = self.get_hash_slots(upd_keys)
+
+ # make updates
+ update_ops = []
+ with tf.control_dependencies([base_update_op]):
+ for i, slot_idxs in enumerate(hash_slot_idxs):
+ # for each slot, choose which entry to replace
+ entry_idx = tf.random_uniform([batch_size],
+ maxval=self.num_per_hash_slot,
+ dtype=tf.int32)
+ entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
+ dtype=tf.int32)
+ entry_add = (tf.expand_dims(upd_idxs, 1) *
+ tf.one_hot(entry_idx, self.num_per_hash_slot,
+ dtype=tf.int32))
+
+ mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
+ with tf.control_dependencies([mul_op]):
+ add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
+ update_ops.append(add_op)
+
+ return tf.group(*update_ops)
diff --git a/models/research/learning_to_remember_rare_events/model.py b/models/research/learning_to_remember_rare_events/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a6b460047fda3349c04d0e024c035f69a300461
--- /dev/null
+++ b/models/research/learning_to_remember_rare_events/model.py
@@ -0,0 +1,302 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+"""Model using memory component.
+
+The model embeds images using a standard CNN architecture.
+These embeddings are used as keys to the memory component,
+which returns nearest neighbors.
+"""
+
+import tensorflow as tf
+
+import memory
+
+FLAGS = tf.flags.FLAGS
+
+
+class BasicClassifier(object):
+
+ def __init__(self, output_dim):
+ self.output_dim = output_dim
+
+ def core_builder(self, memory_val, x, y):
+ del x, y
+ y_pred = memory_val
+ loss = 0.0
+
+ return loss, y_pred
+
+
+class LeNet(object):
+ """Standard CNN architecture."""
+
+ def __init__(self, image_size, num_channels, hidden_dim):
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.hidden_dim = hidden_dim
+ self.matrix_init = tf.truncated_normal_initializer(stddev=0.1)
+ self.vector_init = tf.constant_initializer(0.0)
+
+ def core_builder(self, x):
+ """Embeds x using standard CNN architecture.
+
+ Args:
+ x: Batch of images as a 2-d Tensor [batch_size, -1].
+
+ Returns:
+ A 2-d Tensor [batch_size, hidden_dim] of embedded images.
+ """
+
+ ch1 = 32 * 2 # number of channels in 1st layer
+ ch2 = 64 * 2 # number of channels in 2nd layer
+ conv1_weights = tf.get_variable('conv1_w',
+ [3, 3, self.num_channels, ch1],
+ initializer=self.matrix_init)
+ conv1_biases = tf.get_variable('conv1_b', [ch1],
+ initializer=self.vector_init)
+ conv1a_weights = tf.get_variable('conv1a_w',
+ [3, 3, ch1, ch1],
+ initializer=self.matrix_init)
+ conv1a_biases = tf.get_variable('conv1a_b', [ch1],
+ initializer=self.vector_init)
+
+ conv2_weights = tf.get_variable('conv2_w', [3, 3, ch1, ch2],
+ initializer=self.matrix_init)
+ conv2_biases = tf.get_variable('conv2_b', [ch2],
+ initializer=self.vector_init)
+ conv2a_weights = tf.get_variable('conv2a_w', [3, 3, ch2, ch2],
+ initializer=self.matrix_init)
+ conv2a_biases = tf.get_variable('conv2a_b', [ch2],
+ initializer=self.vector_init)
+
+ # fully connected
+ fc1_weights = tf.get_variable(
+ 'fc1_w', [self.image_size // 4 * self.image_size // 4 * ch2,
+ self.hidden_dim], initializer=self.matrix_init)
+ fc1_biases = tf.get_variable('fc1_b', [self.hidden_dim],
+ initializer=self.vector_init)
+
+ # define model
+ x = tf.reshape(x,
+ [-1, self.image_size, self.image_size, self.num_channels])
+ batch_size = tf.shape(x)[0]
+
+ conv1 = tf.nn.conv2d(x, conv1_weights,
+ strides=[1, 1, 1, 1], padding='SAME')
+ relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))
+ conv1 = tf.nn.conv2d(relu1, conv1a_weights,
+ strides=[1, 1, 1, 1], padding='SAME')
+ relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1a_biases))
+
+ pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1], padding='SAME')
+
+ conv2 = tf.nn.conv2d(pool1, conv2_weights,
+ strides=[1, 1, 1, 1], padding='SAME')
+ relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))
+ conv2 = tf.nn.conv2d(relu2, conv2a_weights,
+ strides=[1, 1, 1, 1], padding='SAME')
+ relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2a_biases))
+
+ pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1], padding='SAME')
+
+ reshape = tf.reshape(pool2, [batch_size, -1])
+ hidden = tf.matmul(reshape, fc1_weights) + fc1_biases
+
+ return hidden
+
+
+class Model(object):
+ """Model for coordinating between CNN embedder and Memory module."""
+
+ def __init__(self, input_dim, output_dim, rep_dim, memory_size, vocab_size,
+ learning_rate=0.0001, use_lsh=False):
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.rep_dim = rep_dim
+ self.memory_size = memory_size
+ self.vocab_size = vocab_size
+ self.learning_rate = learning_rate
+ self.use_lsh = use_lsh
+
+ self.embedder = self.get_embedder()
+ self.memory = self.get_memory()
+ self.classifier = self.get_classifier()
+
+ self.global_step = tf.train.get_or_create_global_step()
+
+ def get_embedder(self):
+ return LeNet(int(self.input_dim ** 0.5), 1, self.rep_dim)
+
+ def get_memory(self):
+ cls = memory.LSHMemory if self.use_lsh else memory.Memory
+ return cls(self.rep_dim, self.memory_size, self.vocab_size)
+
+ def get_classifier(self):
+ return BasicClassifier(self.output_dim)
+
+ def core_builder(self, x, y, keep_prob, use_recent_idx=True):
+ embeddings = self.embedder.core_builder(x)
+ if keep_prob < 1.0:
+ embeddings = tf.nn.dropout(embeddings, keep_prob)
+ memory_val, _, teacher_loss = self.memory.query(
+ embeddings, y, use_recent_idx=use_recent_idx)
+ loss, y_pred = self.classifier.core_builder(memory_val, x, y)
+
+ return loss + teacher_loss, y_pred
+
+ def train(self, x, y):
+ loss, _ = self.core_builder(x, y, keep_prob=0.3)
+ gradient_ops = self.training_ops(loss)
+ return loss, gradient_ops
+
+ def eval(self, x, y):
+ _, y_preds = self.core_builder(x, y, keep_prob=1.0,
+ use_recent_idx=False)
+ return y_preds
+
+ def get_xy_placeholders(self):
+ return (tf.placeholder(tf.float32, [None, self.input_dim]),
+ tf.placeholder(tf.int32, [None]))
+
+ def setup(self):
+ """Sets up all components of the computation graph."""
+
+ self.x, self.y = self.get_xy_placeholders()
+
+ # This context creates variables
+ with tf.variable_scope('core', reuse=None):
+ self.loss, self.gradient_ops = self.train(self.x, self.y)
+ # And this one re-uses them (thus the `reuse=True`)
+ with tf.variable_scope('core', reuse=True):
+ self.y_preds = self.eval(self.x, self.y)
+
+ def training_ops(self, loss):
+ opt = self.get_optimizer()
+ params = tf.trainable_variables()
+ gradients = tf.gradients(loss, params)
+ clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
+ return opt.apply_gradients(zip(clipped_gradients, params),
+ global_step=self.global_step)
+
+ def get_optimizer(self):
+ return tf.train.AdamOptimizer(learning_rate=self.learning_rate,
+ epsilon=1e-4)
+
+ def one_step(self, sess, x, y):
+ outputs = [self.loss, self.gradient_ops]
+ return sess.run(outputs, feed_dict={self.x: x, self.y: y})
+
+ def episode_step(self, sess, x, y, clear_memory=False):
+ """Performs training steps on episodic input.
+
+ Args:
+ sess: A Tensorflow Session.
+ x: A list of batches of images defining the episode.
+ y: A list of batches of labels corresponding to x.
+ clear_memory: Whether to clear the memory before the episode.
+
+ Returns:
+ List of losses the same length as the episode.
+ """
+
+ outputs = [self.loss, self.gradient_ops]
+
+ if clear_memory:
+ self.clear_memory(sess)
+
+ losses = []
+ for xx, yy in zip(x, y):
+ out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy})
+ loss = out[0]
+ losses.append(loss)
+
+ return losses
+
+ def predict(self, sess, x, y=None):
+ """Predict the labels on a single batch of examples.
+
+ Args:
+ sess: A Tensorflow Session.
+ x: A batch of images.
+ y: The labels for the images in x.
+ This allows for updating the memory.
+
+ Returns:
+ Predicted y.
+ """
+
+ # Storing current memory state to restore it after prediction
+ mem_keys, mem_vals, mem_age, _ = self.memory.get()
+ cur_memory = (
+ tf.identity(mem_keys),
+ tf.identity(mem_vals),
+ tf.identity(mem_age),
+ None,
+ )
+
+ outputs = [self.y_preds]
+ if y is None:
+ ret = sess.run(outputs, feed_dict={self.x: x})
+ else:
+ ret = sess.run(outputs, feed_dict={self.x: x, self.y: y})
+
+ # Restoring memory state
+ self.memory.set(*cur_memory)
+
+ return ret
+
+ def episode_predict(self, sess, x, y, clear_memory=False):
+ """Predict the labels on an episode of examples.
+
+ Args:
+ sess: A Tensorflow Session.
+ x: A list of batches of images.
+ y: A list of labels for the images in x.
+ This allows for updating the memory.
+ clear_memory: Whether to clear the memory before the episode.
+
+ Returns:
+ List of predicted y.
+ """
+
+ # Storing current memory state to restore it after prediction
+ mem_keys, mem_vals, mem_age, _ = self.memory.get()
+ cur_memory = (
+ tf.identity(mem_keys),
+ tf.identity(mem_vals),
+ tf.identity(mem_age),
+ None,
+ )
+
+ if clear_memory:
+ self.clear_memory(sess)
+
+ outputs = [self.y_preds]
+ y_preds = []
+ for xx, yy in zip(x, y):
+ out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy})
+ y_pred = out[0]
+ y_preds.append(y_pred)
+
+ # Restoring memory state
+ self.memory.set(*cur_memory)
+
+ return y_preds
+
+ def clear_memory(self, sess):
+ sess.run([self.memory.clear()])
diff --git a/models/research/learning_to_remember_rare_events/train.py b/models/research/learning_to_remember_rare_events/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5c6d06b5ee02e73128ee2b23f3b399d29b1e212
--- /dev/null
+++ b/models/research/learning_to_remember_rare_events/train.py
@@ -0,0 +1,242 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+r"""Script for training model.
+
+Simple command to get up and running:
+ python train.py --memory_size=8192 \
+ --batch_size=16 --validation_length=50 \
+ --episode_width=5 --episode_length=30
+"""
+
+import logging
+import os
+import random
+
+import numpy as np
+from six.moves import xrange
+import tensorflow as tf
+
+import data_utils
+import model
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_integer('rep_dim', 128,
+ 'dimension of keys to use in memory')
+tf.flags.DEFINE_integer('episode_length', 100, 'length of episode')
+tf.flags.DEFINE_integer('episode_width', 5,
+ 'number of distinct labels in a single episode')
+tf.flags.DEFINE_integer('memory_size', None, 'number of slots in memory. '
+ 'Leave as None to default to episode length')
+tf.flags.DEFINE_integer('batch_size', 16, 'batch size')
+tf.flags.DEFINE_integer('num_episodes', 100000, 'number of training episodes')
+tf.flags.DEFINE_integer('validation_frequency', 20,
+ 'every so many training episodes, '
+ 'assess validation accuracy')
+tf.flags.DEFINE_integer('validation_length', 10,
+ 'number of episodes to use to compute '
+ 'validation accuracy')
+tf.flags.DEFINE_integer('seed', 888, 'random seed for training sampling')
+tf.flags.DEFINE_string('save_dir', '', 'directory to save model to')
+tf.flags.DEFINE_bool('use_lsh', False,
+ 'use locality-sensitive hashing '
+ '(NOTE: not fully tested)')
+
+
+class Trainer(object):
+ """Class that takes care of training, validating, and checkpointing model."""
+
+ def __init__(self, train_data, valid_data, input_dim, output_dim=None):
+ self.train_data = train_data
+ self.valid_data = valid_data
+ self.input_dim = input_dim
+
+ self.rep_dim = FLAGS.rep_dim
+ self.episode_length = FLAGS.episode_length
+ self.episode_width = FLAGS.episode_width
+ self.batch_size = FLAGS.batch_size
+ self.memory_size = (self.episode_length * self.batch_size
+ if FLAGS.memory_size is None else FLAGS.memory_size)
+ self.use_lsh = FLAGS.use_lsh
+
+ self.output_dim = (output_dim if output_dim is not None
+ else self.episode_width)
+
+ def get_model(self):
+ # vocab size is the number of distinct values that
+ # could go into the memory key-value storage
+ vocab_size = self.episode_width * self.batch_size
+ return model.Model(
+ self.input_dim, self.output_dim, self.rep_dim, self.memory_size,
+ vocab_size, use_lsh=self.use_lsh)
+
+ def sample_episode_batch(self, data,
+ episode_length, episode_width, batch_size):
+ """Generates a random batch for training or validation.
+
+ Structures each element of the batch as an 'episode'.
+ Each episode contains episode_length examples and
+ episode_width distinct labels.
+
+ Args:
+ data: A dictionary mapping label to list of examples.
+ episode_length: Number of examples in each episode.
+ episode_width: Distinct number of labels in each episode.
+ batch_size: Batch size (number of episodes).
+
+ Returns:
+ A tuple (x, y) where x is a list of batches of examples
+ with size episode_length and y is a list of batches of labels.
+ """
+
+ episodes_x = [[] for _ in xrange(episode_length)]
+ episodes_y = [[] for _ in xrange(episode_length)]
+ assert len(data) >= episode_width
+ keys = data.keys()
+ for b in xrange(batch_size):
+ episode_labels = random.sample(keys, episode_width)
+ remainder = episode_length % episode_width
+ remainders = [0] * (episode_width - remainder) + [1] * remainder
+ episode_x = [
+ random.sample(data[lab],
+ r + (episode_length - remainder) // episode_width)
+ for lab, r in zip(episode_labels, remainders)]
+ episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
+ for i, xx in enumerate(episode_x)], [])
+ random.shuffle(episode)
+ # Arrange episode so that each distinct label is seen before moving to
+ # 2nd showing
+ episode.sort(key=lambda elem: elem[2])
+ assert len(episode) == episode_length
+ for i in xrange(episode_length):
+ episodes_x[i].append(episode[i][0])
+ episodes_y[i].append(episode[i][1] + b * episode_width)
+
+ return ([np.array(xx).astype('float32') for xx in episodes_x],
+ [np.array(yy).astype('int32') for yy in episodes_y])
+
+ def compute_correct(self, ys, y_preds):
+ return np.mean(np.equal(y_preds, np.array(ys)))
+
+ def individual_compute_correct(self, y, y_pred):
+ return y_pred == y
+
+ def run(self):
+ """Performs training.
+
+ Trains a model using episodic training.
+ Every so often, runs some evaluations on validation data.
+ """
+
+ train_data, valid_data = self.train_data, self.valid_data
+ input_dim, output_dim = self.input_dim, self.output_dim
+ rep_dim, episode_length = self.rep_dim, self.episode_length
+ episode_width, memory_size = self.episode_width, self.memory_size
+ batch_size = self.batch_size
+
+ train_size = len(train_data)
+ valid_size = len(valid_data)
+ logging.info('train_size (number of labels) %d', train_size)
+ logging.info('valid_size (number of labels) %d', valid_size)
+ logging.info('input_dim %d', input_dim)
+ logging.info('output_dim %d', output_dim)
+ logging.info('rep_dim %d', rep_dim)
+ logging.info('episode_length %d', episode_length)
+ logging.info('episode_width %d', episode_width)
+ logging.info('memory_size %d', memory_size)
+ logging.info('batch_size %d', batch_size)
+
+ assert all(len(v) >= float(episode_length) / episode_width
+ for v in train_data.values())
+ assert all(len(v) >= float(episode_length) / episode_width
+ for v in valid_data.values())
+
+ output_dim = episode_width
+ self.model = self.get_model()
+ self.model.setup()
+
+ sess = tf.Session()
+ sess.run(tf.global_variables_initializer())
+
+ saver = tf.train.Saver(max_to_keep=10)
+ ckpt = None
+ if FLAGS.save_dir:
+ ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir)
+ if ckpt and ckpt.model_checkpoint_path:
+ logging.info('restoring from %s', ckpt.model_checkpoint_path)
+ saver.restore(sess, ckpt.model_checkpoint_path)
+
+ logging.info('starting now')
+ losses = []
+ random.seed(FLAGS.seed)
+ np.random.seed(FLAGS.seed)
+ for i in xrange(FLAGS.num_episodes):
+ x, y = self.sample_episode_batch(
+ train_data, episode_length, episode_width, batch_size)
+ outputs = self.model.episode_step(sess, x, y, clear_memory=True)
+ loss = outputs
+ losses.append(loss)
+
+ if i % FLAGS.validation_frequency == 0:
+ logging.info('episode batch %d, avg train loss %f',
+ i, np.mean(losses))
+ losses = []
+
+ # validation
+ correct = []
+ num_shots = episode_length // episode_width
+ correct_by_shot = dict((k, []) for k in xrange(num_shots))
+ for _ in xrange(FLAGS.validation_length):
+ x, y = self.sample_episode_batch(
+ valid_data, episode_length, episode_width, 1)
+ outputs = self.model.episode_predict(
+ sess, x, y, clear_memory=True)
+ y_preds = outputs
+ correct.append(self.compute_correct(np.array(y), y_preds))
+
+ # compute per-shot accuracies
+ seen_counts = [0] * episode_width
+ # loop over episode steps
+ for yy, yy_preds in zip(y, y_preds):
+ # loop over batch examples
+ yyy, yyy_preds = int(yy[0]), int(yy_preds[0])
+ count = seen_counts[yyy % episode_width]
+ if count in correct_by_shot:
+ correct_by_shot[count].append(
+ self.individual_compute_correct(yyy, yyy_preds))
+ seen_counts[yyy % episode_width] = count + 1
+
+ logging.info('validation overall accuracy %f', np.mean(correct))
+ logging.info('%d-shot: %.3f, ' * num_shots,
+ *sum([[k, np.mean(correct_by_shot[k])]
+ for k in xrange(num_shots)], []))
+
+ if saver and FLAGS.save_dir:
+ saved_file = saver.save(sess,
+ os.path.join(FLAGS.save_dir, 'model.ckpt'),
+ global_step=self.model.global_step)
+ logging.info('saved model to %s', saved_file)
+
+
+def main(unused_argv):
+ train_data, valid_data = data_utils.get_data()
+ trainer = Trainer(train_data, valid_data, data_utils.IMAGE_NEW_SIZE ** 2)
+ trainer.run()
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ tf.app.run()
diff --git a/models/research/learning_unsupervised_learning/.gitignore b/models/research/learning_unsupervised_learning/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..0d20b6487c61e7d1bde93acf4a14b7a89083a16d
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/.gitignore
@@ -0,0 +1 @@
+*.pyc
diff --git a/models/research/learning_unsupervised_learning/README.md b/models/research/learning_unsupervised_learning/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0e38717f5de29df28959062889abeb1ce578feea
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/README.md
@@ -0,0 +1,40 @@
+
+
+
+
+# Learning Unsupervised Learning Rules
+This repository contains code and weights for the learned update rule
+presented in "Learning Unsupervised Learning Rules." At this time, this
+code can not meta-train the update rule.
+
+### Structure
+`run_eval.py` contains the main training loop. This constructs an op
+that runs one iteration of the learned update rule and assigns the
+results to variables. Additionally, it loads the weights from our
+pre-trained model.
+
+The base model and the update rule architecture definition can be found in
+`architectures/more_local_weight_update.py`. For a complete description
+of the model, see our [paper](https://arxiv.org/abs/1804.00222).
+
+### Dependencies
+[absl]([https://github.com/abseil/abseil-py), [tensorflow](https://tensorflow.org), [sonnet](https://github.com/deepmind/sonnet)
+
+### Usage
+
+First, download the [pre-trained optimizer model weights](https://storage.googleapis.com/learning_unsupervised_learning/200_tf_graph.zip) and extract it.
+
+```bash
+# move to the folder above this folder
+cd path_to/research/learning_unsupervised_learning/../
+
+# launch the eval script
+python -m learning_unsupervised_learning.run_eval \
+--train_log_dir="/tmp/learning_unsupervised_learning" \
+--checkpoint_dir="/path/to/downloaded/model/tf_graph_data.ckpt"
+```
+
+### Contact
+Luke Metz, Niru Maheswaranathan, Github: @lukemetz, @nirum. Email: {lmetz, nirum}@google.com
+
+
diff --git a/models/research/learning_unsupervised_learning/__init__.py b/models/research/learning_unsupervised_learning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/learning_unsupervised_learning/architectures/__init__.py b/models/research/learning_unsupervised_learning/architectures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..af9545f26da538aa986b19a96b6cfa2bc7459227
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/architectures/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import more_local_weight_update
diff --git a/models/research/learning_unsupervised_learning/architectures/common.py b/models/research/learning_unsupervised_learning/architectures/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..43a2d4f8965ecd337abd3a072a7ecb789df21910
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/architectures/common.py
@@ -0,0 +1,153 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sonnet as snt
+import tensorflow as tf
+import numpy as np
+import collections
+from learning_unsupervised_learning import utils
+
+from tensorflow.python.util import nest
+
+from learning_unsupervised_learning import variable_replace
+
+
+class LinearBatchNorm(snt.AbstractModule):
+ """Module that does a Linear layer then a BatchNorm followed by an activation fn"""
+ def __init__(self, size, activation_fn=tf.nn.relu, name="LinearBatchNorm"):
+ self.size = size
+ self.activation_fn = activation_fn
+ super(LinearBatchNorm, self).__init__(name=name)
+
+ def _build(self, x):
+ x = tf.to_float(x)
+ initializers={"w": tf.truncated_normal_initializer(stddev=0.01)}
+ lin = snt.Linear(self.size, use_bias=False, initializers=initializers)
+ z = lin(x)
+
+ scale = tf.constant(1., dtype=tf.float32)
+ offset = tf.get_variable(
+ "b",
+ shape=[1, z.shape.as_list()[1]],
+ initializer=tf.truncated_normal_initializer(stddev=0.1),
+ dtype=tf.float32
+ )
+
+ mean, var = tf.nn.moments(z, [0], keep_dims=True)
+ z = ((z - mean) * tf.rsqrt(var + 1e-6)) * scale + offset
+
+ x_p = self.activation_fn(z)
+
+ return z, x_p
+
+ # This needs to work by string name sadly due to how the variable replace
+ # works and would also work even if the custom getter approuch was used.
+ # This is verbose, but it should atleast be clear as to what is going on.
+ # TODO(lmetz) a better way to do this (the next 3 functions:
+ # _raw_name, w(), b() )
+ def _raw_name(self, var_name):
+ """Return just the name of the variable, not the scopes."""
+ return var_name.split("/")[-1].split(":")[0]
+
+
+ @property
+ def w(self):
+ var_list = snt.get_variables_in_module(self)
+ w = [x for x in var_list if self._raw_name(x.name) == "w"]
+ assert len(w) == 1
+ return w[0]
+
+ @property
+ def b(self):
+ var_list = snt.get_variables_in_module(self)
+ b = [x for x in var_list if self._raw_name(x.name) == "b"]
+ assert len(b) == 1
+ return b[0]
+
+
+
+class Linear(snt.AbstractModule):
+ def __init__(self, size, use_bias=True, init_const_mag=True):
+ self.size = size
+ self.use_bias = use_bias
+ self.init_const_mag = init_const_mag
+ super(Linear, self).__init__(name="commonLinear")
+
+ def _build(self, x):
+ if self.init_const_mag:
+ initializers={"w": tf.truncated_normal_initializer(stddev=0.01)}
+ else:
+ initializers={}
+ lin = snt.Linear(self.size, use_bias=self.use_bias, initializers=initializers)
+ z = lin(x)
+ return z
+
+ # This needs to work by string name sadly due to how the variable replace
+ # works and would also work even if the custom getter approuch was used.
+ # This is verbose, but it should atleast be clear as to what is going on.
+ # TODO(lmetz) a better way to do this (the next 3 functions:
+ # _raw_name, w(), b() )
+ def _raw_name(self, var_name):
+ """Return just the name of the variable, not the scopes."""
+ return var_name.split("/")[-1].split(":")[0]
+
+ @property
+ def w(self):
+ var_list = snt.get_variables_in_module(self)
+ if self.use_bias:
+ assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
+ else:
+ assert len(var_list) == 1, "Found not 1 but %d" % len(var_list)
+ w = [x for x in var_list if self._raw_name(x.name) == "w"]
+ assert len(w) == 1
+ return w[0]
+
+ @property
+ def b(self):
+ var_list = snt.get_variables_in_module(self)
+ assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
+ b = [x for x in var_list if self._raw_name(x.name) == "b"]
+ assert len(b) == 1
+ return b[0]
+
+
+def transformer_at_state(base_model, new_variables):
+ """Get the base_model that has been transformed to use the variables
+ in final_state.
+ Args:
+ base_model: snt.Module
+ Goes from batch to features
+ new_variables: list
+ New list of variables to use
+ Returns:
+ func: callable of same api as base_model.
+ """
+ assert not variable_replace.in_variable_replace_scope()
+
+ def _feature_transformer(input_data):
+ """Feature transformer at the end of training."""
+ initial_variables = base_model.get_variables()
+ replacement = collections.OrderedDict(
+ utils.eqzip(initial_variables, new_variables))
+ with variable_replace.variable_replace(replacement):
+ features = base_model(input_data)
+ return features
+
+ return _feature_transformer
diff --git a/models/research/learning_unsupervised_learning/architectures/more_local_weight_update.py b/models/research/learning_unsupervised_learning/architectures/more_local_weight_update.py
new file mode 100644
index 0000000000000000000000000000000000000000..117549af0f21f9e5148435b73f664a08013f8786
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/architectures/more_local_weight_update.py
@@ -0,0 +1,861 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import numpy as np
+import sonnet as snt
+import tensorflow as tf
+
+from learning_unsupervised_learning.architectures import common
+from learning_unsupervised_learning import optimizers
+from learning_unsupervised_learning import utils
+from learning_unsupervised_learning import summary_utils
+
+OptState = collections.namedtuple('OptState',
+ ['variables', 'opt_state', 'index'])
+
+BaseModelOutputs = collections.namedtuple(
+ 'BaseModelOutputs', ['xs', 'zs', 'mods', 'batch', 'backward_mods'])
+
+
+class GradChannelReadout(snt.AbstractModule):
+ """Perform a linear readout and reshape from input 3 tensor."""
+
+ def __init__(self,
+ num_grad_channels,
+ device,
+ perm=(2, 0, 1),
+ name='GradChannelReadout'):
+ """Args:
+
+ num_grad_channels: int
+ number of channels to readout to.
+ device: str or callable
+ devicwe to place weights.
+ perm: list or tuple
+ transpose applied.
+ """
+
+ self.num_grad_channels = num_grad_channels
+ self.device = device
+ self.perm = perm
+ super(GradChannelReadout, self).__init__(name=name)
+
+ def _build(self, h):
+ with tf.device(self.device):
+ mod = snt.Linear(self.num_grad_channels)
+ ret = snt.BatchApply(mod)(h)
+ # return as [num_grad_channels] x [bs] x [num units]
+ return tf.transpose(ret, perm=self.perm)
+
+
+def get_weight_stats(x, axis):
+ """ Compute weight statistics over the given axis.
+
+ Args:
+ x: tf.Tensor
+ a batch of activations.
+ axis: int
+ axis to perform statistics over.
+ Returns:
+ tf.Tensor
+ a 3-D tensor with statistics.
+ """
+ if x is None:
+ return []
+
+ stats = []
+ l1 = tf.reduce_mean(tf.abs(x), axis=axis)
+ l2 = tf.sqrt(tf.reduce_mean(x**2, axis=axis) + 1e-6)
+
+ mean, var = tf.nn.moments(x, [axis])
+ stats.extend([l1, l2, mean, tf.sqrt(var + 1e-8)])
+
+ stats = [tf.reshape(s, [-1, 1, 1]) for s in stats]
+
+ return stats
+
+
+class AddUnitBatchStatistics(snt.AbstractModule):
+ """Compute some number of statistics over units and concat them on."""
+
+ def __init__(self, name='AddUnitBatchStatistics'):
+ super(AddUnitBatchStatistics, self).__init__(name=name)
+
+ def _build(self, x):
+ # [channel, bs, 1]
+ output = x
+ for d in [0, 1]:
+ stats = []
+ l1 = tf.reduce_mean(tf.abs(x), axis=d, keepdims=True)
+ l2 = tf.sqrt(tf.reduce_mean(x**2, axis=d, keepdims=True) + 1e-6)
+
+ mean, var = tf.nn.moments(x, [d], keepdims=True)
+ stats.extend([l1, l2, mean, tf.sqrt(var + 1e-8)])
+
+ to_add = tf.concat(stats, axis=2) # [channels/1, units/1, stats]
+ output += snt.BatchApply(snt.Linear(x.shape.as_list()[2]))(to_add)
+ return output
+
+
+class ConcatUnitConv(snt.AbstractModule):
+ """Do a small number of convolutions over units and concat / add them on."""
+
+ def __init__(self, add=True):
+ self.add = add
+ super(ConcatUnitConv, self).__init__(name='ConcatUnitConv')
+
+ def _build(self, x):
+ # x is [units, bs, 1]
+ net = tf.transpose(x, [1, 0, 2]) # now [bs x units x 1]
+ channels = x.shape.as_list()[2]
+ mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
+ net = mod(net)
+ net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
+ net = tf.nn.relu(net)
+ mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
+ net = mod(net)
+ net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
+ net = tf.nn.relu(net)
+ to_concat = tf.transpose(net, [1, 0, 2])
+ if self.add:
+ return x + to_concat
+ else:
+ return tf.concat([x, to_concat], 2)
+
+
+class MoreLocalWeightUpdateProcess(snt.AbstractModule):
+
+ def __init__(
+ self,
+ remote_device,
+ local_device,
+ top_delta_size=64,
+ top_delta_layers=2,
+ compute_h_size=64,
+ compute_h_layers=1,
+ delta_dim=32,
+ num_grad_channels=4,
+ normalize_epsilon=1.,
+ ):
+ self.local_device = local_device
+ self.remote_device = remote_device
+ self.top_delta_size = top_delta_size
+ self.top_delta_layers = top_delta_layers
+ self.compute_h_size = compute_h_size
+ self.compute_h_layers = compute_h_layers
+ self.delta_dim = delta_dim
+ self.num_grad_channels = num_grad_channels
+ self.normalize_epsilon = normalize_epsilon,
+
+ with tf.device(local_device):
+ self.opt = optimizers.UnrollableGradientDescentRollingOptimizer(
+ learning_rate=1e-4)
+
+ # lazily initialized for readouts
+ self.readout_mods = {}
+
+ super(MoreLocalWeightUpdateProcess,
+ self).__init__(name='MoreLocalWeightUpdateProcess')
+
+ with tf.device(remote_device):
+ self()
+
+ def normalize(self, change_w, normalize_epsilon=None):
+ if normalize_epsilon is None:
+ normalize_epsilon = self.normalize_epsilon
+
+ # normalize the weights per receptive-field, rather than per-matrix
+ var = tf.reduce_mean(tf.square(change_w), axis=0, keepdims=True)
+ change_w = (change_w) / tf.sqrt(normalize_epsilon + var)
+ return change_w
+
+ def _build(self):
+ pass
+
+ @snt.reuse_variables
+ def compute_top_delta(self, z):
+ """ parameterization of topD. This converts the top level activation
+ to an error signal.
+ Args:
+ z: tf.Tensor
+ batch of final layer post activations
+ Returns
+ delta: tf.Tensor
+ the error signal
+ """
+ s_idx = 0
+ with tf.variable_scope('compute_top_delta'), tf.device(self.remote_device):
+ # typically this takes [BS, length, input_channels],
+ # We are applying this such that we convolve over the batch dimension.
+ act = tf.expand_dims(tf.transpose(z, [1, 0]), 2) # [channels, BS, 1]
+
+ mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[5])
+ act = mod(act)
+
+ act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
+ act = tf.nn.relu(act)
+
+ bs = act.shape.as_list()[0]
+ act = tf.transpose(act, [2, 1, 0])
+ act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
+ act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
+ act = tf.nn.relu(act)
+ act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
+ act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
+ act = tf.nn.relu(act)
+ act = tf.transpose(act, [2, 1, 0])
+
+ prev_act = act
+ for i in range(self.top_delta_layers):
+ mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[3])
+ act = mod(act)
+
+ act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
+ act = tf.nn.relu(act)
+
+ prev_act = act
+
+ mod = snt.Conv1D(output_channels=self.delta_dim, kernel_shape=[3])
+ act = mod(act)
+
+ # [bs, feature_channels, delta_channels]
+ act = tf.transpose(act, [1, 0, 2])
+ return act
+
+ @snt.reuse_variables
+ def compute_h(self,
+ x,
+ z,
+ d,
+ bias,
+ W_bot,
+ W_top,
+ compute_perc=1.0,
+ compute_units=None):
+ """z = [BS, n_units] a = [BS, n_units] b = [BS, n_units] d = [BS, n_units, delta_channels]
+
+ """
+
+ s_idx = 0
+ if compute_perc != 1.0:
+ assert compute_units is None
+
+ with tf.device(self.remote_device):
+ inp_feat = [x, z]
+ inp_feat = [tf.transpose(f, [1, 0]) for f in inp_feat]
+
+ units = x.shape.as_list()[1]
+ bs = x.shape.as_list()[0]
+
+ # add unit ID, to help the network differentiate units
+ id_theta = tf.linspace(0., (4) * np.pi, units)
+ assert bs is not None
+ id_theta_bs = tf.reshape(id_theta, [-1, 1]) * tf.ones([1, bs])
+ inp_feat += [tf.sin(id_theta_bs), tf.cos(id_theta_bs)]
+
+ # list of [units, BS, 1]
+ inp_feat = [tf.expand_dims(f, 2) for f in inp_feat]
+
+ d_trans = tf.transpose(d, [1, 0, 2])
+
+ if compute_perc != 1.0:
+ compute_units = int(compute_perc * inp_feat.shape.as_list()[0])
+
+ # add weight matrix statistics, both from above and below
+ w_stats_bot = get_weight_stats(W_bot, 0)
+ w_stats_top = get_weight_stats(W_top, 1)
+ w_stats = w_stats_bot + w_stats_top
+ if W_bot is None or W_top is None:
+ # if it's an edge layer (top or bottom), just duplicate the stats for
+ # the weight matrix that does exist
+ w_stats = w_stats + w_stats
+ w_stats = [tf.ones([1, x.shape[0], 1]) * ww for ww in w_stats]
+ # w_stats is a list, with entries with shape UNITS x 1 x channels
+
+ if compute_units is None:
+ inp_feat_in = inp_feat
+ d_trans_in = d_trans
+ w_stats_in = w_stats
+ bias_in = tf.transpose(bias)
+ else:
+ # only run on a subset of the activations.
+ mask = tf.random_uniform(
+ minval=0,
+ maxval=1,
+ dtype=tf.float32,
+ shape=inp_feat[0].shape.as_list()[0:1])
+ _, ind = tf.nn.top_k(mask, k=compute_units)
+ ind = tf.reshape(ind, [-1, 1])
+
+ inp_feat_in = [tf.gather_nd(xx, ind) for xx in inp_feat]
+ w_stats_in = [tf.gather_nd(xx, ind) for xx in w_stats]
+ d_trans_in = tf.gather_nd(d_trans, ind)
+ bias_in = tf.gather_nd(tf.transpose(bias), ind)
+
+ w_stats_in = tf.concat(w_stats_in, 2)
+ w_stats_in_norm = w_stats_in * tf.rsqrt(
+ tf.reduce_mean(w_stats_in**2) + 1e-6)
+
+ act = tf.concat(inp_feat_in + [d_trans_in], 2)
+ act = snt.BatchNorm(axis=[0, 1])(act, is_training=True)
+
+ bias_dense = tf.reshape(bias_in, [-1, 1, 1]) * tf.ones([1, bs, 1])
+ act = tf.concat([w_stats_in_norm, bias_dense, act], 2)
+
+ mod = snt.Conv1D(output_channels=self.compute_h_size, kernel_shape=[3])
+ act = mod(act)
+
+ act = snt.BatchNorm(axis=[0, 1])(act, is_training=True)
+ act = tf.nn.relu(act)
+
+ act2 = ConcatUnitConv()(act)
+ act = act2
+
+ prev_act = act
+ for i in range(self.compute_h_layers):
+ mod = snt.Conv1D(output_channels=self.compute_h_size, kernel_shape=[3])
+ act = mod(act)
+
+ act = snt.BatchNorm(axis=[0, 1])(act, is_training=True)
+ act = tf.nn.relu(act)
+
+ act = ConcatUnitConv()(act)
+
+ prev_act = act
+
+ h = act
+ if compute_units is not None:
+ shape = inp_feat[0].shape.as_list()[:1] + h.shape.as_list()[1:]
+ h = tf.scatter_nd(ind, h, shape=shape)
+
+ h = tf.transpose(h, [1, 0, 2]) # [bs, units, channels]
+
+ return h
+
+ ## wrappers to allow forward and backward to have different variables
+ @snt.reuse_variables
+ def merge_change_w_forward(self, change_w_terms, global_prefix='', prefix=''):
+ return self.merge_change_w(
+ change_w_terms, global_prefix=global_prefix, prefix=prefix)
+
+ @snt.reuse_variables
+ def merge_change_w_backward(self, change_w_terms, global_prefix='',
+ prefix=''):
+ return self.merge_change_w(
+ change_w_terms, global_prefix=global_prefix, prefix=prefix)
+
+ def merge_change_w(self, change_w_terms, global_prefix='', prefix=''):
+ with tf.device(
+ self.remote_device), tf.name_scope(global_prefix + '_merge_change_w'):
+ w_base = change_w_terms['w_base']
+
+ for kk in sorted(change_w_terms.keys()):
+ name = global_prefix + 'change_w_plane_%s' % kk
+ delta_w = change_w_terms[kk]
+ mean, var = tf.nn.moments(delta_w, [0, 1])
+ root_mean_square = tf.sqrt(tf.reduce_mean(delta_w**2) + 1e-6)
+
+ for kk in sorted(change_w_terms.keys()):
+ change_w_terms[kk] = self.normalize(change_w_terms[kk])
+
+ initializers = {
+ 'w': tf.constant_initializer(0.1),
+ 'b': tf.zeros_initializer()
+ }
+ mod = snt.Linear(
+ 1,
+ name=global_prefix + '_weight_readout_coeffs',
+ initializers=initializers)
+
+ change_w_terms_list = [
+ change_w_terms[kk] for kk in sorted(change_w_terms.keys())
+ ]
+ stack_terms = tf.stack(change_w_terms_list, axis=-1)
+ change_w = tf.squeeze(
+ snt.BatchApply(mod)(stack_terms), axis=-1) / len(change_w_terms)
+
+ # only allow perpendicular updates, or updates which grow length. don't
+ # allow length to decay towards zero.
+ ip = tf.reduce_mean(change_w * w_base)
+ # zero out any updates that shrink length
+ ip = tf.nn.relu(ip)
+ change_w -= w_base * ip
+ change_w /= tf.sqrt(len(change_w_terms) * 1.)
+
+ change_w = self.normalize(change_w)
+
+ # encourage the receptive field to not collapse to 0
+ change_w -= w_base / 7. # This is an arbitrary scale choice
+
+ return tf.identity(change_w)
+
+ @snt.reuse_variables
+ def bias_readout(self, h):
+ with tf.device(self.remote_device):
+ mod = snt.Linear(1, name='bias_readout')
+ ret = snt.BatchApply(mod)(h)
+ return tf.squeeze(ret, 2)
+
+ @snt.reuse_variables
+ def next_delta(self, z, h, d):
+ with tf.device(self.remote_device):
+ return d * tf.expand_dims(tf.nn.sigmoid(z), 2) + self.to_delta_size(h)
+
+ @utils.create_variables_in_class_scope
+ def get_readout_mod(self, name):
+ if name not in self.readout_mods:
+ self.readout_mods[name] = GradChannelReadout(
+ self.num_grad_channels, device=self.remote_device, name=name)
+
+ return self.readout_mods[name]
+
+ @utils.create_variables_in_class_scope
+ def low_rank_readout(self, name, h1, h2, psd=False):
+ BS = h1.shape.as_list()[0]
+ r_t = self.get_readout_mod(name + '_top')(h1)
+ if psd:
+ r_b = r_t
+ else:
+ r_b = self.get_readout_mod(name + '_bottom')(h2)
+ return tf.reduce_mean(tf.matmul(r_b, r_t, transpose_a=True), axis=0) / BS
+
+ @snt.reuse_variables
+ def to_delta_size(self, h):
+ with tf.device(self.remote_device):
+ mod = snt.Linear(self.delta_dim)
+ return snt.BatchApply(mod)(h)
+
+ @snt.reuse_variables
+ def initial_state(self, variables):
+ """The inner optimization state.
+
+ Args:
+ variables: list of tf.Variable
+ list of variables to get the initial state of.
+ Returns:
+ opt_state: OptState
+ """
+
+ with tf.device(self.local_device):
+ initial_opt_state = self.opt.get_state(variables)
+
+ return OptState(
+ variables=variables, opt_state=initial_opt_state, index=tf.constant(0))
+
+ @snt.reuse_variables
+ def compute_next_state(self, grads, learning_rate, cur_state,
+ cur_transformer):
+
+ summaries = []
+ with tf.device(self.local_device):
+ with tf.control_dependencies(summaries):
+ new_vars, new_state = self.opt.compute_updates(
+ cur_state.variables, grads, learning_rate, cur_state.opt_state)
+ pass
+
+ return OptState(
+ variables=tuple(new_vars),
+ opt_state=new_state,
+ index=cur_state.index + 1)
+
+ def assign_state(self, base_model, next_state):
+ var_ups = [
+ v.assign(nv) for v, nv in utils.eqzip(base_model.get_variables(),
+ next_state.variables)
+ ]
+
+ opt_ups = self.opt.assign_state(next_state.opt_state)
+
+ return tf.group(opt_ups, *var_ups)
+
+ def local_variables(self):
+ return list(self.opt.get_variables())
+
+ def remote_variables(self):
+ train = list(
+ snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
+ train += list(
+ snt.get_variables_in_module(self,
+ tf.GraphKeys.MOVING_AVERAGE_VARIABLES))
+ return train
+
+
+class MoreLocalWeightUpdateWLearner(snt.AbstractModule):
+ """The BaseModel that the UnsupervisedUpdateRule acts on.
+ """
+
+ def __init__(self,
+ remote_device,
+ local_device,
+ inner_size=128,
+ output_size=32,
+ n_layers=4,
+ shuffle_input=True,
+ activation_fn=tf.nn.relu,
+ identical_updates=True,
+ **kwargs):
+ self.local_device = local_device
+ self.remote_device = remote_device
+ self.inner_size = inner_size
+ self.n_layers = n_layers
+ self.shuffle_input = shuffle_input
+ self.activation_fn = activation_fn
+ self.identical_updates = identical_updates
+
+ self.output_size = output_size
+ if output_size == None:
+ self.output_size = inner_size
+
+ self.shuffle_ind = None
+
+ super(MoreLocalWeightUpdateWLearner, self).__init__(
+ name='LocalWeightUpdateWLearner', **kwargs)
+
+ @snt.reuse_variables
+ def get_shuffle_ind(self, size):
+ if self.shuffle_ind is None:
+ # put the shuffle in tf memory to make the eval jobs
+ # re-entrant.
+ shuffle_ind_val = np.random.permutation(size)
+ shuffle_ind = tf.get_variable(
+ name='shuffle_ind', dtype=tf.int64, initializer=shuffle_ind_val)
+ unshuffle_ind = tf.scatter_nd(
+ tf.reshape(shuffle_ind, [-1, 1]), tf.range(size), [size])
+
+ return shuffle_ind, unshuffle_ind
+
+ def _build(self, batch):
+ image = batch.image
+ x0 = snt.BatchFlatten()(image)
+ if self.shuffle_input:
+ size = x0.shape.as_list()[1]
+ shuffle_ind, unshuffle_ind = self.get_shuffle_ind(size)
+ x0 = tf.gather(x0, shuffle_ind, axis=1)
+
+ xs = [x0]
+ mods = []
+ zs = []
+ init = {}
+
+ for i in range(self.n_layers):
+ mod = common.LinearBatchNorm(
+ self.inner_size, activation_fn=self.activation_fn)
+ z, x = mod(xs[i])
+ xs.append(x)
+ zs.append(z)
+ mods.append(mod)
+
+ mod = common.LinearBatchNorm(
+ self.output_size, activation_fn=self.activation_fn)
+ z, x = mod(xs[-1])
+ mods.append(mod)
+
+ xs.append(x)
+ zs.append(z)
+
+ embedding_x = xs[-1]
+
+ # make a random set of backward mods
+ backward_mods = []
+ for i, (x, x_p1) in enumerate(zip(xs[0:-1], xs[1:])):
+ m = common.LinearBatchNorm(
+ x_p1.shape.as_list()[1], activation_fn=tf.identity)
+ _ = m(x)
+ backward_mods.append(m)
+
+ shape = image.shape.as_list()[1:4]
+
+ for mods_p, prefix in [(mods, 'forward'), (backward_mods, 'backward')]:
+ if self.shuffle_input:
+ unshuf_w = tf.gather(mods_p[0].w, unshuffle_ind, axis=0)
+ else:
+ unshuf_w = mods_p[0].w
+ img = summary_utils.first_layer_weight_image(unshuf_w, shape)
+ tf.summary.image(prefix + '_w0_receptive_field', img)
+
+ for i, m in enumerate(mods_p[0:]):
+ img = summary_utils.inner_layer_weight_image(m.w)
+ tf.summary.image(prefix + '_w%d' % (i + 1), img)
+
+ img = summary_utils.sorted_images(image, batch.label_onehot)
+ tf.summary.image('inputs', img)
+
+ # log out pre-activations and activations
+ for all_vis, base_name in [(xs, 'x'), (zs, 'z')]:
+ for i, x_vis in enumerate(all_vis):
+ img = summary_utils.activation_image(x_vis, batch.label_onehot)
+ tf.summary.image('%s%d' % (base_name, i), img)
+
+ embedding_x = tf.identity(embedding_x)
+
+ outputs = BaseModelOutputs(
+ xs=xs, zs=zs, mods=mods, batch=batch, backward_mods=backward_mods)
+
+ return embedding_x, outputs
+
+ def compute_next_h_d(self, meta_opt, w_bot, w_top, bias, x, z, d, backward_w):
+ """ Propogate error back down the network while computing hidden state.
+ """
+ if z is None:
+ z = x
+
+ h = meta_opt.compute_h(x, z, d, bias, w_bot,
+ w_top) # [bs x 60 x h_channels]
+
+ # compute the next d
+ delta = meta_opt.next_delta(z, h, d)
+
+ if backward_w is not None:
+
+ def delta_matmul(w, delta):
+ d = tf.transpose(delta, [0, 2, 1]) # [bs x delta_channels x n_units)
+ d = snt.BatchApply(lambda x: tf.matmul(x, w, transpose_b=True))(d)
+ d = tf.transpose(d, [0, 2, 1])
+ return d
+
+ # replace the "backward pass" with a random matrix.
+ d = delta_matmul(backward_w, delta) # [bs x 60 x delta_channels]
+ var = tf.reduce_mean(tf.square(d), [2], keepdims=True)
+ d = d * tf.rsqrt(1e-6 + var)
+
+ return h, d
+
+ def weight_change_for_layer(self, meta_opt, l_idx, w_base, b_base, upper_h,
+ lower_h, upper_x, lower_x, prefix, include_bias):
+ """Compute the change in weights for each layer.
+ This computes something roughly analagous to a gradient.
+ """
+ reduce_upper_h = upper_h
+ reduce_lower_h = lower_h
+
+ BS = lower_x.shape.as_list()[0]
+
+ change_w_terms = dict()
+
+ # initial weight value normalized
+ # normalize the weights per receptive-field, rather than per-matrix
+ weight_scale = tf.rsqrt(
+ tf.reduce_mean(w_base**2, axis=0, keepdims=True) + 1e-6)
+ w_base *= weight_scale
+
+ change_w_terms['w_base'] = w_base
+
+ # this will act to decay larger weights towards zero
+ change_w_terms['large_decay'] = w_base**2 * tf.sign(w_base)
+
+ # term based on activations
+ ux0 = upper_x - tf.reduce_mean(upper_x, axis=0, keepdims=True)
+ uxs0 = ux0 * tf.rsqrt(tf.reduce_mean(ux0**2, axis=0, keepdims=True) + 1e-6)
+ change_U = tf.matmul(uxs0, uxs0, transpose_a=True) / BS
+ change_U /= tf.sqrt(float(change_U.shape.as_list()[0]))
+
+ cw = tf.matmul(w_base, change_U)
+ cw_scale = tf.rsqrt(tf.reduce_mean(cw**2 + 1e-8))
+ cw *= cw_scale
+ change_w_terms['decorr_x'] = cw
+
+ # hebbian term
+ lx0 = lower_x - tf.reduce_mean(lower_x, axis=0, keepdims=True)
+ lxs0 = lx0 * tf.rsqrt(tf.reduce_mean(lx0**2, axis=0, keepdims=True) + 1e-6)
+ cw = tf.matmul(lxs0, uxs0, transpose_a=True) / BS
+ change_w_terms['hebb'] = -cw
+
+ # 0th order term
+ w_term = meta_opt.low_rank_readout(prefix + 'weight_readout_0', upper_h,
+ lower_h)
+ change_w_terms['0_order'] = w_term
+
+ # # rbf term (weight update scaled by distance from 0)
+ w_term = meta_opt.low_rank_readout(prefix + 'weight_readout_rbf',
+ reduce_upper_h, reduce_lower_h)
+ change_w_terms['rbf'] = tf.exp(-w_base**2) * w_term
+
+ # 1st order term (weight dependent update to weights)
+ w_term = meta_opt.low_rank_readout(prefix + 'weight_readout_1',
+ reduce_upper_h, reduce_lower_h)
+ change_w_terms['1_order'] = w_base * w_term
+
+ # more terms based on single layer readouts.
+ for update_type in ['lin', 'sqr']:
+ for h_source, h_source_name in [(reduce_upper_h, 'upper'),
+ (reduce_lower_h, 'lower')]:
+ structures = ['symm']
+ if update_type == 'lin' and h_source_name == 'upper':
+ structures += ['psd']
+ for structure in structures:
+ name = update_type + '_' + h_source_name + '_' + structure
+ if structure == 'symm':
+ change_U = meta_opt.low_rank_readout(prefix + name, h_source,
+ h_source)
+ change_U = (change_U + tf.transpose(change_U)) / tf.sqrt(2.)
+ change_U = tf.matrix_set_diag(change_U,
+ tf.zeros(
+ [change_U.shape.as_list()[0]]))
+ elif structure == 'psd':
+ change_U = meta_opt.low_rank_readout(
+ prefix + name, h_source, None, psd=True)
+ else:
+ assert False
+ change_U /= tf.sqrt(float(change_U.shape.as_list()[0]))
+
+ if update_type == 'lin':
+ sign_multiplier = tf.ones_like(w_base)
+ w_base_l = w_base
+ elif update_type == 'sqr':
+ sign_multiplier = tf.sign(w_base)
+ w_base_l = tf.sqrt(1. + w_base**2) - 1.
+
+ if h_source_name == 'upper':
+ cw = tf.matmul(w_base_l, change_U) # [N^l-1 x N^l]
+ elif h_source_name == 'lower':
+ cw = tf.matmul(change_U, w_base_l)
+ change_w_terms[name] = cw * sign_multiplier
+
+
+ if prefix == 'forward':
+ change_w = meta_opt.merge_change_w_forward(
+ change_w_terms, global_prefix=prefix, prefix='l%d' % l_idx)
+ elif prefix == 'backward':
+ change_w = meta_opt.merge_change_w_backward(
+ change_w_terms, global_prefix=prefix, prefix='l%d' % l_idx)
+ else:
+ assert (False)
+
+ if not include_bias:
+ return change_w
+
+ change_b = tf.reduce_mean(meta_opt.bias_readout(upper_h), [0])
+
+ # force nonlinearities to be exercised -- biases can't all be increased without bound
+ change_b_mean = tf.reduce_mean(change_b)
+ offset = -tf.nn.relu(-change_b_mean)
+ change_b -= offset
+
+ var = tf.reduce_mean(tf.square(change_b), [0], keepdims=True)
+ change_b = (change_b) / tf.sqrt(0.5 + var)
+ return change_w, change_b
+
+ def compute_next_state(self, outputs, meta_opt, previous_state):
+ zs = outputs.zs
+ xs = outputs.xs
+ batch = outputs.batch
+ mods = outputs.mods
+ backward_mods = outputs.backward_mods
+ variables = self.get_variables()
+
+ rev_mods = mods[::-1]
+ rev_backward_mods = backward_mods[::-1]
+ rev_xs = xs[::-1]
+ rev_zs = zs[::-1] + [None]
+
+ to_top = xs[-1]
+
+ # variables that change in the loop
+ hs = []
+ d = meta_opt.compute_top_delta(to_top) # [bs x 32 x delta_channels]
+
+ iterator = utils.eqzip(rev_backward_mods + [None], rev_mods + [None],
+ [None] + rev_mods, rev_xs, rev_zs)
+ for (backward_mod, lower_mod, upper_mod, x, z) in iterator:
+ w_bot = None
+ if not lower_mod is None:
+ w_bot = previous_state.variables[variables.index(lower_mod.w)]
+ w_top = None
+ if not upper_mod is None:
+ w_top = previous_state.variables[variables.index(upper_mod.w)]
+ backward_w = None
+ if backward_mod is not None:
+ backward_w = previous_state.variables[variables.index(backward_mod.w)]
+ if lower_mod is not None:
+ bias = previous_state.variables[variables.index(lower_mod.b)]
+ else:
+ bias = tf.zeros([x.shape[1]])
+
+ h, d = self.compute_next_h_d(
+ meta_opt=meta_opt,
+ w_bot=w_bot,
+ w_top=w_top,
+ bias=bias,
+ backward_w=backward_w,
+ x=x,
+ z=z,
+ d=d)
+ hs.append(h)
+
+ w_forward_var_idx = [variables.index(mod.w) for mod in rev_mods]
+ w_backward_var_idx = [variables.index(mod.w) for mod in rev_backward_mods]
+ b_var_idx = [variables.index(mod.b) for mod in rev_mods]
+
+ # storage location for outputs of below loop
+ grads = [None for _ in previous_state.variables]
+
+ # over-ride learning rate for perturbation variables
+ learning_rate = [None for _ in previous_state.variables]
+
+ # This is a map -- no state is shared cross loop
+ for l_idx, w_forward_idx, w_backward_idx, b_idx, upper_h, lower_h, lower_x, upper_x in utils.eqzip(
+ range(len(w_forward_var_idx)), w_forward_var_idx, w_backward_var_idx,
+ b_var_idx, hs[:-1], hs[1:], xs[::-1][1:], xs[::-1][:-1]):
+
+ b_base = previous_state.variables[b_idx]
+ change_w_forward, change_b = self.weight_change_for_layer(
+ meta_opt=meta_opt,
+ l_idx=l_idx,
+ w_base=previous_state.variables[w_forward_idx],
+ b_base=b_base,
+ upper_h=upper_h,
+ lower_h=lower_h,
+ upper_x=upper_x,
+ lower_x=lower_x,
+ prefix='forward',
+ include_bias=True)
+
+ if self.identical_updates:
+ change_w_backward = change_w_forward
+ else:
+ change_w_backward = self.weight_change_for_layer(
+ meta_opt=meta_opt,
+ l_idx=l_idx,
+ w_base=previous_state.variables[w_backward_idx],
+ b_base=b_base,
+ upper_h=upper_h,
+ lower_h=lower_h,
+ upper_x=upper_x,
+ lower_x=lower_x,
+ prefix='backward',
+ include_bias=False)
+
+ grads[w_forward_idx] = change_w_forward
+
+ grads[w_backward_idx] = change_w_backward
+
+ grads[b_idx] = change_b
+
+ cur_transformer = common.transformer_at_state(self,
+ previous_state.variables)
+ next_state = meta_opt.compute_next_state(
+ grads,
+ learning_rate=learning_rate,
+ cur_state=previous_state,
+ cur_transformer=lambda x: cur_transformer(x)[0])
+ return next_state
+
+ def initial_state(self, meta_opt):
+ return meta_opt.initial_state(self.get_variables())
diff --git a/models/research/learning_unsupervised_learning/datasets/__init__.py b/models/research/learning_unsupervised_learning/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9949cd96ca8f2fe1c39705a5ca8570de9cad5a66
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/datasets/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import mnist
diff --git a/models/research/learning_unsupervised_learning/datasets/common.py b/models/research/learning_unsupervised_learning/datasets/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..11f65ceab57a4114ca3876b3cb6eed86e2263745
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/datasets/common.py
@@ -0,0 +1,29 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import tensorflow as tf
+import numpy as np
+
+ImageLabelOnehot = collections.namedtuple('ImageLabelOnehot',
+ ['image', 'label', 'label_onehot'])
+ImageLabelOnehotRegression = collections.namedtuple(
+ "ImageLabelOnehotRegression",
+ ["image", "label", "label_onehot", "regression_target"])
diff --git a/models/research/learning_unsupervised_learning/datasets/mnist.py b/models/research/learning_unsupervised_learning/datasets/mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ee595d99ad2523042454f038b4665095f501caf
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/datasets/mnist.py
@@ -0,0 +1,74 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import sonnet as snt
+import tensorflow as tf
+from tensorflow.python.keras.datasets import mnist
+from learning_unsupervised_learning.datasets import common
+
+class Mnist(snt.AbstractModule):
+ def __init__(self, device, batch_size=128, name="Mnist"):
+ self.device = device
+ self.batch_size = batch_size
+
+ self._make_dataset()
+ self.iterator = None
+
+ super(Mnist, self).__init__(name=name)
+
+ def _make_dataset(self):
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
+
+ x_train = x_train.reshape(60000, 784)
+ x_test = x_test.reshape(10000, 784)
+
+ dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ dataset = dataset.repeat()
+ dataset = dataset.shuffle(self.batch_size * 3)
+ dataset = dataset.batch(self.batch_size)
+ def _map_fn(image, label):
+ image = tf.to_float(image) / 255.
+ label.set_shape([self.batch_size])
+ label = tf.cast(label, dtype=tf.int32)
+ label_onehot = tf.one_hot(label, 10)
+ image = tf.reshape(image, [self.batch_size, 28, 28, 1])
+ return common.ImageLabelOnehot(
+ image=image, label=label, label_onehot=label_onehot)
+
+ self.dataset = dataset.map(_map_fn)
+
+ def _build(self):
+ if self.iterator is None:
+ self.iterator = self.dataset.make_one_shot_iterator()
+ batch = self.iterator.get_next()
+ [b.set_shape([self.batch_size] + b.shape.as_list()[1:]) for b in batch]
+ return batch
+
+
+class TinyMnist(Mnist):
+ def __init__(self, *args, **kwargs):
+ kwargs.setdefault("name", "TinyMnist")
+ super(TinyMnist, self).__init__(*args, **kwargs)
+
+ def _make_dataset(self):
+ super(TinyMnist, self)._make_dataset()
+
+ def _map_fn(batch):
+ new_img = tf.image.resize_images(batch.image, [14, 14])
+ return common.ImageLabelOnehot(
+ image=new_img, label=batch.label, label_onehot=batch.label_onehot)
+
+ self.dataset = self.dataset.map(_map_fn)
diff --git a/models/research/learning_unsupervised_learning/evaluation.py b/models/research/learning_unsupervised_learning/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ec40e99a672f9420200653b92818374e0e84d78
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/evaluation.py
@@ -0,0 +1,76 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+"""Evaluation job.
+
+This sits on the side and performs evaluation on a saved model.
+This is a separate process for ease of use and stability of numbers.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from learning_unsupervised_learning import utils
+
+
+def construct_evaluation_graph(theta_process_fn=None,
+ w_learner_fn=None,
+ dataset_fn=None,
+ meta_objectives=None,
+ ):
+ """Construct the evaluation graph.
+ """
+ if meta_objectives is None:
+ meta_objectives = []
+
+ tf.train.create_global_step()
+
+ local_device = ""
+ remote_device = ""
+
+ meta_opt = theta_process_fn(
+ remote_device=remote_device, local_device=local_device)
+
+ base_model = w_learner_fn(
+ remote_device=remote_device, local_device=local_device)
+
+ train_dataset = dataset_fn(device=local_device)
+
+ # construct variables
+ x, outputs = base_model(train_dataset())
+ initial_state = base_model.initial_state(meta_opt, max_steps=10)
+ next_state = base_model.compute_next_state(outputs, meta_opt, initial_state)
+ with utils.state_barrier_context(next_state):
+ train_one_step_op = meta_opt.assign_state(base_model, next_state)
+
+ meta_objs = []
+ for meta_obj_fn in meta_objectives:
+ meta_obj = meta_obj_fn(local_device="", remote_device="")
+ meta_objs.append(meta_obj)
+ J = meta_obj(train_dataset, lambda x: base_model(x)[0])
+ tf.summary.scalar(str(meta_obj.__class__.__name__)+"_J", tf.reduce_mean(J))
+
+ # TODO(lmetz) this is kinda error prone.
+ # We should share the construction of the global variables across train and
+ # make sure both sets of savable variables are the same
+ checkpoint_vars = meta_opt.remote_variables() + [tf.train.get_global_step()]
+ for meta_obj in meta_objs:
+ checkpoint_vars.extend(meta_obj.remote_variables())
+
+ return checkpoint_vars, train_one_step_op, (base_model, train_dataset)
diff --git a/models/research/learning_unsupervised_learning/meta_objective/__init__.py b/models/research/learning_unsupervised_learning/meta_objective/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..54c46145e3c3a9f19110f92197f1d3cb2afe31fb
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/meta_objective/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import sklearn
+import linear_regression
diff --git a/models/research/learning_unsupervised_learning/meta_objective/linear_regression.py b/models/research/learning_unsupervised_learning/meta_objective/linear_regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..b49fc2529ccba08a6b47019cd7546f8fb409b28b
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/meta_objective/linear_regression.py
@@ -0,0 +1,258 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+
+"""Closed form linear regression.
+
+Can be differentiated through.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import numpy as np
+import sonnet as snt
+import tensorflow as tf
+
+from learning_unsupervised_learning import utils
+from learning_unsupervised_learning import variable_replace
+
+
+def solve_ridge(x, y, ridge_factor):
+ with tf.name_scope("solve_ridge"):
+ # Added a column of ones to the end of the feature matrix for bias
+ A = tf.concat([x, tf.ones((x.shape.as_list()[0], 1))], axis=1)
+
+ # Analytic solution for the ridge regression loss
+ inv_target = tf.matmul(A, A, transpose_a=True)
+ np_diag_penalty = ridge_factor * np.ones(
+ A.shape.as_list()[1], dtype="float32")
+ # Remove penalty on bias component of weights
+ np_diag_penalty[-1] = 0.
+ diag_penalty = tf.constant(np_diag_penalty)
+ inv_target += tf.diag(diag_penalty)
+
+ inv = tf.matrix_inverse(inv_target)
+ w = tf.matmul(inv, tf.matmul(A, y, transpose_a=True))
+ return w
+
+
+class LinearRegressionMetaObjective(snt.AbstractModule):
+ """A meta objective based on training Ridge Regression with analytic solution.
+
+ This is used to evaluate the performance of a given feature set trained in
+ some other manner.
+ """
+
+ def __init__(self,
+ local_device=None,
+ remote_device=None,
+ zero_one_labels=True,
+ normalize_y_hat=True,
+ normalize_act=False,
+ averages=1,
+ ridge_factor=0.1,
+ center_y=True,
+ hinge_loss=False,
+ samples_per_class=10,
+ test_train_scalar=1.0,
+ ):
+ self._local_device = local_device
+ self._remote_device = remote_device
+ self.zero_one_labels = zero_one_labels
+ self.normalize_y_hat = normalize_y_hat
+ self.normalize_act = normalize_act
+ self.ridge_factor = ridge_factor
+ self.averages = averages
+ self.samples_per_class = samples_per_class
+ self.center_y=center_y
+ self.test_train_scalar=test_train_scalar
+ self.hinge_loss = hinge_loss
+
+ self.dataset_map = {}
+
+ super(LinearRegressionMetaObjective,
+ self).__init__(name="LinearRegressionMetaObjective")
+
+ def _build(self, dataset, feature_transformer):
+ if self.samples_per_class is not None:
+ if dataset not in self.dataset_map:
+ # datasets are outside of frames from while loops
+ with tf.control_dependencies(None):
+ self.dataset_map[dataset] = utils.sample_n_per_class(
+ dataset, self.samples_per_class)
+
+ dataset = self.dataset_map[dataset]
+
+ stats = collections.defaultdict(list)
+ losses = []
+ # TODO(lmetz) move this to ingraph control flow?
+ for _ in xrange(self.averages):
+ loss, stat = self._build_once(dataset, feature_transformer)
+ losses.append(loss)
+ for k, v in stat.items():
+ stats[k].append(v)
+ stats = {k: tf.add_n(v) / float(len(v)) for k, v in stats.items()}
+
+ summary_updates = []
+ for k, v in stats.items():
+ tf.summary.scalar(k, v)
+
+ with tf.control_dependencies(summary_updates):
+ return tf.add_n(losses) / float(len(losses))
+
+ def _build_once(self, dataset, feature_transformer):
+ with tf.device(self._local_device):
+ batch = dataset()
+ num_classes = batch.label_onehot.shape.as_list()[1]
+
+ regression_mod = snt.Linear(num_classes)
+
+ if self.normalize_act:
+
+ def normalize_transformer(x):
+ unnorm_x = feature_transformer(x)
+ return tf.nn.l2_normalize(unnorm_x, 0)
+
+ feature_transformer_wrap = normalize_transformer
+ else:
+ feature_transformer_wrap = feature_transformer
+
+ # construct the variables of the right shape in the sonnet module by
+ # calling a forward pass through the regressor.
+ with utils.assert_no_new_variables():
+ dummy_features = feature_transformer_wrap(batch)
+ regression_mod(dummy_features)
+ reg_w = regression_mod.w
+ reg_b = regression_mod.b
+
+ batch_test = dataset()
+ all_batch = utils.structure_map_multi(lambda x: tf.concat(x, 0), [batch, batch_test])
+ #all_batch = tf.concat([batch, batch_test], 0)
+ # Grab a new batch of data from the dataset.
+ features = feature_transformer_wrap(all_batch)
+ features, features_test = utils.structure_map_split(lambda x: tf.split(x, 2, axis=0), features)
+
+ def center_y(y):
+ y -= tf.reduce_mean(y)
+ y *= tf.rsqrt(tf.reduce_mean(tf.reduce_sum(y**2, axis=[1], keep_dims=True)))
+ return y
+ def get_y_vec(batch):
+ y_pieces = []
+ if hasattr(batch, "label_onehot"):
+ if self.zero_one_labels:
+ y_pieces += [batch.label_onehot]
+ else:
+ y_pieces += [2. * batch.label_onehot - 1.]
+ if hasattr(batch, "regression_target"):
+ y_pieces += [batch.regression_target]
+ y = tf.concat(y_pieces, 1)
+ if self.center_y:
+ y = center_y(y)
+ return y
+
+ y_train = get_y_vec(batch)
+
+ w = solve_ridge(features, y_train, self.ridge_factor)
+
+ # Generate features from another batch to evaluate loss on the validation
+ # set. This provide a less overfit signal to the learned optimizer.
+ y_test = get_y_vec(batch_test)
+
+ def compute_logit(features):
+ # We have updated the classifier mod in previous steps, we need to
+ # substitute out those variables to get new values.
+ replacement = collections.OrderedDict([(reg_w, w[:-1]), (reg_b, w[-1])])
+ with variable_replace.variable_replace(replacement):
+ logits = regression_mod(features)
+
+ return logits
+
+ batch_size = y_train.shape.as_list()[0]
+
+ logit_train = compute_logit(features)
+ logit_test_unnorm = compute_logit(features_test)
+ if self.normalize_y_hat:
+ logit_test = logit_test_unnorm / tf.sqrt(
+ tf.reduce_sum(logit_test_unnorm**2, axis=[1], keep_dims=True))
+ else:
+ logit_test = logit_test_unnorm
+
+ stats = {}
+
+ if self.hinge_loss:
+ # slightly closer to the true classification loss
+ # any distance smaller than 1 is guaranteed to map to the correct class
+ mse_test = tf.reduce_sum(tf.nn.relu(tf.reduce_sum(tf.square(logit_test - y_test), axis=1)-1.)) / batch_size
+ else:
+ mse_test = tf.reduce_sum(tf.square(logit_test - y_test)) / batch_size
+
+ stats["mse_test"] = mse_test
+
+ mse_train = tf.reduce_sum(tf.square(logit_train - y_train)) / batch_size
+ stats["mse_train"] = mse_train
+
+ is_correct_test = tf.equal(tf.argmax(logit_test, 1), tf.argmax(y_test, 1))
+ accuracy_test = tf.reduce_mean(tf.cast(is_correct_test, tf.float32))
+ stats["accuracy_test"] = accuracy_test
+
+ def test_confusion_fn():
+ test_confusion = tf.confusion_matrix(tf.argmax(y_test, 1), tf.argmax(logit_test, 1))
+ test_confusion = tf.to_float(test_confusion) / tf.constant((logit_test.shape.as_list()[0] / float(logit_test.shape.as_list()[1])), dtype=tf.float32)
+ test_confusion = tf.expand_dims(tf.expand_dims(test_confusion, 0), 3)
+ return test_confusion
+ tf.summary.image("test_confusion", test_confusion_fn())
+
+ def train_confusion_fn():
+ train_confusion = tf.confusion_matrix(tf.argmax(y_train, 1), tf.argmax(logit_train, 1))
+ train_confusion = tf.to_float(train_confusion) / tf.constant((logit_train.shape.as_list()[0] / float(logit_train.shape.as_list()[1])), dtype=tf.float32)
+ train_confusion = tf.expand_dims(tf.expand_dims(train_confusion, 0), 3)
+ return train_confusion
+ tf.summary.image("train_confusion", train_confusion_fn())
+
+ is_correct = tf.equal(tf.argmax(logit_train, 1), tf.argmax(y_train, 1))
+ accuracy_train = tf.reduce_mean(tf.cast(is_correct, tf.float32))
+ stats["accuracy_train"] = accuracy_train
+
+ reg = self.ridge_factor * tf.reduce_sum(tf.square(w[:-1])) / batch_size
+ stats["ridge_component"] = reg
+
+ stats["total_loss"] = mse_test + reg
+
+ loss_to_train_at = (reg+ mse_test) * self.test_train_scalar + (mse_train + reg)*(1 - self.test_train_scalar)
+
+ loss_to_train_at = tf.identity(loss_to_train_at)
+
+ # Minimizing the test loss should not require regurization because the
+ # metaobjective is solved for the training loss
+ return loss_to_train_at, stats
+
+ def local_variables(self):
+ """List of variables that need to be updated for each evaluation.
+
+ These variables should not be stored on a parameter server and
+ should be reset every computation of a meta_objective loss.
+
+ Returns:
+ vars: list of tf.Variable
+ """
+ return list(
+ snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
+
+ def remote_variables(self):
+ return []
diff --git a/models/research/learning_unsupervised_learning/meta_objective/sklearn.py b/models/research/learning_unsupervised_learning/meta_objective/sklearn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f1f2d59102c511fd42ad323c32ab1709bd60c90
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/meta_objective/sklearn.py
@@ -0,0 +1,167 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+"""
+
+Can NOT be differentiated through.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import numpy as np
+import sonnet as snt
+import tensorflow as tf
+from tensorflow.python.framework import function
+
+from learning_unsupervised_learning import utils
+
+from learning_unsupervised_learning.meta_objective import utils as meta_obj_utils
+
+from sklearn import svm
+from sklearn import linear_model
+
+
+def build_fit(device, model_fn, num_classes, probs=True):
+
+ def _py_fit_predict(trX, trY, teX):
+ assert len(np.unique(trY)) == num_classes
+ model = model_fn()
+ model.fit(trX, trY)
+ trP = model.predict(trX)
+ teP = model.predict(teX)
+ if probs:
+ teP_probs = model.predict_log_proba(teX)
+ return trP.astype(np.int64), teP.astype(np.int64), teP_probs.astype(
+ np.float32)
+ else:
+ teP = model.predict(teX)
+ return trP.astype(np.int64), teP.astype(np.int64)
+
+ def return_fn(trX, trY, teX):
+ with tf.device(device):
+ with tf.device("/cpu:0"):
+ if probs:
+ return tf.py_func(
+ _py_fit_predict,
+ [tf.identity(trX),
+ tf.identity(trY),
+ tf.identity(teX)], [tf.int64, tf.int64, tf.float32])
+ else:
+ return tf.py_func(
+ _py_fit_predict,
+ [tf.identity(trX),
+ tf.identity(trY),
+ tf.identity(teX)], [tf.int64, tf.int64])
+
+ return return_fn
+
+
+class SKLearn(meta_obj_utils.MultiTrialMetaObjective):
+
+ def __init__(
+ self,
+ local_device=None,
+ remote_device=None,
+ averages=1,
+ samples_per_class=10,
+ probs=False,
+ stddev=0.01,
+ n_samples=10,
+ name="SKLearn",
+ ):
+ self._local_device = local_device
+ self._remote_device = remote_device
+ self.name = name
+ self.probs = probs
+ self.n_samples = n_samples
+ self.stddev = stddev
+
+ super(SKLearn, self).__init__(
+ name=name, samples_per_class=samples_per_class, averages=averages)
+
+ def _get_model(self):
+ raise NotImplemented()
+
+ def _build_once(self, dataset, feature_transformer):
+ with tf.device(self._local_device):
+ tr_batch = dataset()
+ te_batch = dataset()
+ num_classes = tr_batch.label_onehot.shape.as_list()[1]
+ all_batch = utils.structure_map_multi(lambda x: tf.concat(x, 0),
+ [tr_batch, te_batch])
+ features = feature_transformer(all_batch)
+ trX, teX = utils.structure_map_split(lambda x: tf.split(x, 2, axis=0),
+ features)
+ trY = tf.to_int64(tr_batch.label)
+ trY_onehot = tf.to_int32(tr_batch.label_onehot)
+ teY = tf.to_int64(te_batch.label)
+ teY_shape = teY.shape.as_list()
+
+ def blackbox((trX, trY, teX, teY)):
+ trY = tf.to_int32(tf.rint(trY))
+ teY = tf.to_int32(tf.rint(teY))
+ tf_fn = build_fit(
+ self._local_device,
+ self._get_model,
+ num_classes=num_classes,
+ probs=self.probs)
+ if self.probs:
+ trP, teP, teP_probs = tf_fn(trX, trY, teX)
+ else:
+ trP, teP = tf_fn(trX, trY, teX)
+
+ teY.set_shape(teY_shape)
+ if self.probs:
+ onehot = tf.one_hot(teY, num_classes)
+ crossent = -tf.reduce_sum(onehot * teP_probs, [1])
+ return tf.reduce_mean(crossent)
+ else:
+ # use error rate as the loss if no surrogate is avalible.
+ return 1 - tf.reduce_mean(
+ tf.to_float(tf.equal(teY, tf.to_int32(teP))))
+
+ test_loss = blackbox((trX, tf.to_float(trY), teX, tf.to_float(teY)))
+
+ stats = {}
+
+ tf_fn = build_fit(
+ self._local_device,
+ self._get_model,
+ num_classes=num_classes,
+ probs=self.probs)
+ if self.probs:
+ trP, teP, teP_probs = tf_fn(trX, trY, teX)
+ else:
+ trP, teP = tf_fn(trX, trY, teX)
+ stats["%s/accuracy_train" % self.name] = tf.reduce_mean(
+ tf.to_float(tf.equal(tf.to_int32(trY), tf.to_int32(trP))))
+ stats["%s/accuracy_test" % self.name] = tf.reduce_mean(
+ tf.to_float(tf.equal(tf.to_int32(teY), tf.to_int32(teP))))
+ stats["%s/test_loss" % self.name] = test_loss
+ return test_loss, stats
+
+
+class LogisticRegression(SKLearn):
+
+ def __init__(self, C=1.0, name="LogisticRegression", probs=True, **kwargs):
+ self.C = C
+ super(LogisticRegression, self).__init__(name=name, probs=probs, **kwargs)
+
+ def _get_model(self):
+ return linear_model.LogisticRegression(C=self.C)
diff --git a/models/research/learning_unsupervised_learning/meta_objective/utils.py b/models/research/learning_unsupervised_learning/meta_objective/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a29197d1d0cb7f0fdcebac3980027640651f185b
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/meta_objective/utils.py
@@ -0,0 +1,78 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import numpy as np
+import sonnet as snt
+import tensorflow as tf
+
+from learning_unsupervised_learning import optimizers
+from learning_unsupervised_learning import utils
+from learning_unsupervised_learning import summary_utils
+from learning_unsupervised_learning import variable_replace
+
+class MultiTrialMetaObjective(snt.AbstractModule):
+ def __init__(self, samples_per_class, averages, **kwargs):
+ self.samples_per_class = samples_per_class
+ self.averages = averages
+ self.dataset_map = {}
+
+ super(MultiTrialMetaObjective,
+ self).__init__(**kwargs)
+
+ def _build(self, dataset, feature_transformer):
+ if self.samples_per_class is not None:
+ if dataset not in self.dataset_map:
+ # datasets are outside of frames from while loops
+ with tf.control_dependencies(None):
+ self.dataset_map[dataset] = utils.sample_n_per_class(
+ dataset, self.samples_per_class)
+
+ dataset = self.dataset_map[dataset]
+
+ stats = collections.defaultdict(list)
+ losses = []
+ # TODO(lmetz) move this to ingraph control flow?
+ for _ in xrange(self.averages):
+ loss, stat = self._build_once(dataset, feature_transformer)
+ losses.append(loss)
+ for k, v in stat.items():
+ stats[k].append(v)
+ stats = {k: tf.add_n(v) / float(len(v)) for k, v in stats.items()}
+
+ for k, v in stats.items():
+ tf.summary.scalar(k, v)
+
+ return tf.add_n(losses) / float(len(losses))
+
+ def local_variables(self):
+ """List of variables that need to be updated for each evaluation.
+
+ These variables should not be stored on a parameter server and
+ should be reset every computation of a meta_objective loss.
+
+ Returns:
+ vars: list of tf.Variable
+ """
+ return list(
+ snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
+
+ def remote_variables(self):
+ return []
diff --git a/models/research/learning_unsupervised_learning/optimizers.py b/models/research/learning_unsupervised_learning/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..02c6106b19d1255907beb0ade07c46c5b065f701
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/optimizers.py
@@ -0,0 +1,133 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+
+"""Optimizers for use in unrolled optimization.
+
+These optimizers contain a compute_updates function and its own ability to keep
+track of internal state.
+These functions can be used with a tf.while_loop to perform multiple training
+steps per sess.run.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import tensorflow as tf
+import sonnet as snt
+
+from learning_unsupervised_learning import utils
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class UnrollableOptimizer(snt.AbstractModule):
+ """Interface for optimizers that can be used in unrolled computation.
+ apply_gradients is derrived from compute_update and assign_state.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(UnrollableOptimizer, self).__init__(*args, **kwargs)
+ self()
+
+ @abc.abstractmethod
+ def compute_updates(self, xs, gs, state=None):
+ """Compute next step updates for a given variable list and state.
+
+ Args:
+ xs: list of tensors
+ The "variables" to perform an update on.
+ Note these must match the same order for which get_state was originally
+ called.
+ gs: list of tensors
+ Gradients of `xs` with respect to some loss.
+ state: Any
+ Optimizer specific state to keep track of accumulators such as momentum
+ terms
+ """
+ raise NotImplementedError()
+
+ def _build(self):
+ pass
+
+ @abc.abstractmethod
+ def get_state(self, var_list):
+ """Get the state value associated with a list of tf.Variables.
+
+ This state is commonly going to be a NamedTuple that contains some
+ mapping between variables and the state associated with those variables.
+ This state could be a moving momentum variable tracked by the optimizer.
+
+ Args:
+ var_list: list of tf.Variable
+ Returns:
+ state: Any
+ Optimizer specific state
+ """
+ raise NotImplementedError()
+
+ def assign_state(self, state):
+ """Assigns the state to the optimizers internal variables.
+
+ Args:
+ state: Any
+ Returns:
+ op: tf.Operation
+ The operation that performs the assignment.
+ """
+ raise NotImplementedError()
+
+ def apply_gradients(self, grad_vars):
+ gradients, variables = zip(*grad_vars)
+ state = self.get_state(variables)
+ new_vars, new_state = self.compute_updates(variables, gradients, state)
+ assign_op = self.assign_state(new_state)
+ op = utils.assign_variables(variables, new_vars)
+ return tf.group(assign_op, op, name="apply_gradients")
+
+
+class UnrollableGradientDescentRollingOptimizer(UnrollableOptimizer):
+
+ def __init__(self,
+ learning_rate,
+ name="UnrollableGradientDescentRollingOptimizer"):
+ self.learning_rate = learning_rate
+ super(UnrollableGradientDescentRollingOptimizer, self).__init__(name=name)
+
+
+ def compute_updates(self, xs, gs, learning_rates, state):
+ new_vars = []
+ for x, g, lr in utils.eqzip(xs, gs, learning_rates):
+ if lr is None:
+ lr = self.learning_rate
+ if g is not None:
+ new_vars.append((x * (1 - lr) - g * lr))
+ else:
+ new_vars.append(x)
+ return new_vars, state
+
+ def get_state(self, var_list):
+ return tf.constant(0.0)
+
+ def assign_state(self, state, var_list=None):
+ return tf.no_op()
diff --git a/models/research/learning_unsupervised_learning/run_eval.py b/models/research/learning_unsupervised_learning/run_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcb2529dd4cc5354012befd5790c8d402f4caafd
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/run_eval.py
@@ -0,0 +1,122 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+""" Script that iteratively applies the unsupervised update rule and evaluates the
+
+meta-objective performance.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+from absl import app
+
+from learning_unsupervised_learning import evaluation
+from learning_unsupervised_learning import datasets
+from learning_unsupervised_learning import architectures
+from learning_unsupervised_learning import summary_utils
+from learning_unsupervised_learning import meta_objective
+
+import tensorflow as tf
+import sonnet as snt
+
+from tensorflow.contrib.framework.python.framework import checkpoint_utils
+
+flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from")
+flags.DEFINE_string("train_log_dir", None, "Training log directory")
+
+FLAGS = flags.FLAGS
+
+
+def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000):
+ dataset_fn = datasets.mnist.TinyMnist
+ w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
+ theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess
+
+ meta_objectives = []
+ meta_objectives.append(
+ meta_objective.linear_regression.LinearRegressionMetaObjective)
+ meta_objectives.append(meta_objective.sklearn.LogisticRegression)
+
+ checkpoint_vars, train_one_step_op, (
+ base_model, dataset) = evaluation.construct_evaluation_graph(
+ theta_process_fn=theta_process_fn,
+ w_learner_fn=w_learner_fn,
+ dataset_fn=dataset_fn,
+ meta_objectives=meta_objectives)
+ batch = dataset()
+ pre_logit, outputs = base_model(batch)
+
+ global_step = tf.train.get_or_create_global_step()
+ var_list = list(
+ snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES))
+
+ tf.logging.info("all vars")
+ for v in tf.all_variables():
+ tf.logging.info(" %s" % str(v))
+ global_step = tf.train.get_global_step()
+ accumulate_global_step = global_step.assign_add(1)
+ reset_global_step = global_step.assign(0)
+
+ train_op = tf.group(
+ train_one_step_op, accumulate_global_step, name="train_op")
+
+ summary_op = tf.summary.merge_all()
+
+ file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
+ if checkpoint_dir:
+ str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
+ name_to_v_map = {v.op.name: v for v in tf.all_variables()}
+ var_list = [
+ name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
+ ]
+ saver = tf.train.Saver(var_list)
+ missed_variables = [
+ v.op.name for v in set(
+ snt.get_variables_in_scope("LocalWeightUpdateProcess",
+ tf.GraphKeys.GLOBAL_VARIABLES)) -
+ set(var_list)
+ ]
+ assert len(missed_variables) == 0, "Missed a theta variable."
+
+ hooks = []
+
+ with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:
+
+ # global step should be restored from the evals job checkpoint or zero for fresh.
+ step = sess.run(global_step)
+
+ if step == 0 and checkpoint_dir:
+ tf.logging.info("force restore")
+ saver.restore(sess, checkpoint_dir)
+ tf.logging.info("force restore done")
+ sess.run(reset_global_step)
+ step = sess.run(global_step)
+
+ while step < num_steps:
+ if step % eval_every_n_steps == 0:
+ s, _, step = sess.run([summary_op, train_op, global_step])
+ file_writer.add_summary(s, step)
+ else:
+ _, step = sess.run([train_op, global_step])
+
+
+def main(argv):
+ train(FLAGS.train_log_dir, FLAGS.checkpoint_dir)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/models/research/learning_unsupervised_learning/summary_utils.py b/models/research/learning_unsupervised_learning/summary_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5c0fdd9186bdef0b4e25ca10978e22ab910d276
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/summary_utils.py
@@ -0,0 +1,181 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+
+import collections
+import functools
+import threading
+import tensorflow as tf
+import matplotlib
+import numpy as np
+import time
+import re
+import math
+matplotlib.use("Agg")
+
+import matplotlib.pyplot as plt
+import scipy.signal
+
+from tensorflow.python.util import tf_should_use
+from tensorflow.contrib.summary import summary_ops
+from tensorflow.python.ops import summary_op_util
+from tensorflow.contrib.summary import gen_summary_ops
+
+_DEBUG_DISABLE_SUMMARIES=False
+
+class LoggingFileWriter(tf.summary.FileWriter):
+ """A FileWriter that also logs things out.
+
+ This is entirely for ease of debugging / not having to open up Tensorboard
+ a lot.
+ """
+
+ def __init__(self, logdir, regexes=[], **kwargs):
+ self.regexes = regexes
+ super(LoggingFileWriter, self).__init__(logdir, **kwargs)
+
+ def add_summary(self, summary, global_step):
+ if type(summary) != tf.Summary:
+ summary_p = tf.Summary()
+ summary_p.ParseFromString(summary)
+ summary = summary_p
+ for s in summary.value:
+ for exists in [re.match(p, s.tag) for p in self.regexes]:
+ if exists is not None:
+ tf.logging.info("%d ] %s : %f", global_step, s.tag, s.simple_value)
+ break
+ super(LoggingFileWriter, self).add_summary(summary, global_step)
+
+
+def image_grid(images, max_grid_size=4, border=1):
+ """Given images and N, return first N^2 images as an NxN image grid.
+
+ Args:
+ images: a `Tensor` of size [batch_size, height, width, channels]
+ max_grid_size: Maximum image grid height/width
+
+ Returns:
+ Single image batch, of dim [1, h*n, w*n, c]
+ """
+ batch_size = images.shape.as_list()[0]
+ to_pad = int((np.ceil(np.sqrt(batch_size)))**2 - batch_size)
+ images = tf.pad(images, [[0, to_pad], [0, border], [0, border], [0, 0]])
+
+ batch_size = images.shape.as_list()[0]
+ grid_size = min(int(np.sqrt(batch_size)), max_grid_size)
+ assert images.shape.as_list()[0] >= grid_size * grid_size
+
+ # If we have a depth channel
+ if images.shape.as_list()[-1] == 4:
+ images = images[:grid_size * grid_size, :, :, 0:3]
+ depth = tf.image.grayscale_to_rgb(images[:grid_size * grid_size, :, :, 3:4])
+
+ images = tf.reshape(images, [-1, images.shape.as_list()[2], 3])
+ split = tf.split(images, grid_size, axis=0)
+ depth = tf.reshape(depth, [-1, images.shape.as_list()[2], 3])
+ depth_split = tf.split(depth, grid_size, axis=0)
+ grid = tf.concat(split + depth_split, 1)
+ return tf.expand_dims(grid, 0)
+ else:
+ images = images[:grid_size * grid_size, :, :, :]
+ images = tf.reshape(
+ images, [-1, images.shape.as_list()[2],
+ images.shape.as_list()[3]])
+ split = tf.split(value=images, num_or_size_splits=grid_size, axis=0)
+ grid = tf.concat(split, 1)
+ return tf.expand_dims(grid, 0)
+
+
+def first_layer_weight_image(weight, shape):
+ weight_image = tf.reshape(weight,
+ shape + [tf.identity(weight).shape.as_list()[1]])
+ # [winx, winy, wout]
+ mean, var = tf.nn.moments(weight_image, [0,1,2], keep_dims=True)
+ #mean, var = tf.nn.moments(weight_image, [0,1], keep_dims=True)
+ weight_image = (weight_image - mean) / tf.sqrt(var + 1e-5)
+ weight_image = (weight_image + 1.0) / 2.0
+ weight_image = tf.clip_by_value(weight_image, 0, 1)
+ weight_image = tf.transpose(weight_image, (3, 0, 1, 2))
+ grid = image_grid(weight_image, max_grid_size=10)
+ return grid
+
+def inner_layer_weight_image(weight):
+ """Visualize a weight matrix of an inner layer.
+ Add padding to make it square, then visualize as a gray scale image
+ """
+ weight = tf.identity(weight) # turn into a tensor
+ weight = weight / (tf.reduce_max(tf.abs(weight), [0], keep_dims=True))
+ weight = tf.reshape(weight, [1]+weight.shape.as_list() + [1])
+ return weight
+
+
+def activation_image(activations, label_onehot):
+ """Make a row sorted by class for each activation. Put a black line around the activations."""
+ labels = tf.argmax(label_onehot, axis=1)
+ _, n_classes = label_onehot.shape.as_list()
+ mean, var = tf.nn.moments(activations, [0, 1])
+ activations = (activations - mean)/tf.sqrt(var+1e-5)
+
+ activations = tf.clip_by_value(activations, -1, 1)
+ activations = (activations + 1.0) / 2.0 # shift to [0, 1]
+
+ canvas = []
+ for i in xrange(n_classes):
+ inds = tf.where(tf.equal(labels, i))
+
+ def _gather():
+ return tf.squeeze(tf.gather(activations, inds), 1)
+
+ def _empty():
+ return tf.zeros([0, activations.shape.as_list()[1]], dtype=tf.float32)
+
+ assert inds.shape.as_list()[0] is None
+ x = tf.cond(tf.equal(tf.shape(inds)[0], 0), _empty, _gather)
+ canvas.append(x)
+ canvas.append(tf.zeros([1, activations.shape.as_list()[1]]))
+ canvas = tf.concat(canvas, 0)
+ canvas = tf.reshape(canvas, [1, activations.shape.as_list()[0]+n_classes, canvas.shape.as_list()[1], 1])
+ return canvas
+
+
+def sorted_images(images, label_onehot):
+ # images is [bs, x, y, c]
+ labels = tf.argmax(label_onehot, axis=1)
+ _, n_classes = label_onehot.shape.as_list()
+ to_stack = []
+ for i in xrange(n_classes):
+ inds = tf.where(tf.equal(labels, i))
+
+ def _gather():
+ return tf.squeeze(tf.gather(images, inds), 1)
+
+ def _empty():
+ return tf.zeros([0] + images.shape.as_list()[1:], dtype=tf.float32)
+
+ assert inds.shape.as_list()[0] is None
+ x = tf.cond(tf.equal(tf.shape(inds)[0], 0), _empty, _gather)
+ to_stack.append(x)
+ # pad / trim all up to 10.
+ padded = []
+ for t in to_stack:
+ n_found = tf.shape(t)[0]
+ pad = tf.pad(t[0:10], tf.stack([tf.stack([0,tf.maximum(0, 10-n_found)]), [0,0], [0,0], [0,0]]))
+ padded.append(pad)
+
+ xs = [tf.concat(tf.split(p, 10), axis=1) for p in padded]
+ ys = tf.concat(xs, axis=2)
+ ys = tf.cast(tf.clip_by_value(ys, 0., 1.) * 255., tf.uint8)
+ return ys
diff --git a/models/research/learning_unsupervised_learning/utils.py b/models/research/learning_unsupervised_learning/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca56ca93181df1ed9c403fef79e8154c3c9515b4
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/utils.py
@@ -0,0 +1,287 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import tensorflow as tf
+import sonnet as snt
+import itertools
+import functools
+
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import variable_scope as variable_scope_ops
+from sonnet.python.modules import util as snt_util
+
+from tensorflow.python.util import nest
+
+
+def eqzip(*args):
+ """Zip but raises error if lengths don't match.
+
+ Args:
+ *args: list of lists or tuples
+ Returns:
+ list: the result of zip
+ Raises:
+ ValueError: when the lengths don't match
+ """
+
+ sizes = [len(x) for x in args]
+ if not all([sizes[0] == x for x in sizes]):
+ raise ValueError("Lists are of different sizes. \n %s"%str(sizes))
+ return zip(*args)
+
+
+@contextlib.contextmanager
+def assert_no_new_variables():
+ """Ensure that no tf.Variables are constructed inside the context.
+
+ Yields:
+ None
+ Raises:
+ ValueError: if there is a variable created.
+ """
+ num_vars = len(tf.global_variables())
+ old_variables = tf.global_variables()
+ yield
+ if len(tf.global_variables()) != num_vars:
+ new_vars = set(tf.global_variables()) - set(old_variables)
+ tf.logging.error("NEW VARIABLES CREATED")
+ tf.logging.error(10*"=")
+ for v in new_vars:
+ tf.logging.error(v)
+
+ raise ValueError("Variables created inside an "
+ "assert_no_new_variables context")
+ if old_variables != tf.global_variables():
+ raise ValueError("Variables somehow changed inside an "
+ "assert_no_new_variables context."
+ "This means something modified the tf.global_variables()")
+
+
+def get_variables_in_modules(module_list):
+ var_list = []
+ for m in module_list:
+ var_list.extend(snt.get_variables_in_module(m))
+ return var_list
+
+
+def state_barrier_context(state):
+ """Return a context manager that prevents interior ops from running
+ unless the whole state has been computed.
+
+ This is to prevent assign race conditions.
+ """
+ tensors = [x for x in nest.flatten(state) if type(x) == tf.Tensor]
+ tarray = [x.flow for x in nest.flatten(state) if hasattr(x, "flow")]
+ return tf.control_dependencies(tensors + tarray)
+
+
+def _identity_fn(tf_entity):
+ if hasattr(tf_entity, "identity"):
+ return tf_entity.identity()
+ else:
+ return tf.identity(tf_entity)
+
+
+def state_barrier_result(state):
+ """Return the same state, but with a control dependency to prevent it from
+ being partially computed
+ """
+ with state_barrier_context(state):
+ return nest.map_structure(_identity_fn, state)
+
+
+def train_iterator(num_iterations):
+ """Iterator that returns an index of the current step.
+ This iterator runs forever if num_iterations is None
+ otherwise it runs for some fixed amount of steps.
+ """
+ if num_iterations is None:
+ return itertools.count()
+ else:
+ return xrange(num_iterations)
+
+
+def print_op(op, msg):
+ """Print a string and return an op wrapped in a control dependency to make
+ sure it ran."""
+ print_op = tf.Print(tf.constant(0), [tf.constant(0)], msg)
+ return tf.group(op, print_op)
+
+
+class MultiQueueRunner(tf.train.QueueRunner):
+ """A QueueRunner with multiple queues """
+ def __init__(self, queues, enqueue_ops):
+ close_op = tf.group(* [q.close() for q in queues])
+ cancel_op = tf.group(
+ * [q.close(cancel_pending_enqueues=True) for q in queues])
+ queue_closed_exception_types = (errors.OutOfRangeError,)
+
+ enqueue_op = tf.group(*enqueue_ops, name="multi_enqueue")
+
+ super(MultiQueueRunner, self).__init__(
+ queues[0],
+ enqueue_ops=[enqueue_op],
+ close_op=close_op,
+ cancel_op=cancel_op,
+ queue_closed_exception_types=queue_closed_exception_types)
+
+
+# This function is not elegant, but I tried so many other ways to get this to
+# work and this is the only one that ended up not incuring significant overhead
+# or obscure tensorflow bugs.
+def sample_n_per_class(dataset, samples_per_class):
+ """Create a new callable / dataset object that returns batches of each with
+ samples_per_class per label.
+
+ Args:
+ dataset: fn
+ samples_per_class: int
+ Returns:
+ function, [] -> batch where batch is the same type as the return of
+ dataset().
+ """
+
+ with tf.control_dependencies(None), tf.name_scope(None):
+ with tf.name_scope("queue_runner/sample_n_per_class"):
+ batch = dataset()
+ num_classes = batch.label_onehot.shape.as_list()[1]
+ batch_size = num_classes * samples_per_class
+
+ flatten = nest.flatten(batch)
+ queues = []
+ enqueue_ops = []
+ capacity = samples_per_class * 20
+ for i in xrange(num_classes):
+ queue = tf.FIFOQueue(
+ capacity=capacity,
+ shapes=[f.shape.as_list()[1:] for f in flatten],
+ dtypes=[f.dtype for f in flatten])
+ queues.append(queue)
+
+ idx = tf.where(tf.equal(batch.label, i))
+ sub_batch = []
+ to_enqueue = []
+ for elem in batch:
+ new_e = tf.gather(elem, idx)
+ new_e = tf.squeeze(new_e, 1)
+ to_enqueue.append(new_e)
+
+ remaining = (capacity - queue.size())
+ to_add = tf.minimum(tf.shape(idx)[0], remaining)
+
+ def _enqueue():
+ return queue.enqueue_many([t[:to_add] for t in to_enqueue])
+
+ enqueue_op = tf.cond(
+ tf.equal(to_add, 0), tf.no_op, _enqueue)
+ enqueue_ops.append(enqueue_op)
+
+ # This has caused many deadlocks / issues. This is some logging to at least
+ # shed light to what is going on.
+ print_lam = lambda: tf.Print(tf.constant(0.0), [q.size() for q in queues], "MultiQueueRunner queues status. Has capacity %d"%capacity)
+ some_percent_of_time = tf.less(tf.random_uniform([]), 0.0005)
+ maybe_print = tf.cond(some_percent_of_time, print_lam, lambda: tf.constant(0.0))
+ with tf.control_dependencies([maybe_print]):
+ enqueue_ops = [tf.group(e) for e in enqueue_ops]
+ qr = MultiQueueRunner(queues=queues, enqueue_ops=enqueue_ops)
+ tf.train.add_queue_runner(qr)
+
+ def dequeue_batch():
+ with tf.name_scope("sample_n_per_batch/dequeue/"):
+ entries = []
+ for q in queues:
+ entries.append(q.dequeue_many(samples_per_class))
+
+ flat_batch = [tf.concat(x, 0) for x in zip(*entries)]
+ idx = tf.random_shuffle(tf.range(batch_size))
+ flat_batch = [tf.gather(f, idx, axis=0) for f in flat_batch]
+ return nest.pack_sequence_as(batch, flat_batch)
+
+ return dequeue_batch
+
+def structure_map_multi(func, values):
+ all_values = [nest.flatten(v) for v in values]
+ rets = []
+ for pair in zip(*all_values):
+ rets.append(func(pair))
+ return nest.pack_sequence_as(values[0], rets)
+
+def structure_map_split(func, value):
+ vv = nest.flatten(value)
+ rets = []
+ for v in vv:
+ rets.append(func(v))
+ return [nest.pack_sequence_as(value, r) for r in zip(*rets)]
+
+def assign_variables(targets, values):
+ return tf.group(*[t.assign(v) for t,v in eqzip(targets, values)],
+ name="assign_variables")
+
+
+def create_variables_in_class_scope(method):
+ """Force the variables constructed in this class to live in the sonnet module.
+ Wraps a method on a sonnet module.
+
+ For example the following will create two different variables.
+ ```
+ class Mod(snt.AbstractModule):
+ @create_variables_in_class_scope
+ def dynamic_thing(self, input, name):
+ return snt.Linear(name)(input)
+ mod.dynamic_thing(x, name="module_nameA")
+ mod.dynamic_thing(x, name="module_nameB")
+ # reuse
+ mod.dynamic_thing(y, name="module_nameA")
+ ```
+ """
+ @functools.wraps(method)
+ def wrapper(obj, *args, **kwargs):
+ def default_context_manager(reuse=None):
+ variable_scope = obj.variable_scope
+ return tf.variable_scope(variable_scope, reuse=reuse)
+
+ variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
+ default_context_manager)
+ graph = tf.get_default_graph()
+
+ # Temporarily enter the variable scope to capture it
+ with variable_scope_context_manager() as tmp_variable_scope:
+ variable_scope = tmp_variable_scope
+
+ with variable_scope_ops._pure_variable_scope(
+ variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope:
+
+ name_scope = variable_scope.original_name_scope
+ if name_scope[-1] != "/":
+ name_scope += "/"
+
+ with tf.name_scope(name_scope):
+ sub_scope = snt_util.to_snake_case(method.__name__)
+ with tf.name_scope(sub_scope) as scope:
+ out_ops = method(obj, *args, **kwargs)
+ return out_ops
+
+ return wrapper
+
diff --git a/models/research/learning_unsupervised_learning/variable_replace.py b/models/research/learning_unsupervised_learning/variable_replace.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebfbeadc8aba7f8a09e1392f1de8d7b33f10d43c
--- /dev/null
+++ b/models/research/learning_unsupervised_learning/variable_replace.py
@@ -0,0 +1,112 @@
+# Copyright 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+
+import tensorflow as tf
+from contextlib import contextmanager
+
+from tensorflow.python.ops import variable_scope
+
+# sanity global state to ensure non recursive.
+_is_variable_replacing = [False]
+
+def in_variable_replace_scope():
+ return _is_variable_replacing[0]
+
+@contextmanager
+def variable_replace(replacements, no_new=True):
+ """ A context manager that replaces variables.
+
+ This is a context manager that replaces all calls to
+ get_variable with the variable in replacements.
+ This function does not support recursive application.
+
+ Args:
+ replacements: dict
+ dictionary mapping a variable to replace (the key), with
+ the variable one wants to replace this variable with (the value).
+ no_new: bool
+ raise an error if variables were created.
+ This is for sanity checking.
+ Raises:
+ ValueError: if a new variable or not all the replacements are used.
+ """
+ # TODO(lmetz) This function is a bit scary, as it relies on monkey patching
+ # the call to get_variable. Ideally this can be done with variable_scope's
+ # custom_getter attribute, but when initially writing this that was not
+ # avalible.
+
+ replacements = {k: v for k, v in replacements.items() if not k == v}
+
+ init_vars = tf.trainable_variables()
+ old_get_variable = variable_scope.get_variable
+ old_tf_get_variable = tf.get_variable
+
+ names_replace = {}
+ has_replaced_names = []
+ tf.logging.vlog(2, "Trying to replace")
+ for k, v in replacements.items():
+ tf.logging.vlog(2, k.name + " >> " + v.name)
+ tf.logging.vlog(2, "===")
+
+ for k, v in replacements.items():
+ strip_name = k.name.replace("/read:0", "")
+ strip_name = strip_name.replace(":0", "")
+ names_replace[strip_name] = v
+ # TODO(lmetz) is there a cleaner way to do this?
+ def new_get_variable(name, *args, **kwargs):
+ #print "Monkeypatch get variable run with name:", name
+ n = tf.get_variable_scope().name + "/" + name
+ #print "Monkeypatch get variable run with name:", n
+ if n in names_replace:
+ has_replaced_names.append(n)
+ return names_replace[n]
+ else:
+ return old_get_variable(name, *args, **kwargs)
+
+ # perform the monkey patch
+ if _is_variable_replacing[0] == True:
+ raise ValueError("No recursive calling to variable replace allowed.")
+
+ variable_scope.get_variable = new_get_variable
+ tf.get_variable = new_get_variable
+
+ _is_variable_replacing[0] = True
+
+ yield
+
+ if set(has_replaced_names) != set(names_replace.keys()):
+ print "Didn't use all replacements"
+ print "replaced variables that are not requested??"
+ print "==="
+ for n in list(set(has_replaced_names) - set(names_replace.keys())):
+ print n
+ print "Missed replacing variables"
+ print "==="
+ for n in list(set(names_replace.keys()) - set(has_replaced_names)):
+ print n, "==>", names_replace[n].name
+ raise ValueError("Fix this -- see stderr")
+
+ # undo the monkey patch
+ tf.get_variable = old_tf_get_variable
+ variable_scope.get_variable = old_get_variable
+
+ _is_variable_replacing[0] = False
+
+ final_vars = tf.trainable_variables()
+ assert set(init_vars) == set(final_vars), "trainable variables changed"
diff --git a/models/research/lexnet_nc/README.md b/models/research/lexnet_nc/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4ecb5d39867c2ebf7280b9d19bbabb41957b9465
--- /dev/null
+++ b/models/research/lexnet_nc/README.md
@@ -0,0 +1,215 @@
+
+
+
+
+# LexNET for Noun Compound Relation Classification
+
+This is a [Tensorflow](http://www.tensorflow.org/) implementation of the LexNET
+algorithm for classifying relationships, specifically applied to classifying the
+relationships that hold between noun compounds:
+
+* *olive oil* is oil that is *made from* olives
+* *cooking oil* which is oil that is *used for* cooking
+* *motor oil* is oil that is *contained in* a motor
+
+The model is a supervised classifier that predicts the relationship that holds
+between the constituents of a two-word noun compound using:
+
+1. A neural "paraphrase" of each syntactic dependency path that connects the
+ constituents in a large corpus. For example, given a sentence like *This fine
+ oil is made from first-press olives*, the dependency path is something like
+ `oil from POBJ> olive`.
+2. The distributional information provided by the individual words; i.e., the
+ word embeddings of the two consituents.
+3. The distributional signal provided by the compound itself; i.e., the
+ embedding of the noun compound in context.
+
+The model includes several variants: *path-based model* uses (1) alone, the
+*distributional model* uses (2) alone, and the *integrated model* uses (1) and
+(2). The *distributional-nc model* and the *integrated-nc* model each add (3).
+
+Training a model requires the following:
+
+1. A collection of noun compounds that have been labeled using a *relation
+ inventory*. The inventory describes the specific relationships that you'd
+ like the model to differentiate (e.g. *part of* versus *composed of* versus
+ *purpose*), and generally may consist of tens of classes. You can download
+ the dataset used in the paper from
+ [here](https://vered1986.github.io/papers/Tratz2011_Dataset.tar.gz).
+2. A collection of word embeddings: the path-based model uses the word
+ embeddings as part of the path representation, and the distributional models
+ use the word embeddings directly as prediction features.
+3. The path-based model requires a collection of syntactic dependency parses
+ that connect the constituents for each noun compound. To generate these,
+ you'll need a corpus from which to train this data; we used Wikipedia and the
+ [LDC GigaWord5](https://catalog.ldc.upenn.edu/LDC2011T07) corpora.
+
+# Contents
+
+The following source code is included here:
+
+* `learn_path_embeddings.py` is a script that trains and evaluates a path-based
+ model to predict a noun-compound relationship given labeled noun-compounds and
+ dependency parse paths.
+* `learn_classifier.py` is a script that trains and evaluates a classifier based
+ on any combination of paths, word embeddings, and noun-compound embeddings.
+* `get_indicative_paths.py` is a script that generates the most indicative
+ syntactic dependency paths for a particular relationship.
+
+Also included are utilities for preparing data for training:
+
+* `text_embeddings_to_binary.py` converts a text file containing word embeddings
+ into a binary file that is quicker to load.
+* `extract_paths.py` finds all the dependency paths that connect words in a
+ corpus.
+* `sorted_paths_to_examples.py` processes the output of `extract_paths.py` to
+ produce summarized training data.
+
+This code (in particular, the utilities used to prepare the data) differs from
+the code that was used to prepare data for the paper. Notably, we used a
+proprietary dependency parser instead of spaCy, which is used here.
+
+# Dependencies
+
+* [TensorFlow](http://www.tensorflow.org/): see detailed installation
+ instructions at that site.
+* [SciKit Learn](http://scikit-learn.org/): you can probably just install this
+ with `pip install sklearn`.
+* [SpaCy](https://spacy.io/): `pip install spacy` ought to do the trick, along
+ with the English model.
+
+# Creating the Model
+
+This sections described the steps necessary to create and evaluate the model
+described in the paper.
+
+## Generate Path Data
+
+To begin, you need three text files:
+
+1. **Corpus**. This file should contain natural language sentences, written with
+ one sentence per line. For purposes of exposition, we'll assume that you
+ have English Wikipedia serialized this way in `${HOME}/data/wiki.txt`.
+2. **Labeled Noun Compound Pairs**. This file contain (modfier, head, label)
+ tuples, tab-separated, with one per line. The *label* represented the
+ relationship between the head and the modifier; e.g., if `purpose` is one
+ your labels, you could possibly include `toothpastepurpose`.
+3. **Word Embeddings**. We used the
+ [GloVe](https://nlp.stanford.edu/projects/glove/) word embeddings; in
+ particular the 6B token, 300d variant. We'll assume you have this file as
+ `${HOME}/data/glove.6B.300d.txt`.
+
+We first processed the embeddings from their text format into something that we
+can load a little bit more quickly:
+
+ ./text_embeddings_to_binary.py \
+ --input ${HOME}/data/glove.6B.300d.txt \
+ --output_vocab ${HOME}/data/vocab.txt \
+ --output_npy ${HOME}/data/glove.6B.300d.npy
+
+Next, we'll extract all the dependency parse paths connecting our labeled pairs
+from the corpus. This process takes a *looooong* time, but is trivially
+parallelized using map-reduce if you have access to that technology.
+
+ ./extract_paths.py \
+ --corpus ${HOME}/data/wiki.txt \
+ --labeled_pairs ${HOME}/data/labeled-pairs.tsv \
+ --output ${HOME}/data/paths.tsv
+
+The file it produces (`paths.tsv`) is a tab-separated file that contains the
+modifier, the head, the label, the encoded path, and the sentence from which the
+path was drawn. (This last is mostly for sanity checking.) A sample row might
+look something like this (where newlines would actually be tab characters):
+
+ navy
+ captain
+ owner_emp_use
+ /PROPN/dobj/>::enter/VERB/ROOT/^::follow/VERB/advcl/<::in/ADP/prep/<::footstep/NOUN/pobj/<::of/ADP/prep/<::father/NOUN/pobj/<::bover/PROPN/appos/<::/PROPN/compound/<
+ He entered the Royal Navy following in the footsteps of his father Captain John Bover and two of his elder brothers as volunteer aboard HMS Perseus
+
+This file must be sorted as follows:
+
+ sort -k1,3 -t$'\t' paths.tsv > sorted.paths.tsv
+
+In particular, rows with the same modifier, head, and label must appear
+contiguously.
+
+We next create a file that contains all the relation labels from our original
+labeled pairs:
+
+ awk 'BEGIN {FS="\t"} {print $3}' < ${HOME}/data/labeled-pairs.tsv \
+ | sort -u > ${HOME}/data/relations.txt
+
+With these in hand, we're ready to produce the train, validation, and test data:
+
+ ./sorted_paths_to_examples.py \
+ --input ${HOME}/data/sorted.paths.tsv \
+ --vocab ${HOME}/data/vocab.txt \
+ --relations ${HOME}/data/relations.txt \
+ --splits ${HOME}/data/splits.txt \
+ --output_dir ${HOME}/data
+
+Here, `splits.txt` is a file that indicates which "split" (train, test, or
+validation) you want the pair to appear in. It should be a tab-separate file
+which conatins the modifier, head, and the dataset ( `train`, `test`, or `val`)
+into which the pair should be placed; e.g.,:
+
+ tooth paste train
+ banana seat test
+
+The program will produce a separate file for each dataset split in the directory
+specified by `--output_dir`. Each file is contains `tf.train.Example` protocol
+buffers encoded using the `TFRecord` file format.
+
+## Create Path Embeddings
+
+Now we're ready to train the path embeddings using `learn_path_embeddings.py`:
+
+ ./learn_path_embeddings.py \
+ --train ${HOME}/data/train.tfrecs.gz \
+ --val ${HOME}/data/val.tfrecs.gz \
+ --text ${HOME}/data/test.tfrecs.gz \
+ --embeddings ${HOME}/data/glove.6B.300d.npy
+ --relations ${HOME}/data/relations.txt
+ --output ${HOME}/data/path-embeddings \
+ --logdir /tmp/learn_path_embeddings
+
+The path embeddings will be placed at the location specified by `--output`.
+
+## Train classifiers
+
+Train classifiers and evaluate on the validation and test data using
+`train_classifiers.py` script. This shell script fragment will iterate through
+each dataset, split, corpus, and model type to train and evaluate classifiers.
+
+ LOGDIR=/tmp/learn_classifier
+ for DATASET in tratz/fine_grained tratz/coarse_grained ; do
+ for SPLIT in random lexical_head lexical_mod lexical_full ; do
+ for CORPUS in wiki_gigiawords ; do
+ for MODEL in dist dist-nc path integrated integrated-nc ; do
+ # Filename for the log that will contain the classifier results.
+ LOGFILE=$(echo "${DATASET}.${SPLIT}.${CORPUS}.${MODEL}.log" | sed -e "s,/,.,g")
+ python learn_classifier.py \
+ --dataset_dir ~/lexnet/datasets \
+ --dataset "${DATASET}" \
+ --corpus "${SPLIT}/${CORPUS}" \
+ --embeddings_base_path ~/lexnet/embeddings \
+ --logdir ${LOGDIR} \
+ --input "${MODEL}" > "${LOGDIR}/${LOGFILE}"
+ done
+ done
+ done
+ done
+
+The log file will contain the final performance (precision, recall, F1) on the
+train, dev, and test sets, and will include a confusion matrix for each.
+
+# Contact
+
+If you have any questions, issues, or suggestions, feel free to contact either
+@vered1986 or @waterson.
+
+If you use this code for any published research, please include the following citation:
+
+Olive Oil Is Made of Olives, Baby Oil Is Made for Babies: Interpreting Noun Compounds Using Paraphrases in a Neural Model.
+Vered Shwartz and Chris Waterson. NAACL 2018. [link](https://arxiv.org/pdf/1803.08073.pdf).
diff --git a/models/research/lexnet_nc/extract_paths.py b/models/research/lexnet_nc/extract_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..833eec2c1b8a176b487d4e663a737b9502b49eda
--- /dev/null
+++ b/models/research/lexnet_nc/extract_paths.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import sys
+
+import spacy
+import tensorflow as tf
+
+tf.flags.DEFINE_string('corpus', '', 'Filename of corpus')
+tf.flags.DEFINE_string('labeled_pairs', '', 'Filename of labeled pairs')
+tf.flags.DEFINE_string('output', '', 'Filename of output file')
+FLAGS = tf.flags.FLAGS
+
+
+def get_path(mod_token, head_token):
+ """Returns the path between a modifier token and a head token."""
+ # Compute the path from the root to each token.
+ mod_ancestors = list(reversed(list(mod_token.ancestors)))
+ head_ancestors = list(reversed(list(head_token.ancestors)))
+
+ # If the paths don't start at the same place (odd!) then there is no path at
+ # all.
+ if (not mod_ancestors or not head_ancestors
+ or mod_ancestors[0] != head_ancestors[0]):
+ return None
+
+ # Eject elements from the common path until we reach the first differing
+ # ancestor.
+ ix = 1
+ while (ix < len(mod_ancestors) and ix < len(head_ancestors)
+ and mod_ancestors[ix] == head_ancestors[ix]):
+ ix += 1
+
+ # Construct the path. TODO: add "satellites", possibly honor sentence
+ # ordering between modifier and head rather than just always traversing from
+ # the modifier to the head?
+ path = ['/'.join(('', mod_token.pos_, mod_token.dep_, '>'))]
+
+ path += ['/'.join((tok.lemma_, tok.pos_, tok.dep_, '>'))
+ for tok in reversed(mod_ancestors[ix:])]
+
+ root_token = mod_ancestors[ix - 1]
+ path += ['/'.join((root_token.lemma_, root_token.pos_, root_token.dep_, '^'))]
+
+ path += ['/'.join((tok.lemma_, tok.pos_, tok.dep_, '<'))
+ for tok in head_ancestors[ix:]]
+
+ path += ['/'.join(('', head_token.pos_, head_token.dep_, '<'))]
+
+ return '::'.join(path)
+
+
+def main(_):
+ nlp = spacy.load('en_core_web_sm')
+
+ # Grab the set of labeled pairs for which we wish to collect paths.
+ with tf.gfile.GFile(FLAGS.labeled_pairs) as fh:
+ parts = (l.decode('utf-8').split('\t') for l in fh.read().splitlines())
+ labeled_pairs = {(mod, head): rel for mod, head, rel in parts}
+
+ # Create a mapping from each head to the modifiers that are used with it.
+ mods_for_head = {
+ head: set(hm[1] for hm in head_mods)
+ for head, head_mods in itertools.groupby(
+ sorted((head, mod) for (mod, head) in labeled_pairs.iterkeys()),
+ lambda (head, mod): head)}
+
+ # Collect all the heads that we know about.
+ heads = set(mods_for_head.keys())
+
+ # For each sentence that contains a (head, modifier) pair that's in our set,
+ # emit the dependency path that connects the pair.
+ out_fh = sys.stdout if not FLAGS.output else tf.gfile.GFile(FLAGS.output, 'w')
+ in_fh = sys.stdin if not FLAGS.corpus else tf.gfile.GFile(FLAGS.corpus)
+
+ num_paths = 0
+ for line, sen in enumerate(in_fh, start=1):
+ if line % 100 == 0:
+ print('\rProcessing line %d: %d paths' % (line, num_paths),
+ end='', file=sys.stderr)
+
+ sen = sen.decode('utf-8').strip()
+ doc = nlp(sen)
+
+ for head_token in doc:
+ head_text = head_token.text.lower()
+ if head_text in heads:
+ mods = mods_for_head[head_text]
+ for mod_token in doc:
+ mod_text = mod_token.text.lower()
+ if mod_text in mods:
+ path = get_path(mod_token, head_token)
+ if path:
+ label = labeled_pairs[(mod_text, head_text)]
+ line = '\t'.join((mod_text, head_text, label, path, sen))
+ print(line.encode('utf-8'), file=out_fh)
+ num_paths += 1
+
+ out_fh.close()
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/lexnet_nc/get_indicative_paths.py b/models/research/lexnet_nc/get_indicative_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8b34cca221a07c0b633024b71f082b8f61b3a45
--- /dev/null
+++ b/models/research/lexnet_nc/get_indicative_paths.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Extracts paths that are indicative of each relation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import tensorflow as tf
+
+from . import path_model
+from . import lexnet_common
+
+tf.flags.DEFINE_string(
+ 'dataset_dir', 'datasets',
+ 'Dataset base directory')
+
+tf.flags.DEFINE_string(
+ 'dataset',
+ 'tratz/fine_grained',
+ 'Subdirectory containing the corpus directories: '
+ 'subdirectory of dataset_dir')
+
+tf.flags.DEFINE_string(
+ 'corpus', 'random/wiki',
+ 'Subdirectory containing the corpus and split: '
+ 'subdirectory of dataset_dir/dataset')
+
+tf.flags.DEFINE_string(
+ 'embeddings_base_path', 'embeddings',
+ 'Embeddings base directory')
+
+tf.flags.DEFINE_string(
+ 'logdir', 'logdir',
+ 'Directory of model output files')
+
+tf.flags.DEFINE_integer(
+ 'top_k', 20, 'Number of top paths to extract')
+
+tf.flags.DEFINE_float(
+ 'threshold', 0.8, 'Threshold above which to consider paths as indicative')
+
+FLAGS = tf.flags.FLAGS
+
+
+def main(_):
+ hparams = path_model.PathBasedModel.default_hparams()
+
+ # First things first. Load the path data.
+ path_embeddings_file = 'path_embeddings/{dataset}/{corpus}'.format(
+ dataset=FLAGS.dataset,
+ corpus=FLAGS.corpus)
+
+ path_dim = (hparams.lemma_dim + hparams.pos_dim +
+ hparams.dep_dim + hparams.dir_dim)
+
+ path_embeddings, path_to_index = path_model.load_path_embeddings(
+ os.path.join(FLAGS.embeddings_base_path, path_embeddings_file),
+ path_dim)
+
+ # Load and count the classes so we can correctly instantiate the model.
+ classes_filename = os.path.join(
+ FLAGS.dataset_dir, FLAGS.dataset, 'classes.txt')
+
+ with open(classes_filename) as f_in:
+ classes = f_in.read().splitlines()
+
+ hparams.num_classes = len(classes)
+
+ # We need the word embeddings to instantiate the model, too.
+ print('Loading word embeddings...')
+ lemma_embeddings = lexnet_common.load_word_embeddings(
+ FLAGS.embeddings_base_path, hparams.lemma_embeddings_file)
+
+ # Instantiate the model.
+ with tf.Graph().as_default():
+ with tf.variable_scope('lexnet'):
+ instance = tf.placeholder(dtype=tf.string)
+ model = path_model.PathBasedModel(
+ hparams, lemma_embeddings, instance)
+
+ with tf.Session() as session:
+ model_dir = '{logdir}/results/{dataset}/path/{corpus}'.format(
+ logdir=FLAGS.logdir,
+ dataset=FLAGS.dataset,
+ corpus=FLAGS.corpus)
+
+ saver = tf.train.Saver()
+ saver.restore(session, os.path.join(model_dir, 'best.ckpt'))
+
+ path_model.get_indicative_paths(
+ model, session, path_to_index, path_embeddings, classes,
+ model_dir, FLAGS.top_k, FLAGS.threshold)
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/lexnet_nc/learn_classifier.py b/models/research/lexnet_nc/learn_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec284029535609ffd2cc0f2f5cddb9b87954aa81
--- /dev/null
+++ b/models/research/lexnet_nc/learn_classifier.py
@@ -0,0 +1,223 @@
+#!/usr/bin/env python
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trains the integrated LexNET classifier."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import lexnet_common
+import lexnet_model
+import path_model
+from sklearn import metrics
+import tensorflow as tf
+
+tf.flags.DEFINE_string(
+ 'dataset_dir', 'datasets',
+ 'Dataset base directory')
+
+tf.flags.DEFINE_string(
+ 'dataset', 'tratz/fine_grained',
+ 'Subdirectory containing the corpus directories: '
+ 'subdirectory of dataset_dir')
+
+tf.flags.DEFINE_string(
+ 'corpus', 'wiki/random',
+ 'Subdirectory containing the corpus and split: '
+ 'subdirectory of dataset_dir/dataset')
+
+tf.flags.DEFINE_string(
+ 'embeddings_base_path', 'embeddings',
+ 'Embeddings base directory')
+
+tf.flags.DEFINE_string(
+ 'logdir', 'logdir',
+ 'Directory of model output files')
+
+tf.flags.DEFINE_string('hparams', '', 'Hyper-parameters')
+
+tf.flags.DEFINE_string(
+ 'input', 'integrated',
+ 'The model(dist/dist-nc/path/integrated/integrated-nc')
+
+FLAGS = tf.flags.FLAGS
+
+
+def main(_):
+ # Pick up any one-off hyper-parameters.
+ hparams = lexnet_model.LexNETModel.default_hparams()
+ hparams.corpus = FLAGS.corpus
+ hparams.input = FLAGS.input
+ hparams.path_embeddings_file = 'path_embeddings/%s/%s' % (
+ FLAGS.dataset, FLAGS.corpus)
+
+ input_dir = hparams.input if hparams.input != 'path' else 'path_classifier'
+
+ # Set the number of classes
+ classes_filename = os.path.join(
+ FLAGS.dataset_dir, FLAGS.dataset, 'classes.txt')
+ with open(classes_filename) as f_in:
+ classes = f_in.read().splitlines()
+
+ hparams.num_classes = len(classes)
+ print('Model will predict into %d classes' % hparams.num_classes)
+
+ # Get the datasets
+ train_set, val_set, test_set = (
+ os.path.join(
+ FLAGS.dataset_dir, FLAGS.dataset, FLAGS.corpus,
+ filename + '.tfrecs.gz')
+ for filename in ['train', 'val', 'test'])
+
+ print('Running with hyper-parameters: {}'.format(hparams))
+
+ # Load the instances
+ print('Loading instances...')
+ opts = tf.python_io.TFRecordOptions(
+ compression_type=tf.python_io.TFRecordCompressionType.GZIP)
+ train_instances = list(tf.python_io.tf_record_iterator(train_set, opts))
+ val_instances = list(tf.python_io.tf_record_iterator(val_set, opts))
+ test_instances = list(tf.python_io.tf_record_iterator(test_set, opts))
+
+ # Load the word embeddings
+ print('Loading word embeddings...')
+ relata_embeddings, path_embeddings, nc_embeddings, path_to_index = (
+ None, None, None, None)
+ if hparams.input in ['dist', 'dist-nc', 'integrated', 'integrated-nc']:
+ relata_embeddings = lexnet_common.load_word_embeddings(
+ FLAGS.embeddings_base_path, hparams.relata_embeddings_file)
+
+ if hparams.input in ['path', 'integrated', 'integrated-nc']:
+ path_embeddings, path_to_index = path_model.load_path_embeddings(
+ os.path.join(FLAGS.embeddings_base_path, hparams.path_embeddings_file),
+ hparams.path_dim)
+
+ if hparams.input in ['dist-nc', 'integrated-nc']:
+ nc_embeddings = lexnet_common.load_word_embeddings(
+ FLAGS.embeddings_base_path, hparams.nc_embeddings_file)
+
+ # Define the graph and the model
+ with tf.Graph().as_default():
+ model = lexnet_model.LexNETModel(
+ hparams, relata_embeddings, path_embeddings,
+ nc_embeddings, path_to_index)
+
+ # Initialize a session and start training
+ session = tf.Session()
+ session.run(tf.global_variables_initializer())
+
+ # Initalize the path mapping
+ if hparams.input in ['path', 'integrated', 'integrated-nc']:
+ session.run(tf.tables_initializer())
+ session.run(model.initialize_path_op, {
+ model.path_initial_value_t: path_embeddings
+ })
+
+ # Initialize the NC embeddings
+ if hparams.input in ['dist-nc', 'integrated-nc']:
+ session.run(model.initialize_nc_op, {
+ model.nc_initial_value_t: nc_embeddings
+ })
+
+ # Load the labels
+ print('Loading labels...')
+ train_labels = model.load_labels(session, train_instances)
+ val_labels = model.load_labels(session, val_instances)
+ test_labels = model.load_labels(session, test_instances)
+
+ save_path = '{logdir}/results/{dataset}/{input}/{corpus}'.format(
+ logdir=FLAGS.logdir, dataset=FLAGS.dataset,
+ corpus=model.hparams.corpus, input=input_dir)
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+ # Train the model
+ print('Training the model...')
+ model.fit(session, train_instances, epoch_completed,
+ val_instances, val_labels, save_path)
+
+ # Print the best performance on the validation set
+ print('Best performance on the validation set: F1=%.3f' %
+ epoch_completed.best_f1)
+
+ # Evaluate on the train and validation sets
+ lexnet_common.full_evaluation(model, session, train_instances, train_labels,
+ 'Train', classes)
+ lexnet_common.full_evaluation(model, session, val_instances, val_labels,
+ 'Validation', classes)
+ test_predictions = lexnet_common.full_evaluation(
+ model, session, test_instances, test_labels, 'Test', classes)
+
+ # Write the test predictions to a file
+ predictions_file = os.path.join(save_path, 'test_predictions.tsv')
+ print('Saving test predictions to %s' % save_path)
+ test_pairs = model.load_pairs(session, test_instances)
+ lexnet_common.write_predictions(test_pairs, test_labels, test_predictions,
+ classes, predictions_file)
+
+
+def epoch_completed(model, session, epoch, epoch_loss,
+ val_instances, val_labels, save_path):
+ """Runs every time an epoch completes.
+
+ Print the performance on the validation set, and update the saved model if
+ its performance is better on the previous ones. If the performance dropped,
+ tell the training to stop.
+
+ Args:
+ model: The currently trained path-based model.
+ session: The current TensorFlow session.
+ epoch: The epoch number.
+ epoch_loss: The current epoch loss.
+ val_instances: The validation set instances (evaluation between epochs).
+ val_labels: The validation set labels (for evaluation between epochs).
+ save_path: Where to save the model.
+
+ Returns:
+ whether the training should stop.
+ """
+ stop_training = False
+
+ # Evaluate on the validation set
+ val_pred = model.predict(session, val_instances)
+ precision, recall, f1, _ = metrics.precision_recall_fscore_support(
+ val_labels, val_pred, average='weighted')
+ print(
+ 'Epoch: %d/%d, Loss: %f, validation set: P: %.3f, R: %.3f, F1: %.3f\n' % (
+ epoch + 1, model.hparams.num_epochs, epoch_loss,
+ precision, recall, f1))
+
+ # If the F1 is much smaller than the previous one, stop training. Else, if
+ # it's bigger, save the model.
+ if f1 < epoch_completed.best_f1 - 0.08:
+ stop_training = True
+
+ if f1 > epoch_completed.best_f1:
+ saver = tf.train.Saver()
+ checkpoint_filename = os.path.join(save_path, 'best.ckpt')
+ print('Saving model in: %s' % checkpoint_filename)
+ saver.save(session, checkpoint_filename)
+ print('Model saved in file: %s' % checkpoint_filename)
+ epoch_completed.best_f1 = f1
+
+ return stop_training
+
+epoch_completed.best_f1 = 0
+
+if __name__ == '__main__':
+ tf.app.run(main)
diff --git a/models/research/lexnet_nc/learn_path_embeddings.py b/models/research/lexnet_nc/learn_path_embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..480378f4aa010ee27f0387685bac488cedbb2ab9
--- /dev/null
+++ b/models/research/lexnet_nc/learn_path_embeddings.py
@@ -0,0 +1,186 @@
+#!/usr/bin/env python
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Trains the LexNET path-based model."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import lexnet_common
+import path_model
+from sklearn import metrics
+import tensorflow as tf
+
+tf.flags.DEFINE_string('train', '', 'training dataset, tfrecs')
+tf.flags.DEFINE_string('val', '', 'validation dataset, tfrecs')
+tf.flags.DEFINE_string('test', '', 'test dataset, tfrecs')
+tf.flags.DEFINE_string('embeddings', '', 'embeddings, npy')
+tf.flags.DEFINE_string('relations', '', 'file containing relation labels')
+tf.flags.DEFINE_string('output_dir', '', 'output directory for path embeddings')
+tf.flags.DEFINE_string('logdir', '', 'directory for model training')
+FLAGS = tf.flags.FLAGS
+
+
+def main(_):
+ # Pick up any one-off hyper-parameters.
+ hparams = path_model.PathBasedModel.default_hparams()
+
+ with open(FLAGS.relations) as fh:
+ relations = fh.read().splitlines()
+
+ hparams.num_classes = len(relations)
+ print('Model will predict into %d classes' % hparams.num_classes)
+
+ print('Running with hyper-parameters: {}'.format(hparams))
+
+ # Load the instances
+ print('Loading instances...')
+ opts = tf.python_io.TFRecordOptions(
+ compression_type=tf.python_io.TFRecordCompressionType.GZIP)
+
+ train_instances = list(tf.python_io.tf_record_iterator(FLAGS.train, opts))
+ val_instances = list(tf.python_io.tf_record_iterator(FLAGS.val, opts))
+ test_instances = list(tf.python_io.tf_record_iterator(FLAGS.test, opts))
+
+ # Load the word embeddings
+ print('Loading word embeddings...')
+ lemma_embeddings = lexnet_common.load_word_embeddings(FLAGS.embeddings)
+
+ # Define the graph and the model
+ with tf.Graph().as_default():
+ with tf.variable_scope('lexnet'):
+ options = tf.python_io.TFRecordOptions(
+ compression_type=tf.python_io.TFRecordCompressionType.GZIP)
+ reader = tf.TFRecordReader(options=options)
+ _, train_instance = reader.read(
+ tf.train.string_input_producer([FLAGS.train]))
+ shuffled_train_instance = tf.train.shuffle_batch(
+ [train_instance],
+ batch_size=1,
+ num_threads=1,
+ capacity=len(train_instances),
+ min_after_dequeue=100,
+ )[0]
+
+ train_model = path_model.PathBasedModel(
+ hparams, lemma_embeddings, shuffled_train_instance)
+
+ with tf.variable_scope('lexnet', reuse=True):
+ val_instance = tf.placeholder(dtype=tf.string)
+ val_model = path_model.PathBasedModel(
+ hparams, lemma_embeddings, val_instance)
+
+ # Initialize a session and start training
+ best_model_saver = tf.train.Saver()
+ f1_t = tf.placeholder(tf.float32)
+ best_f1_t = tf.Variable(0.0, trainable=False, name='best_f1')
+ assign_best_f1_op = tf.assign(best_f1_t, f1_t)
+
+ supervisor = tf.train.Supervisor(
+ logdir=FLAGS.logdir,
+ global_step=train_model.global_step)
+
+ with supervisor.managed_session() as session:
+ # Load the labels
+ print('Loading labels...')
+ val_labels = train_model.load_labels(session, val_instances)
+
+ # Train the model
+ print('Training the model...')
+
+ while True:
+ step = session.run(train_model.global_step)
+ epoch = (step + len(train_instances) - 1) // len(train_instances)
+ if epoch > hparams.num_epochs:
+ break
+
+ print('Starting epoch %d (step %d)...' % (1 + epoch, step))
+
+ epoch_loss = train_model.run_one_epoch(session, len(train_instances))
+
+ best_f1 = session.run(best_f1_t)
+ f1 = epoch_completed(val_model, session, epoch, epoch_loss,
+ val_instances, val_labels, best_model_saver,
+ FLAGS.logdir, best_f1)
+
+ if f1 > best_f1:
+ session.run(assign_best_f1_op, {f1_t: f1})
+
+ if f1 < best_f1 - 0.08:
+ tf.logging.info('Stopping training after %d epochs.\n' % epoch)
+ break
+
+ # Print the best performance on the validation set
+ best_f1 = session.run(best_f1_t)
+ print('Best performance on the validation set: F1=%.3f' % best_f1)
+
+ # Save the path embeddings
+ print('Computing the path embeddings...')
+ instances = train_instances + val_instances + test_instances
+ path_index, path_vectors = path_model.compute_path_embeddings(
+ val_model, session, instances)
+
+ if not os.path.exists(path_emb_dir):
+ os.makedirs(path_emb_dir)
+
+ path_model.save_path_embeddings(
+ val_model, path_vectors, path_index, FLAGS.output_dir)
+
+
+def epoch_completed(model, session, epoch, epoch_loss,
+ val_instances, val_labels, saver, save_path, best_f1):
+ """Runs every time an epoch completes.
+
+ Print the performance on the validation set, and update the saved model if
+ its performance is better on the previous ones. If the performance dropped,
+ tell the training to stop.
+
+ Args:
+ model: The currently trained path-based model.
+ session: The current TensorFlow session.
+ epoch: The epoch number.
+ epoch_loss: The current epoch loss.
+ val_instances: The validation set instances (evaluation between epochs).
+ val_labels: The validation set labels (for evaluation between epochs).
+ saver: tf.Saver object
+ save_path: Where to save the model.
+ best_f1: the best F1 achieved so far.
+
+ Returns:
+ The F1 achieved on the training set.
+ """
+ # Evaluate on the validation set
+ val_pred = model.predict(session, val_instances)
+ precision, recall, f1, _ = metrics.precision_recall_fscore_support(
+ val_labels, val_pred, average='weighted')
+ print(
+ 'Epoch: %d/%d, Loss: %f, validation set: P: %.3f, R: %.3f, F1: %.3f\n' % (
+ epoch + 1, model.hparams.num_epochs, epoch_loss,
+ precision, recall, f1))
+
+ if f1 > best_f1:
+ save_filename = os.path.join(save_path, 'best.ckpt')
+ print('Saving model in: %s' % save_filename)
+ saver.save(session, save_filename)
+ print('Model saved in file: %s' % save_filename)
+
+ return f1
+
+
+if __name__ == '__main__':
+ tf.app.run(main)
diff --git a/models/research/lexnet_nc/lexnet_common.py b/models/research/lexnet_nc/lexnet_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e8a104d00c1c2f90731f4045c3c8e69e370dbf
--- /dev/null
+++ b/models/research/lexnet_nc/lexnet_common.py
@@ -0,0 +1,197 @@
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Common stuff used with LexNET."""
+# pylint: disable=bad-whitespace
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+from sklearn import metrics
+import tensorflow as tf
+
+# Part of speech tags used in the paths.
+POSTAGS = [
+ 'PAD', 'VERB', 'CONJ', 'NOUN', 'PUNCT',
+ 'ADP', 'ADJ', 'DET', 'ADV', 'PART',
+ 'NUM', 'X', 'INTJ', 'SYM',
+]
+
+POSTAG_TO_ID = {tag: tid for tid, tag in enumerate(POSTAGS)}
+
+# Dependency labels used in the paths.
+DEPLABELS = [
+ 'PAD', 'UNK', 'ROOT', 'abbrev', 'acomp', 'advcl',
+ 'advmod', 'agent', 'amod', 'appos', 'attr', 'aux',
+ 'auxpass', 'cc', 'ccomp', 'complm', 'conj', 'cop',
+ 'csubj', 'csubjpass', 'dep', 'det', 'dobj', 'expl',
+ 'infmod', 'iobj', 'mark', 'mwe', 'nc', 'neg',
+ 'nn', 'npadvmod', 'nsubj', 'nsubjpass', 'num', 'number',
+ 'p', 'parataxis', 'partmod', 'pcomp', 'pobj', 'poss',
+ 'preconj', 'predet', 'prep', 'prepc', 'prt', 'ps',
+ 'purpcl', 'quantmod', 'rcmod', 'ref', 'rel', 'suffix',
+ 'title', 'tmod', 'xcomp', 'xsubj',
+]
+
+DEPLABEL_TO_ID = {label: lid for lid, label in enumerate(DEPLABELS)}
+
+# Direction codes used in the paths.
+DIRS = '_^V<>'
+DIR_TO_ID = {dir: did for did, dir in enumerate(DIRS)}
+
+
+def load_word_embeddings(embedding_filename):
+ """Loads pretrained word embeddings from a binary file and returns the matrix.
+
+ Adds the , , , and tokens to the beginning of the vocab.
+
+ Args:
+ embedding_filename: filename of the binary NPY data
+
+ Returns:
+ The word embeddings matrix
+ """
+ embeddings = np.load(embedding_filename)
+ dim = embeddings.shape[1]
+
+ # Four initially random vectors for the special tokens: , , ,
+ special_embeddings = np.random.normal(0, 0.1, (4, dim))
+ embeddings = np.vstack((special_embeddings, embeddings))
+ embeddings = embeddings.astype(np.float32)
+
+ return embeddings
+
+
+def full_evaluation(model, session, instances, labels, set_name, classes):
+ """Prints a full evaluation on the current set.
+
+ Performance (recall, precision and F1), classification report (per
+ class performance), and confusion matrix).
+
+ Args:
+ model: The currently trained path-based model.
+ session: The current TensorFlow session.
+ instances: The current set instances.
+ labels: The current set labels.
+ set_name: The current set name (train/validation/test).
+ classes: The class label names.
+
+ Returns:
+ The model's prediction for the given instances.
+ """
+
+ # Predict the labels
+ pred = model.predict(session, instances)
+
+ # Print the performance
+ precision, recall, f1, _ = metrics.precision_recall_fscore_support(
+ labels, pred, average='weighted')
+
+ print('%s set: Precision: %.3f, Recall: %.3f, F1: %.3f' % (
+ set_name, precision, recall, f1))
+
+ # Print a classification report
+ print('%s classification report:' % set_name)
+ print(metrics.classification_report(labels, pred, target_names=classes))
+
+ # Print the confusion matrix
+ print('%s confusion matrix:' % set_name)
+ cm = metrics.confusion_matrix(labels, pred, labels=range(len(classes)))
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
+ print_cm(cm, labels=classes)
+ return pred
+
+
+def print_cm(cm, labels):
+ """Pretty print for confusion matrices.
+
+ From: https://gist.github.com/zachguo/10296432.
+
+ Args:
+ cm: The confusion matrix.
+ labels: The class names.
+ """
+ columnwidth = 10
+ empty_cell = ' ' * columnwidth
+ short_labels = [label[:12].rjust(10, ' ') for label in labels]
+
+ # Print header
+ header = empty_cell + ' '
+ header += ''.join([' %{0}s '.format(columnwidth) % label
+ for label in short_labels])
+
+ print(header)
+
+ # Print rows
+ for i, label1 in enumerate(short_labels):
+ row = '%{0}s '.format(columnwidth) % label1[:10]
+ for j in range(len(short_labels)):
+ value = int(cm[i, j]) if not np.isnan(cm[i, j]) else 0
+ cell = ' %{0}d '.format(10) % value
+ row += cell + ' '
+ print(row)
+
+
+def load_all_labels(records):
+ """Reads TensorFlow examples from a RecordReader and returns only the labels.
+
+ Args:
+ records: a record list with TensorFlow examples.
+
+ Returns:
+ The labels
+ """
+ curr_features = tf.parse_example(records, {
+ 'rel_id': tf.FixedLenFeature([1], dtype=tf.int64),
+ })
+
+ labels = tf.squeeze(curr_features['rel_id'], [-1])
+ return labels
+
+
+def load_all_pairs(records):
+ """Reads TensorFlow examples from a RecordReader and returns the word pairs.
+
+ Args:
+ records: a record list with TensorFlow examples.
+
+ Returns:
+ The word pairs
+ """
+ curr_features = tf.parse_example(records, {
+ 'pair': tf.FixedLenFeature([1], dtype=tf.string)
+ })
+
+ word_pairs = curr_features['pair']
+ return word_pairs
+
+
+def write_predictions(pairs, labels, predictions, classes, predictions_file):
+ """Write the predictions to a file.
+
+ Args:
+ pairs: the word pairs (list of tuple of two strings).
+ labels: the gold-standard labels for these pairs (array of rel ID).
+ predictions: the predicted labels for these pairs (array of rel ID).
+ classes: a list of relation names.
+ predictions_file: where to save the predictions.
+ """
+ with open(predictions_file, 'w') as f_out:
+ for pair, label, pred in zip(pairs, labels, predictions):
+ w1, w2 = pair
+ f_out.write('\t'.join([w1, w2, classes[label], classes[pred]]) + '\n')
diff --git a/models/research/lexnet_nc/lexnet_model.py b/models/research/lexnet_nc/lexnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0f16b030b3bb3fee68b91122bcd03226ffcfa4a
--- /dev/null
+++ b/models/research/lexnet_nc/lexnet_model.py
@@ -0,0 +1,438 @@
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""The integrated LexNET model."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import lexnet_common
+import numpy as np
+import tensorflow as tf
+from six.moves import xrange
+
+
+class LexNETModel(object):
+ """The LexNET model for classifying relationships between noun compounds."""
+
+ @classmethod
+ def default_hparams(cls):
+ """Returns the default hyper-parameters."""
+ return tf.contrib.training.HParams(
+ batch_size=10,
+ num_classes=37,
+ num_epochs=30,
+ input_keep_prob=0.9,
+ input='integrated', # dist/ dist-nc/ path/ integrated/ integrated-nc
+ learn_relata=False,
+ corpus='wiki_gigawords',
+ random_seed=133, # zero means no random seed
+ relata_embeddings_file='glove/glove.6B.300d.bin',
+ nc_embeddings_file='nc_glove/vecs.6B.300d.bin',
+ path_embeddings_file='path_embeddings/tratz/fine_grained/wiki',
+ hidden_layers=1,
+ path_dim=60)
+
+ def __init__(self, hparams, relata_embeddings, path_embeddings, nc_embeddings,
+ path_to_index):
+ """Initialize the LexNET classifier.
+
+ Args:
+ hparams: the hyper-parameters.
+ relata_embeddings: word embeddings for the distributional component.
+ path_embeddings: embeddings for the paths.
+ nc_embeddings: noun compound embeddings.
+ path_to_index: a mapping from string path to an index in the path
+ embeddings matrix.
+ """
+ self.hparams = hparams
+
+ self.path_embeddings = path_embeddings
+ self.relata_embeddings = relata_embeddings
+ self.nc_embeddings = nc_embeddings
+
+ self.vocab_size, self.relata_dim = 0, 0
+ self.path_to_index = None
+ self.path_dim = 0
+
+ # Set the random seed
+ if hparams.random_seed > 0:
+ tf.set_random_seed(hparams.random_seed)
+
+ # Get the vocabulary size and relata dim
+ if self.hparams.input in ['dist', 'dist-nc', 'integrated', 'integrated-nc']:
+ self.vocab_size, self.relata_dim = self.relata_embeddings.shape
+
+ # Create the mapping from string path to an index in the embeddings matrix
+ if self.hparams.input in ['path', 'integrated', 'integrated-nc']:
+ self.path_to_index = tf.contrib.lookup.HashTable(
+ tf.contrib.lookup.KeyValueTensorInitializer(
+ tf.constant(path_to_index.keys()),
+ tf.constant(path_to_index.values()),
+ key_dtype=tf.string, value_dtype=tf.int32), 0)
+
+ self.path_dim = self.path_embeddings.shape[1]
+
+ # Create the network
+ self.__create_computation_graph__()
+
+ def __create_computation_graph__(self):
+ """Initialize the model and define the graph."""
+ network_input = 0
+
+ # Define the network inputs
+ # Distributional x and y
+ if self.hparams.input in ['dist', 'dist-nc', 'integrated', 'integrated-nc']:
+ network_input += 2 * self.relata_dim
+ self.relata_lookup = tf.get_variable(
+ 'relata_lookup',
+ initializer=self.relata_embeddings,
+ dtype=tf.float32,
+ trainable=self.hparams.learn_relata)
+
+ # Path-based
+ if self.hparams.input in ['path', 'integrated', 'integrated-nc']:
+ network_input += self.path_dim
+
+ self.path_initial_value_t = tf.placeholder(tf.float32, None)
+
+ self.path_lookup = tf.get_variable(
+ name='path_lookup',
+ dtype=tf.float32,
+ trainable=False,
+ shape=self.path_embeddings.shape)
+
+ self.initialize_path_op = tf.assign(
+ self.path_lookup, self.path_initial_value_t, validate_shape=False)
+
+ # Distributional noun compound
+ if self.hparams.input in ['dist-nc', 'integrated-nc']:
+ network_input += self.relata_dim
+
+ self.nc_initial_value_t = tf.placeholder(tf.float32, None)
+
+ self.nc_lookup = tf.get_variable(
+ name='nc_lookup',
+ dtype=tf.float32,
+ trainable=False,
+ shape=self.nc_embeddings.shape)
+
+ self.initialize_nc_op = tf.assign(
+ self.nc_lookup, self.nc_initial_value_t, validate_shape=False)
+
+ hidden_dim = network_input // 2
+
+ # Define the MLP
+ if self.hparams.hidden_layers == 0:
+ self.weights1 = tf.get_variable(
+ 'W1',
+ shape=[network_input, self.hparams.num_classes],
+ dtype=tf.float32)
+ self.bias1 = tf.get_variable(
+ 'b1',
+ shape=[self.hparams.num_classes],
+ dtype=tf.float32)
+
+ elif self.hparams.hidden_layers == 1:
+
+ self.weights1 = tf.get_variable(
+ 'W1',
+ shape=[network_input, hidden_dim],
+ dtype=tf.float32)
+ self.bias1 = tf.get_variable(
+ 'b1',
+ shape=[hidden_dim],
+ dtype=tf.float32)
+
+ self.weights2 = tf.get_variable(
+ 'W2',
+ shape=[hidden_dim, self.hparams.num_classes],
+ dtype=tf.float32)
+ self.bias2 = tf.get_variable(
+ 'b2',
+ shape=[self.hparams.num_classes],
+ dtype=tf.float32)
+
+ else:
+ raise ValueError('Only 0 or 1 hidden layers are supported')
+
+ # Define the variables
+ self.instances = tf.placeholder(dtype=tf.string,
+ shape=[self.hparams.batch_size])
+
+ (self.x_embedding_id,
+ self.y_embedding_id,
+ self.nc_embedding_id,
+ self.path_embedding_id,
+ self.path_counts,
+ self.labels) = parse_tensorflow_examples(
+ self.instances, self.hparams.batch_size, self.path_to_index)
+
+ # Create the MLP
+ self.__mlp__()
+
+ self.instances_to_load = tf.placeholder(dtype=tf.string, shape=[None])
+ self.labels_to_load = lexnet_common.load_all_labels(self.instances_to_load)
+ self.pairs_to_load = lexnet_common.load_all_pairs(self.instances_to_load)
+
+ def load_labels(self, session, instances):
+ """Loads the labels for these instances.
+
+ Args:
+ session: The current TensorFlow session,
+ instances: The instances for which to load the labels.
+
+ Returns:
+ the labels of these instances.
+ """
+ return session.run(self.labels_to_load,
+ feed_dict={self.instances_to_load: instances})
+
+ def load_pairs(self, session, instances):
+ """Loads the word pairs for these instances.
+
+ Args:
+ session: The current TensorFlow session,
+ instances: The instances for which to load the labels.
+
+ Returns:
+ the word pairs of these instances.
+ """
+ word_pairs = session.run(self.pairs_to_load,
+ feed_dict={self.instances_to_load: instances})
+ return [pair[0].split('::') for pair in word_pairs]
+
+ def __train_single_batch__(self, session, batch_instances):
+ """Train a single batch.
+
+ Args:
+ session: The current TensorFlow session.
+ batch_instances: TensorFlow examples containing the training intances
+
+ Returns:
+ The cost for the current batch.
+ """
+ cost, _ = session.run([self.cost, self.train_op],
+ feed_dict={self.instances: batch_instances})
+
+ return cost
+
+ def fit(self, session, inputs, on_epoch_completed, val_instances, val_labels,
+ save_path):
+ """Train the model.
+
+ Args:
+ session: The current TensorFlow session.
+ inputs:
+ on_epoch_completed: A method to call after each epoch.
+ val_instances: The validation set instances (evaluation between epochs).
+ val_labels: The validation set labels (for evaluation between epochs).
+ save_path: Where to save the model.
+ """
+ for epoch in range(self.hparams.num_epochs):
+
+ losses = []
+ epoch_indices = list(np.random.permutation(len(inputs)))
+
+ # If the number of instances doesn't divide by batch_size, enlarge it
+ # by duplicating training examples
+ mod = len(epoch_indices) % self.hparams.batch_size
+ if mod > 0:
+ epoch_indices.extend([np.random.randint(0, high=len(inputs))] * mod)
+
+ # Define the batches
+ n_batches = len(epoch_indices) // self.hparams.batch_size
+
+ for minibatch in range(n_batches):
+
+ batch_indices = epoch_indices[minibatch * self.hparams.batch_size:(
+ minibatch + 1) * self.hparams.batch_size]
+ batch_instances = [inputs[i] for i in batch_indices]
+
+ loss = self.__train_single_batch__(session, batch_instances)
+ losses.append(loss)
+
+ epoch_loss = np.nanmean(losses)
+
+ if on_epoch_completed:
+ should_stop = on_epoch_completed(self, session, epoch, epoch_loss,
+ val_instances, val_labels, save_path)
+ if should_stop:
+ print('Stopping training after %d epochs.' % epoch)
+ return
+
+ def predict(self, session, inputs):
+ """Predict the classification of the test set.
+
+ Args:
+ session: The current TensorFlow session.
+ inputs: the train paths, x, y and/or nc vectors
+
+ Returns:
+ The test predictions.
+ """
+ predictions, _ = zip(*self.predict_with_score(session, inputs))
+ return np.array(predictions)
+
+ def predict_with_score(self, session, inputs):
+ """Predict the classification of the test set.
+
+ Args:
+ session: The current TensorFlow session.
+ inputs: the test paths, x, y and/or nc vectors
+
+ Returns:
+ The test predictions along with their scores.
+ """
+ test_pred = [0] * len(inputs)
+
+ for chunk in xrange(0, len(test_pred), self.hparams.batch_size):
+
+ # Initialize the variables with the current batch data
+ batch_indices = list(
+ range(chunk, min(chunk + self.hparams.batch_size, len(test_pred))))
+
+ # If the batch is too small, add a few other examples
+ if len(batch_indices) < self.hparams.batch_size:
+ batch_indices += [0] * (self.hparams.batch_size-len(batch_indices))
+
+ batch_instances = [inputs[i] for i in batch_indices]
+
+ predictions, scores = session.run(
+ [self.predictions, self.scores],
+ feed_dict={self.instances: batch_instances})
+
+ for index_in_batch, index_in_dataset in enumerate(batch_indices):
+ prediction = predictions[index_in_batch]
+ score = scores[index_in_batch][prediction]
+ test_pred[index_in_dataset] = (prediction, score)
+
+ return test_pred
+
+ def __mlp__(self):
+ """Performs the MLP operations.
+
+ Returns: the prediction object to be computed in a Session
+ """
+ # Define the operations
+
+ # Network input
+ vec_inputs = []
+
+ # Distributional component
+ if self.hparams.input in ['dist', 'dist-nc', 'integrated', 'integrated-nc']:
+ for emb_id in [self.x_embedding_id, self.y_embedding_id]:
+ vec_inputs.append(tf.nn.embedding_lookup(self.relata_lookup, emb_id))
+
+ # Noun compound component
+ if self.hparams.input in ['dist-nc', 'integrated-nc']:
+ vec = tf.nn.embedding_lookup(self.nc_lookup, self.nc_embedding_id)
+ vec_inputs.append(vec)
+
+ # Path-based component
+ if self.hparams.input in ['path', 'integrated', 'integrated-nc']:
+
+ # Get the current paths for each batch instance
+ self.path_embeddings = tf.nn.embedding_lookup(self.path_lookup,
+ self.path_embedding_id)
+
+ # self.path_embeddings is of shape
+ # [batch_size, max_path_per_instance, output_dim]
+ # We need to multiply it by path counts
+ # ([batch_size, max_path_per_instance]).
+ # Start by duplicating path_counts along the output_dim axis.
+ self.path_freq = tf.tile(tf.expand_dims(self.path_counts, -1),
+ [1, 1, self.path_dim])
+
+ # Compute the averaged path vector for each instance.
+ # First, multiply the path embeddings and frequencies element-wise.
+ self.weighted = tf.multiply(self.path_freq, self.path_embeddings)
+
+ # Second, take the sum to get a tensor of shape [batch_size, output_dim].
+ self.pair_path_embeddings = tf.reduce_sum(self.weighted, 1)
+
+ # Finally, divide by the total number of paths.
+ # The number of paths for each pair has a shape [batch_size, 1],
+ # We duplicate it output_dim times along the second axis.
+ self.num_paths = tf.clip_by_value(
+ tf.reduce_sum(self.path_counts, 1), 1, np.inf)
+ self.num_paths = tf.tile(tf.expand_dims(self.num_paths, -1),
+ [1, self.path_dim])
+
+ # And finally, divide pair_path_embeddings by num_paths element-wise.
+ self.pair_path_embeddings = tf.div(
+ self.pair_path_embeddings, self.num_paths)
+ vec_inputs.append(self.pair_path_embeddings)
+
+ # Concatenate the inputs and feed to the MLP
+ self.input_vec = tf.nn.dropout(
+ tf.concat(vec_inputs, 1),
+ keep_prob=self.hparams.input_keep_prob)
+
+ h = tf.matmul(self.input_vec, self.weights1)
+ self.output = h
+
+ if self.hparams.hidden_layers == 1:
+ self.output = tf.matmul(tf.nn.tanh(h), self.weights2)
+
+ self.scores = self.output
+ self.predictions = tf.argmax(self.scores, axis=1)
+
+ # Define the loss function and the optimization algorithm
+ self.cross_entropies = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=self.scores, labels=self.labels)
+ self.cost = tf.reduce_sum(self.cross_entropies, name='cost')
+ self.global_step = tf.Variable(0, name='global_step', trainable=False)
+ self.optimizer = tf.train.AdamOptimizer()
+ self.train_op = self.optimizer.minimize(
+ self.cost, global_step=self.global_step)
+
+
+def parse_tensorflow_examples(record, batch_size, path_to_index):
+ """Reads TensorFlow examples from a RecordReader.
+
+ Args:
+ record: a record with TensorFlow examples.
+ batch_size: the number of instances in a minibatch
+ path_to_index: mapping from string path to index in the embeddings matrix.
+
+ Returns:
+ The word embeddings IDs, paths and counts
+ """
+ features = tf.parse_example(
+ record, {
+ 'x_embedding_id': tf.FixedLenFeature([1], dtype=tf.int64),
+ 'y_embedding_id': tf.FixedLenFeature([1], dtype=tf.int64),
+ 'nc_embedding_id': tf.FixedLenFeature([1], dtype=tf.int64),
+ 'reprs': tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.string, allow_missing=True),
+ 'counts': tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.int64, allow_missing=True),
+ 'rel_id': tf.FixedLenFeature([1], dtype=tf.int64)
+ })
+
+ x_embedding_id = tf.squeeze(features['x_embedding_id'], [-1])
+ y_embedding_id = tf.squeeze(features['y_embedding_id'], [-1])
+ nc_embedding_id = tf.squeeze(features['nc_embedding_id'], [-1])
+ labels = tf.squeeze(features['rel_id'], [-1])
+ path_counts = tf.to_float(tf.reshape(features['counts'], [batch_size, -1]))
+
+ path_embedding_id = None
+ if path_to_index:
+ path_embedding_id = path_to_index.lookup(features['reprs'])
+
+ return (
+ x_embedding_id, y_embedding_id, nc_embedding_id,
+ path_embedding_id, path_counts, labels)
diff --git a/models/research/lexnet_nc/path_model.py b/models/research/lexnet_nc/path_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c283841775d673baa8a4bc8c438d65f288a2c555
--- /dev/null
+++ b/models/research/lexnet_nc/path_model.py
@@ -0,0 +1,547 @@
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""LexNET Path-based Model."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import itertools
+import os
+
+import lexnet_common
+import numpy as np
+import tensorflow as tf
+
+
+class PathBasedModel(object):
+ """The LexNET path-based model for classifying semantic relations."""
+
+ @classmethod
+ def default_hparams(cls):
+ """Returns the default hyper-parameters."""
+ return tf.contrib.training.HParams(
+ max_path_len=8,
+ num_classes=37,
+ num_epochs=30,
+ input_keep_prob=0.9,
+ learning_rate=0.001,
+ learn_lemmas=False,
+ random_seed=133, # zero means no random seed
+ lemma_embeddings_file='glove/glove.6B.50d.bin',
+ num_pos=len(lexnet_common.POSTAGS),
+ num_dep=len(lexnet_common.DEPLABELS),
+ num_directions=len(lexnet_common.DIRS),
+ lemma_dim=50,
+ pos_dim=4,
+ dep_dim=5,
+ dir_dim=1)
+
+ def __init__(self, hparams, lemma_embeddings, instance):
+ """Initialize the LexNET classifier.
+
+ Args:
+ hparams: the hyper-parameters.
+ lemma_embeddings: word embeddings for the path-based component.
+ instance: string tensor containing the input instance
+ """
+ self.hparams = hparams
+ self.lemma_embeddings = lemma_embeddings
+ self.instance = instance
+ self.vocab_size, self.lemma_dim = self.lemma_embeddings.shape
+
+ # Set the random seed
+ if hparams.random_seed > 0:
+ tf.set_random_seed(hparams.random_seed)
+
+ # Create the network
+ self.__create_computation_graph__()
+
+ def __create_computation_graph__(self):
+ """Initialize the model and define the graph."""
+ self.lstm_input_dim = sum([self.hparams.lemma_dim, self.hparams.pos_dim,
+ self.hparams.dep_dim, self.hparams.dir_dim])
+ self.lstm_output_dim = self.lstm_input_dim
+
+ network_input = self.lstm_output_dim
+ self.lemma_lookup = tf.get_variable(
+ 'lemma_lookup',
+ initializer=self.lemma_embeddings,
+ dtype=tf.float32,
+ trainable=self.hparams.learn_lemmas)
+ self.pos_lookup = tf.get_variable(
+ 'pos_lookup',
+ shape=[self.hparams.num_pos, self.hparams.pos_dim],
+ dtype=tf.float32)
+ self.dep_lookup = tf.get_variable(
+ 'dep_lookup',
+ shape=[self.hparams.num_dep, self.hparams.dep_dim],
+ dtype=tf.float32)
+ self.dir_lookup = tf.get_variable(
+ 'dir_lookup',
+ shape=[self.hparams.num_directions, self.hparams.dir_dim],
+ dtype=tf.float32)
+
+ self.weights1 = tf.get_variable(
+ 'W1',
+ shape=[network_input, self.hparams.num_classes],
+ dtype=tf.float32)
+ self.bias1 = tf.get_variable(
+ 'b1',
+ shape=[self.hparams.num_classes],
+ dtype=tf.float32)
+
+ # Define the variables
+ (self.batch_paths,
+ self.path_counts,
+ self.seq_lengths,
+ self.path_strings,
+ self.batch_labels) = _parse_tensorflow_example(
+ self.instance, self.hparams.max_path_len, self.hparams.input_keep_prob)
+
+ # Create the LSTM
+ self.__lstm__()
+
+ # Create the MLP
+ self.__mlp__()
+
+ self.instances_to_load = tf.placeholder(dtype=tf.string, shape=[None])
+ self.labels_to_load = lexnet_common.load_all_labels(self.instances_to_load)
+
+ def load_labels(self, session, batch_instances):
+ """Loads the labels of the current instances.
+
+ Args:
+ session: the current TensorFlow session.
+ batch_instances: the dataset instances.
+
+ Returns:
+ the labels.
+ """
+ return session.run(self.labels_to_load,
+ feed_dict={self.instances_to_load: batch_instances})
+
+ def run_one_epoch(self, session, num_steps):
+ """Train the model.
+
+ Args:
+ session: The current TensorFlow session.
+ num_steps: The number of steps in each epoch.
+
+ Returns:
+ The mean loss for the epoch.
+
+ Raises:
+ ArithmeticError: if the loss becomes non-finite.
+ """
+ losses = []
+
+ for step in range(num_steps):
+ curr_loss, _ = session.run([self.cost, self.train_op])
+ if not np.isfinite(curr_loss):
+ raise ArithmeticError('nan loss at step %d' % step)
+
+ losses.append(curr_loss)
+
+ return np.mean(losses)
+
+ def predict(self, session, inputs):
+ """Predict the classification of the test set.
+
+ Args:
+ session: The current TensorFlow session.
+ inputs: the train paths, x, y and/or nc vectors
+
+ Returns:
+ The test predictions.
+ """
+ predictions, _ = zip(*self.predict_with_score(session, inputs))
+ return np.array(predictions)
+
+ def predict_with_score(self, session, inputs):
+ """Predict the classification of the test set.
+
+ Args:
+ session: The current TensorFlow session.
+ inputs: the test paths, x, y and/or nc vectors
+
+ Returns:
+ The test predictions along with their scores.
+ """
+ test_pred = [0] * len(inputs)
+
+ for index, instance in enumerate(inputs):
+
+ prediction, scores = session.run(
+ [self.predictions, self.scores],
+ feed_dict={self.instance: instance})
+
+ test_pred[index] = (prediction, scores[prediction])
+
+ return test_pred
+
+ def __mlp__(self):
+ """Performs the MLP operations.
+
+ Returns: the prediction object to be computed in a Session
+ """
+ # Feed the paths to the MLP: path_embeddings is
+ # [num_batch_paths, output_dim], and when we multiply it by W
+ # ([output_dim, num_classes]), we get a matrix of class distributions:
+ # [num_batch_paths, num_classes].
+ self.distributions = tf.matmul(self.path_embeddings, self.weights1)
+
+ # Now, compute weighted average on the class distributions, using the path
+ # frequency as weights.
+
+ # First, reshape path_freq to the same shape of distributions
+ self.path_freq = tf.tile(tf.expand_dims(self.path_counts, -1),
+ [1, self.hparams.num_classes])
+
+ # Second, multiply the distributions and frequencies element-wise.
+ self.weighted = tf.multiply(self.path_freq, self.distributions)
+
+ # Finally, take the average to get a tensor of shape [1, num_classes].
+ self.weighted_sum = tf.reduce_sum(self.weighted, 0)
+ self.num_paths = tf.clip_by_value(tf.reduce_sum(self.path_counts),
+ 1, np.inf)
+ self.num_paths = tf.tile(tf.expand_dims(self.num_paths, -1),
+ [self.hparams.num_classes])
+ self.scores = tf.div(self.weighted_sum, self.num_paths)
+ self.predictions = tf.argmax(self.scores)
+
+ # Define the loss function and the optimization algorithm
+ self.cross_entropies = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=self.scores, labels=tf.reduce_mean(self.batch_labels))
+ self.cost = tf.reduce_sum(self.cross_entropies, name='cost')
+ self.global_step = tf.Variable(0, name='global_step', trainable=False)
+ self.optimizer = tf.train.AdamOptimizer()
+ self.train_op = self.optimizer.minimize(self.cost,
+ global_step=self.global_step)
+
+ def __lstm__(self):
+ """Defines the LSTM operations.
+
+ Returns:
+ A matrix of path embeddings.
+ """
+ lookup_tables = [self.lemma_lookup, self.pos_lookup,
+ self.dep_lookup, self.dir_lookup]
+
+ # Split the edges to components: list of 4 tensors
+ # [num_batch_paths, max_path_len, 1]
+ self.edge_components = tf.split(self.batch_paths, 4, axis=2)
+
+ # Look up the components embeddings and concatenate them back together
+ self.path_matrix = tf.concat([
+ tf.squeeze(tf.nn.embedding_lookup(lookup_table, component), 2)
+ for lookup_table, component in
+ zip(lookup_tables, self.edge_components)
+ ], axis=2)
+
+ self.sequence_lengths = tf.reshape(self.seq_lengths, [-1])
+
+ # Define the LSTM.
+ # The input is [num_batch_paths, max_path_len, input_dim].
+ lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.lstm_output_dim)
+
+ # The output is [num_batch_paths, max_path_len, output_dim].
+ self.lstm_outputs, _ = tf.nn.dynamic_rnn(
+ lstm_cell, self.path_matrix, dtype=tf.float32,
+ sequence_length=self.sequence_lengths)
+
+ # Slice the last *relevant* output for each instance ->
+ # [num_batch_paths, output_dim]
+ self.path_embeddings = _extract_last_relevant(self.lstm_outputs,
+ self.sequence_lengths)
+
+
+def _parse_tensorflow_example(record, max_path_len, input_keep_prob):
+ """Reads TensorFlow examples from a RecordReader.
+
+ Args:
+ record: a record with TensorFlow example.
+ max_path_len: the maximum path length.
+ input_keep_prob: 1 - the word dropout probability
+
+ Returns:
+ The paths and counts
+ """
+ features = tf.parse_single_example(record, {
+ 'lemmas':
+ tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.int64, allow_missing=True),
+ 'postags':
+ tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.int64, allow_missing=True),
+ 'deplabels':
+ tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.int64, allow_missing=True),
+ 'dirs':
+ tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.int64, allow_missing=True),
+ 'counts':
+ tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.int64, allow_missing=True),
+ 'pathlens':
+ tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.int64, allow_missing=True),
+ 'reprs':
+ tf.FixedLenSequenceFeature(
+ shape=(), dtype=tf.string, allow_missing=True),
+ 'rel_id':
+ tf.FixedLenFeature([], dtype=tf.int64)
+ })
+
+ path_counts = tf.to_float(features['counts'])
+ seq_lengths = features['pathlens']
+
+ # Concatenate the edge components to create a path tensor:
+ # [max_paths_per_ins, max_path_length, 4]
+ lemmas = _word_dropout(
+ tf.reshape(features['lemmas'], [-1, max_path_len]), input_keep_prob)
+
+ paths = tf.stack(
+ [lemmas] + [
+ tf.reshape(features[f], [-1, max_path_len])
+ for f in ('postags', 'deplabels', 'dirs')
+ ],
+ axis=-1)
+
+ path_strings = features['reprs']
+
+ # Add an empty path to pairs with no paths
+ paths = tf.cond(
+ tf.shape(paths)[0] > 0,
+ lambda: paths,
+ lambda: tf.zeros([1, max_path_len, 4], dtype=tf.int64))
+
+ # Paths are left-padded. We reverse them to make them right-padded.
+ #paths = tf.reverse(paths, axis=[1])
+
+ path_counts = tf.cond(
+ tf.shape(path_counts)[0] > 0,
+ lambda: path_counts,
+ lambda: tf.constant([1.0], dtype=tf.float32))
+
+ seq_lengths = tf.cond(
+ tf.shape(seq_lengths)[0] > 0,
+ lambda: seq_lengths,
+ lambda: tf.constant([1], dtype=tf.int64))
+
+ # Duplicate the label for each path
+ labels = tf.ones_like(path_counts, dtype=tf.int64) * features['rel_id']
+
+ return paths, path_counts, seq_lengths, path_strings, labels
+
+
+def _extract_last_relevant(output, seq_lengths):
+ """Get the last relevant LSTM output cell for each batch instance.
+
+ Args:
+ output: the LSTM outputs - a tensor with shape
+ [num_paths, output_dim, max_path_len]
+ seq_lengths: the sequences length per instance
+
+ Returns:
+ The last relevant LSTM output cell for each batch instance.
+ """
+ max_length = int(output.get_shape()[1])
+ path_lengths = tf.clip_by_value(seq_lengths - 1, 0, max_length)
+ relevant = tf.reduce_sum(tf.multiply(output, tf.expand_dims(
+ tf.one_hot(path_lengths, max_length), -1)), 1)
+ return relevant
+
+
+def _word_dropout(words, input_keep_prob):
+ """Drops words with probability 1 - input_keep_prob.
+
+ Args:
+ words: a list of lemmas from the paths.
+ input_keep_prob: the probability to keep the word.
+
+ Returns:
+ The revised list where some of the words are ed.
+ """
+ # Create the mask: (-1) to drop, 1 to keep
+ prob = tf.random_uniform(tf.shape(words), 0, 1)
+ condition = tf.less(prob, (1 - input_keep_prob))
+ mask = tf.where(condition,
+ tf.negative(tf.ones_like(words)), tf.ones_like(words))
+
+ # We need to keep zeros (), and change other numbers to 1 ()
+ # if their mask is -1. First, we multiply the mask and the words.
+ # Zeros will stay zeros, and words to drop will become negative.
+ # Then, we change negative values to 1.
+ masked_words = tf.multiply(mask, words)
+ condition = tf.less(masked_words, 0)
+ dropped_words = tf.where(condition, tf.ones_like(words), words)
+ return dropped_words
+
+
+def compute_path_embeddings(model, session, instances):
+ """Compute the path embeddings for all the distinct paths.
+
+ Args:
+ model: The trained path-based model.
+ session: The current TensorFlow session.
+ instances: All the train, test and validation instances.
+
+ Returns:
+ The path to ID index and the path embeddings.
+ """
+ # Get an index for each distinct path
+ path_index = collections.defaultdict(itertools.count(0).next)
+ path_vectors = {}
+
+ for instance in instances:
+ curr_path_embeddings, curr_path_strings = session.run(
+ [model.path_embeddings, model.path_strings],
+ feed_dict={model.instance: instance})
+
+ for i, path in enumerate(curr_path_strings):
+ if not path:
+ continue
+
+ # Set a new/existing index for the path
+ index = path_index[path]
+
+ # Save its vector
+ path_vectors[index] = curr_path_embeddings[i, :]
+
+ print('Number of distinct paths: %d' % len(path_index))
+ return path_index, path_vectors
+
+
+def save_path_embeddings(model, path_vectors, path_index, embeddings_base_path):
+ """Saves the path embeddings.
+
+ Args:
+ model: The trained path-based model.
+ path_vectors: The path embeddings.
+ path_index: A map from path to ID.
+ embeddings_base_path: The base directory where the embeddings are.
+ """
+ index_range = range(max(path_index.values()) + 1)
+ path_matrix = [path_vectors[i] for i in index_range]
+ path_matrix = np.vstack(path_matrix)
+
+ # Save the path embeddings
+ path_vector_filename = os.path.join(
+ embeddings_base_path, '%d_path_vectors' % model.lstm_output_dim)
+ with open(path_vector_filename, 'w') as f_out:
+ np.save(f_out, path_matrix)
+
+ index_to_path = {i: p for p, i in path_index.iteritems()}
+ path_vocab = [index_to_path[i] for i in index_range]
+
+ # Save the path vocabulary
+ path_vocab_filename = os.path.join(
+ embeddings_base_path, '%d_path_vocab' % model.lstm_output_dim)
+ with open(path_vocab_filename, 'w') as f_out:
+ f_out.write('\n'.join(path_vocab))
+ f_out.write('\n')
+
+ print('Saved path embeddings.')
+
+
+def load_path_embeddings(path_embeddings_dir, path_dim):
+ """Loads pretrained path embeddings from a binary file and returns the matrix.
+
+ Args:
+ path_embeddings_dir: The directory for the path embeddings.
+ path_dim: The dimension of the path embeddings, used as prefix to the
+ path_vocab and path_vectors files.
+
+ Returns:
+ The path embeddings matrix and the path_to_index dictionary.
+ """
+ prefix = path_embeddings_dir + '/%d' % path_dim + '_'
+ with open(prefix + 'path_vocab') as f_in:
+ vocab = f_in.read().splitlines()
+
+ vocab_size = len(vocab)
+ embedding_file = prefix + 'path_vectors'
+
+ print('Embedding file "%s" has %d paths' % (embedding_file, vocab_size))
+
+ with open(embedding_file) as f_in:
+ embeddings = np.load(f_in)
+
+ path_to_index = {p: i for i, p in enumerate(vocab)}
+ return embeddings, path_to_index
+
+
+def get_indicative_paths(model, session, path_index, path_vectors, classes,
+ save_dir, k=20, threshold=0.8):
+ """Gets the most indicative paths for each class.
+
+ Args:
+ model: The trained path-based model.
+ session: The current TensorFlow session.
+ path_index: A map from path to ID.
+ path_vectors: The path embeddings.
+ classes: The class label names.
+ save_dir: Where to save the paths.
+ k: The k for top-k paths.
+ threshold: The threshold above which to consider paths as indicative.
+ """
+ # Define graph variables for this operation
+ p_path_embedding = tf.placeholder(dtype=tf.float32,
+ shape=[1, model.lstm_output_dim])
+ p_distributions = tf.nn.softmax(tf.matmul(p_path_embedding, model.weights1))
+
+ # Treat each path as a pair instance with a single path, and get the
+ # relation distribution for it. Then, take the top paths for each relation.
+
+ # This dictionary contains a relation as a key, and the value is a list of
+ # tuples of path index and score. A relation r will contain (p, s) if the
+ # path p is classified to r with a confidence of s.
+ prediction_per_relation = collections.defaultdict(list)
+
+ index_to_path = {i: p for p, i in path_index.iteritems()}
+
+ # Predict all the paths
+ for index in range(len(path_index)):
+ curr_path_vector = path_vectors[index]
+
+ distribution = session.run(p_distributions,
+ feed_dict={
+ p_path_embedding: np.reshape(
+ curr_path_vector,
+ [1, model.lstm_output_dim])})
+
+ distribution = distribution[0, :]
+ prediction = np.argmax(distribution)
+ prediction_per_relation[prediction].append(
+ (index, distribution[prediction]))
+
+ if index % 10000 == 0:
+ print('Classified %d/%d (%3.2f%%) of the paths' % (
+ index, len(path_index), 100 * index / len(path_index)))
+
+ # Retrieve k-best scoring paths for each relation
+ for relation_index, relation in enumerate(classes):
+ curr_paths = sorted(prediction_per_relation[relation_index],
+ key=lambda item: item[1], reverse=True)
+ above_t = [(p, s) for (p, s) in curr_paths if s >= threshold]
+ top_k = curr_paths[k+1]
+ relation_paths = above_t if len(above_t) > len(top_k) else top_k
+
+ paths_filename = os.path.join(save_dir, '%s.paths' % relation)
+ with open(paths_filename, 'w') as f_out:
+ for index, score in relation_paths:
+ print('\t'.join([index_to_path[index], str(score)]), file=f_out)
diff --git a/models/research/lexnet_nc/sorted_paths_to_examples.py b/models/research/lexnet_nc/sorted_paths_to_examples.py
new file mode 100644
index 0000000000000000000000000000000000000000..c21d25d710ae793f6eefd889b98414c923e4fbe6
--- /dev/null
+++ b/models/research/lexnet_nc/sorted_paths_to_examples.py
@@ -0,0 +1,202 @@
+#!/usr/bin/env python
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Takes as input a sorted, tab-separated of paths to produce tf.Examples."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import itertools
+import os
+import sys
+import tensorflow as tf
+
+import lexnet_common
+
+tf.flags.DEFINE_string('input', '', 'tab-separated input data')
+tf.flags.DEFINE_string('vocab', '', 'a text file containing lemma vocabulary')
+tf.flags.DEFINE_string('relations', '', 'a text file containing the relations')
+tf.flags.DEFINE_string('output_dir', '', 'output directory')
+tf.flags.DEFINE_string('splits', '', 'text file enumerating splits')
+tf.flags.DEFINE_string('default_split', '', 'default split for unlabeled pairs')
+tf.flags.DEFINE_string('compression', 'GZIP', 'compression for output records')
+tf.flags.DEFINE_integer('max_paths', 100, 'maximum number of paths per record')
+tf.flags.DEFINE_integer('max_pathlen', 8, 'maximum path length')
+FLAGS = tf.flags.FLAGS
+
+
+def _int64_features(value):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
+
+
+def _bytes_features(value):
+ value = [v.encode('utf-8') if isinstance(v, unicode) else v for v in value]
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
+
+
+class CreateExampleFn(object):
+
+ def __init__(self):
+ # Read the vocabulary. N.B. that 0 = PAD, 1 = UNK, 2 = , 3 = , hence
+ # the enumeration starting at 4.
+ with tf.gfile.GFile(FLAGS.vocab) as fh:
+ self.vocab = {w: ix for ix, w in enumerate(fh.read().splitlines(), start=4)}
+
+ self.vocab.update({'': 0, '': 1, '': 2, '': 3})
+
+ # Read the relations.
+ with tf.gfile.GFile(FLAGS.relations) as fh:
+ self.relations = {r: ix for ix, r in enumerate(fh.read().splitlines())}
+
+ # Some hackery to map from SpaCy postags to Google's.
+ lexnet_common.POSTAG_TO_ID['PROPN'] = lexnet_common.POSTAG_TO_ID['NOUN']
+ lexnet_common.POSTAG_TO_ID['PRON'] = lexnet_common.POSTAG_TO_ID['NOUN']
+ lexnet_common.POSTAG_TO_ID['CCONJ'] = lexnet_common.POSTAG_TO_ID['CONJ']
+ #lexnet_common.DEPLABEL_TO_ID['relcl'] = lexnet_common.DEPLABEL_TO_ID['rel']
+ #lexnet_common.DEPLABEL_TO_ID['compound'] = lexnet_common.DEPLABEL_TO_ID['xcomp']
+ #lexnet_common.DEPLABEL_TO_ID['oprd'] = lexnet_common.DEPLABEL_TO_ID['UNK']
+
+ def __call__(self, mod, head, rel, raw_paths):
+ # Drop any really long paths.
+ paths = []
+ counts = []
+ for raw, count in raw_paths.most_common(FLAGS.max_paths):
+ path = raw.split('::')
+ if len(path) <= FLAGS.max_pathlen:
+ paths.append(path)
+ counts.append(count)
+
+ if not paths:
+ return None
+
+ # Compute the true length.
+ pathlens = [len(path) for path in paths]
+
+ # Pad each path out to max_pathlen so the LSTM can eat it.
+ paths = (
+ itertools.islice(
+ itertools.chain(path, itertools.repeat('/PAD/PAD/_')),
+ FLAGS.max_pathlen)
+ for path in paths)
+
+ # Split the lemma, POS, dependency label, and direction each into a
+ # separate feature.
+ lemmas, postags, deplabels, dirs = zip(
+ *(part.split('/') for part in itertools.chain(*paths)))
+
+ lemmas = [self.vocab.get(lemma, 1) for lemma in lemmas]
+ postags = [lexnet_common.POSTAG_TO_ID[pos] for pos in postags]
+ deplabels = [lexnet_common.DEPLABEL_TO_ID.get(dep, 1) for dep in deplabels]
+ dirs = [lexnet_common.DIR_TO_ID.get(d, 0) for d in dirs]
+
+ return tf.train.Example(features=tf.train.Features(feature={
+ 'pair': _bytes_features(['::'.join((mod, head))]),
+ 'rel': _bytes_features([rel]),
+ 'rel_id': _int64_features([self.relations[rel]]),
+ 'reprs': _bytes_features(raw_paths),
+ 'pathlens': _int64_features(pathlens),
+ 'counts': _int64_features(counts),
+ 'lemmas': _int64_features(lemmas),
+ 'dirs': _int64_features(dirs),
+ 'deplabels': _int64_features(deplabels),
+ 'postags': _int64_features(postags),
+ 'x_embedding_id': _int64_features([self.vocab[mod]]),
+ 'y_embedding_id': _int64_features([self.vocab[head]]),
+ }))
+
+
+def main(_):
+ # Read the splits file, if there is one.
+ assignments = {}
+ if FLAGS.splits:
+ with tf.gfile.GFile(FLAGS.splits) as fh:
+ parts = (line.split('\t') for line in fh.read().splitlines())
+ assignments = {(mod, head): split for mod, head, split in parts}
+
+ splits = set(assignments.itervalues())
+ if FLAGS.default_split:
+ default_split = FLAGS.default_split
+ splits.add(FLAGS.default_split)
+ elif splits:
+ default_split = iter(splits).next()
+ else:
+ print('Please specify --splits, --default_split, or both', file=sys.stderr)
+ return 1
+
+ last_mod, last_head, last_label = None, None, None
+ raw_paths = collections.Counter()
+
+ # Keep track of pairs we've seen to ensure that we don't get unsorted data.
+ seen_labeled_pairs = set()
+
+ # Set up output compression
+ compression_type = getattr(
+ tf.python_io.TFRecordCompressionType, FLAGS.compression)
+ options = tf.python_io.TFRecordOptions(compression_type=compression_type)
+
+ writers = {
+ split: tf.python_io.TFRecordWriter(
+ os.path.join(FLAGS.output_dir, '%s.tfrecs.gz' % split),
+ options=options)
+ for split in splits}
+
+ create_example = CreateExampleFn()
+
+ in_fh = sys.stdin if not FLAGS.input else tf.gfile.GFile(FLAGS.input)
+ for lineno, line in enumerate(in_fh, start=1):
+ if lineno % 100 == 0:
+ print('\rProcessed %d lines...' % lineno, end='', file=sys.stderr)
+
+ parts = line.decode('utf-8').strip().split('\t')
+ if len(parts) != 5:
+ print('Skipping line %d: %d columns (expected 5)' % (
+ lineno, len(parts)), file=sys.stderr)
+
+ continue
+
+ mod, head, label, raw_path, source = parts
+ if mod == last_mod and head == last_head and label == last_label:
+ raw_paths.update([raw_path])
+ continue
+
+ if last_mod and last_head and last_label and raw_paths:
+ if (last_mod, last_head, last_label) in seen_labeled_pairs:
+ print('It looks like the input data is not sorted; ignoring extra '
+ 'record for (%s::%s, %s) at line %d' % (
+ last_mod, last_head, last_label, lineno))
+ else:
+ ex = create_example(last_mod, last_head, last_label, raw_paths)
+ if ex:
+ split = assignments.get((last_mod, last_head), default_split)
+ writers[split].write(ex.SerializeToString())
+
+ seen_labeled_pairs.add((last_mod, last_head, last_label))
+
+ last_mod, last_head, last_label = mod, head, label
+ raw_paths = collections.Counter()
+
+ if last_mod and last_head and last_label and raw_paths:
+ ex = create_example(last_mod, last_head, last_label, raw_paths)
+ if ex:
+ split = assignments.get((last_mod, last_head), default_split)
+ writers[split].write(ex.SerializeToString())
+
+ for writer in writers.itervalues():
+ writer.close()
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/lexnet_nc/text_embeddings_to_binary.py b/models/research/lexnet_nc/text_embeddings_to_binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..8226a7654e6da733ba1e8c46810a8ec8afd7a2c0
--- /dev/null
+++ b/models/research/lexnet_nc/text_embeddings_to_binary.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python
+# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts a text embedding file into a binary format for quicker loading."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+tf.flags.DEFINE_string('input', '', 'text file containing embeddings')
+tf.flags.DEFINE_string('output_vocab', '', 'output file for vocabulary')
+tf.flags.DEFINE_string('output_npy', '', 'output file for binary')
+FLAGS = tf.flags.FLAGS
+
+def main(_):
+ vecs = []
+ vocab = []
+ with tf.gfile.GFile(FLAGS.input) as fh:
+ for line in fh:
+ parts = line.strip().split()
+ vocab.append(parts[0])
+ vecs.append([float(x) for x in parts[1:]])
+
+ with tf.gfile.GFile(FLAGS.output_vocab, 'w') as fh:
+ fh.write('\n'.join(vocab))
+ fh.write('\n')
+
+ vecs = np.array(vecs, dtype=np.float32)
+ np.save(FLAGS.output_npy, vecs, allow_pickle=False)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/lfads/README.md b/models/research/lfads/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c75b656e4746894c42251e29a530271bb6484e4f
--- /dev/null
+++ b/models/research/lfads/README.md
@@ -0,0 +1,224 @@
+
+
+
+# LFADS - Latent Factor Analysis via Dynamical Systems
+
+This code implements the model from the paper "[LFADS - Latent Factor Analysis via Dynamical Systems](http://biorxiv.org/content/early/2017/06/20/152884)". It is a sequential variational auto-encoder designed specifically for investigating neuroscience data, but can be applied widely to any time series data. In an unsupervised setting, LFADS is able to decompose time series data into various factors, such as an initial condition, a generative dynamical system, control inputs to that generator, and a low dimensional description of the observed data, called the factors. Additionally, the observation model is a loss on a probability distribution, so when LFADS processes a dataset, a denoised version of the dataset is also created. For example, if the dataset is raw spike counts, then under the negative log-likelihood loss under a Poisson distribution, the denoised data would be the inferred Poisson rates.
+
+
+## Prerequisites
+
+The code is written in Python 2.7.6. You will also need:
+
+* **TensorFlow** version 1.5 ([install](https://www.tensorflow.org/install/)) -
+* **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them)
+* **h5py** ([install](https://pypi.python.org/pypi/h5py))
+
+
+## Getting started
+
+Before starting, run the following:
+
+
+
+where "path/to/your/directory" is replaced with the path to the LFADS repository (you can get this path by using the `pwd` command). This allows the nested directories to access modules from their parent directory.
+
+## Generate synthetic data
+
+In order to generate the synthetic datasets first, from the top-level lfads directory, run:
+
+```sh
+$ cd synth_data
+$ ./run_generate_synth_data.sh
+$ cd ..
+```
+
+These synthetic datasets are provided 1. to gain insight into how the LFADS algorithm operates, and 2. to give reasonable starting points for analyses you might be interested for your own data.
+
+## Train an LFADS model
+
+Now that we have our example datasets, we can train some models! To spin up an LFADS model on the synthetic data, run any of the following commands. For the examples that are in the paper, the important hyperparameters are roughly replicated. Most hyperparameters are insensitive to small changes or won't ever be changed unless you want a very fine level of control. In the first example, all hyperparameter flags are enumerated for easy copy-pasting, but for the rest of the examples only the most important flags (~the first 9) are specified for brevity. For a full list of flags, their descriptions, and their default values, refer to the top of `run_lfads.py`. Please see Table 1 in the Online Methods of the associated paper for definitions of the most important hyperparameters.
+
+```sh
+# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with spiking noise
+$ python run_lfads.py --kind=train \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=chaotic_rnn_no_inputs \
+--lfads_save_dir=/tmp/lfads_chaotic_rnn_no_inputs \
+--co_dim=0 \
+--factors_dim=20 \
+--ext_input_dim=0 \
+--controller_input_lag=1 \
+--output_dist=poisson \
+--do_causal_controller=false \
+--batch_size=128 \
+--learning_rate_init=0.01 \
+--learning_rate_stop=1e-05 \
+--learning_rate_decay_factor=0.95 \
+--learning_rate_n_to_compare=6 \
+--do_reset_learning_rate=false \
+--keep_prob=0.95 \
+--con_dim=128 \
+--gen_dim=200 \
+--ci_enc_dim=128 \
+--ic_dim=64 \
+--ic_enc_dim=128 \
+--ic_prior_var_min=0.1 \
+--gen_cell_input_weight_scale=1.0 \
+--cell_weight_scale=1.0 \
+--do_feed_factors_to_controller=true \
+--kl_start_step=0 \
+--kl_increase_steps=2000 \
+--kl_ic_weight=1.0 \
+--l2_con_scale=0.0 \
+--l2_gen_scale=2000.0 \
+--l2_start_step=0 \
+--l2_increase_steps=2000 \
+--ic_prior_var_scale=0.1 \
+--ic_post_var_min=0.0001 \
+--kl_co_weight=1.0 \
+--prior_ar_nvar=0.1 \
+--cell_clip_value=5.0 \
+--max_ckpt_to_keep_lve=5 \
+--do_train_prior_ar_atau=true \
+--co_prior_var_scale=0.1 \
+--csv_log=fitlog \
+--feedback_factors_or_rates=factors \
+--do_train_prior_ar_nvar=true \
+--max_grad_norm=200.0 \
+--device=gpu:0 \
+--num_steps_for_gen_ic=100000000 \
+--ps_nexamples_to_process=100000000 \
+--checkpoint_name=lfads_vae \
+--temporal_spike_jitter_width=0 \
+--checkpoint_pb_load_name=checkpoint \
+--inject_ext_input_to_gen=false \
+--co_mean_corr_scale=0.0 \
+--gen_cell_rec_weight_scale=1.0 \
+--max_ckpt_to_keep=5 \
+--output_filename_stem="" \
+--ic_prior_var_max=0.1 \
+--prior_ar_atau=10.0 \
+--do_train_io_only=false \
+--do_train_encoder_only=false
+
+# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
+$ python run_lfads.py --kind=train \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=gaussian_chaotic_rnn_no_inputs \
+--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
+--co_dim=1 \
+--factors_dim=20 \
+--output_dist=gaussian
+
+
+# Run LFADS on chaotic rnn data with input pulses (g = 2.5)
+$ python run_lfads.py --kind=train \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=chaotic_rnn_inputs_g2p5 \
+--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
+--co_dim=1 \
+--factors_dim=20 \
+--output_dist=poisson
+
+# Run LFADS on multi-session RNN data
+$ python run_lfads.py --kind=train \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=chaotic_rnn_multisession \
+--lfads_save_dir=/tmp/lfads_chaotic_rnn_multisession \
+--factors_dim=10 \
+--output_dist=poisson
+
+# Run LFADS on integration to bound model data
+$ python run_lfads.py --kind=train \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=itb_rnn \
+--lfads_save_dir=/tmp/lfads_itb_rnn \
+--co_dim=1 \
+--factors_dim=20 \
+--controller_input_lag=0 \
+--output_dist=poisson
+
+# Run LFADS on chaotic RNN data with labels
+$ python run_lfads.py --kind=train \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=chaotic_rnns_labeled \
+--lfads_save_dir=/tmp/lfads_chaotic_rnns_labeled \
+--co_dim=0 \
+--factors_dim=20 \
+--controller_input_lag=0 \
+--ext_input_dim=1 \
+--output_dist=poisson
+
+# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
+$ python run_lfads.py --kind=train \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=chaotic_rnn_no_inputs \
+--lfads_save_dir=/tmp/lfads_chaotic_rnn_no_inputs \
+--co_dim=0 \
+--factors_dim=20 \
+--ext_input_dim=0 \
+--controller_input_lag=1 \
+--output_dist=gaussian \
+
+
+```
+
+**Tip**: If you are running LFADS on GPU and would like to run more than one model concurrently, set the `--allow_gpu_growth=True` flag on each job, otherwise one model will take up the entire GPU for performance purposes. Also, one needs to install the TensorFlow libraries with GPU support.
+
+
+## Visualize a training model
+
+To visualize training curves and various other metrics while training and LFADS model, run the following command on your model directory. To launch a tensorboard on the chaotic RNN data with input pulses, for example:
+
+```sh
+tensorboard --logdir=/tmp/lfads_chaotic_rnn_inputs_g2p5
+```
+
+## Evaluate a trained model
+
+Once your model is finished training, there are multiple ways you can evaluate
+it. Below are some sample commands to evaluate an LFADS model trained on the
+chaotic rnn data with input pulses (g = 2.5). The key differences here are
+setting the `--kind` flag to the appropriate mode, as well as the
+`--checkpoint_pb_load_name` flag to `checkpoint_lve` and the `--batch_size` flag
+(if you'd like to make it larger or smaller). All other flags should be the
+same as used in training, so that the same model architecture is built.
+
+```sh
+# Take samples from posterior then average (denoising operation)
+$ python run_lfads.py --kind=posterior_sample_and_average \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=chaotic_rnn_inputs_g2p5 \
+--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
+--co_dim=1 \
+--factors_dim=20 \
+--batch_size=1024 \
+--checkpoint_pb_load_name=checkpoint_lve
+
+# Sample from prior (generation of completely new samples)
+$ python run_lfads.py --kind=prior_sample \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=chaotic_rnn_inputs_g2p5 \
+--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
+--co_dim=1 \
+--factors_dim=20 \
+--batch_size=50 \
+--checkpoint_pb_load_name=checkpoint_lve
+
+# Write down model parameters
+$ python run_lfads.py --kind=write_model_params \
+--data_dir=/tmp/rnn_synth_data_v1.0/ \
+--data_filename_stem=chaotic_rnn_inputs_g2p5 \
+--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
+--co_dim=1 \
+--factors_dim=20 \
+--checkpoint_pb_load_name=checkpoint_lve
+```
+
+## Contact
+
+File any issues with the [issue tracker](https://github.com/tensorflow/models/issues). For any questions or problems, this code is maintained by [@sussillo](https://github.com/sussillo) and [@jazcollins](https://github.com/jazcollins).
+
diff --git a/models/research/lfads/distributions.py b/models/research/lfads/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..351d019af2b16117eb329b6ef1812aa006834b62
--- /dev/null
+++ b/models/research/lfads/distributions.py
@@ -0,0 +1,493 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+import numpy as np
+import tensorflow as tf
+from utils import linear, log_sum_exp
+
+class Poisson(object):
+ """Poisson distributon
+
+ Computes the log probability under the model.
+
+ """
+ def __init__(self, log_rates):
+ """ Create Poisson distributions with log_rates parameters.
+
+ Args:
+ log_rates: a tensor-like list of log rates underlying the Poisson dist.
+ """
+ self.logr = log_rates
+
+ def logp(self, bin_counts):
+ """Compute the log probability for the counts in the bin, under the model.
+
+ Args:
+ bin_counts: array-like integer counts
+
+ Returns:
+ The log-probability under the Poisson models for each element of
+ bin_counts.
+ """
+ k = tf.to_float(bin_counts)
+ # log poisson(k, r) = log(r^k * e^(-r) / k!) = k log(r) - r - log k!
+ # log poisson(k, r=exp(x)) = k * x - exp(x) - lgamma(k + 1)
+ return k * self.logr - tf.exp(self.logr) - tf.lgamma(k + 1)
+
+
+def diag_gaussian_log_likelihood(z, mu=0.0, logvar=0.0):
+ """Log-likelihood under a Gaussian distribution with diagonal covariance.
+ Returns the log-likelihood for each dimension. One should sum the
+ results for the log-likelihood under the full multidimensional model.
+
+ Args:
+ z: The value to compute the log-likelihood.
+ mu: The mean of the Gaussian
+ logvar: The log variance of the Gaussian.
+
+ Returns:
+ The log-likelihood under the Gaussian model.
+ """
+
+ return -0.5 * (logvar + np.log(2*np.pi) + \
+ tf.square((z-mu)/tf.exp(0.5*logvar)))
+
+
+def gaussian_pos_log_likelihood(unused_mean, logvar, noise):
+ """Gaussian log-likelihood function for a posterior in VAE
+
+ Note: This function is specialized for a posterior distribution, that has the
+ form of z = mean + sigma * noise.
+
+ Args:
+ unused_mean: ignore
+ logvar: The log variance of the distribution
+ noise: The noise used in the sampling of the posterior.
+
+ Returns:
+ The log-likelihood under the Gaussian model.
+ """
+ # ln N(z; mean, sigma) = - ln(sigma) - 0.5 ln 2pi - noise^2 / 2
+ return - 0.5 * (logvar + np.log(2 * np.pi) + tf.square(noise))
+
+
+class Gaussian(object):
+ """Base class for Gaussian distribution classes."""
+ pass
+
+
+class DiagonalGaussian(Gaussian):
+ """Diagonal Gaussian with different constant mean and variances in each
+ dimension.
+ """
+
+ def __init__(self, batch_size, z_size, mean, logvar):
+ """Create a diagonal gaussian distribution.
+
+ Args:
+ batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
+ z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
+ mean: The N-D mean of the distribution.
+ logvar: The N-D log variance of the diagonal distribution.
+ """
+ size__xz = [None, z_size]
+ self.mean = mean # bxn already
+ self.logvar = logvar # bxn already
+ self.noise = noise = tf.random_normal(tf.shape(logvar))
+ self.sample = mean + tf.exp(0.5 * logvar) * noise
+ mean.set_shape(size__xz)
+ logvar.set_shape(size__xz)
+ self.sample.set_shape(size__xz)
+
+ def logp(self, z=None):
+ """Compute the log-likelihood under the distribution.
+
+ Args:
+ z (optional): value to compute likelihood for, if None, use sample.
+
+ Returns:
+ The likelihood of z under the model.
+ """
+ if z is None:
+ z = self.sample
+
+ # This is needed to make sure that the gradients are simple.
+ # The value of the function shouldn't change.
+ if z == self.sample:
+ return gaussian_pos_log_likelihood(self.mean, self.logvar, self.noise)
+
+ return diag_gaussian_log_likelihood(z, self.mean, self.logvar)
+
+
+class LearnableDiagonalGaussian(Gaussian):
+ """Diagonal Gaussian whose mean and variance are learned parameters."""
+
+ def __init__(self, batch_size, z_size, name, mean_init=0.0,
+ var_init=1.0, var_min=0.0, var_max=1000000.0):
+ """Create a learnable diagonal gaussian distribution.
+
+ Args:
+ batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
+ z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
+ name: prefix name for the mean and log TF variables.
+ mean_init (optional): The N-D mean initialization of the distribution.
+ var_init (optional): The N-D variance initialization of the diagonal
+ distribution.
+ var_min (optional): The minimum value the learned variance can take in any
+ dimension.
+ var_max (optional): The maximum value the learned variance can take in any
+ dimension.
+ """
+
+ size_1xn = [1, z_size]
+ size__xn = [None, z_size]
+ size_bx1 = tf.stack([batch_size, 1])
+ assert var_init > 0.0, "Problems"
+ assert var_max >= var_min, "Problems"
+ assert var_init >= var_min, "Problems"
+ assert var_max >= var_init, "Problems"
+
+
+ z_mean_1xn = tf.get_variable(name=name+"/mean", shape=size_1xn,
+ initializer=tf.constant_initializer(mean_init))
+ self.mean_bxn = mean_bxn = tf.tile(z_mean_1xn, size_bx1)
+ mean_bxn.set_shape(size__xn) # tile loses shape
+
+ log_var_init = np.log(var_init)
+ if var_max > var_min:
+ var_is_trainable = True
+ else:
+ var_is_trainable = False
+
+ z_logvar_1xn = \
+ tf.get_variable(name=(name+"/logvar"), shape=size_1xn,
+ initializer=tf.constant_initializer(log_var_init),
+ trainable=var_is_trainable)
+
+ if var_is_trainable:
+ z_logit_var_1xn = tf.exp(z_logvar_1xn)
+ z_var_1xn = tf.nn.sigmoid(z_logit_var_1xn)*(var_max-var_min) + var_min
+ z_logvar_1xn = tf.log(z_var_1xn)
+
+ logvar_bxn = tf.tile(z_logvar_1xn, size_bx1)
+ self.logvar_bxn = logvar_bxn
+ self.noise_bxn = noise_bxn = tf.random_normal(tf.shape(logvar_bxn))
+ self.sample_bxn = mean_bxn + tf.exp(0.5 * logvar_bxn) * noise_bxn
+
+ def logp(self, z=None):
+ """Compute the log-likelihood under the distribution.
+
+ Args:
+ z (optional): value to compute likelihood for, if None, use sample.
+
+ Returns:
+ The likelihood of z under the model.
+ """
+ if z is None:
+ z = self.sample
+
+ # This is needed to make sure that the gradients are simple.
+ # The value of the function shouldn't change.
+ if z == self.sample_bxn:
+ return gaussian_pos_log_likelihood(self.mean_bxn, self.logvar_bxn,
+ self.noise_bxn)
+
+ return diag_gaussian_log_likelihood(z, self.mean_bxn, self.logvar_bxn)
+
+ @property
+ def mean(self):
+ return self.mean_bxn
+
+ @property
+ def logvar(self):
+ return self.logvar_bxn
+
+ @property
+ def sample(self):
+ return self.sample_bxn
+
+
+class DiagonalGaussianFromInput(Gaussian):
+ """Diagonal Gaussian whose mean and variance are conditioned on other
+ variables.
+
+ Note: the parameters to convert from input to the learned mean and log
+ variance are held in this class.
+ """
+
+ def __init__(self, x_bxu, z_size, name, var_min=0.0):
+ """Create an input dependent diagonal Gaussian distribution.
+
+ Args:
+ x: The input tensor from which the mean and variance are computed,
+ via a linear transformation of x. I.e.
+ mu = Wx + b, log(var) = Mx + c
+ z_size: The size of the distribution.
+ name: The name to prefix to learned variables.
+ var_min (optional): Minimal variance allowed. This is an additional
+ way to control the amount of information getting through the stochastic
+ layer.
+ """
+ size_bxn = tf.stack([tf.shape(x_bxu)[0], z_size])
+ self.mean_bxn = mean_bxn = linear(x_bxu, z_size, name=(name+"/mean"))
+ logvar_bxn = linear(x_bxu, z_size, name=(name+"/logvar"))
+ if var_min > 0.0:
+ logvar_bxn = tf.log(tf.exp(logvar_bxn) + var_min)
+ self.logvar_bxn = logvar_bxn
+
+ self.noise_bxn = noise_bxn = tf.random_normal(size_bxn)
+ self.noise_bxn.set_shape([None, z_size])
+ self.sample_bxn = mean_bxn + tf.exp(0.5 * logvar_bxn) * noise_bxn
+
+ def logp(self, z=None):
+ """Compute the log-likelihood under the distribution.
+
+ Args:
+ z (optional): value to compute likelihood for, if None, use sample.
+
+ Returns:
+ The likelihood of z under the model.
+ """
+
+ if z is None:
+ z = self.sample
+
+ # This is needed to make sure that the gradients are simple.
+ # The value of the function shouldn't change.
+ if z == self.sample_bxn:
+ return gaussian_pos_log_likelihood(self.mean_bxn,
+ self.logvar_bxn, self.noise_bxn)
+
+ return diag_gaussian_log_likelihood(z, self.mean_bxn, self.logvar_bxn)
+
+ @property
+ def mean(self):
+ return self.mean_bxn
+
+ @property
+ def logvar(self):
+ return self.logvar_bxn
+
+ @property
+ def sample(self):
+ return self.sample_bxn
+
+
+class GaussianProcess:
+ """Base class for Gaussian processes."""
+ pass
+
+
+class LearnableAutoRegressive1Prior(GaussianProcess):
+ """AR(1) model where autocorrelation and process variance are learned
+ parameters. Assumed zero mean.
+
+ """
+
+ def __init__(self, batch_size, z_size,
+ autocorrelation_taus, noise_variances,
+ do_train_prior_ar_atau, do_train_prior_ar_nvar,
+ num_steps, name):
+ """Create a learnable autoregressive (1) process.
+
+ Args:
+ batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
+ z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
+ autocorrelation_taus: The auto correlation time constant of the AR(1)
+ process.
+ A value of 0 is uncorrelated gaussian noise.
+ noise_variances: The variance of the additive noise, *not* the process
+ variance.
+ do_train_prior_ar_atau: Train or leave as constant, the autocorrelation?
+ do_train_prior_ar_nvar: Train or leave as constant, the noise variance?
+ num_steps: Number of steps to run the process.
+ name: The name to prefix to learned TF variables.
+ """
+
+ # Note the use of the plural in all of these quantities. This is intended
+ # to mark that even though a sample z_t from the posterior is thought of a
+ # single sample of a multidimensional gaussian, the prior is actually
+ # thought of as U AR(1) processes, where U is the dimension of the inferred
+ # input.
+ size_bx1 = tf.stack([batch_size, 1])
+ size__xu = [None, z_size]
+ # process variance, the variance at time t over all instantiations of AR(1)
+ # with these parameters.
+ log_evar_inits_1xu = tf.expand_dims(tf.log(noise_variances), 0)
+ self.logevars_1xu = logevars_1xu = \
+ tf.Variable(log_evar_inits_1xu, name=name+"/logevars", dtype=tf.float32,
+ trainable=do_train_prior_ar_nvar)
+ self.logevars_bxu = logevars_bxu = tf.tile(logevars_1xu, size_bx1)
+ logevars_bxu.set_shape(size__xu) # tile loses shape
+
+ # \tau, which is the autocorrelation time constant of the AR(1) process
+ log_atau_inits_1xu = tf.expand_dims(tf.log(autocorrelation_taus), 0)
+ self.logataus_1xu = logataus_1xu = \
+ tf.Variable(log_atau_inits_1xu, name=name+"/logatau", dtype=tf.float32,
+ trainable=do_train_prior_ar_atau)
+
+ # phi in x_t = \mu + phi x_tm1 + \eps
+ # phi = exp(-1/tau)
+ # phi = exp(-1/exp(logtau))
+ # phi = exp(-exp(-logtau))
+ phis_1xu = tf.exp(-tf.exp(-logataus_1xu))
+ self.phis_bxu = phis_bxu = tf.tile(phis_1xu, size_bx1)
+ phis_bxu.set_shape(size__xu)
+
+ # process noise
+ # pvar = evar / (1- phi^2)
+ # logpvar = log ( exp(logevar) / (1 - phi^2) )
+ # logpvar = logevar - log(1-phi^2)
+ # logpvar = logevar - (log(1-phi) + log(1+phi))
+ self.logpvars_1xu = \
+ logevars_1xu - tf.log(1.0-phis_1xu) - tf.log(1.0+phis_1xu)
+ self.logpvars_bxu = logpvars_bxu = tf.tile(self.logpvars_1xu, size_bx1)
+ logpvars_bxu.set_shape(size__xu)
+
+ # process mean (zero but included in for completeness)
+ self.pmeans_bxu = pmeans_bxu = tf.zeros_like(phis_bxu)
+
+ # For sampling from the prior during de-novo generation.
+ self.means_t = means_t = [None] * num_steps
+ self.logvars_t = logvars_t = [None] * num_steps
+ self.samples_t = samples_t = [None] * num_steps
+ self.gaussians_t = gaussians_t = [None] * num_steps
+ sample_bxu = tf.zeros_like(phis_bxu)
+ for t in range(num_steps):
+ # process variance used here to make process completely stationary
+ if t == 0:
+ logvar_pt_bxu = self.logpvars_bxu
+ else:
+ logvar_pt_bxu = self.logevars_bxu
+
+ z_mean_pt_bxu = pmeans_bxu + phis_bxu * sample_bxu
+ gaussians_t[t] = DiagonalGaussian(batch_size, z_size,
+ mean=z_mean_pt_bxu,
+ logvar=logvar_pt_bxu)
+ sample_bxu = gaussians_t[t].sample
+ samples_t[t] = sample_bxu
+ logvars_t[t] = logvar_pt_bxu
+ means_t[t] = z_mean_pt_bxu
+
+ def logp_t(self, z_t_bxu, z_tm1_bxu=None):
+ """Compute the log-likelihood under the distribution for a given time t,
+ not the whole sequence.
+
+ Args:
+ z_t_bxu: sample to compute likelihood for at time t.
+ z_tm1_bxu (optional): sample condition probability of z_t upon.
+
+ Returns:
+ The likelihood of p_t under the model at time t. i.e.
+ p(z_t|z_tm1_bxu) = N(z_tm1_bxu * phis, eps^2)
+
+ """
+ if z_tm1_bxu is None:
+ return diag_gaussian_log_likelihood(z_t_bxu, self.pmeans_bxu,
+ self.logpvars_bxu)
+ else:
+ means_t_bxu = self.pmeans_bxu + self.phis_bxu * z_tm1_bxu
+ logp_tgtm1_bxu = diag_gaussian_log_likelihood(z_t_bxu,
+ means_t_bxu,
+ self.logevars_bxu)
+ return logp_tgtm1_bxu
+
+
+class KLCost_GaussianGaussian(object):
+ """log p(x|z) + KL(q||p) terms for Gaussian posterior and Gaussian prior. See
+ eqn 10 and Appendix B in VAE for latter term,
+ http://arxiv.org/abs/1312.6114
+
+ The log p(x|z) term is the reconstruction error under the model.
+ The KL term represents the penalty for passing information from the encoder
+ to the decoder.
+ To sample KL(q||p), we simply sample
+ ln q - ln p
+ by drawing samples from q and averaging.
+ """
+
+ def __init__(self, zs, prior_zs):
+ """Create a lower bound in three parts, normalized reconstruction
+ cost, normalized KL divergence cost, and their sum.
+
+ E_q[ln p(z_i | z_{i+1}) / q(z_i | x)
+ \int q(z) ln p(z) dz = - 0.5 ln(2pi) - 0.5 \sum (ln(sigma_p^2) + \
+ sigma_q^2 / sigma_p^2 + (mean_p - mean_q)^2 / sigma_p^2)
+
+ \int q(z) ln q(z) dz = - 0.5 ln(2pi) - 0.5 \sum (ln(sigma_q^2) + 1)
+
+ Args:
+ zs: posterior z ~ q(z|x)
+ prior_zs: prior zs
+ """
+ # L = -KL + log p(x|z), to maximize bound on likelihood
+ # -L = KL - log p(x|z), to minimize bound on NLL
+ # so 'KL cost' is postive KL divergence
+ kl_b = 0.0
+ for z, prior_z in zip(zs, prior_zs):
+ assert isinstance(z, Gaussian)
+ assert isinstance(prior_z, Gaussian)
+ # ln(2pi) terms cancel
+ kl_b += 0.5 * tf.reduce_sum(
+ prior_z.logvar - z.logvar
+ + tf.exp(z.logvar - prior_z.logvar)
+ + tf.square((z.mean - prior_z.mean) / tf.exp(0.5 * prior_z.logvar))
+ - 1.0, [1])
+
+ self.kl_cost_b = kl_b
+ self.kl_cost = tf.reduce_mean(kl_b)
+
+
+class KLCost_GaussianGaussianProcessSampled(object):
+ """ log p(x|z) + KL(q||p) terms for Gaussian posterior and Gaussian process
+ prior via sampling.
+
+ The log p(x|z) term is the reconstruction error under the model.
+ The KL term represents the penalty for passing information from the encoder
+ to the decoder.
+ To sample KL(q||p), we simply sample
+ ln q - ln p
+ by drawing samples from q and averaging.
+ """
+
+ def __init__(self, post_zs, prior_z_process):
+ """Create a lower bound in three parts, normalized reconstruction
+ cost, normalized KL divergence cost, and their sum.
+
+ Args:
+ post_zs: posterior z ~ q(z|x)
+ prior_z_process: prior AR(1) process
+ """
+ assert len(post_zs) > 1, "GP is for time, need more than 1 time step."
+ assert isinstance(prior_z_process, GaussianProcess), "Must use GP."
+
+ # L = -KL + log p(x|z), to maximize bound on likelihood
+ # -L = KL - log p(x|z), to minimize bound on NLL
+ # so 'KL cost' is postive KL divergence
+ z0_bxu = post_zs[0].sample
+ logq_bxu = post_zs[0].logp(z0_bxu)
+ logp_bxu = prior_z_process.logp_t(z0_bxu)
+ z_tm1_bxu = z0_bxu
+ for z_t in post_zs[1:]:
+ # posterior is independent in time, prior is not
+ z_t_bxu = z_t.sample
+ logq_bxu += z_t.logp(z_t_bxu)
+ logp_bxu += prior_z_process.logp_t(z_t_bxu, z_tm1_bxu)
+ z_tm1_bxu = z_t_bxu
+
+ kl_bxu = logq_bxu - logp_bxu
+ kl_b = tf.reduce_sum(kl_bxu, [1])
+ self.kl_cost_b = kl_b
+ self.kl_cost = tf.reduce_mean(kl_b)
diff --git a/models/research/lfads/lfads.py b/models/research/lfads/lfads.py
new file mode 100644
index 0000000000000000000000000000000000000000..308ebabe90fbbb90701ac0585e7c1eaeaf6e3649
--- /dev/null
+++ b/models/research/lfads/lfads.py
@@ -0,0 +1,2170 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+"""
+LFADS - Latent Factor Analysis via Dynamical Systems.
+
+LFADS is an unsupervised method to decompose time series data into
+various factors, such as an initial condition, a generative
+dynamical system, control inputs to that generator, and a low
+dimensional description of the observed data, called the factors.
+Additionally, the observations have a noise model (in this case
+Poisson), so a denoised version of the observations is also created
+(e.g. underlying rates of a Poisson distribution given the observed
+event counts).
+
+The main data structure being passed around is a dataset. This is a dictionary
+of data dictionaries.
+
+DATASET: The top level dictionary is simply name (string -> dictionary).
+The nested dictionary is the DATA DICTIONARY, which has the following keys:
+ 'train_data' and 'valid_data', whose values are the corresponding training
+ and validation data with shape
+ ExTxD, E - # examples, T - # time steps, D - # dimensions in data.
+ The data dictionary also has a few more keys:
+ 'train_ext_input' and 'valid_ext_input', if there are know external inputs
+ to the system being modeled, these take on dimensions:
+ ExTxI, E - # examples, T - # time steps, I = # dimensions in input.
+ 'alignment_matrix_cxf' - If you are using multiple days data, it's possible
+ that one can align the channels (see manuscript). If so each dataset will
+ contain this matrix, which will be used for both the input adapter and the
+ output adapter for each dataset. These matrices, if provided, must be of
+ size [data_dim x factors] where data_dim is the number of neurons recorded
+ on that day, and factors is chosen and set through the '--factors' flag.
+ 'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to
+ the offset for the alignment transformation. It will *subtract* off the
+ bias from the data, so pca style inits can align factors across sessions.
+
+
+ If one runs LFADS on data where the true rates are known for some trials,
+ (say simulated, testing data, as in the example shipped with the paper), then
+ one can add three more fields for plotting purposes. These are 'train_truth'
+ and 'valid_truth', and 'conversion_factor'. These have the same dimensions as
+ 'train_data', and 'valid_data' but represent the underlying rates of the
+ observations. Finally, if one needs to convert scale for plotting the true
+ underlying firing rates, there is the 'conversion_factor' key.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import numpy as np
+import os
+import tensorflow as tf
+from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput
+from distributions import diag_gaussian_log_likelihood
+from distributions import KLCost_GaussianGaussian, Poisson
+from distributions import LearnableAutoRegressive1Prior
+from distributions import KLCost_GaussianGaussianProcessSampled
+
+from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data
+from utils import log_sum_exp, flatten
+from plot_lfads import plot_lfads
+
+
+class GRU(object):
+ """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
+
+ """
+ def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0,
+ clip_value=np.inf, collections=None):
+ """Create a GRU object.
+
+ Args:
+ num_units: Number of units in the GRU
+ forget_bias (optional): Hack to help learning.
+ weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with
+ ws being the weight scale.
+ clip_value (optional): if the recurrent values grow above this value,
+ clip them.
+ collections (optional): List of additonal collections variables should
+ belong to.
+ """
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+ self._weight_scale = weight_scale
+ self._clip_value = clip_value
+ self._collections = collections
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ @property
+ def state_multiplier(self):
+ return 1
+
+ def output_from_state(self, state):
+ """Return the output portion of the state."""
+ return state
+
+ def __call__(self, inputs, state, scope=None):
+ """Gated recurrent unit (GRU) function.
+
+ Args:
+ inputs: A 2D batch x input_dim tensor of inputs.
+ state: The previous state from the last time step.
+ scope (optional): TF variable scope for defined GRU variables.
+
+ Returns:
+ A tuple (state, state), where state is the newly computed state at time t.
+ It is returned twice to respect an interface that works for LSTMs.
+ """
+
+ x = inputs
+ h = state
+ if inputs is not None:
+ xh = tf.concat(axis=1, values=[x, h])
+ else:
+ xh = h
+
+ with tf.variable_scope(scope or type(self).__name__): # "GRU"
+ with tf.variable_scope("Gates"): # Reset gate and update gate.
+ # We start with bias of 1.0 to not reset and not update.
+ r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh,
+ 2 * self._num_units,
+ alpha=self._weight_scale,
+ name="xh_2_ru",
+ collections=self._collections))
+ r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
+ with tf.variable_scope("Candidate"):
+ xrh = tf.concat(axis=1, values=[x, r * h])
+ c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c",
+ collections=self._collections))
+ new_h = u * h + (1 - u) * c
+ new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
+
+ return new_h, new_h
+
+
+class GenGRU(object):
+ """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
+
+ This version is specialized for the generator, but isn't as fast, so
+ we have two. Note this allows for l2 regularization on the recurrent
+ weights, but also implicitly rescales the inputs via the 1/sqrt(input)
+ scaling in the linear helper routine to be large magnitude, if there are
+ fewer inputs than recurrent state.
+
+ """
+ def __init__(self, num_units, forget_bias=1.0,
+ input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf,
+ input_collections=None, recurrent_collections=None):
+ """Create a GRU object.
+
+ Args:
+ num_units: Number of units in the GRU
+ forget_bias (optional): Hack to help learning.
+ input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with
+ ws being the weight scale.
+ rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs),
+ with ws being the weight scale.
+ clip_value (optional): if the recurrent values grow above this value,
+ clip them.
+ input_collections (optional): List of additonal collections variables
+ that input->rec weights should belong to.
+ recurrent_collections (optional): List of additonal collections variables
+ that rec->rec weights should belong to.
+ """
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+ self._input_weight_scale = input_weight_scale
+ self._rec_weight_scale = rec_weight_scale
+ self._clip_value = clip_value
+ self._input_collections = input_collections
+ self._rec_collections = recurrent_collections
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ @property
+ def state_multiplier(self):
+ return 1
+
+ def output_from_state(self, state):
+ """Return the output portion of the state."""
+ return state
+
+ def __call__(self, inputs, state, scope=None):
+ """Gated recurrent unit (GRU) function.
+
+ Args:
+ inputs: A 2D batch x input_dim tensor of inputs.
+ state: The previous state from the last time step.
+ scope (optional): TF variable scope for defined GRU variables.
+
+ Returns:
+ A tuple (state, state), where state is the newly computed state at time t.
+ It is returned twice to respect an interface that works for LSTMs.
+ """
+
+ x = inputs
+ h = state
+ with tf.variable_scope(scope or type(self).__name__): # "GRU"
+ with tf.variable_scope("Gates"): # Reset gate and update gate.
+ # We start with bias of 1.0 to not reset and not update.
+ r_x = u_x = 0.0
+ if x is not None:
+ r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x,
+ 2 * self._num_units,
+ alpha=self._input_weight_scale,
+ do_bias=False,
+ name="x_2_ru",
+ normalized=False,
+ collections=self._input_collections))
+
+ r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h,
+ 2 * self._num_units,
+ do_bias=True,
+ alpha=self._rec_weight_scale,
+ name="h_2_ru",
+ collections=self._rec_collections))
+ r = r_x + r_h
+ u = u_x + u_h
+ r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
+
+ with tf.variable_scope("Candidate"):
+ c_x = 0.0
+ if x is not None:
+ c_x = linear(x, self._num_units, name="x_2_c", do_bias=False,
+ alpha=self._input_weight_scale,
+ normalized=False,
+ collections=self._input_collections)
+ c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True,
+ alpha=self._rec_weight_scale,
+ collections=self._rec_collections)
+ c = tf.tanh(c_x + c_rh)
+
+ new_h = u * h + (1 - u) * c
+ new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
+
+ return new_h, new_h
+
+
+class LFADS(object):
+ """LFADS - Latent Factor Analysis via Dynamical Systems.
+
+ LFADS is an unsupervised method to decompose time series data into
+ various factors, such as an initial condition, a generative
+ dynamical system, inferred inputs to that generator, and a low
+ dimensional description of the observed data, called the factors.
+ Additoinally, the observations have a noise model (in this case
+ Poisson), so a denoised version of the observations is also created
+ (e.g. underlying rates of a Poisson distribution given the observed
+ event counts).
+ """
+
+ def __init__(self, hps, kind="train", datasets=None):
+ """Create an LFADS model.
+
+ train - a model for training, sampling of posteriors is used
+ posterior_sample_and_average - sample from the posterior, this is used
+ for evaluating the expected value of the outputs of LFADS, given a
+ specific input, by averaging over multiple samples from the approx
+ posterior. Also used for the lower bound on the negative
+ log-likelihood using IWAE error (Importance Weighed Auto-encoder).
+ This is the denoising operation.
+ prior_sample - a model for generation - sampling from priors is used
+
+ Args:
+ hps: The dictionary of hyper parameters.
+ kind: the type of model to build (see above).
+ datasets: a dictionary of named data_dictionaries, see top of lfads.py
+ """
+ print("Building graph...")
+ all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean',
+ 'prior_sample']
+ assert kind in all_kinds, 'Wrong kind'
+ if hps.feedback_factors_or_rates == "rates":
+ assert len(hps.dataset_names) == 1, \
+ "Multiple datasets not supported for rate feedback."
+ num_steps = hps.num_steps
+ ic_dim = hps.ic_dim
+ co_dim = hps.co_dim
+ ext_input_dim = hps.ext_input_dim
+ cell_class = GRU
+ gen_cell_class = GenGRU
+
+ def makelambda(v): # Used with tf.case
+ return lambda: v
+
+ # Define the data placeholder, and deal with all parts of the graph
+ # that are dataset dependent.
+ self.dataName = tf.placeholder(tf.string, shape=())
+ # The batch_size to be inferred from data, as normal.
+ # Additionally, the data_dim will be inferred as well, allowing for a
+ # single placeholder for all datasets, regardless of data dimension.
+ if hps.output_dist == 'poisson':
+ # Enforce correct dtype
+ assert np.issubdtype(
+ datasets[hps.dataset_names[0]]['train_data'].dtype, int), \
+ "Data dtype must be int for poisson output distribution"
+ data_dtype = tf.int32
+ elif hps.output_dist == 'gaussian':
+ assert np.issubdtype(
+ datasets[hps.dataset_names[0]]['train_data'].dtype, float), \
+ "Data dtype must be float for gaussian output dsitribution"
+ data_dtype = tf.float32
+ else:
+ assert False, "NIY"
+ self.dataset_ph = dataset_ph = tf.placeholder(data_dtype,
+ [None, num_steps, None],
+ name="data")
+ self.train_step = tf.get_variable("global_step", [], tf.int64,
+ tf.zeros_initializer(),
+ trainable=False)
+ self.hps = hps
+ ndatasets = hps.ndatasets
+ factors_dim = hps.factors_dim
+ self.preds = preds = [None] * ndatasets
+ self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets
+ self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets
+ self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets
+ self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets
+ self.datasetNames = dataset_names = hps.dataset_names
+ self.ext_inputs = ext_inputs = None
+
+ if len(dataset_names) == 1: # single session
+ if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys():
+ used_in_factors_dim = factors_dim
+ in_identity_if_poss = False
+ else:
+ used_in_factors_dim = hps.dataset_dims[dataset_names[0]]
+ in_identity_if_poss = True
+ else: # multisession
+ used_in_factors_dim = factors_dim
+ in_identity_if_poss = False
+
+ for d, name in enumerate(dataset_names):
+ data_dim = hps.dataset_dims[name]
+ in_mat_cxf = None
+ in_bias_1xf = None
+ align_bias_1xc = None
+
+ if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
+ dataset = datasets[name]
+ if hps.do_train_readin:
+ print("Initializing trainable readin matrix with alignment matrix" \
+ " provided for dataset:", name)
+ else:
+ print("Setting non-trainable readin matrix to alignment matrix" \
+ " provided for dataset:", name)
+ in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
+ if in_mat_cxf.shape != (data_dim, factors_dim):
+ raise ValueError("""Alignment matrix must have dimensions %d x %d
+ (data_dim x factors_dim), but currently has %d x %d."""%
+ (data_dim, factors_dim, in_mat_cxf.shape[0],
+ in_mat_cxf.shape[1]))
+ if datasets and 'alignment_bias_c' in datasets[name].keys():
+ dataset = datasets[name]
+ if hps.do_train_readin:
+ print("Initializing trainable readin bias with alignment bias " \
+ "provided for dataset:", name)
+ else:
+ print("Setting non-trainable readin bias to alignment bias " \
+ "provided for dataset:", name)
+ align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
+ align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
+ if align_bias_1xc.shape[1] != data_dim:
+ raise ValueError("""Alignment bias must have dimensions %d
+ (data_dim), but currently has %d."""%
+ (data_dim, in_mat_cxf.shape[0]))
+ if in_mat_cxf is not None and align_bias_1xc is not None:
+ # (data - alignment_bias) * W_in
+ # data * W_in - alignment_bias * W_in
+ # So b = -alignment_bias * W_in to accommodate PCA style offset.
+ in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
+
+ if hps.do_train_readin:
+ # only add to IO transformations collection only if we want it to be
+ # learnable, because IO_transformations collection will be trained
+ # when do_train_io_only
+ collections_readin=['IO_transformations']
+ else:
+ collections_readin=None
+
+ in_fac_lin = init_linear(data_dim, used_in_factors_dim,
+ do_bias=True,
+ mat_init_value=in_mat_cxf,
+ bias_init_value=in_bias_1xf,
+ identity_if_possible=in_identity_if_poss,
+ normalized=False, name="x_2_infac_"+name,
+ collections=collections_readin,
+ trainable=hps.do_train_readin)
+ in_fac_W, in_fac_b = in_fac_lin
+ fns_in_fac_Ws[d] = makelambda(in_fac_W)
+ fns_in_fac_bs[d] = makelambda(in_fac_b)
+
+ with tf.variable_scope("glm"):
+ out_identity_if_poss = False
+ if len(dataset_names) == 1 and \
+ factors_dim == hps.dataset_dims[dataset_names[0]]:
+ out_identity_if_poss = True
+ for d, name in enumerate(dataset_names):
+ data_dim = hps.dataset_dims[name]
+ in_mat_cxf = None
+ if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
+ dataset = datasets[name]
+ in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
+
+ if datasets and 'alignment_bias_c' in datasets[name].keys():
+ dataset = datasets[name]
+ align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
+ align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
+
+ out_mat_fxc = None
+ out_bias_1xc = None
+ if in_mat_cxf is not None:
+ out_mat_fxc = in_mat_cxf.T
+ if align_bias_1xc is not None:
+ out_bias_1xc = align_bias_1xc
+
+ if hps.output_dist == 'poisson':
+ out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True,
+ mat_init_value=out_mat_fxc,
+ bias_init_value=out_bias_1xc,
+ identity_if_possible=out_identity_if_poss,
+ normalized=False,
+ name="fac_2_logrates_"+name,
+ collections=['IO_transformations'])
+ out_fac_W, out_fac_b = out_fac_lin
+
+ elif hps.output_dist == 'gaussian':
+ out_fac_lin_mean = \
+ init_linear(factors_dim, data_dim, do_bias=True,
+ mat_init_value=out_mat_fxc,
+ bias_init_value=out_bias_1xc,
+ normalized=False,
+ name="fac_2_means_"+name,
+ collections=['IO_transformations'])
+ out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
+
+ mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32)
+ bias_init_value = np.ones([1, data_dim]).astype(np.float32)
+ out_fac_lin_logvar = \
+ init_linear(factors_dim, data_dim, do_bias=True,
+ mat_init_value=mat_init_value,
+ bias_init_value=bias_init_value,
+ normalized=False,
+ name="fac_2_logvars_"+name,
+ collections=['IO_transformations'])
+ out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
+ out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar
+ out_fac_W = tf.concat(
+ axis=1, values=[out_fac_W_mean, out_fac_W_logvar])
+ out_fac_b = tf.concat(
+ axis=1, values=[out_fac_b_mean, out_fac_b_logvar])
+ else:
+ assert False, "NIY"
+
+ preds[d] = tf.equal(tf.constant(name), self.dataName)
+ data_dim = hps.dataset_dims[name]
+ fns_out_fac_Ws[d] = makelambda(out_fac_W)
+ fns_out_fac_bs[d] = makelambda(out_fac_b)
+
+ pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws)
+ pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs)
+ pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
+ pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
+
+ this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True)
+ this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True)
+ this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True)
+ this_out_fac_b = tf.case(pf_pairs_out_fac_bs, exclusive=True)
+
+ # External inputs (not changing by dataset, by definition).
+ if hps.ext_input_dim > 0:
+ self.ext_input = tf.placeholder(tf.float32,
+ [None, num_steps, ext_input_dim],
+ name="ext_input")
+ else:
+ self.ext_input = None
+ ext_input_bxtxi = self.ext_input
+
+ self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob")
+ self.batch_size = batch_size = int(hps.batch_size)
+ self.learning_rate = tf.Variable(float(hps.learning_rate_init),
+ trainable=False, name="learning_rate")
+ self.learning_rate_decay_op = self.learning_rate.assign(
+ self.learning_rate * hps.learning_rate_decay_factor)
+
+ # Dropout the data.
+ dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob)
+ if hps.ext_input_dim > 0:
+ ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob)
+ else:
+ ext_input_do_bxtxi = None
+
+ # ENCODERS
+ def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse,
+ num_steps_to_encode):
+ """Encode data for LFADS
+ Args:
+ dataset_bxtxd - the data to encode, as a 3 tensor, with dims
+ time x batch x data dims.
+ enc_cell: encoder cell
+ name: name of encoder
+ forward_or_reverse: string, encode in forward or reverse direction
+ num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode
+ Returns:
+ encoded data as a list with num_steps_to_encode items, in order
+ """
+ if forward_or_reverse == "forward":
+ dstr = "_fwd"
+ time_fwd_or_rev = range(num_steps_to_encode)
+ else:
+ dstr = "_rev"
+ time_fwd_or_rev = reversed(range(num_steps_to_encode))
+
+ with tf.variable_scope(name+"_enc"+dstr, reuse=False):
+ enc_state = tf.tile(
+ tf.Variable(tf.zeros([1, enc_cell.state_size]),
+ name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1]))
+ enc_state.set_shape([None, enc_cell.state_size]) # tile loses shape
+
+ enc_outs = [None] * num_steps_to_encode
+ for i, t in enumerate(time_fwd_or_rev):
+ with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None):
+ dataset_t_bxd = dataset_bxtxd[:,t,:]
+ in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b
+ in_fac_t_bxf.set_shape([None, used_in_factors_dim])
+ if ext_input_dim > 0 and not hps.inject_ext_input_to_gen:
+ ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
+ enc_input_t_bxfpe = tf.concat(
+ axis=1, values=[in_fac_t_bxf, ext_input_t_bxi])
+ else:
+ enc_input_t_bxfpe = in_fac_t_bxf
+ enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state)
+ enc_outs[t] = enc_out
+
+ return enc_outs
+
+ # Encode initial condition means and variances
+ # ([x_T, x_T-1, ... x_0] and [x_0, x_1, ... x_T] -> g0/c0)
+ self.ic_enc_fwd = [None] * num_steps
+ self.ic_enc_rev = [None] * num_steps
+ if ic_dim > 0:
+ enc_ic_cell = cell_class(hps.ic_enc_dim,
+ weight_scale=hps.cell_weight_scale,
+ clip_value=hps.cell_clip_value)
+ ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell,
+ "ic", "forward",
+ hps.num_steps_for_gen_ic)
+ ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell,
+ "ic", "reverse",
+ hps.num_steps_for_gen_ic)
+ self.ic_enc_fwd = ic_enc_fwd
+ self.ic_enc_rev = ic_enc_rev
+
+ # Encoder control input means and variances, bi-directional encoding so:
+ # ([x_T, x_T-1, ..., x_0] and [x_0, x_1 ... x_T] -> u_t)
+ self.ci_enc_fwd = [None] * num_steps
+ self.ci_enc_rev = [None] * num_steps
+ if co_dim > 0:
+ enc_ci_cell = cell_class(hps.ci_enc_dim,
+ weight_scale=hps.cell_weight_scale,
+ clip_value=hps.cell_clip_value)
+ ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell,
+ "ci", "forward",
+ hps.num_steps)
+ if hps.do_causal_controller:
+ ci_enc_rev = None
+ else:
+ ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell,
+ "ci", "reverse",
+ hps.num_steps)
+ self.ci_enc_fwd = ci_enc_fwd
+ self.ci_enc_rev = ci_enc_rev
+
+ # STOCHASTIC LATENT VARIABLES, priors and posteriors
+ # (initial conditions g0, and control inputs, u_t)
+ # Note that zs represent all the stochastic latent variables.
+ with tf.variable_scope("z", reuse=False):
+ self.prior_zs_g0 = None
+ self.posterior_zs_g0 = None
+ self.g0s_val = None
+ if ic_dim > 0:
+ self.prior_zs_g0 = \
+ LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0",
+ mean_init=0.0,
+ var_min=hps.ic_prior_var_min,
+ var_init=hps.ic_prior_var_scale,
+ var_max=hps.ic_prior_var_max)
+ ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]])
+ ic_enc = tf.nn.dropout(ic_enc, keep_prob)
+ self.posterior_zs_g0 = \
+ DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0",
+ var_min=hps.ic_post_var_min)
+ if kind in ["train", "posterior_sample_and_average",
+ "posterior_push_mean"]:
+ zs_g0 = self.posterior_zs_g0
+ else:
+ zs_g0 = self.prior_zs_g0
+ if kind in ["train", "posterior_sample_and_average", "prior_sample"]:
+ self.g0s_val = zs_g0.sample
+ else:
+ self.g0s_val = zs_g0.mean
+
+ # Priors for controller, 'co' for controller output
+ self.prior_zs_co = prior_zs_co = [None] * num_steps
+ self.posterior_zs_co = posterior_zs_co = [None] * num_steps
+ self.zs_co = zs_co = [None] * num_steps
+ self.prior_zs_ar_con = None
+ if co_dim > 0:
+ # Controller outputs
+ autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)]
+ noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)]
+ self.prior_zs_ar_con = prior_zs_ar_con = \
+ LearnableAutoRegressive1Prior(batch_size, hps.co_dim,
+ autocorrelation_taus,
+ noise_variances,
+ hps.do_train_prior_ar_atau,
+ hps.do_train_prior_ar_nvar,
+ num_steps, "u_prior_ar1")
+
+ # CONTROLLER -> GENERATOR -> RATES
+ # (u(t) -> gen(t) -> factors(t) -> rates(t) -> p(x_t|z_t) )
+ self.controller_outputs = u_t = [None] * num_steps
+ self.con_ics = con_state = None
+ self.con_states = con_states = [None] * num_steps
+ self.con_outs = con_outs = [None] * num_steps
+ self.gen_inputs = gen_inputs = [None] * num_steps
+ if co_dim > 0:
+ # gen_cell_class here for l2 penalty recurrent weights
+ # didn't split the cell_weight scale here, because I doubt it matters
+ con_cell = gen_cell_class(hps.con_dim,
+ input_weight_scale=hps.cell_weight_scale,
+ rec_weight_scale=hps.cell_weight_scale,
+ clip_value=hps.cell_clip_value,
+ recurrent_collections=['l2_con_reg'])
+ with tf.variable_scope("con", reuse=False):
+ self.con_ics = tf.tile(
+ tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]),
+ name="c0"),
+ tf.stack([batch_size, 1]))
+ self.con_ics.set_shape([None, con_cell.state_size]) # tile loses shape
+ con_states[-1] = self.con_ics
+
+ gen_cell = gen_cell_class(hps.gen_dim,
+ input_weight_scale=hps.gen_cell_input_weight_scale,
+ rec_weight_scale=hps.gen_cell_rec_weight_scale,
+ clip_value=hps.cell_clip_value,
+ recurrent_collections=['l2_gen_reg'])
+ with tf.variable_scope("gen", reuse=False):
+ if ic_dim == 0:
+ self.gen_ics = tf.tile(
+ tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"),
+ tf.stack([batch_size, 1]))
+ else:
+ self.gen_ics = linear(self.g0s_val, gen_cell.state_size,
+ identity_if_possible=True,
+ name="g0_2_gen_ic")
+
+ self.gen_states = gen_states = [None] * num_steps
+ self.gen_outs = gen_outs = [None] * num_steps
+ gen_states[-1] = self.gen_ics
+ gen_outs[-1] = gen_cell.output_from_state(gen_states[-1])
+ self.factors = factors = [None] * num_steps
+ factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False,
+ normalized=True, name="gen_2_fac")
+
+ self.rates = rates = [None] * num_steps
+ # rates[-1] is collected to potentially feed back to controller
+ with tf.variable_scope("glm", reuse=False):
+ if hps.output_dist == 'poisson':
+ log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b
+ log_rates_t0.set_shape([None, None])
+ rates[-1] = tf.exp(log_rates_t0) # rate
+ rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
+ elif hps.output_dist == 'gaussian':
+ mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b
+ mean_n_logvars.set_shape([None, None])
+ means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
+ value=mean_n_logvars)
+ rates[-1] = means_t_bxd
+ else:
+ assert False, "NIY"
+
+ # We support multiple output distributions, for example Poisson, and also
+ # Gaussian. In these two cases respectively, there are one and two
+ # parameters (rates vs. mean and variance). So the output_dist_params
+ # tensor will variable sizes via tf.concat and tf.split, along the 1st
+ # dimension. So in the case of gaussian, for example, it'll be
+ # batch x (D+D), where each D dims is the mean, and then variances,
+ # respectively. For a distribution with 3 parameters, it would be
+ # batch x (D+D+D).
+ self.output_dist_params = dist_params = [None] * num_steps
+ self.log_p_xgz_b = log_p_xgz_b = 0.0 # log P(x|z)
+ for t in range(num_steps):
+ # Controller
+ if co_dim > 0:
+ # Build inputs for controller
+ tlag = t - hps.controller_input_lag
+ if tlag < 0:
+ con_in_f_t = tf.zeros_like(ci_enc_fwd[0])
+ else:
+ con_in_f_t = ci_enc_fwd[tlag]
+ if hps.do_causal_controller:
+ # If controller is causal (wrt to data generation process), then it
+ # cannot see future data. Thus, excluding ci_enc_rev[t] is obvious.
+ # Less obvious is the need to exclude factors[t-1]. This arises
+ # because information flows from g0 through factors to the controller
+ # input. The g0 encoding is backwards, so we must necessarily exclude
+ # the factors in order to keep the controller input purely from a
+ # forward encoding (however unlikely it is that
+ # g0->factors->controller channel might actually be used in this way).
+ con_in_list_t = [con_in_f_t]
+ else:
+ tlag_rev = t + hps.controller_input_lag
+ if tlag_rev >= num_steps:
+ # better than zeros
+ con_in_r_t = tf.zeros_like(ci_enc_rev[0])
+ else:
+ con_in_r_t = ci_enc_rev[tlag_rev]
+ con_in_list_t = [con_in_f_t, con_in_r_t]
+
+ if hps.do_feed_factors_to_controller:
+ if hps.feedback_factors_or_rates == "factors":
+ con_in_list_t.append(factors[t-1])
+ elif hps.feedback_factors_or_rates == "rates":
+ con_in_list_t.append(rates[t-1])
+ else:
+ assert False, "NIY"
+
+ con_in_t = tf.concat(axis=1, values=con_in_list_t)
+ con_in_t = tf.nn.dropout(con_in_t, keep_prob)
+ with tf.variable_scope("con", reuse=True if t > 0 else None):
+ con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1])
+ posterior_zs_co[t] = \
+ DiagonalGaussianFromInput(con_outs[t], co_dim,
+ name="con_to_post_co")
+ if kind == "train":
+ u_t[t] = posterior_zs_co[t].sample
+ elif kind == "posterior_sample_and_average":
+ u_t[t] = posterior_zs_co[t].sample
+ elif kind == "posterior_push_mean":
+ u_t[t] = posterior_zs_co[t].mean
+ else:
+ u_t[t] = prior_zs_ar_con.samples_t[t]
+
+ # Inputs to the generator (controller output + external input)
+ if ext_input_dim > 0 and hps.inject_ext_input_to_gen:
+ ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
+ if co_dim > 0:
+ gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi])
+ else:
+ gen_inputs[t] = ext_input_t_bxi
+ else:
+ gen_inputs[t] = u_t[t]
+
+ # Generator
+ data_t_bxd = dataset_ph[:,t,:]
+ with tf.variable_scope("gen", reuse=True if t > 0 else None):
+ gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1])
+ gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob)
+ with tf.variable_scope("gen", reuse=True): # ic defined it above
+ factors[t] = linear(gen_outs[t], factors_dim, do_bias=False,
+ normalized=True, name="gen_2_fac")
+ with tf.variable_scope("glm", reuse=True if t > 0 else None):
+ if hps.output_dist == 'poisson':
+ log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b
+ log_rates_t.set_shape([None, None])
+ rates[t] = dist_params[t] = tf.exp(tf.clip_by_value(log_rates_t, -hps._clip_value, hps._clip_value)) # rates feed back
+ rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
+ loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd)
+
+ elif hps.output_dist == 'gaussian':
+ mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b
+ mean_n_logvars.set_shape([None, None])
+ means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
+ value=mean_n_logvars)
+ rates[t] = means_t_bxd # rates feed back to controller
+ dist_params[t] = tf.concat(
+ axis=1, values=[means_t_bxd, tf.exp(tf.clip_by_value(logvars_t_bxd, -hps._clip_value, hps._clip_value))])
+ loglikelihood_t = \
+ diag_gaussian_log_likelihood(data_t_bxd,
+ means_t_bxd, logvars_t_bxd)
+ else:
+ assert False, "NIY"
+
+ log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1])
+
+ # Correlation of inferred inputs cost.
+ self.corr_cost = tf.constant(0.0)
+ if hps.co_mean_corr_scale > 0.0:
+ all_sum_corr = []
+ for i in range(hps.co_dim):
+ for j in range(i+1, hps.co_dim):
+ sum_corr_ij = tf.constant(0.0)
+ for t in range(num_steps):
+ u_mean_t = posterior_zs_co[t].mean
+ sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j]
+ all_sum_corr.append(0.5 * tf.square(sum_corr_ij))
+ self.corr_cost = tf.reduce_mean(all_sum_corr) # div by batch and by n*(n-1)/2 pairs
+
+ # Variational Lower Bound on posterior, p(z|x), plus reconstruction cost.
+ # KL and reconstruction costs are normalized only by batch size, not by
+ # dimension, or by time steps.
+ kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32)
+ kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32)
+ self.kl_cost = tf.constant(0.0) # VAE KL cost
+ self.recon_cost = tf.constant(0.0) # VAE reconstruction cost
+ self.nll_bound_vae = tf.constant(0.0)
+ self.nll_bound_iwae = tf.constant(0.0) # for eval with IWAE cost.
+ if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]:
+ kl_cost_g0_b = 0.0
+ kl_cost_co_b = 0.0
+ if ic_dim > 0:
+ g0_priors = [self.prior_zs_g0]
+ g0_posts = [self.posterior_zs_g0]
+ kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b
+ kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b
+ if co_dim > 0:
+ kl_cost_co_b = \
+ KLCost_GaussianGaussianProcessSampled(
+ posterior_zs_co, prior_zs_ar_con).kl_cost_b
+ kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b
+
+ # L = -KL + log p(x|z), to maximize bound on likelihood
+ # -L = KL - log p(x|z), to minimize bound on NLL
+ # so 'reconstruction cost' is negative log likelihood
+ self.recon_cost = - tf.reduce_mean(log_p_xgz_b)
+ self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b)
+
+ lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b
+
+ # VAE error averages outside the log
+ self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b)
+
+ # IWAE error averages inside the log
+ k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32)
+ iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b)
+ self.nll_bound_iwae = -iwae_lb_on_ll
+
+ # L2 regularization on the generator, normalized by number of parameters.
+ self.l2_cost = tf.constant(0.0)
+ if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0:
+ l2_costs = []
+ l2_numels = []
+ l2_reg_var_lists = [tf.get_collection('l2_gen_reg'),
+ tf.get_collection('l2_con_reg')]
+ l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale]
+ for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales):
+ for v in l2_reg_vars:
+ numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v)))
+ numel_f = tf.cast(numel, tf.float32)
+ l2_numels.append(numel_f)
+ v_l2 = tf.reduce_sum(v*v)
+ l2_costs.append(0.5 * l2_scale * v_l2)
+ self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels)
+
+ # Compute the cost for training, part of the graph regardless.
+ # The KL cost can be problematic at the beginning of optimization,
+ # so we allow an exponential increase in weighting the KL from 0
+ # to 1.
+ self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0)
+ self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0)
+ kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32)
+ l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32)
+ kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32)
+ l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32)
+ self.kl_weight = kl_weight = \
+ tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0)
+ self.l2_weight = l2_weight = \
+ tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0)
+
+ self.timed_kl_cost = kl_weight * self.kl_cost
+ self.timed_l2_cost = l2_weight * self.l2_cost
+ self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost
+ self.cost = self.recon_cost + self.timed_kl_cost + \
+ self.timed_l2_cost + self.weight_corr_cost
+
+ if kind != "train":
+ # save every so often
+ self.seso_saver = tf.train.Saver(tf.global_variables(),
+ max_to_keep=hps.max_ckpt_to_keep)
+ # lowest validation error
+ self.lve_saver = tf.train.Saver(tf.global_variables(),
+ max_to_keep=hps.max_ckpt_to_keep_lve)
+
+ return
+
+ # OPTIMIZATION
+ # train the io matrices only
+ if self.hps.do_train_io_only:
+ self.train_vars = tvars = \
+ tf.get_collection('IO_transformations',
+ scope=tf.get_variable_scope().name)
+ # train the encoder only
+ elif self.hps.do_train_encoder_only:
+ tvars1 = \
+ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
+ scope='LFADS/ic_enc_*')
+ tvars2 = \
+ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
+ scope='LFADS/z/ic_enc_*')
+
+ self.train_vars = tvars = tvars1 + tvars2
+ # train all variables
+ else:
+ self.train_vars = tvars = \
+ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
+ scope=tf.get_variable_scope().name)
+ print("done.")
+ print("Model Variables (to be optimized): ")
+ total_params = 0
+ for i in range(len(tvars)):
+ shape = tvars[i].get_shape().as_list()
+ print(" ", i, tvars[i].name, shape)
+ total_params += np.prod(shape)
+ print("Total model parameters: ", total_params)
+
+ grads = tf.gradients(self.cost, tvars)
+ grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm)
+ opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999,
+ epsilon=1e-01)
+ self.grads = grads
+ self.grad_global_norm = grad_global_norm
+ self.train_op = opt.apply_gradients(
+ zip(grads, tvars), global_step=self.train_step)
+
+ self.seso_saver = tf.train.Saver(tf.global_variables(),
+ max_to_keep=hps.max_ckpt_to_keep)
+
+ # lowest validation error
+ self.lve_saver = tf.train.Saver(tf.global_variables(),
+ max_to_keep=hps.max_ckpt_to_keep)
+
+ # SUMMARIES, used only during training.
+ # example summary
+ self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3],
+ name='image_tensor')
+ self.example_summ = tf.summary.image("LFADS example", self.example_image,
+ collections=["example_summaries"])
+
+ # general training summaries
+ self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate)
+ self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight)
+ self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight)
+ self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost)
+ self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm",
+ self.grad_global_norm)
+ if hps.co_dim > 0:
+ self.atau_summ = [None] * hps.co_dim
+ self.pvar_summ = [None] * hps.co_dim
+ for c in range(hps.co_dim):
+ self.atau_summ[c] = \
+ tf.summary.scalar("AR Autocorrelation taus " + str(c),
+ tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c]))
+ self.pvar_summ[c] = \
+ tf.summary.scalar("AR Variances " + str(c),
+ tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c]))
+
+ # cost summaries, separated into different collections for
+ # training vs validation. We make placeholders for these, because
+ # even though the graph computes these costs on a per-batch basis,
+ # we want to report the more reliable metric of per-epoch cost.
+ kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph')
+ self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph,
+ collections=["train_summaries"])
+ self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph,
+ collections=["valid_summaries"])
+ l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph')
+ self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph,
+ collections=["train_summaries"])
+
+ recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph')
+ self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)",
+ recon_cost_ph,
+ collections=["train_summaries"])
+ self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)",
+ recon_cost_ph,
+ collections=["valid_summaries"])
+
+ total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph')
+ self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph,
+ collections=["train_summaries"])
+ self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph,
+ collections=["valid_summaries"])
+
+ self.kl_cost_ph = kl_cost_ph
+ self.l2_cost_ph = l2_cost_ph
+ self.recon_cost_ph = recon_cost_ph
+ self.total_cost_ph = total_cost_ph
+
+ # Merged summaries, for easy coding later.
+ self.merged_examples = tf.summary.merge_all(key="example_summaries")
+ self.merged_generic = tf.summary.merge_all() # default key is 'summaries'
+ self.merged_train = tf.summary.merge_all(key="train_summaries")
+ self.merged_valid = tf.summary.merge_all(key="valid_summaries")
+
+ session = tf.get_default_session()
+ self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log")
+ self.writer = tf.summary.FileWriter(self.logfile)
+
+ def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None,
+ keep_prob=None):
+ """Build the feed dictionary, handles cases where there is no value defined.
+
+ Args:
+ train_name: The key into the datasets, to set the tf.case statement for
+ the proper readin / readout matrices.
+ data_bxtxd: The data tensor
+ ext_input_bxtxi (optional): The external input tensor
+ keep_prob: The drop out keep probability.
+
+ Returns:
+ The feed dictionary with TF tensors as keys and data as values, for use
+ with tf.Session.run()
+
+ """
+ feed_dict = {}
+ B, T, _ = data_bxtxd.shape
+ feed_dict[self.dataName] = train_name
+ feed_dict[self.dataset_ph] = data_bxtxd
+
+ if self.ext_input is not None and ext_input_bxtxi is not None:
+ feed_dict[self.ext_input] = ext_input_bxtxi
+
+ if keep_prob is None:
+ feed_dict[self.keep_prob] = self.hps.keep_prob
+ else:
+ feed_dict[self.keep_prob] = keep_prob
+
+ return feed_dict
+
+ @staticmethod
+ def get_batch(data_extxd, ext_input_extxi=None, batch_size=None,
+ example_idxs=None):
+ """Get a batch of data, either randomly chosen, or specified directly.
+
+ Args:
+ data_extxd: The data to model, numpy tensors with shape:
+ # examples x # time steps x # dimensions
+ ext_input_extxi (optional): The external inputs, numpy tensor with shape:
+ # examples x # time steps x # external input dimensions
+ batch_size: The size of the batch to return
+ example_idxs (optional): The example indices used to select examples.
+
+ Returns:
+ A tuple with two parts:
+ 1. Batched data numpy tensor with shape:
+ batch_size x # time steps x # dimensions
+ 2. Batched external input numpy tensor with shape:
+ batch_size x # time steps x # external input dims
+ """
+ assert batch_size is not None or example_idxs is not None, "Problems"
+ E, T, D = data_extxd.shape
+ if example_idxs is None:
+ example_idxs = np.random.choice(E, batch_size)
+
+ ext_input_bxtxi = None
+ if ext_input_extxi is not None:
+ ext_input_bxtxi = ext_input_extxi[example_idxs,:,:]
+
+ return data_extxd[example_idxs,:,:], ext_input_bxtxi
+
+ @staticmethod
+ def example_idxs_mod_batch_size(nexamples, batch_size):
+ """Given a number of examples, E, and a batch_size, B, generate indices
+ [0, 1, 2, ... B-1;
+ [B, B+1, ... 2*B-1;
+ ...
+ ]
+ returning those indices as a 2-dim tensor shaped like E/B x B. Note that
+ shape is only correct if E % B == 0. If not, then an extra row is generated
+ so that the remainder of examples is included. The extra examples are
+ explicitly to to the zero index (see randomize_example_idxs_mod_batch_size)
+ for randomized behavior.
+
+ Args:
+ nexamples: The number of examples to batch up.
+ batch_size: The size of the batch.
+ Returns:
+ 2-dim tensor as described above.
+ """
+ bmrem = batch_size - (nexamples % batch_size)
+ bmrem_examples = []
+ if bmrem < batch_size:
+ #bmrem_examples = np.zeros(bmrem, dtype=np.int32)
+ ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32)
+ bmrem_examples = np.sort(ridxs)
+ example_idxs = range(nexamples) + list(bmrem_examples)
+ example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size])
+ return example_idxs_e_x_edivb, bmrem
+
+ @staticmethod
+ def randomize_example_idxs_mod_batch_size(nexamples, batch_size):
+ """Indices 1:nexamples, randomized, in 2D form of
+ shape = (nexamples / batch_size) x batch_size. The remainder
+ is managed by drawing randomly from 1:nexamples.
+
+ Args:
+ nexamples: number of examples to randomize
+ batch_size: number of elements in batch
+
+ Returns:
+ The randomized, properly shaped indicies.
+ """
+ assert nexamples > batch_size, "Problems"
+ bmrem = batch_size - nexamples % batch_size
+ bmrem_examples = []
+ if bmrem < batch_size:
+ bmrem_examples = np.random.choice(range(nexamples),
+ size=bmrem, replace=False)
+ example_idxs = range(nexamples) + list(bmrem_examples)
+ mixed_example_idxs = np.random.permutation(example_idxs)
+ example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size])
+ return example_idxs_e_x_edivb, bmrem
+
+ def shuffle_spikes_in_time(self, data_bxtxd):
+ """Shuffle the spikes in the temporal dimension. This is useful to
+ help the LFADS system avoid overfitting to individual spikes or fast
+ oscillations found in the data that are irrelevant to behavior. A
+ pure 'tabula rasa' approach would avoid this, but LFADS is sensitive
+ enough to pick up dynamics that you may not want.
+
+ Args:
+ data_bxtxd: numpy array of spike count data to be shuffled.
+ Returns:
+ S_bxtxd, a numpy array with the same dimensions and contents as
+ data_bxtxd, but shuffled appropriately.
+
+ """
+
+ B, T, N = data_bxtxd.shape
+ w = self.hps.temporal_spike_jitter_width
+
+ if w == 0:
+ return data_bxtxd
+
+ max_counts = np.max(data_bxtxd)
+ S_bxtxd = np.zeros([B,T,N])
+
+ # Intuitively, shuffle spike occurances, 0 or 1, but since we have counts,
+ # Do it over and over again up to the max count.
+ for mc in range(1,max_counts+1):
+ idxs = np.nonzero(data_bxtxd >= mc)
+
+ data_ones = np.zeros_like(data_bxtxd)
+ data_ones[data_bxtxd >= mc] = 1
+
+ nfound = len(idxs[0])
+ shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound)
+
+ shuffle_tidxs = idxs[1].copy()
+ shuffle_tidxs += shuffles_incrs_in_time
+
+ # Reflect on the boundaries to not lose mass.
+ shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0]
+ shuffle_tidxs[shuffle_tidxs > T-1] = \
+ (T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1))
+
+ for iii in zip(idxs[0], shuffle_tidxs, idxs[2]):
+ S_bxtxd[iii] += 1
+
+ return S_bxtxd
+
+ def shuffle_and_flatten_datasets(self, datasets, kind='train'):
+ """Since LFADS supports multiple datasets in the same dynamical model,
+ we have to be careful to use all the data in a single training epoch. But
+ since the datasets my have different data dimensionality, we cannot batch
+ examples from data dictionaries together. Instead, we generate random
+ batches within each data dictionary, and then randomize these batches
+ while holding onto the dataname, so that when it's time to feed
+ the graph, the correct in/out matrices can be selected, per batch.
+
+ Args:
+ datasets: A dict of data dicts. The dataset dict is simply a
+ name(string)-> data dictionary mapping (See top of lfads.py).
+ kind: 'train' or 'valid'
+
+ Returns:
+ A flat list, in which each element is a pair ('name', indices).
+ """
+ batch_size = self.hps.batch_size
+ ndatasets = len(datasets)
+ random_example_idxs = {}
+ epoch_idxs = {}
+ all_name_example_idx_pairs = []
+ kind_data = kind + '_data'
+ for name, data_dict in datasets.items():
+ nexamples, ntime, data_dim = data_dict[kind_data].shape
+ epoch_idxs[name] = 0
+ random_example_idxs, _ = \
+ self.randomize_example_idxs_mod_batch_size(nexamples, batch_size)
+
+ epoch_size = random_example_idxs.shape[0]
+ names = [name] * epoch_size
+ all_name_example_idx_pairs += zip(names, random_example_idxs)
+
+ np.random.shuffle(all_name_example_idx_pairs) # shuffle in place
+
+ return all_name_example_idx_pairs
+
+ def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True):
+ """Train the model through the entire dataset once.
+
+ Args:
+ datasets: A dict of data dicts. The dataset dict is simply a
+ name(string)-> data dictionary mapping (See top of lfads.py).
+ batch_size (optional): The batch_size to use
+ do_save_ckpt (optional): Should the routine save a checkpoint on this
+ training epoch?
+
+ Returns:
+ A tuple with 6 float values:
+ (total cost of the epoch, epoch reconstruction cost,
+ epoch kl cost, KL weight used this training epoch,
+ total l2 cost on generator, and the corresponding weight).
+ """
+ ops_to_eval = [self.cost, self.recon_cost,
+ self.kl_cost, self.kl_weight,
+ self.l2_cost, self.l2_weight,
+ self.train_op]
+ collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train")
+
+ total_cost = total_recon_cost = total_kl_cost = 0.0
+ # normalizing by batch done in distributions.py
+ epoch_size = len(collected_op_values)
+ for op_values in collected_op_values:
+ total_cost += op_values[0]
+ total_recon_cost += op_values[1]
+ total_kl_cost += op_values[2]
+
+ kl_weight = collected_op_values[-1][3]
+ l2_cost = collected_op_values[-1][4]
+ l2_weight = collected_op_values[-1][5]
+
+ epoch_total_cost = total_cost / epoch_size
+ epoch_recon_cost = total_recon_cost / epoch_size
+ epoch_kl_cost = total_kl_cost / epoch_size
+
+ if do_save_ckpt:
+ session = tf.get_default_session()
+ checkpoint_path = os.path.join(self.hps.lfads_save_dir,
+ self.hps.checkpoint_name + '.ckpt')
+ self.seso_saver.save(session, checkpoint_path,
+ global_step=self.train_step)
+
+ return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \
+ kl_weight, l2_cost, l2_weight
+
+
+ def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None,
+ do_collect=True, keep_prob=None):
+ """Run the model through the entire dataset once.
+
+ Args:
+ datasets: A dict of data dicts. The dataset dict is simply a
+ name(string)-> data dictionary mapping (See top of lfads.py).
+ ops_to_eval: A list of tensorflow operations that will be evaluated in
+ the tf.session.run() call.
+ batch_size (optional): The batch_size to use
+ do_collect (optional): Should the routine collect all session.run
+ output as a list, and return it?
+ keep_prob (optional): The dropout keep probability.
+
+ Returns:
+ A list of lists, the internal list is the return for the ops for each
+ session.run() call. The outer list collects over the epoch.
+ """
+ hps = self.hps
+ all_name_example_idx_pairs = \
+ self.shuffle_and_flatten_datasets(datasets, kind)
+
+ kind_data = kind + '_data'
+ kind_ext_input = kind + '_ext_input'
+
+ total_cost = total_recon_cost = total_kl_cost = 0.0
+ session = tf.get_default_session()
+ epoch_size = len(all_name_example_idx_pairs)
+ evaled_ops_list = []
+ for name, example_idxs in all_name_example_idx_pairs:
+ data_dict = datasets[name]
+ data_extxd = data_dict[kind_data]
+ if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0:
+ data_extxd = self.shuffle_spikes_in_time(data_extxd)
+
+ ext_input_extxi = data_dict[kind_ext_input]
+ data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi,
+ example_idxs=example_idxs)
+
+ feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi,
+ keep_prob=keep_prob)
+ evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict)
+ if do_collect:
+ evaled_ops_list.append(evaled_ops_np)
+
+ return evaled_ops_list
+
+ def summarize_all(self, datasets, summary_values):
+ """Plot and summarize stuff in tensorboard.
+
+ Note that everything done in the current function is otherwise done on
+ a single, randomly selected dataset (except for summary_values, which are
+ passed in.)
+
+ Args:
+ datasets, the dictionary of datasets used in the study.
+ summary_values: These summary values are created from the training loop,
+ and so summarize the entire set of datasets.
+ """
+ hps = self.hps
+ tr_kl_cost = summary_values['tr_kl_cost']
+ tr_recon_cost = summary_values['tr_recon_cost']
+ tr_total_cost = summary_values['tr_total_cost']
+ kl_weight = summary_values['kl_weight']
+ l2_weight = summary_values['l2_weight']
+ l2_cost = summary_values['l2_cost']
+ has_any_valid_set = summary_values['has_any_valid_set']
+ i = summary_values['nepochs']
+
+ session = tf.get_default_session()
+ train_summ, train_step = session.run([self.merged_train,
+ self.train_step],
+ feed_dict={self.l2_cost_ph:l2_cost,
+ self.kl_cost_ph:tr_kl_cost,
+ self.recon_cost_ph:tr_recon_cost,
+ self.total_cost_ph:tr_total_cost})
+ self.writer.add_summary(train_summ, train_step)
+ if has_any_valid_set:
+ ev_kl_cost = summary_values['ev_kl_cost']
+ ev_recon_cost = summary_values['ev_recon_cost']
+ ev_total_cost = summary_values['ev_total_cost']
+ eval_summ = session.run(self.merged_valid,
+ feed_dict={self.kl_cost_ph:ev_kl_cost,
+ self.recon_cost_ph:ev_recon_cost,
+ self.total_cost_ph:ev_total_cost})
+ self.writer.add_summary(eval_summ, train_step)
+ print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\
+ recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\
+ kl weight: %.2f, l2 weight: %.2f" % \
+ (i, train_step, tr_total_cost, ev_total_cost,
+ tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
+ l2_cost, kl_weight, l2_weight))
+
+ csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \
+ recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \
+ klweight,%.2f, l2weight,%.2f\n"% \
+ (i, train_step, tr_total_cost, ev_total_cost,
+ tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
+ l2_cost, kl_weight, l2_weight)
+
+ else:
+ print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\
+ l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \
+ (i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost,
+ l2_cost, kl_weight, l2_weight))
+ csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \
+ l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \
+ (i, train_step, tr_total_cost, tr_recon_cost,
+ tr_kl_cost, l2_cost, kl_weight, l2_weight)
+
+ if self.hps.csv_log:
+ csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv')
+ with open(csv_file, "a") as myfile:
+ myfile.write(csv_outstr)
+
+
+ def plot_single_example(self, datasets):
+ """Plot an image relating to a randomly chosen, specific example. We use
+ posterior sample and average by taking one example, and filling a whole
+ batch with that example, sample from the posterior, and then average the
+ quantities.
+
+ """
+ hps = self.hps
+ all_data_names = datasets.keys()
+ data_name = np.random.permutation(all_data_names)[0]
+ data_dict = datasets[data_name]
+ has_valid_set = True if data_dict['valid_data'] is not None else False
+ cf = 1.0 # plotting concern
+
+ # posterior sample and average here
+ E, _, _ = data_dict['train_data'].shape
+ eidx = np.random.choice(E)
+ example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
+
+ train_data_bxtxd, train_ext_input_bxtxi = \
+ self.get_batch(data_dict['train_data'], data_dict['train_ext_input'],
+ example_idxs=example_idxs)
+
+ truth_train_data_bxtxd = None
+ if 'train_truth' in data_dict and data_dict['train_truth'] is not None:
+ truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'],
+ example_idxs=example_idxs)
+ cf = data_dict['conversion_factor']
+
+ # plotter does averaging
+ train_model_values = self.eval_model_runs_batch(data_name,
+ train_data_bxtxd,
+ train_ext_input_bxtxi,
+ do_average_batch=False)
+
+ train_step = train_model_values['train_steps']
+ feed_dict = self.build_feed_dict(data_name, train_data_bxtxd,
+ train_ext_input_bxtxi, keep_prob=1.0)
+
+ session = tf.get_default_session()
+ generic_summ = session.run(self.merged_generic, feed_dict=feed_dict)
+ self.writer.add_summary(generic_summ, train_step)
+
+ valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None
+ truth_valid_data_bxtxd = None
+ if has_valid_set:
+ E, _, _ = data_dict['valid_data'].shape
+ eidx = np.random.choice(E)
+ example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
+ valid_data_bxtxd, valid_ext_input_bxtxi = \
+ self.get_batch(data_dict['valid_data'],
+ data_dict['valid_ext_input'],
+ example_idxs=example_idxs)
+ if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None:
+ truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'],
+ example_idxs=example_idxs)
+ else:
+ truth_valid_data_bxtxd = None
+
+ # plotter does averaging
+ valid_model_values = self.eval_model_runs_batch(data_name,
+ valid_data_bxtxd,
+ valid_ext_input_bxtxi,
+ do_average_batch=False)
+
+ example_image = plot_lfads(train_bxtxd=train_data_bxtxd,
+ train_model_vals=train_model_values,
+ train_ext_input_bxtxi=train_ext_input_bxtxi,
+ train_truth_bxtxd=truth_train_data_bxtxd,
+ valid_bxtxd=valid_data_bxtxd,
+ valid_model_vals=valid_model_values,
+ valid_ext_input_bxtxi=valid_ext_input_bxtxi,
+ valid_truth_bxtxd=truth_valid_data_bxtxd,
+ bidx=None, cf=cf, output_dist=hps.output_dist)
+ example_image = np.expand_dims(example_image, axis=0)
+ example_summ = session.run(self.merged_examples,
+ feed_dict={self.example_image : example_image})
+ self.writer.add_summary(example_summ)
+
+ def train_model(self, datasets):
+ """Train the model, print per-epoch information, and save checkpoints.
+
+ Loop over training epochs. The function that actually does the
+ training is train_epoch. This function iterates over the training
+ data, one epoch at a time. The learning rate schedule is such
+ that it will stay the same until the cost goes up in comparison to
+ the last few values, then it will drop.
+
+ Args:
+ datasets: A dict of data dicts. The dataset dict is simply a
+ name(string)-> data dictionary mapping (See top of lfads.py).
+ """
+ hps = self.hps
+ has_any_valid_set = False
+ for data_dict in datasets.values():
+ if data_dict['valid_data'] is not None:
+ has_any_valid_set = True
+ break
+
+ session = tf.get_default_session()
+ lr = session.run(self.learning_rate)
+ lr_stop = hps.learning_rate_stop
+ i = -1
+ train_costs = []
+ valid_costs = []
+ ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0
+ lowest_ev_cost = np.Inf
+ while True:
+ i += 1
+ do_save_ckpt = True if i % 10 ==0 else False
+ tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \
+ self.train_epoch(datasets, do_save_ckpt=do_save_ckpt)
+
+ # Evaluate the validation cost, and potentially save. Note that this
+ # routine will not save a validation checkpoint until the kl weight and
+ # l2 weights are equal to 1.0.
+ if has_any_valid_set:
+ ev_total_cost, ev_recon_cost, ev_kl_cost = \
+ self.eval_cost_epoch(datasets, kind='valid')
+ valid_costs.append(ev_total_cost)
+
+ # > 1 may give more consistent results, but not the actual lowest vae.
+ # == 1 gives the lowest vae seen so far.
+ n_lve = 1
+ run_avg_lve = np.mean(valid_costs[-n_lve:])
+
+ # conditions for saving checkpoints:
+ # KL weight must have finished stepping (>=1.0), AND
+ # L2 weight must have finished stepping OR L2 is not being used, AND
+ # the current run has a lower LVE than previous runs AND
+ # len(valid_costs > n_lve) (not sure what that does)
+ if kl_weight >= 1.0 and \
+ (l2_weight >= 1.0 or \
+ (self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \
+ and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost):
+
+ lowest_ev_cost = run_avg_lve
+ checkpoint_path = os.path.join(self.hps.lfads_save_dir,
+ self.hps.checkpoint_name + '_lve.ckpt')
+ self.lve_saver.save(session, checkpoint_path,
+ global_step=self.train_step,
+ latest_filename='checkpoint_lve')
+
+ # Plot and summarize.
+ values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set,
+ 'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost,
+ 'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost,
+ 'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost,
+ 'l2_weight':l2_weight, 'kl_weight':kl_weight,
+ 'l2_cost':l2_cost}
+ self.summarize_all(datasets, values)
+ self.plot_single_example(datasets)
+
+ # Manage learning rate.
+ train_res = tr_total_cost
+ n_lr = hps.learning_rate_n_to_compare
+ if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]):
+ _ = session.run(self.learning_rate_decay_op)
+ lr = session.run(self.learning_rate)
+ print(" Decreasing learning rate to %f." % lr)
+ # Force the system to run n_lr times while at this lr.
+ train_costs.append(np.inf)
+ else:
+ train_costs.append(train_res)
+
+ if lr < lr_stop:
+ print("Stopping optimization based on learning rate criteria.")
+ break
+
+ def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None,
+ batch_size=None):
+ """Evaluate the cost of the epoch.
+
+ Args:
+ data_dict: The dictionary of data (training and validation) used for
+ training and evaluation of the model, respectively.
+
+ Returns:
+ a 3 tuple of costs:
+ (epoch total cost, epoch reconstruction cost, epoch KL cost)
+ """
+ ops_to_eval = [self.cost, self.recon_cost, self.kl_cost]
+ collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind,
+ keep_prob=1.0)
+
+ total_cost = total_recon_cost = total_kl_cost = 0.0
+ # normalizing by batch done in distributions.py
+ epoch_size = len(collected_op_values)
+ for op_values in collected_op_values:
+ total_cost += op_values[0]
+ total_recon_cost += op_values[1]
+ total_kl_cost += op_values[2]
+
+ epoch_total_cost = total_cost / epoch_size
+ epoch_recon_cost = total_recon_cost / epoch_size
+ epoch_kl_cost = total_kl_cost / epoch_size
+
+ return epoch_total_cost, epoch_recon_cost, epoch_kl_cost
+
+ def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None,
+ do_eval_cost=False, do_average_batch=False):
+ """Returns all the goodies for the entire model, per batch.
+
+ If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1
+ in which case this handles the padding and truncating automatically
+
+ Args:
+ data_name: The name of the data dict, to select which in/out matrices
+ to use.
+ data_bxtxd: Numpy array training data with shape:
+ batch_size x # time steps x # dimensions
+ ext_input_bxtxi: Numpy array training external input with shape:
+ batch_size x # time steps x # external input dims
+ do_eval_cost (optional): If true, the IWAE (Importance Weighted
+ Autoencoder) log likeihood bound, instead of the VAE version.
+ do_average_batch (optional): average over the batch, useful for getting
+ good IWAE costs, and model outputs for a single data point.
+
+ Returns:
+ A dictionary with the outputs of the model decoder, namely:
+ prior g0 mean, prior g0 variance, approx. posterior mean, approx
+ posterior mean, the generator initial conditions, the control inputs (if
+ enabled), the state of the generator, the factors, and the rates.
+ """
+ session = tf.get_default_session()
+
+ # if fewer than batch_size provided, pad to batch_size
+ hps = self.hps
+ batch_size = hps.batch_size
+ E, _, _ = data_bxtxd.shape
+ if E < hps.batch_size:
+ data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)),
+ mode='constant', constant_values=0)
+ if ext_input_bxtxi is not None:
+ ext_input_bxtxi = np.pad(ext_input_bxtxi,
+ ((0, hps.batch_size-E), (0, 0), (0, 0)),
+ mode='constant', constant_values=0)
+
+ feed_dict = self.build_feed_dict(data_name, data_bxtxd,
+ ext_input_bxtxi, keep_prob=1.0)
+
+ # Non-temporal signals will be batch x dim.
+ # Temporal signals are list length T with elements batch x dim.
+ tf_vals = [self.gen_ics, self.gen_states, self.factors,
+ self.output_dist_params]
+ tf_vals.append(self.cost)
+ tf_vals.append(self.nll_bound_vae)
+ tf_vals.append(self.nll_bound_iwae)
+ tf_vals.append(self.train_step) # not train_op!
+ if self.hps.ic_dim > 0:
+ tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar,
+ self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar]
+ if self.hps.co_dim > 0:
+ tf_vals.append(self.controller_outputs)
+ tf_vals_flat, fidxs = flatten(tf_vals)
+
+ np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
+
+ ff = 0
+ gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
+ train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
+ if self.hps.ic_dim > 0:
+ prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
+ prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ if self.hps.co_dim > 0:
+ controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+
+ # [0] are to take out the non-temporal items from lists
+ gen_ics = gen_ics[0]
+ costs = costs[0]
+ nll_bound_vaes = nll_bound_vaes[0]
+ nll_bound_iwaes = nll_bound_iwaes[0]
+ train_steps = train_steps[0]
+
+ # Convert to full tensors, not lists of tensors in time dim.
+ gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
+ factors = list_t_bxn_to_tensor_bxtxn(factors)
+ out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params)
+ if self.hps.ic_dim > 0:
+ # select first time point
+ prior_g0_mean = prior_g0_mean[0]
+ prior_g0_logvar = prior_g0_logvar[0]
+ post_g0_mean = post_g0_mean[0]
+ post_g0_logvar = post_g0_logvar[0]
+ if self.hps.co_dim > 0:
+ controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs)
+
+ # slice out the trials in case < batch_size provided
+ if E < hps.batch_size:
+ idx = np.arange(E)
+ gen_ics = gen_ics[idx, :]
+ gen_states = gen_states[idx, :]
+ factors = factors[idx, :, :]
+ out_dist_params = out_dist_params[idx, :, :]
+ if self.hps.ic_dim > 0:
+ prior_g0_mean = prior_g0_mean[idx, :]
+ prior_g0_logvar = prior_g0_logvar[idx, :]
+ post_g0_mean = post_g0_mean[idx, :]
+ post_g0_logvar = post_g0_logvar[idx, :]
+ if self.hps.co_dim > 0:
+ controller_outputs = controller_outputs[idx, :, :]
+
+ if do_average_batch:
+ gen_ics = np.mean(gen_ics, axis=0)
+ gen_states = np.mean(gen_states, axis=0)
+ factors = np.mean(factors, axis=0)
+ out_dist_params = np.mean(out_dist_params, axis=0)
+ if self.hps.ic_dim > 0:
+ prior_g0_mean = np.mean(prior_g0_mean, axis=0)
+ prior_g0_logvar = np.mean(prior_g0_logvar, axis=0)
+ post_g0_mean = np.mean(post_g0_mean, axis=0)
+ post_g0_logvar = np.mean(post_g0_logvar, axis=0)
+ if self.hps.co_dim > 0:
+ controller_outputs = np.mean(controller_outputs, axis=0)
+
+ model_vals = {}
+ model_vals['gen_ics'] = gen_ics
+ model_vals['gen_states'] = gen_states
+ model_vals['factors'] = factors
+ model_vals['output_dist_params'] = out_dist_params
+ model_vals['costs'] = costs
+ model_vals['nll_bound_vaes'] = nll_bound_vaes
+ model_vals['nll_bound_iwaes'] = nll_bound_iwaes
+ model_vals['train_steps'] = train_steps
+ if self.hps.ic_dim > 0:
+ model_vals['prior_g0_mean'] = prior_g0_mean
+ model_vals['prior_g0_logvar'] = prior_g0_logvar
+ model_vals['post_g0_mean'] = post_g0_mean
+ model_vals['post_g0_logvar'] = post_g0_logvar
+ if self.hps.co_dim > 0:
+ model_vals['controller_outputs'] = controller_outputs
+
+ return model_vals
+
+ def eval_model_runs_avg_epoch(self, data_name, data_extxd,
+ ext_input_extxi=None):
+ """Returns all the expected value for goodies for the entire model.
+
+ The expected value is taken over hidden (z) variables, namely the initial
+ conditions and the control inputs. The expected value is approximate, and
+ accomplished via sampling (batch_size) samples for every examples.
+
+ Args:
+ data_name: The name of the data dict, to select which in/out matrices
+ to use.
+ data_extxd: Numpy array training data with shape:
+ # examples x # time steps x # dimensions
+ ext_input_extxi (optional): Numpy array training external input with
+ shape: # examples x # time steps x # external input dims
+
+ Returns:
+ A dictionary with the averaged outputs of the model decoder, namely:
+ prior g0 mean, prior g0 variance, approx. posterior mean, approx
+ posterior mean, the generator initial conditions, the control inputs (if
+ enabled), the state of the generator, the factors, and the output
+ distribution parameters, e.g. (rates or mean and variances).
+ """
+ hps = self.hps
+ batch_size = hps.batch_size
+ E, T, D = data_extxd.shape
+ E_to_process = hps.ps_nexamples_to_process
+ if E_to_process > E:
+ E_to_process = E
+
+ if hps.ic_dim > 0:
+ prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
+ prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
+ post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
+ post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
+
+ if hps.co_dim > 0:
+ controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
+ gen_ics = np.zeros([E_to_process, hps.gen_dim])
+ gen_states = np.zeros([E_to_process, T, hps.gen_dim])
+ factors = np.zeros([E_to_process, T, hps.factors_dim])
+
+ if hps.output_dist == 'poisson':
+ out_dist_params = np.zeros([E_to_process, T, D])
+ elif hps.output_dist == 'gaussian':
+ out_dist_params = np.zeros([E_to_process, T, D+D])
+ else:
+ assert False, "NIY"
+
+ costs = np.zeros(E_to_process)
+ nll_bound_vaes = np.zeros(E_to_process)
+ nll_bound_iwaes = np.zeros(E_to_process)
+ train_steps = np.zeros(E_to_process)
+ for es_idx in range(E_to_process):
+ print("Running %d of %d." % (es_idx+1, E_to_process))
+ example_idxs = es_idx * np.ones(batch_size, dtype=np.int32)
+ data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
+ ext_input_extxi,
+ batch_size=batch_size,
+ example_idxs=example_idxs)
+ model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
+ ext_input_bxtxi,
+ do_eval_cost=True,
+ do_average_batch=True)
+
+ if self.hps.ic_dim > 0:
+ prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
+ prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
+ post_g0_mean[es_idx,:] = model_values['post_g0_mean']
+ post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
+ gen_ics[es_idx,:] = model_values['gen_ics']
+
+ if self.hps.co_dim > 0:
+ controller_outputs[es_idx,:,:] = model_values['controller_outputs']
+ gen_states[es_idx,:,:] = model_values['gen_states']
+ factors[es_idx,:,:] = model_values['factors']
+ out_dist_params[es_idx,:,:] = model_values['output_dist_params']
+ costs[es_idx] = model_values['costs']
+ nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
+ nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
+ train_steps[es_idx] = model_values['train_steps']
+ print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \
+ % (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx]))
+
+ model_runs = {}
+ if self.hps.ic_dim > 0:
+ model_runs['prior_g0_mean'] = prior_g0_mean
+ model_runs['prior_g0_logvar'] = prior_g0_logvar
+ model_runs['post_g0_mean'] = post_g0_mean
+ model_runs['post_g0_logvar'] = post_g0_logvar
+ model_runs['gen_ics'] = gen_ics
+
+ if self.hps.co_dim > 0:
+ model_runs['controller_outputs'] = controller_outputs
+ model_runs['gen_states'] = gen_states
+ model_runs['factors'] = factors
+ model_runs['output_dist_params'] = out_dist_params
+ model_runs['costs'] = costs
+ model_runs['nll_bound_vaes'] = nll_bound_vaes
+ model_runs['nll_bound_iwaes'] = nll_bound_iwaes
+ model_runs['train_steps'] = train_steps
+ return model_runs
+
+ def eval_model_runs_push_mean(self, data_name, data_extxd,
+ ext_input_extxi=None):
+ """Returns values of interest for the model by pushing the means through
+
+ The mean values for both initial conditions and the control inputs are
+ pushed through the model instead of sampling (as is done in
+ eval_model_runs_avg_epoch).
+ This is a quick and approximate version of estimating these values instead
+ of sampling from the posterior many times and then averaging those values of
+ interest.
+
+ Internally, a total of batch_size trials are run through the model at once.
+
+ Args:
+ data_name: The name of the data dict, to select which in/out matrices
+ to use.
+ data_extxd: Numpy array training data with shape:
+ # examples x # time steps x # dimensions
+ ext_input_extxi (optional): Numpy array training external input with
+ shape: # examples x # time steps x # external input dims
+
+ Returns:
+ A dictionary with the estimated outputs of the model decoder, namely:
+ prior g0 mean, prior g0 variance, approx. posterior mean, approx
+ posterior mean, the generator initial conditions, the control inputs (if
+ enabled), the state of the generator, the factors, and the output
+ distribution parameters, e.g. (rates or mean and variances).
+ """
+ hps = self.hps
+ batch_size = hps.batch_size
+ E, T, D = data_extxd.shape
+ E_to_process = hps.ps_nexamples_to_process
+ if E_to_process > E:
+ print("Setting number of posterior samples to process to : ", E)
+ E_to_process = E
+
+ if hps.ic_dim > 0:
+ prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
+ prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
+ post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
+ post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
+
+ if hps.co_dim > 0:
+ controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
+ gen_ics = np.zeros([E_to_process, hps.gen_dim])
+ gen_states = np.zeros([E_to_process, T, hps.gen_dim])
+ factors = np.zeros([E_to_process, T, hps.factors_dim])
+
+ if hps.output_dist == 'poisson':
+ out_dist_params = np.zeros([E_to_process, T, D])
+ elif hps.output_dist == 'gaussian':
+ out_dist_params = np.zeros([E_to_process, T, D+D])
+ else:
+ assert False, "NIY"
+
+ costs = np.zeros(E_to_process)
+ nll_bound_vaes = np.zeros(E_to_process)
+ nll_bound_iwaes = np.zeros(E_to_process)
+ train_steps = np.zeros(E_to_process)
+
+ # generator that will yield 0:N in groups of per items, e.g.
+ # (0:per-1), (per:2*per-1), ..., with the last group containing <= per items
+ # this will be used to feed per=batch_size trials into the model at a time
+ def trial_batches(N, per):
+ for i in range(0, N, per):
+ yield np.arange(i, min(i+per, N), dtype=np.int32)
+
+ for batch_idx, es_idx in enumerate(trial_batches(E_to_process,
+ hps.batch_size)):
+ print("Running trial batch %d with %d trials" % (batch_idx+1,
+ len(es_idx)))
+ data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
+ ext_input_extxi,
+ batch_size=batch_size,
+ example_idxs=es_idx)
+ model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
+ ext_input_bxtxi,
+ do_eval_cost=True,
+ do_average_batch=False)
+
+ if self.hps.ic_dim > 0:
+ prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
+ prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
+ post_g0_mean[es_idx,:] = model_values['post_g0_mean']
+ post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
+ gen_ics[es_idx,:] = model_values['gen_ics']
+
+ if self.hps.co_dim > 0:
+ controller_outputs[es_idx,:,:] = model_values['controller_outputs']
+ gen_states[es_idx,:,:] = model_values['gen_states']
+ factors[es_idx,:,:] = model_values['factors']
+ out_dist_params[es_idx,:,:] = model_values['output_dist_params']
+
+ # TODO
+ # model_values['costs'] and other costs come out as scalars, summed over
+ # all the trials in the batch. what we want is the per-trial costs
+ costs[es_idx] = model_values['costs']
+ nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
+ nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
+
+ train_steps[es_idx] = model_values['train_steps']
+
+ model_runs = {}
+ if self.hps.ic_dim > 0:
+ model_runs['prior_g0_mean'] = prior_g0_mean
+ model_runs['prior_g0_logvar'] = prior_g0_logvar
+ model_runs['post_g0_mean'] = post_g0_mean
+ model_runs['post_g0_logvar'] = post_g0_logvar
+ model_runs['gen_ics'] = gen_ics
+
+ if self.hps.co_dim > 0:
+ model_runs['controller_outputs'] = controller_outputs
+ model_runs['gen_states'] = gen_states
+ model_runs['factors'] = factors
+ model_runs['output_dist_params'] = out_dist_params
+
+ # You probably do not want the LL associated values when pushing the mean
+ # instead of sampling.
+ model_runs['costs'] = costs
+ model_runs['nll_bound_vaes'] = nll_bound_vaes
+ model_runs['nll_bound_iwaes'] = nll_bound_iwaes
+ model_runs['train_steps'] = train_steps
+ return model_runs
+
+ def write_model_runs(self, datasets, output_fname=None, push_mean=False):
+ """Run the model on the data in data_dict, and save the computed values.
+
+ LFADS generates a number of outputs for each examples, and these are all
+ saved. They are:
+ The mean and variance of the prior of g0.
+ The mean and variance of approximate posterior of g0.
+ The control inputs (if enabled)
+ The initial conditions, g0, for all examples.
+ The generator states for all time.
+ The factors for all time.
+ The output distribution parameters (e.g. rates) for all time.
+
+ Args:
+ datasets: a dictionary of named data_dictionaries, see top of lfads.py
+ output_fname: a file name stem for the output files.
+ push_mean: if False (default), generates batch_size samples for each trial
+ and averages the results. if True, runs each trial once without noise,
+ pushing the posterior mean initial conditions and control inputs through
+ the trained model. False is used for posterior_sample_and_average, True
+ is used for posterior_push_mean.
+ """
+ hps = self.hps
+ kind = hps.kind
+
+ for data_name, data_dict in datasets.items():
+ data_tuple = [('train', data_dict['train_data'],
+ data_dict['train_ext_input']),
+ ('valid', data_dict['valid_data'],
+ data_dict['valid_ext_input'])]
+ for data_kind, data_extxd, ext_input_extxi in data_tuple:
+ if not output_fname:
+ fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind
+ else:
+ fname = output_fname + data_name + '_' + data_kind + '_' + kind
+
+ print("Writing data for %s data and kind %s." % (data_name, data_kind))
+ if push_mean:
+ model_runs = self.eval_model_runs_push_mean(data_name, data_extxd,
+ ext_input_extxi)
+ else:
+ model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd,
+ ext_input_extxi)
+ full_fname = os.path.join(hps.lfads_save_dir, fname)
+ write_data(full_fname, model_runs, compression='gzip')
+ print("Done.")
+
+ def write_model_samples(self, dataset_name, output_fname=None):
+ """Use the prior distribution to generate batch_size number of samples
+ from the model.
+
+ LFADS generates a number of outputs for each sample, and these are all
+ saved. They are:
+ The mean and variance of the prior of g0.
+ The control inputs (if enabled)
+ The initial conditions, g0, for all examples.
+ The generator states for all time.
+ The factors for all time.
+ The output distribution parameters (e.g. rates) for all time.
+
+ Args:
+ dataset_name: The name of the dataset to grab the factors -> rates
+ alignment matrices from.
+ output_fname: The name of the file in which to save the generated
+ samples.
+ """
+ hps = self.hps
+ batch_size = hps.batch_size
+
+ print("Generating %d samples" % (batch_size))
+ tf_vals = [self.factors, self.gen_states, self.gen_ics,
+ self.cost, self.output_dist_params]
+ if hps.ic_dim > 0:
+ tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar]
+ if hps.co_dim > 0:
+ tf_vals += [self.prior_zs_ar_con.samples_t]
+ tf_vals_flat, fidxs = flatten(tf_vals)
+
+ session = tf.get_default_session()
+ feed_dict = {}
+ feed_dict[self.dataName] = dataset_name
+ feed_dict[self.keep_prob] = 1.0
+
+ np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
+
+ ff = 0
+ factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ if hps.ic_dim > 0:
+ prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+ if hps.co_dim > 0:
+ prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
+
+ # [0] are to take out the non-temporal items from lists
+ gen_ics = gen_ics[0]
+ costs = costs[0]
+
+ # Convert to full tensors, not lists of tensors in time dim.
+ gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
+ factors = list_t_bxn_to_tensor_bxtxn(factors)
+ output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params)
+ if hps.ic_dim > 0:
+ prior_g0_mean = prior_g0_mean[0]
+ prior_g0_logvar = prior_g0_logvar[0]
+ if hps.co_dim > 0:
+ prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con)
+
+ model_vals = {}
+ model_vals['gen_ics'] = gen_ics
+ model_vals['gen_states'] = gen_states
+ model_vals['factors'] = factors
+ model_vals['output_dist_params'] = output_dist_params
+ model_vals['costs'] = costs.reshape(1)
+ if hps.ic_dim > 0:
+ model_vals['prior_g0_mean'] = prior_g0_mean
+ model_vals['prior_g0_logvar'] = prior_g0_logvar
+ if hps.co_dim > 0:
+ model_vals['prior_zs_ar_con'] = prior_zs_ar_con
+
+ full_fname = os.path.join(hps.lfads_save_dir, output_fname)
+ write_data(full_fname, model_vals, compression='gzip')
+ print("Done.")
+
+ @staticmethod
+ def eval_model_parameters(use_nested=True, include_strs=None):
+ """Evaluate and return all of the TF variables in the model.
+
+ Args:
+ use_nested (optional): For returning values, use a nested dictoinary, based
+ on variable scoping, or return all variables in a flat dictionary.
+ include_strs (optional): A list of strings to use as a filter, to reduce the
+ number of variables returned. A variable name must contain at least one
+ string in include_strs as a sub-string in order to be returned.
+
+ Returns:
+ The parameters of the model. This can be in a flat
+ dictionary, or a nested dictionary, where the nesting is by variable
+ scope.
+ """
+ all_tf_vars = tf.global_variables()
+ session = tf.get_default_session()
+ all_tf_vars_eval = session.run(all_tf_vars)
+ vars_dict = {}
+ strs = ["LFADS"]
+ if include_strs:
+ strs += include_strs
+
+ for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
+ if any(s in include_strs for s in var.name):
+ if not isinstance(var_eval, np.ndarray): # for H5PY
+ print(var.name, """ is not numpy array, saving as numpy array
+ with value: """, var_eval, type(var_eval))
+ e = np.array(var_eval)
+ print(e, type(e))
+ else:
+ e = var_eval
+ vars_dict[var.name] = e
+
+ if not use_nested:
+ return vars_dict
+
+ var_names = vars_dict.keys()
+ nested_vars_dict = {}
+ current_dict = nested_vars_dict
+ for v, var_name in enumerate(var_names):
+ var_split_name_list = var_name.split('/')
+ split_name_list_len = len(var_split_name_list)
+ current_dict = nested_vars_dict
+ for p, part in enumerate(var_split_name_list):
+ if p < split_name_list_len - 1:
+ if part in current_dict:
+ current_dict = current_dict[part]
+ else:
+ current_dict[part] = {}
+ current_dict = current_dict[part]
+ else:
+ current_dict[part] = vars_dict[var_name]
+
+ return nested_vars_dict
+
+ @staticmethod
+ def spikify_rates(rates_bxtxd):
+ """Randomly spikify underlying rates according a Poisson distribution
+
+ Args:
+ rates_bxtxd: a numpy tensor with shape:
+
+ Returns:
+ A numpy array with the same shape as rates_bxtxd, but with the event
+ counts.
+ """
+
+ B,T,N = rates_bxtxd.shape
+ assert all([B > 0, N > 0]), "problems"
+
+ # Because the rates are changing, there is nesting
+ spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32)
+ for b in range(B):
+ for t in range(T):
+ for n in range(N):
+ rate = rates_bxtxd[b,t,n]
+ count = np.random.poisson(rate)
+ spikes_bxtxd[b,t,n] = count
+
+ return spikes_bxtxd
diff --git a/models/research/lfads/plot_lfads.py b/models/research/lfads/plot_lfads.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e1a0332ef2affeae147edda4779cc4a7e9a0ef
--- /dev/null
+++ b/models/research/lfads/plot_lfads.py
@@ -0,0 +1,181 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import matplotlib
+matplotlib.use('Agg')
+from matplotlib import pyplot as plt
+import numpy as np
+import tensorflow as tf
+
+def _plot_item(W, name, full_name, nspaces):
+ plt.figure()
+ if W.shape == ():
+ print(name, ": ", W)
+ elif W.shape[0] == 1:
+ plt.stem(W.T)
+ plt.title(full_name)
+ elif W.shape[1] == 1:
+ plt.stem(W)
+ plt.title(full_name)
+ else:
+ plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
+ plt.colorbar()
+ plt.title(full_name)
+
+
+def all_plot(d, full_name="", exclude="", nspaces=0):
+ """Recursively plot all the LFADS model parameters in the nested
+ dictionary."""
+ for k, v in d.iteritems():
+ this_name = full_name+"/"+k
+ if isinstance(v, dict):
+ all_plot(v, full_name=this_name, exclude=exclude, nspaces=nspaces+4)
+ else:
+ if exclude == "" or exclude not in this_name:
+ _plot_item(v, name=k, full_name=full_name+"/"+k, nspaces=nspaces+4)
+
+
+
+def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0,
+ color='r', title=None):
+
+ if bidx is None:
+ vals_txn = np.mean(vals_bxtxn, axis=0)
+ else:
+ vals_txn = vals_bxtxn[bidx,:,:]
+
+ T, N = vals_txn.shape
+ if n_to_plot > N:
+ n_to_plot = N
+
+ plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)),
+ color=color, lw=1.0)
+ plt.axis('tight')
+ if title:
+ plt.title(title)
+
+
+def plot_lfads_timeseries(data_bxtxn, model_vals, ext_input_bxtxi=None,
+ truth_bxtxn=None, bidx=None, output_dist="poisson",
+ conversion_factor=1.0, subplot_cidx=0,
+ col_title=None):
+
+ n_to_plot = 10
+ scale = 1.0
+ nrows = 7
+ plt.subplot(nrows,2,1+subplot_cidx)
+
+ if output_dist == 'poisson':
+ rates = means = conversion_factor * model_vals['output_dist_params']
+ plot_time_series(rates, bidx, n_to_plot=n_to_plot, scale=scale,
+ title=col_title + " rates (LFADS - red, Truth - black)")
+ elif output_dist == 'gaussian':
+ means_vars = model_vals['output_dist_params']
+ means, vars = np.split(means_vars,2, axis=2) # bxtxn
+ stds = np.sqrt(vars)
+ plot_time_series(means, bidx, n_to_plot=n_to_plot, scale=scale,
+ title=col_title + " means (LFADS - red, Truth - black)")
+ plot_time_series(means+stds, bidx, n_to_plot=n_to_plot, scale=scale,
+ color='c')
+ plot_time_series(means-stds, bidx, n_to_plot=n_to_plot, scale=scale,
+ color='c')
+ else:
+ assert 'NIY'
+
+
+ if truth_bxtxn is not None:
+ plot_time_series(truth_bxtxn, bidx, n_to_plot=n_to_plot, color='k',
+ scale=scale)
+
+ input_title = ""
+ if "controller_outputs" in model_vals.keys():
+ input_title += " Controller Output"
+ plt.subplot(nrows,2,3+subplot_cidx)
+ u_t = model_vals['controller_outputs'][0:-1]
+ plot_time_series(u_t, bidx, n_to_plot=n_to_plot, color='c', scale=1.0,
+ title=col_title + input_title)
+
+ if ext_input_bxtxi is not None:
+ input_title += " External Input"
+ plot_time_series(ext_input_bxtxi, n_to_plot=n_to_plot, color='b',
+ scale=scale, title=col_title + input_title)
+
+ plt.subplot(nrows,2,5+subplot_cidx)
+ plot_time_series(means, bidx,
+ n_to_plot=n_to_plot, scale=1.0,
+ title=col_title + " Spikes (LFADS - red, Spikes - black)")
+ plot_time_series(data_bxtxn, bidx, n_to_plot=n_to_plot, color='k', scale=1.0)
+
+ plt.subplot(nrows,2,7+subplot_cidx)
+ plot_time_series(model_vals['factors'], bidx, n_to_plot=n_to_plot, color='b',
+ scale=2.0, title=col_title + " Factors")
+
+ plt.subplot(nrows,2,9+subplot_cidx)
+ plot_time_series(model_vals['gen_states'], bidx, n_to_plot=n_to_plot,
+ color='g', scale=1.0, title=col_title + " Generator State")
+
+ if bidx is not None:
+ data_nxt = data_bxtxn[bidx,:,:].T
+ params_nxt = model_vals['output_dist_params'][bidx,:,:].T
+ else:
+ data_nxt = np.mean(data_bxtxn, axis=0).T
+ params_nxt = np.mean(model_vals['output_dist_params'], axis=0).T
+ if output_dist == 'poisson':
+ means_nxt = params_nxt
+ elif output_dist == 'gaussian': # (means+vars) x time
+ means_nxt = np.vsplit(params_nxt,2)[0] # get means
+ else:
+ assert "NIY"
+
+ plt.subplot(nrows,2,11+subplot_cidx)
+ plt.imshow(data_nxt, aspect='auto', interpolation='nearest')
+ plt.title(col_title + ' Data')
+
+ plt.subplot(nrows,2,13+subplot_cidx)
+ plt.imshow(means_nxt, aspect='auto', interpolation='nearest')
+ plt.title(col_title + ' Means')
+
+
+def plot_lfads(train_bxtxd, train_model_vals,
+ train_ext_input_bxtxi=None, train_truth_bxtxd=None,
+ valid_bxtxd=None, valid_model_vals=None,
+ valid_ext_input_bxtxi=None, valid_truth_bxtxd=None,
+ bidx=None, cf=1.0, output_dist='poisson'):
+
+ # Plotting
+ f = plt.figure(figsize=(18,20), tight_layout=True)
+ plot_lfads_timeseries(train_bxtxd, train_model_vals,
+ train_ext_input_bxtxi,
+ truth_bxtxn=train_truth_bxtxd,
+ conversion_factor=cf, bidx=bidx,
+ output_dist=output_dist, col_title='Train')
+ plot_lfads_timeseries(valid_bxtxd, valid_model_vals,
+ valid_ext_input_bxtxi,
+ truth_bxtxn=valid_truth_bxtxd,
+ conversion_factor=cf, bidx=bidx,
+ output_dist=output_dist,
+ subplot_cidx=1, col_title='Valid')
+
+ # Convert from figure to an numpy array width x height x 3 (last for RGB)
+ f.canvas.draw()
+ data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,))
+ plt.close()
+
+ return data_wxhx3
diff --git a/models/research/lfads/run_lfads.py b/models/research/lfads/run_lfads.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd1c0d5e4deab50481cd32efdd044c61707204cc
--- /dev/null
+++ b/models/research/lfads/run_lfads.py
@@ -0,0 +1,815 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from lfads import LFADS
+import numpy as np
+import os
+import tensorflow as tf
+import re
+import utils
+import sys
+MAX_INT = sys.maxsize
+
+# Lots of hyperparameters, but most are pretty insensitive. The
+# explanation of these hyperparameters is found below, in the flags
+# session.
+
+CHECKPOINT_PB_LOAD_NAME = "checkpoint"
+CHECKPOINT_NAME = "lfads_vae"
+CSV_LOG = "fitlog"
+OUTPUT_FILENAME_STEM = ""
+DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1"
+MAX_CKPT_TO_KEEP = 5
+MAX_CKPT_TO_KEEP_LVE = 5
+PS_NEXAMPLES_TO_PROCESS = MAX_INT # if larger than number of examples, process all
+EXT_INPUT_DIM = 0
+IC_DIM = 64
+FACTORS_DIM = 50
+IC_ENC_DIM = 128
+GEN_DIM = 200
+GEN_CELL_INPUT_WEIGHT_SCALE = 1.0
+GEN_CELL_REC_WEIGHT_SCALE = 1.0
+CELL_WEIGHT_SCALE = 1.0
+BATCH_SIZE = 128
+LEARNING_RATE_INIT = 0.01
+LEARNING_RATE_DECAY_FACTOR = 0.95
+LEARNING_RATE_STOP = 0.00001
+LEARNING_RATE_N_TO_COMPARE = 6
+INJECT_EXT_INPUT_TO_GEN = False
+DO_TRAIN_IO_ONLY = False
+DO_TRAIN_ENCODER_ONLY = False
+DO_RESET_LEARNING_RATE = False
+FEEDBACK_FACTORS_OR_RATES = "factors"
+DO_TRAIN_READIN = True
+
+# Calibrated just above the average value for the rnn synthetic data.
+MAX_GRAD_NORM = 200.0
+CELL_CLIP_VALUE = 5.0
+KEEP_PROB = 0.95
+TEMPORAL_SPIKE_JITTER_WIDTH = 0
+OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian'
+NUM_STEPS_FOR_GEN_IC = MAX_INT # set to num_steps if greater than num_steps
+
+DATA_DIR = "/tmp/rnn_synth_data_v1.0/"
+DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5"
+LFADS_SAVE_DIR = "/tmp/lfads_chaotic_rnn_inputs_g1p5/"
+CO_DIM = 1
+DO_CAUSAL_CONTROLLER = False
+DO_FEED_FACTORS_TO_CONTROLLER = True
+CONTROLLER_INPUT_LAG = 1
+PRIOR_AR_AUTOCORRELATION = 10.0
+PRIOR_AR_PROCESS_VAR = 0.1
+DO_TRAIN_PRIOR_AR_ATAU = True
+DO_TRAIN_PRIOR_AR_NVAR = True
+CI_ENC_DIM = 128
+CON_DIM = 128
+CO_PRIOR_VAR_SCALE = 0.1
+KL_INCREASE_STEPS = 2000
+L2_INCREASE_STEPS = 2000
+L2_GEN_SCALE = 2000.0
+L2_CON_SCALE = 0.0
+# scale of regularizer on time correlation of inferred inputs
+CO_MEAN_CORR_SCALE = 0.0
+KL_IC_WEIGHT = 1.0
+KL_CO_WEIGHT = 1.0
+KL_START_STEP = 0
+L2_START_STEP = 0
+IC_PRIOR_VAR_MIN = 0.1
+IC_PRIOR_VAR_SCALE = 0.1
+IC_PRIOR_VAR_MAX = 0.1
+IC_POST_VAR_MIN = 0.0001 # protection from KL blowing up
+
+flags = tf.app.flags
+flags.DEFINE_string("kind", "train",
+ "Type of model to build {train, \
+ posterior_sample_and_average, \
+ posterior_push_mean, \
+ prior_sample, write_model_params")
+flags.DEFINE_string("output_dist", OUTPUT_DISTRIBUTION,
+ "Type of output distribution, 'poisson' or 'gaussian'")
+flags.DEFINE_boolean("allow_gpu_growth", False,
+ "If true, only allocate amount of memory needed for \
+ Session. Otherwise, use full GPU memory.")
+
+# DATA
+flags.DEFINE_string("data_dir", DATA_DIR, "Data for training")
+flags.DEFINE_string("data_filename_stem", DATA_FILENAME_STEM,
+ "Filename stem for data dictionaries.")
+flags.DEFINE_string("lfads_save_dir", LFADS_SAVE_DIR, "model save dir")
+flags.DEFINE_string("checkpoint_pb_load_name", CHECKPOINT_PB_LOAD_NAME,
+ "Name of checkpoint files, use 'checkpoint_lve' for best \
+ error")
+flags.DEFINE_string("checkpoint_name", CHECKPOINT_NAME,
+ "Name of checkpoint files (.ckpt appended)")
+flags.DEFINE_string("output_filename_stem", OUTPUT_FILENAME_STEM,
+ "Name of output file (postfix will be added)")
+flags.DEFINE_string("device", DEVICE,
+ "Which device to use (default: \"gpu:0\", can also be \
+ \"cpu:0\", \"gpu:1\", etc)")
+flags.DEFINE_string("csv_log", CSV_LOG,
+ "Name of file to keep running log of fit likelihoods, \
+ etc (.csv appended)")
+flags.DEFINE_integer("max_ckpt_to_keep", MAX_CKPT_TO_KEEP,
+ "Max # of checkpoints to keep (rolling)")
+flags.DEFINE_integer("ps_nexamples_to_process", PS_NEXAMPLES_TO_PROCESS,
+ "Number of examples to process for posterior sample and \
+ average (not number of samples to average over).")
+flags.DEFINE_integer("max_ckpt_to_keep_lve", MAX_CKPT_TO_KEEP_LVE,
+ "Max # of checkpoints to keep for lowest validation error \
+ models (rolling)")
+flags.DEFINE_integer("ext_input_dim", EXT_INPUT_DIM, "Dimension of external \
+inputs")
+flags.DEFINE_integer("num_steps_for_gen_ic", NUM_STEPS_FOR_GEN_IC,
+ "Number of steps to train the generator initial conditon.")
+
+
+# If there are observed inputs, there are two ways to add that observed
+# input to the model. The first is by treating as something to be
+# inferred, and thus encoding the observed input via the encoders, and then
+# input to the generator via the "inferred inputs" channel. Second, one
+# can input the input directly into the generator. This has the downside
+# of making the generation process strictly dependent on knowing the
+# observed input for any generated trial.
+flags.DEFINE_boolean("inject_ext_input_to_gen",
+ INJECT_EXT_INPUT_TO_GEN,
+ "Should observed inputs be input to model via encoders, \
+ or injected directly into generator?")
+
+# CELL
+
+# The combined recurrent and input weights of the encoder and
+# controller cells are by default set to scale at ws/sqrt(#inputs),
+# with ws=1.0. You can change this scaling with this parameter.
+flags.DEFINE_float("cell_weight_scale", CELL_WEIGHT_SCALE,
+ "Input scaling for input weights in generator.")
+
+
+# GENERATION
+
+# Note that the dimension of the initial conditions is separated from the
+# dimensions of the generator initial conditions (and a linear matrix will
+# adapt the shapes if necessary). This is just another way to control
+# complexity. In all likelihood, setting the ic dims to the size of the
+# generator hidden state is just fine.
+flags.DEFINE_integer("ic_dim", IC_DIM, "Dimension of h0")
+# Setting the dimensions of the factors to something smaller than the data
+# dimension is a way to get a reduced dimensionality representation of your
+# data.
+flags.DEFINE_integer("factors_dim", FACTORS_DIM,
+ "Number of factors from generator")
+flags.DEFINE_integer("ic_enc_dim", IC_ENC_DIM,
+ "Cell hidden size, encoder of h0")
+
+# Controlling the size of the generator is one way to control complexity of
+# the dynamics (there is also l2, which will squeeze out unnecessary
+# dynamics also). The modern deep learning approach is to make these cells
+# as large as tolerable (from a waiting perspective), and then regularize
+# them to death with drop out or whatever. I don't know if this is correct
+# for the LFADS application or not.
+flags.DEFINE_integer("gen_dim", GEN_DIM,
+ "Cell hidden size, generator.")
+# The weights of the generator cell by default set to scale at
+# ws/sqrt(#inputs), with ws=1.0. You can change ws for
+# the input weights or the recurrent weights with these hyperparameters.
+flags.DEFINE_float("gen_cell_input_weight_scale", GEN_CELL_INPUT_WEIGHT_SCALE,
+ "Input scaling for input weights in generator.")
+flags.DEFINE_float("gen_cell_rec_weight_scale", GEN_CELL_REC_WEIGHT_SCALE,
+ "Input scaling for rec weights in generator.")
+
+# KL DISTRIBUTIONS
+# If you don't know what you are donig here, please leave alone, the
+# defaults should be fine for most cases, irregardless of other parameters.
+#
+# If you don't want the prior variance to be learned, set the
+# following values to the same thing: ic_prior_var_min,
+# ic_prior_var_scale, ic_prior_var_max. The prior mean will be
+# learned regardless.
+flags.DEFINE_float("ic_prior_var_min", IC_PRIOR_VAR_MIN,
+ "Minimum variance in posterior h0 codes.")
+flags.DEFINE_float("ic_prior_var_scale", IC_PRIOR_VAR_SCALE,
+ "Variance of ic prior distribution")
+flags.DEFINE_float("ic_prior_var_max", IC_PRIOR_VAR_MAX,
+ "Maximum variance of IC prior distribution.")
+# If you really want to limit the information from encoder to decoder,
+# Increase ic_post_var_min above 0.0.
+flags.DEFINE_float("ic_post_var_min", IC_POST_VAR_MIN,
+ "Minimum variance of IC posterior distribution.")
+flags.DEFINE_float("co_prior_var_scale", CO_PRIOR_VAR_SCALE,
+ "Variance of control input prior distribution.")
+
+
+flags.DEFINE_float("prior_ar_atau", PRIOR_AR_AUTOCORRELATION,
+ "Initial autocorrelation of AR(1) priors.")
+flags.DEFINE_float("prior_ar_nvar", PRIOR_AR_PROCESS_VAR,
+ "Initial noise variance for AR(1) priors.")
+flags.DEFINE_boolean("do_train_prior_ar_atau", DO_TRAIN_PRIOR_AR_ATAU,
+ "Is the value for atau an init, or the constant value?")
+flags.DEFINE_boolean("do_train_prior_ar_nvar", DO_TRAIN_PRIOR_AR_NVAR,
+ "Is the value for noise variance an init, or the constant \
+ value?")
+
+# CONTROLLER
+# This parameter critically controls whether or not there is a controller
+# (along with controller encoders placed into the LFADS graph. If CO_DIM >
+# 1, that means there is a 1 dimensional controller outputs, if equal to 0,
+# then no controller.
+flags.DEFINE_integer("co_dim", CO_DIM,
+ "Number of control net outputs (>0 builds that graph).")
+
+# The controller will be more powerful if it can see the encoding of the entire
+# trial. However, this allows the controller to create inferred inputs that are
+# acausal with respect to the actual data generation process. E.g. the data
+# generator could have an input at time t, but the controller, after seeing the
+# entirety of the trial could infer that the input is coming a little before
+# time t, because there are no restrictions on the data the controller sees.
+# One can force the controller to be causal (with respect to perturbations in
+# the data generator) so that it only sees forward encodings of the data at time
+# t that originate at times before or at time t. One can also control the data
+# the controller sees by using an input lag (forward encoding at time [t-tlag]
+# for controller input at time t. The same can be done in the reverse direction
+# (controller input at time t from reverse encoding at time [t+tlag], in the
+# case of an acausal controller). Setting this lag > 0 (even lag=1) can be a
+# powerful way of avoiding very spiky decodes. Finally, one can manually control
+# whether the factors at time t-1 are fed to the controller at time t.
+#
+# If you don't care about any of this, and just want to smooth your data, set
+# do_causal_controller = False
+# do_feed_factors_to_controller = True
+# causal_input_lag = 0
+flags.DEFINE_boolean("do_causal_controller",
+ DO_CAUSAL_CONTROLLER,
+ "Restrict the controller create only causal inferred \
+ inputs?")
+# Strictly speaking, feeding either the factors or the rates to the controller
+# violates causality, since the g0 gets to see all the data. This may or may not
+# be only a theoretical concern.
+flags.DEFINE_boolean("do_feed_factors_to_controller",
+ DO_FEED_FACTORS_TO_CONTROLLER,
+ "Should factors[t-1] be input to controller at time t?")
+flags.DEFINE_string("feedback_factors_or_rates", FEEDBACK_FACTORS_OR_RATES,
+ "Feedback the factors or the rates to the controller? \
+ Acceptable values: 'factors' or 'rates'.")
+flags.DEFINE_integer("controller_input_lag", CONTROLLER_INPUT_LAG,
+ "Time lag on the encoding to controller t-lag for \
+ forward, t+lag for reverse.")
+
+flags.DEFINE_integer("ci_enc_dim", CI_ENC_DIM,
+ "Cell hidden size, encoder of control inputs")
+flags.DEFINE_integer("con_dim", CON_DIM,
+ "Cell hidden size, controller")
+
+
+# OPTIMIZATION
+flags.DEFINE_integer("batch_size", BATCH_SIZE,
+ "Batch size to use during training.")
+flags.DEFINE_float("learning_rate_init", LEARNING_RATE_INIT,
+ "Learning rate initial value")
+flags.DEFINE_float("learning_rate_decay_factor", LEARNING_RATE_DECAY_FACTOR,
+ "Learning rate decay, decay by this fraction every so \
+ often.")
+flags.DEFINE_float("learning_rate_stop", LEARNING_RATE_STOP,
+ "The lr is adaptively reduced, stop training at this value.")
+# Rather put the learning rate on an exponentially decreasiong schedule,
+# the current algorithm pays attention to the learning rate, and if it
+# isn't regularly decreasing, it will decrease the learning rate. So far,
+# it works fine, though it is not perfect.
+flags.DEFINE_integer("learning_rate_n_to_compare", LEARNING_RATE_N_TO_COMPARE,
+ "Number of previous costs current cost has to be worse \
+ than, to lower learning rate.")
+
+# This sets a value, above which, the gradients will be clipped. This hp
+# is extremely useful to avoid an infrequent, but highly pathological
+# problem whereby the gradient is so large that it destroys the
+# optimziation by setting parameters too large, leading to a vicious cycle
+# that ends in NaNs. If it's too large, it's useless, if it's too small,
+# it essentially becomes the learning rate. It's pretty insensitive, though.
+flags.DEFINE_float("max_grad_norm", MAX_GRAD_NORM,
+ "Max norm of gradient before clipping.")
+
+# If your optimizations start "NaN-ing out", reduce this value so that
+# the values of the network don't grow out of control. Typically, once
+# this parameter is set to a reasonable value, one stops having numerical
+# problems.
+flags.DEFINE_float("cell_clip_value", CELL_CLIP_VALUE,
+ "Max value recurrent cell can take before being clipped.")
+
+# This flag is used for an experiment where one sees if training a model with
+# many days data can be used to learn the dynamics from a held-out days data.
+# If you don't care about that particular experiment, this flag should always be
+# false.
+flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
+ "Train only the input (readin) and output (readout) \
+ affine functions.")
+
+# This flag is used for an experiment where one wants to know if the dynamics
+# learned by the generator generalize across conditions. In that case, you might
+# train up a model on one set of data, and then only further train the encoder
+# on another set of data (the conditions to be tested) so that the model is
+# forced to use the same dynamics to describe that data. If you don't care about
+# that particular experiment, this flag should always be false.
+flags.DEFINE_boolean("do_train_encoder_only", DO_TRAIN_ENCODER_ONLY,
+ "Train only the encoder weights.")
+
+flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
+ "Reset the learning rate to initial value.")
+
+
+# for multi-session "stitching" models, the per-session readin matrices map from
+# neurons to input factors which are fed into the shared encoder. These are
+# initialized by alignment_matrix_cxf and alignment_bias_c in the input .h5
+# files. They can be fixed or made trainable.
+flags.DEFINE_boolean("do_train_readin", DO_TRAIN_READIN, "Whether to train the \
+ readin matrices and bias vectors. False leaves them fixed \
+ at their initial values specified by the alignment \
+ matrices and vectors.")
+
+
+# OVERFITTING
+# Dropout is done on the input data, on controller inputs (from
+# encoder), on outputs from generator to factors.
+flags.DEFINE_float("keep_prob", KEEP_PROB, "Dropout keep probability.")
+# It appears that the system will happily fit spikes (blessing or
+# curse, depending). You may not want this. Jittering the spikes a
+# bit will help (-/+ bin size, as specified here).
+flags.DEFINE_integer("temporal_spike_jitter_width",
+ TEMPORAL_SPIKE_JITTER_WIDTH,
+ "Shuffle spikes around this window.")
+
+# General note about helping ascribe controller inputs vs dynamics:
+#
+# If controller is heavily penalized, then it won't have any output.
+# If dynamics are heavily penalized, then generator won't make
+# dynamics. Note this l2 penalty is only on the recurrent portion of
+# the RNNs, as dropout is also available, penalizing the feed-forward
+# connections.
+flags.DEFINE_float("l2_gen_scale", L2_GEN_SCALE,
+ "L2 regularization cost for the generator only.")
+flags.DEFINE_float("l2_con_scale", L2_CON_SCALE,
+ "L2 regularization cost for the controller only.")
+flags.DEFINE_float("co_mean_corr_scale", CO_MEAN_CORR_SCALE,
+ "Cost of correlation (thru time)in the means of \
+ controller output.")
+
+# UNDERFITTING
+# If the primary task of LFADS is "filtering" of data and not
+# generation, then it is possible that the KL penalty is too strong.
+# Empirically, we have found this to be the case. So we add a
+# hyperparameter in front of the the two KL terms (one for the initial
+# conditions to the generator, the other for the controller outputs).
+# You should always think of the the default values as 1.0, and that
+# leads to a standard VAE formulation whereby the numbers that are
+# optimized are a lower-bound on the log-likelihood of the data. When
+# these 2 HPs deviate from 1.0, one cannot make any statement about
+# what those LL lower bounds mean anymore, and they cannot be compared
+# (AFAIK).
+flags.DEFINE_float("kl_ic_weight", KL_IC_WEIGHT,
+ "Strength of KL weight on initial conditions KL penatly.")
+flags.DEFINE_float("kl_co_weight", KL_CO_WEIGHT,
+ "Strength of KL weight on controller output KL penalty.")
+
+# Sometimes the task can be sufficiently hard to learn that the
+# optimizer takes the 'easy route', and simply minimizes the KL
+# divergence, setting it to near zero, and the optimization gets
+# stuck. These two parameters will help avoid that by by getting the
+# optimization to 'latch' on to the main optimization, and only
+# turning in the regularizers later.
+flags.DEFINE_integer("kl_start_step", KL_START_STEP,
+ "Start increasing weight after this many steps.")
+# training passes, not epochs, increase by 0.5 every kl_increase_steps
+flags.DEFINE_integer("kl_increase_steps", KL_INCREASE_STEPS,
+ "Increase weight of kl cost to avoid local minimum.")
+# Same story for l2 regularizer. One wants a simple generator, for scientific
+# reasons, but not at the expense of hosing the optimization.
+flags.DEFINE_integer("l2_start_step", L2_START_STEP,
+ "Start increasing l2 weight after this many steps.")
+flags.DEFINE_integer("l2_increase_steps", L2_INCREASE_STEPS,
+ "Increase weight of l2 cost to avoid local minimum.")
+
+FLAGS = flags.FLAGS
+
+
+def build_model(hps, kind="train", datasets=None):
+ """Builds a model from either random initialization, or saved parameters.
+
+ Args:
+ hps: The hyper parameters for the model.
+ kind: (optional) The kind of model to build. Training vs inference require
+ different graphs.
+ datasets: The datasets structure (see top of lfads.py).
+
+ Returns:
+ an LFADS model.
+ """
+
+ build_kind = kind
+ if build_kind == "write_model_params":
+ build_kind = "train"
+ with tf.variable_scope("LFADS", reuse=None):
+ model = LFADS(hps, kind=build_kind, datasets=datasets)
+
+ if not os.path.exists(hps.lfads_save_dir):
+ print("Save directory %s does not exist, creating it." % hps.lfads_save_dir)
+ os.makedirs(hps.lfads_save_dir)
+
+ cp_pb_ln = hps.checkpoint_pb_load_name
+ cp_pb_ln = 'checkpoint' if cp_pb_ln == "" else cp_pb_ln
+ if cp_pb_ln == 'checkpoint':
+ print("Loading latest training checkpoint in: ", hps.lfads_save_dir)
+ saver = model.seso_saver
+ elif cp_pb_ln == 'checkpoint_lve':
+ print("Loading lowest validation checkpoint in: ", hps.lfads_save_dir)
+ saver = model.lve_saver
+ else:
+ print("Loading checkpoint: ", cp_pb_ln, ", in: ", hps.lfads_save_dir)
+ saver = model.seso_saver
+
+ ckpt = tf.train.get_checkpoint_state(hps.lfads_save_dir,
+ latest_filename=cp_pb_ln)
+
+ session = tf.get_default_session()
+ print("ckpt: ", ckpt)
+ if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
+ print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
+ saver.restore(session, ckpt.model_checkpoint_path)
+ else:
+ print("Created model with fresh parameters.")
+ if kind in ["posterior_sample_and_average", "posterior_push_mean",
+ "prior_sample", "write_model_params"]:
+ print("Possible error!!! You are running ", kind, " on a newly \
+ initialized model!")
+ # cannot print ckpt.model_check_point path if no ckpt
+ print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir,
+ " exists?")
+
+ tf.global_variables_initializer().run()
+
+ if ckpt:
+ train_step_str = re.search('-[0-9]+$', ckpt.model_checkpoint_path).group()
+ else:
+ train_step_str = '-0'
+
+ fname = 'hyperparameters' + train_step_str + '.txt'
+ hp_fname = os.path.join(hps.lfads_save_dir, fname)
+ hps_for_saving = jsonify_dict(hps)
+ utils.write_data(hp_fname, hps_for_saving, use_json=True)
+
+ return model
+
+
+def jsonify_dict(d):
+ """Turns python booleans into strings so hps dict can be written in json.
+ Creates a shallow-copied dictionary first, then accomplishes string
+ conversion.
+
+ Args:
+ d: hyperparameter dictionary
+
+ Returns: hyperparameter dictionary with bool's as strings
+ """
+
+ d2 = d.copy() # shallow copy is fine by assumption of d being shallow
+ def jsonify_bool(boolean_value):
+ if boolean_value:
+ return "true"
+ else:
+ return "false"
+
+ for key in d2.keys():
+ if isinstance(d2[key], bool):
+ d2[key] = jsonify_bool(d2[key])
+ return d2
+
+
+def build_hyperparameter_dict(flags):
+ """Simple script for saving hyper parameters. Under the hood the
+ flags structure isn't a dictionary, so it has to be simplified since we
+ want to be able to view file as text.
+
+ Args:
+ flags: From tf.app.flags
+
+ Returns:
+ dictionary of hyper parameters (ignoring other flag types).
+ """
+ d = {}
+ # Data
+ d['output_dist'] = flags.output_dist
+ d['data_dir'] = flags.data_dir
+ d['lfads_save_dir'] = flags.lfads_save_dir
+ d['checkpoint_pb_load_name'] = flags.checkpoint_pb_load_name
+ d['checkpoint_name'] = flags.checkpoint_name
+ d['output_filename_stem'] = flags.output_filename_stem
+ d['max_ckpt_to_keep'] = flags.max_ckpt_to_keep
+ d['max_ckpt_to_keep_lve'] = flags.max_ckpt_to_keep_lve
+ d['ps_nexamples_to_process'] = flags.ps_nexamples_to_process
+ d['ext_input_dim'] = flags.ext_input_dim
+ d['data_filename_stem'] = flags.data_filename_stem
+ d['device'] = flags.device
+ d['csv_log'] = flags.csv_log
+ d['num_steps_for_gen_ic'] = flags.num_steps_for_gen_ic
+ d['inject_ext_input_to_gen'] = flags.inject_ext_input_to_gen
+ # Cell
+ d['cell_weight_scale'] = flags.cell_weight_scale
+ # Generation
+ d['ic_dim'] = flags.ic_dim
+ d['factors_dim'] = flags.factors_dim
+ d['ic_enc_dim'] = flags.ic_enc_dim
+ d['gen_dim'] = flags.gen_dim
+ d['gen_cell_input_weight_scale'] = flags.gen_cell_input_weight_scale
+ d['gen_cell_rec_weight_scale'] = flags.gen_cell_rec_weight_scale
+ # KL distributions
+ d['ic_prior_var_min'] = flags.ic_prior_var_min
+ d['ic_prior_var_scale'] = flags.ic_prior_var_scale
+ d['ic_prior_var_max'] = flags.ic_prior_var_max
+ d['ic_post_var_min'] = flags.ic_post_var_min
+ d['co_prior_var_scale'] = flags.co_prior_var_scale
+ d['prior_ar_atau'] = flags.prior_ar_atau
+ d['prior_ar_nvar'] = flags.prior_ar_nvar
+ d['do_train_prior_ar_atau'] = flags.do_train_prior_ar_atau
+ d['do_train_prior_ar_nvar'] = flags.do_train_prior_ar_nvar
+ # Controller
+ d['do_causal_controller'] = flags.do_causal_controller
+ d['controller_input_lag'] = flags.controller_input_lag
+ d['do_feed_factors_to_controller'] = flags.do_feed_factors_to_controller
+ d['feedback_factors_or_rates'] = flags.feedback_factors_or_rates
+ d['co_dim'] = flags.co_dim
+ d['ci_enc_dim'] = flags.ci_enc_dim
+ d['con_dim'] = flags.con_dim
+ d['co_mean_corr_scale'] = flags.co_mean_corr_scale
+ # Optimization
+ d['batch_size'] = flags.batch_size
+ d['learning_rate_init'] = flags.learning_rate_init
+ d['learning_rate_decay_factor'] = flags.learning_rate_decay_factor
+ d['learning_rate_stop'] = flags.learning_rate_stop
+ d['learning_rate_n_to_compare'] = flags.learning_rate_n_to_compare
+ d['max_grad_norm'] = flags.max_grad_norm
+ d['cell_clip_value'] = flags.cell_clip_value
+ d['do_train_io_only'] = flags.do_train_io_only
+ d['do_train_encoder_only'] = flags.do_train_encoder_only
+ d['do_reset_learning_rate'] = flags.do_reset_learning_rate
+ d['do_train_readin'] = flags.do_train_readin
+
+ # Overfitting
+ d['keep_prob'] = flags.keep_prob
+ d['temporal_spike_jitter_width'] = flags.temporal_spike_jitter_width
+ d['l2_gen_scale'] = flags.l2_gen_scale
+ d['l2_con_scale'] = flags.l2_con_scale
+ # Underfitting
+ d['kl_ic_weight'] = flags.kl_ic_weight
+ d['kl_co_weight'] = flags.kl_co_weight
+ d['kl_start_step'] = flags.kl_start_step
+ d['kl_increase_steps'] = flags.kl_increase_steps
+ d['l2_start_step'] = flags.l2_start_step
+ d['l2_increase_steps'] = flags.l2_increase_steps
+ d['_clip_value'] = 80 # bounds the tf.exp to avoid INF
+
+ return d
+
+
+class hps_dict_to_obj(dict):
+ """Helper class allowing us to access hps dictionary more easily."""
+
+ def __getattr__(self, key):
+ if key in self:
+ return self[key]
+ else:
+ assert False, ("%s does not exist." % key)
+ def __setattr__(self, key, value):
+ self[key] = value
+
+
+def train(hps, datasets):
+ """Train the LFADS model.
+
+ Args:
+ hps: The dictionary of hyperparameters.
+ datasets: A dictionary of data dictionaries. The dataset dict is simply a
+ name(string)-> data dictionary mapping (See top of lfads.py).
+ """
+ model = build_model(hps, kind="train", datasets=datasets)
+ if hps.do_reset_learning_rate:
+ sess = tf.get_default_session()
+ sess.run(model.learning_rate.initializer)
+
+ model.train_model(datasets)
+
+
+def write_model_runs(hps, datasets, output_fname=None, push_mean=False):
+ """Run the model on the data in data_dict, and save the computed values.
+
+ LFADS generates a number of outputs for each examples, and these are all
+ saved. They are:
+ The mean and variance of the prior of g0.
+ The mean and variance of approximate posterior of g0.
+ The control inputs (if enabled)
+ The initial conditions, g0, for all examples.
+ The generator states for all time.
+ The factors for all time.
+ The rates for all time.
+
+ Args:
+ hps: The dictionary of hyperparameters.
+ datasets: A dictionary of data dictionaries. The dataset dict is simply a
+ name(string)-> data dictionary mapping (See top of lfads.py).
+ output_fname (optional): output filename stem to write the model runs.
+ push_mean: if False (default), generates batch_size samples for each trial
+ and averages the results. if True, runs each trial once without noise,
+ pushing the posterior mean initial conditions and control inputs through
+ the trained model. False is used for posterior_sample_and_average, True
+ is used for posterior_push_mean.
+ """
+ model = build_model(hps, kind=hps.kind, datasets=datasets)
+ model.write_model_runs(datasets, output_fname, push_mean)
+
+
+def write_model_samples(hps, datasets, dataset_name=None, output_fname=None):
+ """Use the prior distribution to generate samples from the model.
+ Generates batch_size number of samples (set through FLAGS).
+
+ LFADS generates a number of outputs for each examples, and these are all
+ saved. They are:
+ The mean and variance of the prior of g0.
+ The control inputs (if enabled)
+ The initial conditions, g0, for all examples.
+ The generator states for all time.
+ The factors for all time.
+ The output distribution parameters (e.g. rates) for all time.
+
+ Args:
+ hps: The dictionary of hyperparameters.
+ datasets: A dictionary of data dictionaries. The dataset dict is simply a
+ name(string)-> data dictionary mapping (See top of lfads.py).
+ dataset_name: The name of the dataset to grab the factors -> rates
+ alignment matrices from. Only a concern with models trained on
+ multi-session data. By default, uses the first dataset in the data dict.
+ output_fname: The name prefix of the file in which to save the generated
+ samples.
+ """
+ if not output_fname:
+ output_fname = "model_runs_" + hps.kind
+ else:
+ output_fname = output_fname + "model_runs_" + hps.kind
+ if not dataset_name:
+ dataset_name = datasets.keys()[0]
+ else:
+ if dataset_name not in datasets.keys():
+ raise ValueError("Invalid dataset name '%s'."%(dataset_name))
+ model = build_model(hps, kind=hps.kind, datasets=datasets)
+ model.write_model_samples(dataset_name, output_fname)
+
+
+def write_model_parameters(hps, output_fname=None, datasets=None):
+ """Save all the model parameters
+
+ Save all the parameters to hps.lfads_save_dir.
+
+ Args:
+ hps: The dictionary of hyperparameters.
+ output_fname: The prefix of the file in which to save the generated
+ samples.
+ datasets: A dictionary of data dictionaries. The dataset dict is simply a
+ name(string)-> data dictionary mapping (See top of lfads.py).
+ """
+ if not output_fname:
+ output_fname = "model_params"
+ else:
+ output_fname = output_fname + "_model_params"
+ fname = os.path.join(hps.lfads_save_dir, output_fname)
+ print("Writing model parameters to: ", fname)
+ # save the optimizer params as well
+ model = build_model(hps, kind="write_model_params", datasets=datasets)
+ model_params = model.eval_model_parameters(use_nested=False,
+ include_strs="LFADS")
+ utils.write_data(fname, model_params, compression=None)
+ print("Done.")
+
+
+def clean_data_dict(data_dict):
+ """Add some key/value pairs to the data dict, if they are missing.
+ Args:
+ data_dict - dictionary containing data for LFADS
+ Returns:
+ data_dict with some keys filled in, if they are absent.
+ """
+
+ keys = ['train_truth', 'train_ext_input', 'valid_data',
+ 'valid_truth', 'valid_ext_input', 'valid_train']
+ for k in keys:
+ if k not in data_dict:
+ data_dict[k] = None
+
+ return data_dict
+
+
+def load_datasets(data_dir, data_filename_stem):
+ """Load the datasets from a specified directory.
+
+ Example files look like
+ >data_dir/my_dataset_first_day
+ >data_dir/my_dataset_second_day
+
+ If my_dataset (filename) stem is in the directory, the read routine will try
+ and load it. The datasets dictionary will then look like
+ dataset['first_day'] -> (first day data dictionary)
+ dataset['second_day'] -> (first day data dictionary)
+
+ Args:
+ data_dir: The directory from which to load the datasets.
+ data_filename_stem: The stem of the filename for the datasets.
+
+ Returns:
+ datasets: a dataset dictionary, with one name->data dictionary pair for
+ each dataset file.
+ """
+ print("Reading data from ", data_dir)
+ datasets = utils.read_datasets(data_dir, data_filename_stem)
+ for k, data_dict in datasets.items():
+ datasets[k] = clean_data_dict(data_dict)
+
+ train_total_size = len(data_dict['train_data'])
+ if train_total_size == 0:
+ print("Did not load training set.")
+ else:
+ print("Found training set with number examples: ", train_total_size)
+
+ valid_total_size = len(data_dict['valid_data'])
+ if valid_total_size == 0:
+ print("Did not load validation set.")
+ else:
+ print("Found validation set with number examples: ", valid_total_size)
+
+ return datasets
+
+
+def main(_):
+ """Get this whole shindig off the ground."""
+ d = build_hyperparameter_dict(FLAGS)
+ hps = hps_dict_to_obj(d) # hyper parameters
+ kind = FLAGS.kind
+
+ # Read the data, if necessary.
+ train_set = valid_set = None
+ if kind in ["train", "posterior_sample_and_average", "posterior_push_mean",
+ "prior_sample", "write_model_params"]:
+ datasets = load_datasets(hps.data_dir, hps.data_filename_stem)
+ else:
+ raise ValueError('Kind {} is not supported.'.format(kind))
+
+ # infer the dataset names and dataset dimensions from the loaded files
+ hps.kind = kind # needs to be added here, cuz not saved as hyperparam
+ hps.dataset_names = []
+ hps.dataset_dims = {}
+ for key in datasets:
+ hps.dataset_names.append(key)
+ hps.dataset_dims[key] = datasets[key]['data_dim']
+
+ # also store down the dimensionality of the data
+ # - just pull from one set, required to be same for all sets
+ hps.num_steps = datasets.values()[0]['num_steps']
+ hps.ndatasets = len(hps.dataset_names)
+
+ if hps.num_steps_for_gen_ic > hps.num_steps:
+ hps.num_steps_for_gen_ic = hps.num_steps
+
+ # Build and run the model, for varying purposes.
+ config = tf.ConfigProto(allow_soft_placement=True,
+ log_device_placement=False)
+ if FLAGS.allow_gpu_growth:
+ config.gpu_options.allow_growth = True
+ sess = tf.Session(config=config)
+ with sess.as_default():
+ with tf.device(hps.device):
+ if kind == "train":
+ train(hps, datasets)
+ elif kind == "posterior_sample_and_average":
+ write_model_runs(hps, datasets, hps.output_filename_stem,
+ push_mean=False)
+ elif kind == "posterior_push_mean":
+ write_model_runs(hps, datasets, hps.output_filename_stem,
+ push_mean=True)
+ elif kind == "prior_sample":
+ write_model_samples(hps, datasets, hps.output_filename_stem)
+ elif kind == "write_model_params":
+ write_model_parameters(hps, hps.output_filename_stem, datasets)
+ else:
+ assert False, ("Kind %s is not implemented. " % kind)
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/models/research/lfads/synth_data/generate_chaotic_rnn_data.py b/models/research/lfads/synth_data/generate_chaotic_rnn_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3de72e58b2208eacf508e6048d3fb6d66bf2e167
--- /dev/null
+++ b/models/research/lfads/synth_data/generate_chaotic_rnn_data.py
@@ -0,0 +1,200 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+from __future__ import print_function
+
+import h5py
+import numpy as np
+import os
+import tensorflow as tf # used for flags here
+
+from utils import write_datasets
+from synthetic_data_utils import add_alignment_projections, generate_data
+from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
+from synthetic_data_utils import nparray_and_transpose
+from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds
+import matplotlib
+import matplotlib.pyplot as plt
+import scipy.signal
+
+matplotlib.rcParams['image.interpolation'] = 'nearest'
+DATA_DIR = "rnn_synth_data_v1.0"
+
+flags = tf.app.flags
+flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
+ "Directory for saving data.")
+flags.DEFINE_string("datafile_name", "thits_data",
+ "Name of data file for input case.")
+flags.DEFINE_string("noise_type", "poisson", "Noise type for data.")
+flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
+flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
+flags.DEFINE_integer("C", 100, "Number of conditions")
+flags.DEFINE_integer("N", 50, "Number of units for the RNN")
+flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
+flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
+flags.DEFINE_float("train_percentage", 4.0/5.0,
+ "Percentage of train vs validation trials")
+flags.DEFINE_integer("nreplications", 40,
+ "Number of noise replications of the same underlying rates.")
+flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
+flags.DEFINE_float("x0_std", 1.0,
+ "Volume from which to pull initial conditions (affects diversity of dynamics.")
+flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
+flags.DEFINE_float("dt", 0.010, "Time bin")
+flags.DEFINE_float("input_magnitude", 20.0,
+ "For the input case, what is the value of the input?")
+flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
+FLAGS = flags.FLAGS
+
+
+# Note that with N small, (as it is 25 above), the finite size effects
+# will have pretty dramatic effects on the dynamics of the random RNN.
+# If you want more complex dynamics, you'll have to run the script a
+# lot, or increase N (or g).
+
+# Getting hard vs. easy data can be a little stochastic, so we set the seed.
+
+# Pull out some commonly used parameters.
+# These are user parameters (configuration)
+rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
+T = FLAGS.T
+C = FLAGS.C
+N = FLAGS.N
+S = FLAGS.S
+input_magnitude = FLAGS.input_magnitude
+nreplications = FLAGS.nreplications
+E = nreplications * C # total number of trials
+# S is the number of measurements in each datasets, w/ each
+# dataset having a different set of observations.
+ndatasets = N/S # ok if rounded down
+train_percentage = FLAGS.train_percentage
+ntime_steps = int(T / FLAGS.dt)
+# End of user parameters
+
+rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
+
+# Check to make sure the RNN is the one we used in the paper.
+if N == 50:
+ assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
+ rem_check = nreplications * train_percentage
+ assert abs(rem_check - int(rem_check)) < 1e-8, \
+ 'Train percentage * nreplications should be integral number.'
+
+
+# Initial condition generation, and condition label generation. This
+# happens outside of the dataset loop, so that all datasets have the
+# same conditions, which is similar to a neurophys setup.
+condition_number = 0
+x0s = []
+condition_labels = []
+for c in range(C):
+ x0 = FLAGS.x0_std * rng.randn(N, 1)
+ x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times
+ # replicate the condition label nreplications times
+ for ns in range(nreplications):
+ condition_labels.append(condition_number)
+ condition_number += 1
+x0s = np.concatenate(x0s, axis=1)
+
+# Containers for storing data across data.
+datasets = {}
+for n in range(ndatasets):
+ print(n+1, " of ", ndatasets)
+
+ # First generate all firing rates. in the next loop, generate all
+ # replications this allows the random state for rate generation to be
+ # independent of n_replications.
+ dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
+ if S < N:
+ dataset_name += '_n' + str(n+1)
+
+ # Sample neuron subsets. The assumption is the PC axes of the RNN
+ # are not unit aligned, so sampling units is adequate to sample all
+ # the high-variance PCs.
+ P_sxn = np.eye(S,N)
+ for m in range(n):
+ P_sxn = np.roll(P_sxn, S, axis=1)
+
+ if input_magnitude > 0.0:
+ # time of "hits" randomly chosen between [1/4 and 3/4] of total time
+ input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
+ else:
+ input_times = None
+
+ rates, x0s, inputs = \
+ generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
+ input_magnitude=input_magnitude,
+ input_times=input_times)
+
+ if FLAGS.noise_type == "poisson":
+ noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
+ elif FLAGS.noise_type == "gaussian":
+ noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
+ else:
+ raise ValueError("Only noise types supported are poisson or gaussian")
+
+ # split into train and validation sets
+ train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
+ nreplications)
+
+ # Split the data, inputs, labels and times into train vs. validation.
+ rates_train, rates_valid = \
+ split_list_by_inds(rates, train_inds, valid_inds)
+ noisy_data_train, noisy_data_valid = \
+ split_list_by_inds(noisy_data, train_inds, valid_inds)
+ input_train, inputs_valid = \
+ split_list_by_inds(inputs, train_inds, valid_inds)
+ condition_labels_train, condition_labels_valid = \
+ split_list_by_inds(condition_labels, train_inds, valid_inds)
+ input_times_train, input_times_valid = \
+ split_list_by_inds(input_times, train_inds, valid_inds)
+
+ # Turn rates, noisy_data, and input into numpy arrays.
+ rates_train = nparray_and_transpose(rates_train)
+ rates_valid = nparray_and_transpose(rates_valid)
+ noisy_data_train = nparray_and_transpose(noisy_data_train)
+ noisy_data_valid = nparray_and_transpose(noisy_data_valid)
+ input_train = nparray_and_transpose(input_train)
+ inputs_valid = nparray_and_transpose(inputs_valid)
+
+ # Note that we put these 'truth' rates and input into this
+ # structure, the only data that is used in LFADS are the noisy
+ # data e.g. spike trains. The rest is either for printing or posterity.
+ data = {'train_truth': rates_train,
+ 'valid_truth': rates_valid,
+ 'input_train_truth' : input_train,
+ 'input_valid_truth' : inputs_valid,
+ 'train_data' : noisy_data_train,
+ 'valid_data' : noisy_data_valid,
+ 'train_percentage' : train_percentage,
+ 'nreplications' : nreplications,
+ 'dt' : rnn['dt'],
+ 'input_magnitude' : input_magnitude,
+ 'input_times_train' : input_times_train,
+ 'input_times_valid' : input_times_valid,
+ 'P_sxn' : P_sxn,
+ 'condition_labels_train' : condition_labels_train,
+ 'condition_labels_valid' : condition_labels_valid,
+ 'conversion_factor': 1.0 / rnn['conversion_factor']}
+ datasets[dataset_name] = data
+
+if S < N:
+ # Note that this isn't necessary for this synthetic example, but
+ # it's useful to see how the input factor matrices were initialized
+ # for actual neurophysiology data.
+ datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
+
+# Write out the datasets.
+write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
diff --git a/models/research/lfads/synth_data/generate_itb_data.py b/models/research/lfads/synth_data/generate_itb_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bc45d02e962915eb4be09d41da3162763ad40c
--- /dev/null
+++ b/models/research/lfads/synth_data/generate_itb_data.py
@@ -0,0 +1,209 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+from __future__ import print_function
+
+import h5py
+import numpy as np
+import os
+from six.moves import xrange
+import tensorflow as tf
+
+from utils import write_datasets
+from synthetic_data_utils import normalize_rates
+from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose
+from synthetic_data_utils import spikify_data, split_list_by_inds
+
+DATA_DIR = "rnn_synth_data_v1.0"
+
+flags = tf.app.flags
+flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
+ "Directory for saving data.")
+flags.DEFINE_string("datafile_name", "itb_rnn",
+ "Name of data file for input case.")
+flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
+flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
+flags.DEFINE_integer("C", 800, "Number of conditions")
+flags.DEFINE_integer("N", 50, "Number of units for the RNN")
+flags.DEFINE_float("train_percentage", 4.0/5.0,
+ "Percentage of train vs validation trials")
+flags.DEFINE_integer("nreplications", 5,
+ "Number of spikifications of the same underlying rates.")
+flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
+flags.DEFINE_float("dt", 0.010, "Time bin")
+flags.DEFINE_float("max_firing_rate", 30.0,
+ "Map 1.0 of RNN to a spikes per second")
+flags.DEFINE_float("u_std", 0.25,
+ "Std dev of input to integration to bound model")
+flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT",
+ """Path to directory with checkpoints of model
+ trained on integration to bound task. Currently this
+ is a placeholder which tells the code to grab the
+ checkpoint that is provided with the code
+ (in /trained_itb/..). If you have your own checkpoint
+ you would like to restore, you would point it to
+ that path.""")
+FLAGS = flags.FLAGS
+
+
+class IntegrationToBoundModel:
+ def __init__(self, N):
+ scale = 0.8 / float(N**0.5)
+ self.N = N
+ self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale))
+ self.b_1xn = tf.Variable(tf.zeros([1, N]))
+ self.Bu_1xn = tf.Variable(tf.zeros([1, N]))
+ self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale))
+ self.bro_o = tf.Variable(tf.zeros([1]))
+
+ def call(self, h_tm1_bxn, u_bx1):
+ act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn
+ h_t_bxn = tf.nn.tanh(act_t_bxn)
+ z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o)
+ return z_t, h_t_bxn
+
+def get_data_batch(batch_size, T, rng, u_std):
+ u_bxt = rng.randn(batch_size, T) * u_std
+ running_sum_b = np.zeros([batch_size])
+ labels_bxt = np.zeros([batch_size, T])
+ for t in xrange(T):
+ running_sum_b += u_bxt[:, t]
+ labels_bxt[:, t] += running_sum_b
+ labels_bxt = np.clip(labels_bxt, -1, 1)
+ return u_bxt, labels_bxt
+
+
+rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
+u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
+T = FLAGS.T
+C = FLAGS.C
+N = FLAGS.N # must be same N as in trained model (provided example is N = 50)
+nreplications = FLAGS.nreplications
+E = nreplications * C # total number of trials
+train_percentage = FLAGS.train_percentage
+ntimesteps = int(T / FLAGS.dt)
+batch_size = 1 # gives one example per ntrial
+
+model = IntegrationToBoundModel(N)
+inputs_ph_t = [tf.placeholder(tf.float32,
+ shape=[None, 1]) for _ in range(ntimesteps)]
+state = tf.zeros([batch_size, N])
+saver = tf.train.Saver()
+
+P_nxn = rng.randn(N,N) / np.sqrt(N) # random projections
+
+# unroll RNN for T timesteps
+outputs_t = []
+states_t = []
+
+for inp in inputs_ph_t:
+ output, state = model.call(state, inp)
+ outputs_t.append(output)
+ states_t.append(state)
+
+with tf.Session() as sess:
+ # restore the latest model ckpt
+ if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT":
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000")
+ else:
+ model_checkpoint_path = FLAGS.checkpoint_path
+ try:
+ saver.restore(sess, model_checkpoint_path)
+ print ('Model restored from', model_checkpoint_path)
+ except:
+ assert False, ("No checkpoints to restore from, is the path %s correct?"
+ %model_checkpoint_path)
+
+ # generate data for trials
+ data_e = []
+ u_e = []
+ outs_e = []
+ for c in range(C):
+ u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std)
+
+ feed_dict = {}
+ for t in xrange(ntimesteps):
+ feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))
+
+ states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
+ feed_dict=feed_dict)
+ states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
+ outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
+ r_sxt = np.dot(P_nxn, states_nxt)
+
+ for s in xrange(nreplications):
+ data_e.append(r_sxt)
+ u_e.append(u_1xt)
+ outs_e.append(outputs_t_bxn)
+
+ truth_data_e = normalize_rates(data_e, E, N)
+
+spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
+ max_firing_rate=FLAGS.max_firing_rate)
+train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
+ nreplications)
+
+data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
+ train_inds,
+ valid_inds)
+data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
+ train_inds,
+ valid_inds)
+
+data_train_truth = nparray_and_transpose(data_train_truth)
+data_valid_truth = nparray_and_transpose(data_valid_truth)
+data_train_spiking = nparray_and_transpose(data_train_spiking)
+data_valid_spiking = nparray_and_transpose(data_valid_spiking)
+
+# save down the inputs used to generate this data
+train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
+ train_inds,
+ valid_inds)
+train_inputs_u = nparray_and_transpose(train_inputs_u)
+valid_inputs_u = nparray_and_transpose(valid_inputs_u)
+
+# save down the network outputs (may be useful later)
+train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
+ train_inds,
+ valid_inds)
+train_outputs_u = np.array(train_outputs_u)
+valid_outputs_u = np.array(valid_outputs_u)
+
+
+data = { 'train_truth': data_train_truth,
+ 'valid_truth': data_valid_truth,
+ 'train_data' : data_train_spiking,
+ 'valid_data' : data_valid_spiking,
+ 'train_percentage' : train_percentage,
+ 'nreplications' : nreplications,
+ 'dt' : FLAGS.dt,
+ 'u_std' : FLAGS.u_std,
+ 'max_firing_rate': FLAGS.max_firing_rate,
+ 'train_inputs_u': train_inputs_u,
+ 'valid_inputs_u': valid_inputs_u,
+ 'train_outputs_u': train_outputs_u,
+ 'valid_outputs_u': valid_outputs_u,
+ 'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) }
+
+# just one dataset here
+datasets = {}
+dataset_name = 'dataset_N' + str(N)
+datasets[dataset_name] = data
+
+# write out the dataset
+write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
+print ('Saved to ', os.path.join(FLAGS.save_dir,
+ FLAGS.datafile_name + '_' + dataset_name))
diff --git a/models/research/lfads/synth_data/generate_labeled_rnn_data.py b/models/research/lfads/synth_data/generate_labeled_rnn_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..0695585486534428c77e328e7ee1de755292d6c0
--- /dev/null
+++ b/models/research/lfads/synth_data/generate_labeled_rnn_data.py
@@ -0,0 +1,147 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+from __future__ import print_function
+
+import os
+import h5py
+import numpy as np
+from six.moves import xrange
+
+from synthetic_data_utils import generate_data, generate_rnn
+from synthetic_data_utils import get_train_n_valid_inds
+from synthetic_data_utils import nparray_and_transpose
+from synthetic_data_utils import spikify_data, split_list_by_inds
+import tensorflow as tf
+from utils import write_datasets
+
+DATA_DIR = "rnn_synth_data_v1.0"
+
+flags = tf.app.flags
+flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
+ "Directory for saving data.")
+flags.DEFINE_string("datafile_name", "conditioned_rnn_data",
+ "Name of data file for input case.")
+flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
+flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
+flags.DEFINE_integer("C", 400, "Number of conditions")
+flags.DEFINE_integer("N", 50, "Number of units for the RNN")
+flags.DEFINE_float("train_percentage", 4.0/5.0,
+ "Percentage of train vs validation trials")
+flags.DEFINE_integer("nreplications", 10,
+ "Number of spikifications of the same underlying rates.")
+flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
+flags.DEFINE_float("x0_std", 1.0,
+ "Volume from which to pull initial conditions (affects diversity of dynamics.")
+flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
+flags.DEFINE_float("dt", 0.010, "Time bin")
+flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
+FLAGS = flags.FLAGS
+
+rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
+rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
+ np.random.RandomState(seed=FLAGS.synth_data_seed+2)]
+T = FLAGS.T
+C = FLAGS.C
+N = FLAGS.N
+nreplications = FLAGS.nreplications
+E = nreplications * C
+train_percentage = FLAGS.train_percentage
+ntimesteps = int(T / FLAGS.dt)
+
+rnn_a = generate_rnn(rnn_rngs[0], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
+ FLAGS.max_firing_rate)
+rnn_b = generate_rnn(rnn_rngs[1], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
+ FLAGS.max_firing_rate)
+rnns = [rnn_a, rnn_b]
+
+# pick which RNN is used on each trial
+rnn_to_use = rng.randint(2, size=E)
+ext_input = np.repeat(np.expand_dims(rnn_to_use, axis=1), ntimesteps, axis=1)
+ext_input = np.expand_dims(ext_input, axis=2) # these are "a's" in the paper
+
+x0s = []
+condition_labels = []
+condition_number = 0
+for c in range(C):
+ x0 = FLAGS.x0_std * rng.randn(N, 1)
+ x0s.append(np.tile(x0, nreplications))
+ for ns in range(nreplications):
+ condition_labels.append(condition_number)
+ condition_number += 1
+x0s = np.concatenate(x0s, axis=1)
+
+P_nxn = rng.randn(N, N) / np.sqrt(N)
+
+# generate trials for both RNNs
+rates_a, x0s_a, _ = generate_data(rnn_a, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
+ input_magnitude=0.0, input_times=None)
+spikes_a = spikify_data(rates_a, rng, rnn_a['dt'], rnn_a['max_firing_rate'])
+
+rates_b, x0s_b, _ = generate_data(rnn_b, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
+ input_magnitude=0.0, input_times=None)
+spikes_b = spikify_data(rates_b, rng, rnn_b['dt'], rnn_b['max_firing_rate'])
+
+# not the best way to do this but E is small enough
+rates = []
+spikes = []
+for trial in xrange(E):
+ if rnn_to_use[trial] == 0:
+ rates.append(rates_a[trial])
+ spikes.append(spikes_a[trial])
+ else:
+ rates.append(rates_b[trial])
+ spikes.append(spikes_b[trial])
+
+# split into train and validation sets
+train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
+ nreplications)
+
+rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
+spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
+condition_labels_train, condition_labels_valid = split_list_by_inds(
+ condition_labels, train_inds, valid_inds)
+ext_input_train, ext_input_valid = split_list_by_inds(
+ ext_input, train_inds, valid_inds)
+
+rates_train = nparray_and_transpose(rates_train)
+rates_valid = nparray_and_transpose(rates_valid)
+spikes_train = nparray_and_transpose(spikes_train)
+spikes_valid = nparray_and_transpose(spikes_valid)
+
+# add train_ext_input and valid_ext input
+data = {'train_truth': rates_train,
+ 'valid_truth': rates_valid,
+ 'train_data' : spikes_train,
+ 'valid_data' : spikes_valid,
+ 'train_ext_input' : np.array(ext_input_train),
+ 'valid_ext_input': np.array(ext_input_valid),
+ 'train_percentage' : train_percentage,
+ 'nreplications' : nreplications,
+ 'dt' : FLAGS.dt,
+ 'P_sxn' : P_nxn,
+ 'condition_labels_train' : condition_labels_train,
+ 'condition_labels_valid' : condition_labels_valid,
+ 'conversion_factor': 1.0 / rnn_a['conversion_factor']}
+
+# just one dataset here
+datasets = {}
+dataset_name = 'dataset_N' + str(N)
+datasets[dataset_name] = data
+
+# write out the dataset
+write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
+print ('Saved to ', os.path.join(FLAGS.save_dir,
+ FLAGS.datafile_name + '_' + dataset_name))
diff --git a/models/research/lfads/synth_data/run_generate_synth_data.sh b/models/research/lfads/synth_data/run_generate_synth_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9ebc8ce2e5eec1e21fd839db18f247b38ebfde38
--- /dev/null
+++ b/models/research/lfads/synth_data/run_generate_synth_data.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
+
+ echo "Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
+ python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
+
+echo "Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
+python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=gaussian_chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='gaussian'
+
+ echo "Generating chaotic rnn data with input pulses (g=1.5)"
+ python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
+
+ echo "Generating chaotic rnn data with input pulses (g=2.5)"
+ python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
+
+ echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
+ python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nreplications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
+
+ echo "Generating Integration-to-bound RNN data"
+ python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nreplications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
+
+ echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
+ python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
diff --git a/models/research/lfads/synth_data/synthetic_data_utils.py b/models/research/lfads/synth_data/synthetic_data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc264ee49fdc7fbb53f17d52ca4ced64addefb27
--- /dev/null
+++ b/models/research/lfads/synth_data/synthetic_data_utils.py
@@ -0,0 +1,348 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+from __future__ import print_function
+
+import h5py
+import numpy as np
+import os
+
+from utils import write_datasets
+import matplotlib
+import matplotlib.pyplot as plt
+import scipy.signal
+
+
+def generate_rnn(rng, N, g, tau, dt, max_firing_rate):
+ """Create a (vanilla) RNN with a bunch of hyper parameters for generating
+chaotic data.
+ Args:
+ rng: numpy random number generator
+ N: number of hidden units
+ g: scaling of recurrent weight matrix in g W, with W ~ N(0,1/N)
+ tau: time scale of individual unit dynamics
+ dt: time step for equation updates
+ max_firing_rate: how to resecale the -1,1 firing rates
+ Returns:
+ the dictionary of these parameters, plus some others.
+"""
+ rnn = {}
+ rnn['N'] = N
+ rnn['W'] = rng.randn(N,N)/np.sqrt(N)
+ rnn['Bin'] = rng.randn(N)/np.sqrt(1.0)
+ rnn['Bin2'] = rng.randn(N)/np.sqrt(1.0)
+ rnn['b'] = np.zeros(N)
+ rnn['g'] = g
+ rnn['tau'] = tau
+ rnn['dt'] = dt
+ rnn['max_firing_rate'] = max_firing_rate
+ mfr = rnn['max_firing_rate'] # spikes / sec
+ nbins_per_sec = 1.0/rnn['dt'] # bins / sec
+ # Used for plotting in LFADS
+ rnn['conversion_factor'] = mfr / nbins_per_sec # spikes / bin
+ return rnn
+
+
+def generate_data(rnn, T, E, x0s=None, P_sxn=None, input_magnitude=0.0,
+ input_times=None):
+ """ Generates data from an randomly initialized RNN.
+ Args:
+ rnn: the rnn
+ T: Time in seconds to run (divided by rnn['dt'] to get steps, rounded down.
+ E: total number of examples
+ S: number of samples (subsampling N)
+ Returns:
+ A list of length E of NxT tensors of the network being run.
+ """
+ N = rnn['N']
+ def run_rnn(rnn, x0, ntime_steps, input_time=None):
+ rs = np.zeros([N,ntime_steps])
+ x_tm1 = x0
+ r_tm1 = np.tanh(x0)
+ tau = rnn['tau']
+ dt = rnn['dt']
+ alpha = (1.0-dt/tau)
+ W = dt/tau*rnn['W']*rnn['g']
+ Bin = dt/tau*rnn['Bin']
+ Bin2 = dt/tau*rnn['Bin2']
+ b = dt/tau*rnn['b']
+
+ us = np.zeros([1, ntime_steps])
+ for t in range(ntime_steps):
+ x_t = alpha*x_tm1 + np.dot(W,r_tm1) + b
+ if input_time is not None and t == input_time:
+ us[0,t] = input_magnitude
+ x_t += Bin * us[0,t] # DCS is this what was used?
+ r_t = np.tanh(x_t)
+ x_tm1 = x_t
+ r_tm1 = r_t
+ rs[:,t] = r_t
+ return rs, us
+
+ if P_sxn is None:
+ P_sxn = np.eye(N)
+ ntime_steps = int(T / rnn['dt'])
+ data_e = []
+ inputs_e = []
+ for e in range(E):
+ input_time = input_times[e] if input_times is not None else None
+ r_nxt, u_uxt = run_rnn(rnn, x0s[:,e], ntime_steps, input_time)
+ r_sxt = np.dot(P_sxn, r_nxt)
+ inputs_e.append(u_uxt)
+ data_e.append(r_sxt)
+
+ S = P_sxn.shape[0]
+ data_e = normalize_rates(data_e, E, S)
+
+ return data_e, x0s, inputs_e
+
+
+def normalize_rates(data_e, E, S):
+ # Normalization, made more complex because of the P matrices.
+ # Normalize by min and max in each channel. This normalization will
+ # cause offset differences between identical rnn runs, but different
+ # t hits.
+ for e in range(E):
+ r_sxt = data_e[e]
+ for i in range(S):
+ rmin = np.min(r_sxt[i,:])
+ rmax = np.max(r_sxt[i,:])
+ assert rmax - rmin != 0, 'Something wrong'
+ r_sxt[i,:] = (r_sxt[i,:] - rmin)/(rmax-rmin)
+ data_e[e] = r_sxt
+ return data_e
+
+
+def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
+ """ Apply spikes to a continuous dataset whose values are between 0.0 and 1.0
+ Args:
+ data_e: nexamples length list of NxT trials
+ dt: how often the data are sampled
+ max_firing_rate: the firing rate that is associated with a value of 1.0
+ Returns:
+ spikified_e: a list of length b of the data represented as spikes,
+ sampled from the underlying poisson process.
+ """
+
+ E = len(data_e)
+ spikes_e = []
+ for e in range(E):
+ data = data_e[e]
+ N,T = data.shape
+ data_s = np.zeros([N,T]).astype(np.int)
+ for n in range(N):
+ f = data[n,:]
+ s = rng.poisson(f*max_firing_rate*dt, size=T)
+ data_s[n,:] = s
+ spikes_e.append(data_s)
+
+ return spikes_e
+
+
+def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
+ """ Apply gaussian noise to a continuous dataset whose values are between
+ 0.0 and 1.0
+
+ Args:
+ data_e: nexamples length list of NxT trials
+ dt: how often the data are sampled
+ max_firing_rate: the firing rate that is associated with a value of 1.0
+ Returns:
+ gauss_e: a list of length b of the data with noise.
+ """
+
+ E = len(data_e)
+ mfr = max_firing_rate
+ gauss_e = []
+ for e in range(E):
+ data = data_e[e]
+ N,T = data.shape
+ noisy_data = data * mfr + np.random.randn(N,T) * (5.0*mfr) * np.sqrt(dt)
+ gauss_e.append(noisy_data)
+
+ return gauss_e
+
+
+
+def get_train_n_valid_inds(num_trials, train_fraction, nreplications):
+ """Split the numbers between 0 and num_trials-1 into two portions for
+ training and validation, based on the train fraction.
+ Args:
+ num_trials: the number of trials
+ train_fraction: (e.g. .80)
+ nreplications: the number of spiking trials per initial condition
+ Returns:
+ a 2-tuple of two lists: the training indices and validation indices
+ """
+ train_inds = []
+ valid_inds = []
+ for i in range(num_trials):
+ # This line divides up the trials so that within one initial condition,
+ # the randomness of spikifying the condition is shared among both
+ # training and validation data splits.
+ if (i % nreplications)+1 > train_fraction * nreplications:
+ valid_inds.append(i)
+ else:
+ train_inds.append(i)
+
+ return train_inds, valid_inds
+
+
+def split_list_by_inds(data, inds1, inds2):
+ """Take the data, a list, and split it up based on the indices in inds1 and
+ inds2.
+ Args:
+ data: the list of data to split
+ inds1, the first list of indices
+ inds2, the second list of indices
+ Returns: a 2-tuple of two lists.
+ """
+ if data is None or len(data) == 0:
+ return [], []
+ else:
+ dout1 = [data[i] for i in inds1]
+ dout2 = [data[i] for i in inds2]
+ return dout1, dout2
+
+
+def nparray_and_transpose(data_a_b_c):
+ """Convert the list of items in data to a numpy array, and transpose it
+ Args:
+ data: data_asbsc: a nested, nested list of length a, with sublist length
+ b, with sublist length c.
+ Returns:
+ a numpy 3-tensor with dimensions a x c x b
+"""
+ data_axbxc = np.array([datum_b_c for datum_b_c in data_a_b_c])
+ data_axcxb = np.transpose(data_axbxc, axes=[0,2,1])
+ return data_axcxb
+
+
+def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
+ """Create a matrix that aligns the datasets a bit, under
+ the assumption that each dataset is observing the same underlying dynamical
+ system.
+
+ Args:
+ datasets: The dictionary of dataset structures.
+ npcs: The number of pcs for each, basically like lfads factors.
+ nsamples (optional): Number of samples to take for each dataset.
+ ntime (optional): Number of time steps to take in each sample.
+
+ Returns:
+ The dataset structures, with the field alignment_matrix_cxf added.
+ This is # channels x npcs dimension
+"""
+ nchannels_all = 0
+ channel_idxs = {}
+ conditions_all = {}
+ nconditions_all = 0
+ for name, dataset in datasets.items():
+ cidxs = np.where(dataset['P_sxn'])[1] # non-zero entries in columns
+ channel_idxs[name] = [cidxs[0], cidxs[-1]+1]
+ nchannels_all += cidxs[-1]+1 - cidxs[0]
+ conditions_all[name] = np.unique(dataset['condition_labels_train'])
+
+ all_conditions_list = \
+ np.unique(np.ndarray.flatten(np.array(conditions_all.values())))
+ nconditions_all = all_conditions_list.shape[0]
+
+ if ntime is None:
+ ntime = dataset['train_data'].shape[1]
+ if nsamples is None:
+ nsamples = dataset['train_data'].shape[0]
+
+ # In the data workup in the paper, Chethan did intra condition
+ # averaging, so let's do that here.
+ avg_data_all = {}
+ for name, conditions in conditions_all.items():
+ dataset = datasets[name]
+ avg_data_all[name] = {}
+ for cname in conditions:
+ td_idxs = np.argwhere(np.array(dataset['condition_labels_train'])==cname)
+ data = np.squeeze(dataset['train_data'][td_idxs,:,:], axis=1)
+ avg_data = np.mean(data, axis=0)
+ avg_data_all[name][cname] = avg_data
+
+ # Visualize this in the morning.
+ all_data_nxtc = np.zeros([nchannels_all, ntime * nconditions_all])
+ for name, dataset in datasets.items():
+ cidx_s = channel_idxs[name][0]
+ cidx_f = channel_idxs[name][1]
+ for cname in conditions_all[name]:
+ cidxs = np.argwhere(all_conditions_list == cname)
+ if cidxs.shape[0] > 0:
+ cidx = cidxs[0][0]
+ all_tidxs = np.arange(0, ntime+1) + cidx*ntime
+ all_data_nxtc[cidx_s:cidx_f, all_tidxs[0]:all_tidxs[-1]] = \
+ avg_data_all[name][cname].T
+
+ # A bit of filtering. We don't care about spectral properties, or
+ # filtering artifacts, simply correlate time steps a bit.
+ filt_len = 6
+ bc_filt = np.ones([filt_len])/float(filt_len)
+ for c in range(nchannels_all):
+ all_data_nxtc[c,:] = scipy.signal.filtfilt(bc_filt, [1.0], all_data_nxtc[c,:])
+
+ # Compute the PCs.
+ all_data_mean_nx1 = np.mean(all_data_nxtc, axis=1, keepdims=True)
+ all_data_zm_nxtc = all_data_nxtc - all_data_mean_nx1
+ corr_mat_nxn = np.dot(all_data_zm_nxtc, all_data_zm_nxtc.T)
+ evals_n, evecs_nxn = np.linalg.eigh(corr_mat_nxn)
+ sidxs = np.flipud(np.argsort(evals_n)) # sort such that 0th is highest
+ evals_n = evals_n[sidxs]
+ evecs_nxn = evecs_nxn[:,sidxs]
+
+ # Project all the channels data onto the low-D PCA basis, where
+ # low-d is the npcs parameter.
+ all_data_pca_pxtc = np.dot(evecs_nxn[:, 0:npcs].T, all_data_zm_nxtc)
+
+ # Now for each dataset, we regress the channel data onto the top
+ # pcs, and this will be our alignment matrix for that dataset.
+ # |B - A*W|^2
+ for name, dataset in datasets.items():
+ cidx_s = channel_idxs[name][0]
+ cidx_f = channel_idxs[name][1]
+ all_data_zm_chxtc = all_data_zm_nxtc[cidx_s:cidx_f,:] # ch for channel
+ W_chxp, _, _, _ = \
+ np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
+ dataset['alignment_matrix_cxf'] = W_chxp
+ alignment_bias_cx1 = all_data_mean_nx1[cidx_s:cidx_f]
+ dataset['alignment_bias_c'] = np.squeeze(alignment_bias_cx1, axis=1)
+
+ do_debug_plot = False
+ if do_debug_plot:
+ pc_vecs = evecs_nxn[:,0:npcs]
+ ntoplot = 400
+
+ plt.figure()
+ plt.plot(np.log10(evals_n), '-x')
+ plt.figure()
+ plt.subplot(311)
+ plt.imshow(all_data_pca_pxtc)
+ plt.colorbar()
+
+ plt.subplot(312)
+ plt.imshow(np.dot(W_chxp.T, all_data_zm_chxtc))
+ plt.colorbar()
+
+ plt.subplot(313)
+ plt.imshow(np.dot(all_data_zm_chxtc.T, W_chxp).T - all_data_pca_pxtc)
+ plt.colorbar()
+
+ import pdb
+ pdb.set_trace()
+
+ return datasets
diff --git a/models/research/lfads/synth_data/trained_itb/model-65000.data-00000-of-00001 b/models/research/lfads/synth_data/trained_itb/model-65000.data-00000-of-00001
new file mode 100644
index 0000000000000000000000000000000000000000..9459a2a1b72f56dc16b3eca210911f14081e7fd5
Binary files /dev/null and b/models/research/lfads/synth_data/trained_itb/model-65000.data-00000-of-00001 differ
diff --git a/models/research/lfads/synth_data/trained_itb/model-65000.index b/models/research/lfads/synth_data/trained_itb/model-65000.index
new file mode 100644
index 0000000000000000000000000000000000000000..dd9c793acf8dc79e07833d1c0edc8a2fa86d806a
Binary files /dev/null and b/models/research/lfads/synth_data/trained_itb/model-65000.index differ
diff --git a/models/research/lfads/synth_data/trained_itb/model-65000.meta b/models/research/lfads/synth_data/trained_itb/model-65000.meta
new file mode 100644
index 0000000000000000000000000000000000000000..2b97380f5373b6ffc29629a6653cb229df629e27
--- /dev/null
+++ b/models/research/lfads/synth_data/trained_itb/model-65000.meta
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc43cc94549c645387862920487167dfa19e8ea26e0978ce03f286fa96f7a462
+size 1053549
diff --git a/models/research/lfads/utils.py b/models/research/lfads/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e64825ffc1d423de1d9fe85bc1c00a19e5f4ad7e
--- /dev/null
+++ b/models/research/lfads/utils.py
@@ -0,0 +1,367 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+from __future__ import print_function
+
+import os
+import h5py
+import json
+
+import numpy as np
+import tensorflow as tf
+
+
+def log_sum_exp(x_k):
+ """Computes log \sum exp in a numerically stable way.
+ log ( sum_i exp(x_i) )
+ log ( sum_i exp(x_i - m + m) ), with m = max(x_i)
+ log ( sum_i exp(x_i - m)*exp(m) )
+ log ( sum_i exp(x_i - m) + m
+
+ Args:
+ x_k - k -dimensional list of arguments to log_sum_exp.
+
+ Returns:
+ log_sum_exp of the arguments.
+ """
+ m = tf.reduce_max(x_k)
+ x1_k = x_k - m
+ u_k = tf.exp(x1_k)
+ z = tf.reduce_sum(u_k)
+ return tf.log(z) + m
+
+
+def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
+ normalized=False, name=None, collections=None):
+ """Linear (affine) transformation, y = x W + b, for a variety of
+ configurations.
+
+ Args:
+ x: input The tensor to tranformation.
+ out_size: The integer size of non-batch output dimension.
+ do_bias (optional): Add a learnable bias vector to the operation.
+ alpha (optional): A multiplicative scaling for the weight initialization
+ of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
+ identity_if_possible (optional): just return identity,
+ if x.shape[1] == out_size.
+ normalized (optional): Option to divide out by the norms of the rows of W.
+ name (optional): The name prefix to add to variables.
+ collections (optional): List of additional collections. (Placed in
+ tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
+
+ Returns:
+ In the equation, y = x W + b, returns the tensorflow op that yields y.
+ """
+ in_size = int(x.get_shape()[1]) # from Dimension(10) -> 10
+ stddev = alpha/np.sqrt(float(in_size))
+ mat_init = tf.random_normal_initializer(0.0, stddev)
+ wname = (name + "/W") if name else "/W"
+
+ if identity_if_possible and in_size == out_size:
+ # Sometimes linear layers are nothing more than size adapters.
+ return tf.identity(x, name=(wname+'_ident'))
+
+ W,b = init_linear(in_size, out_size, do_bias=do_bias, alpha=alpha,
+ normalized=normalized, name=name, collections=collections)
+
+ if do_bias:
+ return tf.matmul(x, W) + b
+ else:
+ return tf.matmul(x, W)
+
+
+def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
+ bias_init_value=None, alpha=1.0, identity_if_possible=False,
+ normalized=False, name=None, collections=None, trainable=True):
+ """Linear (affine) transformation, y = x W + b, for a variety of
+ configurations.
+
+ Args:
+ in_size: The integer size of the non-batc input dimension. [(x),y]
+ out_size: The integer size of non-batch output dimension. [x,(y)]
+ do_bias (optional): Add a (learnable) bias vector to the operation,
+ if false, b will be None
+ mat_init_value (optional): numpy constant for matrix initialization, if None
+ , do random, with additional parameters.
+ alpha (optional): A multiplicative scaling for the weight initialization
+ of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
+ identity_if_possible (optional): just return identity,
+ if x.shape[1] == out_size.
+ normalized (optional): Option to divide out by the norms of the rows of W.
+ name (optional): The name prefix to add to variables.
+ collections (optional): List of additional collections. (Placed in
+ tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
+
+ Returns:
+ In the equation, y = x W + b, returns the pair (W, b).
+ """
+
+ if mat_init_value is not None and mat_init_value.shape != (in_size, out_size):
+ raise ValueError(
+ 'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
+ if bias_init_value is not None and bias_init_value.shape != (1,out_size):
+ raise ValueError(
+ 'Provided bias_init_value must have shape [1,%d].'%(out_size,))
+
+ if mat_init_value is None:
+ stddev = alpha/np.sqrt(float(in_size))
+ mat_init = tf.random_normal_initializer(0.0, stddev)
+
+ wname = (name + "/W") if name else "/W"
+
+ if identity_if_possible and in_size == out_size:
+ return (tf.constant(np.eye(in_size).astype(np.float32)),
+ tf.zeros(in_size))
+
+ # Note the use of get_variable vs. tf.Variable. this is because get_variable
+ # does not allow the initialization of the variable with a value.
+ if normalized:
+ w_collections = [tf.GraphKeys.GLOBAL_VARIABLES, "norm-variables"]
+ if collections:
+ w_collections += collections
+ if mat_init_value is not None:
+ w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
+ trainable=trainable)
+ else:
+ w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
+ collections=w_collections, trainable=trainable)
+ w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij
+ else:
+ w_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
+ if collections:
+ w_collections += collections
+ if mat_init_value is not None:
+ w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
+ trainable=trainable)
+ else:
+ w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
+ collections=w_collections, trainable=trainable)
+ b = None
+ if do_bias:
+ b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
+ if collections:
+ b_collections += collections
+ bname = (name + "/b") if name else "/b"
+ if bias_init_value is None:
+ b = tf.get_variable(bname, [1, out_size],
+ initializer=tf.zeros_initializer(),
+ collections=b_collections,
+ trainable=trainable)
+ else:
+ b = tf.Variable(bias_init_value, name=bname,
+ collections=b_collections,
+ trainable=trainable)
+
+ return (w, b)
+
+
+def write_data(data_fname, data_dict, use_json=False, compression=None):
+ """Write data in HD5F format.
+
+ Args:
+ data_fname: The filename of teh file in which to write the data.
+ data_dict: The dictionary of data to write. The keys are strings
+ and the values are numpy arrays.
+ use_json (optional): human readable format for simple items
+ compression (optional): The compression to use for h5py (disabled by
+ default because the library borks on scalars, otherwise try 'gzip').
+ """
+
+ dir_name = os.path.dirname(data_fname)
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
+
+ if use_json:
+ the_file = open(data_fname,'wb')
+ json.dump(data_dict, the_file)
+ the_file.close()
+ else:
+ try:
+ with h5py.File(data_fname, 'w') as hf:
+ for k, v in data_dict.items():
+ clean_k = k.replace('/', '_')
+ if clean_k is not k:
+ print('Warning: saving variable with name: ', k, ' as ', clean_k)
+ else:
+ print('Saving variable with name: ', clean_k)
+ hf.create_dataset(clean_k, data=v, compression=compression)
+ except IOError:
+ print("Cannot open %s for writing.", data_fname)
+ raise
+
+
+def read_data(data_fname):
+ """ Read saved data in HDF5 format.
+
+ Args:
+ data_fname: The filename of the file from which to read the data.
+ Returns:
+ A dictionary whose keys will vary depending on dataset (but should
+ always contain the keys 'train_data' and 'valid_data') and whose
+ values are numpy arrays.
+ """
+
+ try:
+ with h5py.File(data_fname, 'r') as hf:
+ data_dict = {k: np.array(v) for k, v in hf.items()}
+ return data_dict
+ except IOError:
+ print("Cannot open %s for reading." % data_fname)
+ raise
+
+
+def write_datasets(data_path, data_fname_stem, dataset_dict, compression=None):
+ """Write datasets in HD5F format.
+
+ This function assumes the dataset_dict is a mapping ( string ->
+ to data_dict ). It calls write_data for each data dictionary,
+ post-fixing the data filename with the key of the dataset.
+
+ Args:
+ data_path: The path to the save directory.
+ data_fname_stem: The filename stem of the file in which to write the data.
+ dataset_dict: The dictionary of datasets. The keys are strings
+ and the values data dictionaries (str -> numpy arrays) associations.
+ compression (optional): The compression to use for h5py (disabled by
+ default because the library borks on scalars, otherwise try 'gzip').
+ """
+
+ full_name_stem = os.path.join(data_path, data_fname_stem)
+ for s, data_dict in dataset_dict.items():
+ write_data(full_name_stem + "_" + s, data_dict, compression=compression)
+
+
+def read_datasets(data_path, data_fname_stem):
+ """Read dataset sin HD5F format.
+
+ This function assumes the dataset_dict is a mapping ( string ->
+ to data_dict ). It calls write_data for each data dictionary,
+ post-fixing the data filename with the key of the dataset.
+
+ Args:
+ data_path: The path to the save directory.
+ data_fname_stem: The filename stem of the file in which to write the data.
+ """
+
+ dataset_dict = {}
+ fnames = os.listdir(data_path)
+
+ print ('loading data from ' + data_path + ' with stem ' + data_fname_stem)
+ for fname in fnames:
+ if fname.startswith(data_fname_stem):
+ data_dict = read_data(os.path.join(data_path,fname))
+ idx = len(data_fname_stem) + 1
+ key = fname[idx:]
+ data_dict['data_dim'] = data_dict['train_data'].shape[2]
+ data_dict['num_steps'] = data_dict['train_data'].shape[1]
+ dataset_dict[key] = data_dict
+
+ if len(dataset_dict) == 0:
+ raise ValueError("Failed to load any datasets, are you sure that the "
+ "'--data_dir' and '--data_filename_stem' flag values "
+ "are correct?")
+
+ print (str(len(dataset_dict)) + ' datasets loaded')
+ return dataset_dict
+
+
+# NUMPY utility functions
+def list_t_bxn_to_list_b_txn(values_t_bxn):
+ """Convert a length T list of BxN numpy tensors of length B list of TxN numpy
+ tensors.
+
+ Args:
+ values_t_bxn: The length T list of BxN numpy tensors.
+
+ Returns:
+ The length B list of TxN numpy tensors.
+ """
+ T = len(values_t_bxn)
+ B, N = values_t_bxn[0].shape
+ values_b_txn = []
+ for b in range(B):
+ values_pb_txn = np.zeros([T,N])
+ for t in range(T):
+ values_pb_txn[t,:] = values_t_bxn[t][b,:]
+ values_b_txn.append(values_pb_txn)
+
+ return values_b_txn
+
+
+def list_t_bxn_to_tensor_bxtxn(values_t_bxn):
+ """Convert a length T list of BxN numpy tensors to single numpy tensor with
+ shape BxTxN.
+
+ Args:
+ values_t_bxn: The length T list of BxN numpy tensors.
+
+ Returns:
+ values_bxtxn: The BxTxN numpy tensor.
+ """
+
+ T = len(values_t_bxn)
+ B, N = values_t_bxn[0].shape
+ values_bxtxn = np.zeros([B,T,N])
+ for t in range(T):
+ values_bxtxn[:,t,:] = values_t_bxn[t]
+
+ return values_bxtxn
+
+
+def tensor_bxtxn_to_list_t_bxn(tensor_bxtxn):
+ """Convert a numpy tensor with shape BxTxN to a length T list of numpy tensors
+ with shape BxT.
+
+ Args:
+ tensor_bxtxn: The BxTxN numpy tensor.
+
+ Returns:
+ A length T list of numpy tensors with shape BxT.
+ """
+
+ values_t_bxn = []
+ B, T, N = tensor_bxtxn.shape
+ for t in range(T):
+ values_t_bxn.append(np.squeeze(tensor_bxtxn[:,t,:]))
+
+ return values_t_bxn
+
+
+def flatten(list_of_lists):
+ """Takes a list of lists and returns a list of the elements.
+
+ Args:
+ list_of_lists: List of lists.
+
+ Returns:
+ flat_list: Flattened list.
+ flat_list_idxs: Flattened list indices.
+ """
+ flat_list = []
+ flat_list_idxs = []
+ start_idx = 0
+ for item in list_of_lists:
+ if isinstance(item, list):
+ flat_list += item
+ l = len(item)
+ idxs = range(start_idx, start_idx+l)
+ start_idx = start_idx+l
+ else: # a value
+ flat_list.append(item)
+ idxs = [start_idx]
+ start_idx += 1
+ flat_list_idxs.append(idxs)
+
+ return flat_list, flat_list_idxs
diff --git a/models/research/lm_1b/BUILD b/models/research/lm_1b/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..ca5bc1f6ce4347a3b5f18d1bb59284aa9d07a567
--- /dev/null
+++ b/models/research/lm_1b/BUILD
@@ -0,0 +1,27 @@
+package(default_visibility = [":internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+ name = "internal",
+ packages = [
+ "//lm_1b/...",
+ ],
+)
+
+py_library(
+ name = "data_utils",
+ srcs = ["data_utils.py"],
+)
+
+py_binary(
+ name = "lm_1b_eval",
+ srcs = [
+ "lm_1b_eval.py",
+ ],
+ deps = [
+ ":data_utils",
+ ],
+)
diff --git a/models/research/lm_1b/README.md b/models/research/lm_1b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f48afbfe23aff6681e641296e73b2c6b0e5a9b48
--- /dev/null
+++ b/models/research/lm_1b/README.md
@@ -0,0 +1,198 @@
+
+
+
+
+Language Model on One Billion Word Benchmark
+
+Authors:
+
+Oriol Vinyals (vinyals@google.com, github: OriolVinyals),
+Xin Pan
+
+Paper Authors:
+
+Rafal Jozefowicz, Oriol Vinyals, Mike Schuster, Noam Shazeer, Yonghui Wu
+
+TL;DR
+
+This is a pretrained model on One Billion Word Benchmark.
+If you use this model in your publication, please cite the original paper:
+
+@article{jozefowicz2016exploring,
+ title={Exploring the Limits of Language Modeling},
+ author={Jozefowicz, Rafal and Vinyals, Oriol and Schuster, Mike
+ and Shazeer, Noam and Wu, Yonghui},
+ journal={arXiv preprint arXiv:1602.02410},
+ year={2016}
+}
+
+Introduction
+
+In this release, we open source a model trained on the One Billion Word
+Benchmark (http://arxiv.org/abs/1312.3005), a large language corpus in English
+which was released in 2013. This dataset contains about one billion words, and
+has a vocabulary size of about 800K words. It contains mostly news data. Since
+sentences in the training set are shuffled, models can ignore the context and
+focus on sentence level language modeling.
+
+In the original release and subsequent work, people have used the same test set
+to train models on this dataset as a standard benchmark for language modeling.
+Recently, we wrote an article (http://arxiv.org/abs/1602.02410) describing a
+model hybrid between character CNN, a large and deep LSTM, and a specific
+Softmax architecture which allowed us to train the best model on this dataset
+thus far, almost halving the best perplexity previously obtained by others.
+
+Code Release
+
+The open-sourced components include:
+
+* TensorFlow GraphDef proto buffer text file.
+* TensorFlow pre-trained checkpoint shards.
+* Code used to evaluate the pre-trained model.
+* Vocabulary file.
+* Test set from LM-1B evaluation.
+
+The code supports 4 evaluation modes:
+
+* Given provided dataset, calculate the model's perplexity.
+* Given a prefix sentence, predict the next words.
+* Dump the softmax embedding, character-level CNN word embeddings.
+* Give a sentence, dump the embedding from the LSTM state.
+
+Results
+
+Model | Test Perplexity | Number of Params [billions]
+------|-----------------|----------------------------
+Sigmoid-RNN-2048 [Blackout] | 68.3 | 4.1
+Interpolated KN 5-gram, 1.1B n-grams [chelba2013one] | 67.6 | 1.76
+Sparse Non-Negative Matrix LM [shazeer2015sparse] | 52.9 | 33
+RNN-1024 + MaxEnt 9-gram features [chelba2013one] | 51.3 | 20
+LSTM-512-512 | 54.1 | 0.82
+LSTM-1024-512 | 48.2 | 0.82
+LSTM-2048-512 | 43.7 | 0.83
+LSTM-8192-2048 (No Dropout) | 37.9 | 3.3
+LSTM-8192-2048 (50\% Dropout) | 32.2 | 3.3
+2-Layer LSTM-8192-1024 (BIG LSTM) | 30.6 | 1.8
+(THIS RELEASE) BIG LSTM+CNN Inputs | 30.0 | 1.04
+
+How To Run
+
+Prerequisites:
+
+* Install TensorFlow.
+* Install Bazel.
+* Download the data files:
+ * Model GraphDef file:
+ [link](http://download.tensorflow.org/models/LM_LSTM_CNN/graph-2016-09-10.pbtxt)
+ * Model Checkpoint sharded file:
+ [1](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-base)
+ [2](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-char-embedding)
+ [3](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-lstm)
+ [4](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax0)
+ [5](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax1)
+ [6](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax2)
+ [7](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax3)
+ [8](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax4)
+ [9](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax5)
+ [10](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax6)
+ [11](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax7)
+ [12](http://download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax8)
+ * Vocabulary file:
+ [link](http://download.tensorflow.org/models/LM_LSTM_CNN/vocab-2016-09-10.txt)
+ * test dataset: link
+ [link](http://download.tensorflow.org/models/LM_LSTM_CNN/test/news.en.heldout-00000-of-00050)
+* It is recommended to run on a modern desktop instead of a laptop.
+
+```shell
+# 1. Clone the code to your workspace.
+# 2. Download the data to your workspace.
+# 3. Create an empty WORKSPACE file in your workspace.
+# 4. Create an empty output directory in your workspace.
+# Example directory structure below:
+$ ls -R
+.:
+data lm_1b output WORKSPACE
+
+./data:
+ckpt-base ckpt-lstm ckpt-softmax1 ckpt-softmax3 ckpt-softmax5
+ckpt-softmax7 graph-2016-09-10.pbtxt vocab-2016-09-10.txt
+ckpt-char-embedding ckpt-softmax0 ckpt-softmax2 ckpt-softmax4 ckpt-softmax6
+ckpt-softmax8 news.en.heldout-00000-of-00050
+
+./lm_1b:
+BUILD data_utils.py lm_1b_eval.py README.md
+
+./output:
+
+# Build the codes.
+$ bazel build -c opt lm_1b/...
+# Run sample mode:
+$ bazel-bin/lm_1b/lm_1b_eval --mode sample \
+ --prefix "I love that I" \
+ --pbtxt data/graph-2016-09-10.pbtxt \
+ --vocab_file data/vocab-2016-09-10.txt \
+ --ckpt 'data/ckpt-*'
+...(omitted some TensorFlow output)
+I love
+I love that
+I love that I
+I love that I find
+I love that I find that
+I love that I find that amazing
+...(omitted)
+
+# Run eval mode:
+$ bazel-bin/lm_1b/lm_1b_eval --mode eval \
+ --pbtxt data/graph-2016-09-10.pbtxt \
+ --vocab_file data/vocab-2016-09-10.txt \
+ --input_data data/news.en.heldout-00000-of-00050 \
+ --ckpt 'data/ckpt-*'
+...(omitted some TensorFlow output)
+Loaded step 14108582.
+# perplexity is high initially because words without context are harder to
+# predict.
+Eval Step: 0, Average Perplexity: 2045.512297.
+Eval Step: 1, Average Perplexity: 229.478699.
+Eval Step: 2, Average Perplexity: 208.116787.
+Eval Step: 3, Average Perplexity: 338.870601.
+Eval Step: 4, Average Perplexity: 228.950107.
+Eval Step: 5, Average Perplexity: 197.685857.
+Eval Step: 6, Average Perplexity: 156.287063.
+Eval Step: 7, Average Perplexity: 124.866189.
+Eval Step: 8, Average Perplexity: 147.204975.
+Eval Step: 9, Average Perplexity: 90.124864.
+Eval Step: 10, Average Perplexity: 59.897914.
+Eval Step: 11, Average Perplexity: 42.591137.
+...(omitted)
+Eval Step: 4529, Average Perplexity: 29.243668.
+Eval Step: 4530, Average Perplexity: 29.302362.
+Eval Step: 4531, Average Perplexity: 29.285674.
+...(omitted. At convergence, it should be around 30.)
+
+# Run dump_emb mode:
+$ bazel-bin/lm_1b/lm_1b_eval --mode dump_emb \
+ --pbtxt data/graph-2016-09-10.pbtxt \
+ --vocab_file data/vocab-2016-09-10.txt \
+ --ckpt 'data/ckpt-*' \
+ --save_dir output
+...(omitted some TensorFlow output)
+Finished softmax weights
+Finished word embedding 0/793471
+Finished word embedding 1/793471
+Finished word embedding 2/793471
+...(omitted)
+$ ls output/
+embeddings_softmax.npy ...
+
+# Run dump_lstm_emb mode:
+$ bazel-bin/lm_1b/lm_1b_eval --mode dump_lstm_emb \
+ --pbtxt data/graph-2016-09-10.pbtxt \
+ --vocab_file data/vocab-2016-09-10.txt \
+ --ckpt 'data/ckpt-*' \
+ --sentence "I love who I am ." \
+ --save_dir output
+$ ls output/
+lstm_emb_step_0.npy lstm_emb_step_2.npy lstm_emb_step_4.npy
+lstm_emb_step_6.npy lstm_emb_step_1.npy lstm_emb_step_3.npy
+lstm_emb_step_5.npy
+```
diff --git a/models/research/lm_1b/data_utils.py b/models/research/lm_1b/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad8d3391ef6db07c1d6c234450a6d23a8e19a178
--- /dev/null
+++ b/models/research/lm_1b/data_utils.py
@@ -0,0 +1,279 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A library for loading 1B word benchmark dataset."""
+
+import random
+
+import numpy as np
+import tensorflow as tf
+
+
+class Vocabulary(object):
+ """Class that holds a vocabulary for the dataset."""
+
+ def __init__(self, filename):
+ """Initialize vocabulary.
+
+ Args:
+ filename: Vocabulary file name.
+ """
+
+ self._id_to_word = []
+ self._word_to_id = {}
+ self._unk = -1
+ self._bos = -1
+ self._eos = -1
+
+ with tf.gfile.Open(filename) as f:
+ idx = 0
+ for line in f:
+ word_name = line.strip()
+ if word_name == '':
+ self._bos = idx
+ elif word_name == '':
+ self._eos = idx
+ elif word_name == '':
+ self._unk = idx
+ if word_name == '!!!MAXTERMID':
+ continue
+
+ self._id_to_word.append(word_name)
+ self._word_to_id[word_name] = idx
+ idx += 1
+
+ @property
+ def bos(self):
+ return self._bos
+
+ @property
+ def eos(self):
+ return self._eos
+
+ @property
+ def unk(self):
+ return self._unk
+
+ @property
+ def size(self):
+ return len(self._id_to_word)
+
+ def word_to_id(self, word):
+ if word in self._word_to_id:
+ return self._word_to_id[word]
+ return self.unk
+
+ def id_to_word(self, cur_id):
+ if cur_id < self.size:
+ return self._id_to_word[cur_id]
+ return 'ERROR'
+
+ def decode(self, cur_ids):
+ """Convert a list of ids to a sentence, with space inserted."""
+ return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids])
+
+ def encode(self, sentence):
+ """Convert a sentence to a list of ids, with special tokens added."""
+ word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
+ return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)
+
+
+class CharsVocabulary(Vocabulary):
+ """Vocabulary containing character-level information."""
+
+ def __init__(self, filename, max_word_length):
+ super(CharsVocabulary, self).__init__(filename)
+ self._max_word_length = max_word_length
+ chars_set = set()
+
+ for word in self._id_to_word:
+ chars_set |= set(word)
+
+ free_ids = []
+ for i in range(256):
+ if chr(i) in chars_set:
+ continue
+ free_ids.append(chr(i))
+
+ if len(free_ids) < 5:
+ raise ValueError('Not enough free char ids: %d' % len(free_ids))
+
+ self.bos_char = free_ids[0] #
+ self.eos_char = free_ids[1] #
+ self.bow_char = free_ids[2] #
+ self.eow_char = free_ids[3] #
+ self.pad_char = free_ids[4] #
+
+ chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char,
+ self.pad_char}
+
+ self._char_set = chars_set
+ num_words = len(self._id_to_word)
+
+ self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32)
+
+ self.bos_chars = self._convert_word_to_char_ids(self.bos_char)
+ self.eos_chars = self._convert_word_to_char_ids(self.eos_char)
+
+ for i, word in enumerate(self._id_to_word):
+ self._word_char_ids[i] = self._convert_word_to_char_ids(word)
+
+ @property
+ def word_char_ids(self):
+ return self._word_char_ids
+
+ @property
+ def max_word_length(self):
+ return self._max_word_length
+
+ def _convert_word_to_char_ids(self, word):
+ code = np.zeros([self.max_word_length], dtype=np.int32)
+ code[:] = ord(self.pad_char)
+
+ if len(word) > self.max_word_length - 2:
+ word = word[:self.max_word_length-2]
+ cur_word = self.bow_char + word + self.eow_char
+ for j in range(len(cur_word)):
+ code[j] = ord(cur_word[j])
+ return code
+
+ def word_to_char_ids(self, word):
+ if word in self._word_to_id:
+ return self._word_char_ids[self._word_to_id[word]]
+ else:
+ return self._convert_word_to_char_ids(word)
+
+ def encode_chars(self, sentence):
+ chars_ids = [self.word_to_char_ids(cur_word)
+ for cur_word in sentence.split()]
+ return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])
+
+
+def get_batch(generator, batch_size, num_steps, max_word_length, pad=False):
+ """Read batches of input."""
+ cur_stream = [None] * batch_size
+
+ inputs = np.zeros([batch_size, num_steps], np.int32)
+ char_inputs = np.zeros([batch_size, num_steps, max_word_length], np.int32)
+ global_word_ids = np.zeros([batch_size, num_steps], np.int32)
+ targets = np.zeros([batch_size, num_steps], np.int32)
+ weights = np.ones([batch_size, num_steps], np.float32)
+
+ no_more_data = False
+ while True:
+ inputs[:] = 0
+ char_inputs[:] = 0
+ global_word_ids[:] = 0
+ targets[:] = 0
+ weights[:] = 0.0
+
+ for i in range(batch_size):
+ cur_pos = 0
+
+ while cur_pos < num_steps:
+ if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
+ try:
+ cur_stream[i] = list(generator.next())
+ except StopIteration:
+ # No more data, exhaust current streams and quit
+ no_more_data = True
+ break
+
+ how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
+ next_pos = cur_pos + how_many
+
+ inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
+ char_inputs[i, cur_pos:next_pos] = cur_stream[i][1][:how_many]
+ global_word_ids[i, cur_pos:next_pos] = cur_stream[i][2][:how_many]
+ targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many+1]
+ weights[i, cur_pos:next_pos] = 1.0
+
+ cur_pos = next_pos
+ cur_stream[i][0] = cur_stream[i][0][how_many:]
+ cur_stream[i][1] = cur_stream[i][1][how_many:]
+ cur_stream[i][2] = cur_stream[i][2][how_many:]
+
+ if pad:
+ break
+
+ if no_more_data and np.sum(weights) == 0:
+ # There is no more data and this is an empty batch. Done!
+ break
+ yield inputs, char_inputs, global_word_ids, targets, weights
+
+
+class LM1BDataset(object):
+ """Utility class for 1B word benchmark dataset.
+
+ The current implementation reads the data from the tokenized text files.
+ """
+
+ def __init__(self, filepattern, vocab):
+ """Initialize LM1BDataset reader.
+
+ Args:
+ filepattern: Dataset file pattern.
+ vocab: Vocabulary.
+ """
+ self._vocab = vocab
+ self._all_shards = tf.gfile.Glob(filepattern)
+ tf.logging.info('Found %d shards at %s', len(self._all_shards), filepattern)
+
+ def _load_random_shard(self):
+ """Randomly select a file and read it."""
+ return self._load_shard(random.choice(self._all_shards))
+
+ def _load_shard(self, shard_name):
+ """Read one file and convert to ids.
+
+ Args:
+ shard_name: file path.
+
+ Returns:
+ list of (id, char_id, global_word_id) tuples.
+ """
+ tf.logging.info('Loading data from: %s', shard_name)
+ with tf.gfile.Open(shard_name) as f:
+ sentences = f.readlines()
+ chars_ids = [self.vocab.encode_chars(sentence) for sentence in sentences]
+ ids = [self.vocab.encode(sentence) for sentence in sentences]
+
+ global_word_ids = []
+ current_idx = 0
+ for word_ids in ids:
+ current_size = len(word_ids) - 1 # without symbol
+ cur_ids = np.arange(current_idx, current_idx + current_size)
+ global_word_ids.append(cur_ids)
+ current_idx += current_size
+
+ tf.logging.info('Loaded %d words.', current_idx)
+ tf.logging.info('Finished loading')
+ return zip(ids, chars_ids, global_word_ids)
+
+ def _get_sentence(self, forever=True):
+ while True:
+ ids = self._load_random_shard()
+ for current_ids in ids:
+ yield current_ids
+ if not forever:
+ break
+
+ def get_batch(self, batch_size, num_steps, pad=False, forever=True):
+ return get_batch(self._get_sentence(forever), batch_size, num_steps,
+ self.vocab.max_word_length, pad=pad)
+
+ @property
+ def vocab(self):
+ return self._vocab
diff --git a/models/research/lm_1b/lm_1b_eval.py b/models/research/lm_1b/lm_1b_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce8634757558c135ba137a9b9e09a733977adc3a
--- /dev/null
+++ b/models/research/lm_1b/lm_1b_eval.py
@@ -0,0 +1,308 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Eval pre-trained 1 billion word language model.
+"""
+import os
+import sys
+
+import numpy as np
+from six.moves import xrange
+import tensorflow as tf
+
+from google.protobuf import text_format
+import data_utils
+
+FLAGS = tf.flags.FLAGS
+# General flags.
+tf.flags.DEFINE_string('mode', 'eval',
+ 'One of [sample, eval, dump_emb, dump_lstm_emb]. '
+ '"sample" mode samples future word predictions, using '
+ 'FLAGS.prefix as prefix (prefix could be left empty). '
+ '"eval" mode calculates perplexity of the '
+ 'FLAGS.input_data. '
+ '"dump_emb" mode dumps word and softmax embeddings to '
+ 'FLAGS.save_dir. embeddings are dumped in the same '
+ 'order as words in vocabulary. All words in vocabulary '
+ 'are dumped.'
+ 'dump_lstm_emb dumps lstm embeddings of FLAGS.sentence '
+ 'to FLAGS.save_dir.')
+tf.flags.DEFINE_string('pbtxt', '',
+ 'GraphDef proto text file used to construct model '
+ 'structure.')
+tf.flags.DEFINE_string('ckpt', '',
+ 'Checkpoint directory used to fill model values.')
+tf.flags.DEFINE_string('vocab_file', '', 'Vocabulary file.')
+tf.flags.DEFINE_string('save_dir', '',
+ 'Used for "dump_emb" mode to save word embeddings.')
+# sample mode flags.
+tf.flags.DEFINE_string('prefix', '',
+ 'Used for "sample" mode to predict next words.')
+tf.flags.DEFINE_integer('max_sample_words', 100,
+ 'Sampling stops either when is met or this number '
+ 'of steps has passed.')
+tf.flags.DEFINE_integer('num_samples', 3,
+ 'Number of samples to generate for the prefix.')
+# dump_lstm_emb mode flags.
+tf.flags.DEFINE_string('sentence', '',
+ 'Used as input for "dump_lstm_emb" mode.')
+# eval mode flags.
+tf.flags.DEFINE_string('input_data', '',
+ 'Input data files for eval model.')
+tf.flags.DEFINE_integer('max_eval_steps', 1000000,
+ 'Maximum mumber of steps to run "eval" mode.')
+
+
+# For saving demo resources, use batch size 1 and step 1.
+BATCH_SIZE = 1
+NUM_TIMESTEPS = 1
+MAX_WORD_LEN = 50
+
+
+def _LoadModel(gd_file, ckpt_file):
+ """Load the model from GraphDef and Checkpoint.
+
+ Args:
+ gd_file: GraphDef proto text file.
+ ckpt_file: TensorFlow Checkpoint file.
+
+ Returns:
+ TensorFlow session and tensors dict.
+ """
+ with tf.Graph().as_default():
+ sys.stderr.write('Recovering graph.\n')
+ with tf.gfile.FastGFile(gd_file, 'r') as f:
+ s = f.read().decode()
+ gd = tf.GraphDef()
+ text_format.Merge(s, gd)
+
+ tf.logging.info('Recovering Graph %s', gd_file)
+ t = {}
+ [t['states_init'], t['lstm/lstm_0/control_dependency'],
+ t['lstm/lstm_1/control_dependency'], t['softmax_out'], t['class_ids_out'],
+ t['class_weights_out'], t['log_perplexity_out'], t['inputs_in'],
+ t['targets_in'], t['target_weights_in'], t['char_inputs_in'],
+ t['all_embs'], t['softmax_weights'], t['global_step']
+ ] = tf.import_graph_def(gd, {}, ['states_init',
+ 'lstm/lstm_0/control_dependency:0',
+ 'lstm/lstm_1/control_dependency:0',
+ 'softmax_out:0',
+ 'class_ids_out:0',
+ 'class_weights_out:0',
+ 'log_perplexity_out:0',
+ 'inputs_in:0',
+ 'targets_in:0',
+ 'target_weights_in:0',
+ 'char_inputs_in:0',
+ 'all_embs_out:0',
+ 'Reshape_3:0',
+ 'global_step:0'], name='')
+
+ sys.stderr.write('Recovering checkpoint %s\n' % ckpt_file)
+ sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+ sess.run('save/restore_all', {'save/Const:0': ckpt_file})
+ sess.run(t['states_init'])
+
+ return sess, t
+
+
+def _EvalModel(dataset):
+ """Evaluate model perplexity using provided dataset.
+
+ Args:
+ dataset: LM1BDataset object.
+ """
+ sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
+
+ current_step = t['global_step'].eval(session=sess)
+ sys.stderr.write('Loaded step %d.\n' % current_step)
+
+ data_gen = dataset.get_batch(BATCH_SIZE, NUM_TIMESTEPS, forever=False)
+ sum_num = 0.0
+ sum_den = 0.0
+ perplexity = 0.0
+ for i, (inputs, char_inputs, _, targets, weights) in enumerate(data_gen):
+ input_dict = {t['inputs_in']: inputs,
+ t['targets_in']: targets,
+ t['target_weights_in']: weights}
+ if 'char_inputs_in' in t:
+ input_dict[t['char_inputs_in']] = char_inputs
+ log_perp = sess.run(t['log_perplexity_out'], feed_dict=input_dict)
+
+ if np.isnan(log_perp):
+ sys.stderr.error('log_perplexity is Nan.\n')
+ else:
+ sum_num += log_perp * weights.mean()
+ sum_den += weights.mean()
+ if sum_den > 0:
+ perplexity = np.exp(sum_num / sum_den)
+
+ sys.stderr.write('Eval Step: %d, Average Perplexity: %f.\n' %
+ (i, perplexity))
+
+ if i > FLAGS.max_eval_steps:
+ break
+
+
+def _SampleSoftmax(softmax):
+ return min(np.sum(np.cumsum(softmax) < np.random.rand()), len(softmax) - 1)
+
+
+def _SampleModel(prefix_words, vocab):
+ """Predict next words using the given prefix words.
+
+ Args:
+ prefix_words: Prefix words.
+ vocab: Vocabulary. Contains max word chard id length and converts between
+ words and ids.
+ """
+ targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+ weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
+
+ sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
+
+ if prefix_words.find('') != 0:
+ prefix_words = ' ' + prefix_words
+
+ prefix = [vocab.word_to_id(w) for w in prefix_words.split()]
+ prefix_char_ids = [vocab.word_to_char_ids(w) for w in prefix_words.split()]
+ for _ in xrange(FLAGS.num_samples):
+ inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+ char_ids_inputs = np.zeros(
+ [BATCH_SIZE, NUM_TIMESTEPS, vocab.max_word_length], np.int32)
+ samples = prefix[:]
+ char_ids_samples = prefix_char_ids[:]
+ sent = ''
+ while True:
+ inputs[0, 0] = samples[0]
+ char_ids_inputs[0, 0, :] = char_ids_samples[0]
+ samples = samples[1:]
+ char_ids_samples = char_ids_samples[1:]
+
+ softmax = sess.run(t['softmax_out'],
+ feed_dict={t['char_inputs_in']: char_ids_inputs,
+ t['inputs_in']: inputs,
+ t['targets_in']: targets,
+ t['target_weights_in']: weights})
+
+ sample = _SampleSoftmax(softmax[0])
+ sample_char_ids = vocab.word_to_char_ids(vocab.id_to_word(sample))
+
+ if not samples:
+ samples = [sample]
+ char_ids_samples = [sample_char_ids]
+ sent += vocab.id_to_word(samples[0]) + ' '
+ sys.stderr.write('%s\n' % sent)
+
+ if (vocab.id_to_word(samples[0]) == '' or
+ len(sent) > FLAGS.max_sample_words):
+ break
+
+
+def _DumpEmb(vocab):
+ """Dump the softmax weights and word embeddings to files.
+
+ Args:
+ vocab: Vocabulary. Contains vocabulary size and converts word to ids.
+ """
+ assert FLAGS.save_dir, 'Must specify FLAGS.save_dir for dump_emb.'
+ inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+ targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+ weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
+
+ sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
+
+ softmax_weights = sess.run(t['softmax_weights'])
+ fname = FLAGS.save_dir + '/embeddings_softmax.npy'
+ with tf.gfile.Open(fname, mode='w') as f:
+ np.save(f, softmax_weights)
+ sys.stderr.write('Finished softmax weights\n')
+
+ all_embs = np.zeros([vocab.size, 1024])
+ for i in xrange(vocab.size):
+ input_dict = {t['inputs_in']: inputs,
+ t['targets_in']: targets,
+ t['target_weights_in']: weights}
+ if 'char_inputs_in' in t:
+ input_dict[t['char_inputs_in']] = (
+ vocab.word_char_ids[i].reshape([-1, 1, MAX_WORD_LEN]))
+ embs = sess.run(t['all_embs'], input_dict)
+ all_embs[i, :] = embs
+ sys.stderr.write('Finished word embedding %d/%d\n' % (i, vocab.size))
+
+ fname = FLAGS.save_dir + '/embeddings_char_cnn.npy'
+ with tf.gfile.Open(fname, mode='w') as f:
+ np.save(f, all_embs)
+ sys.stderr.write('Embedding file saved\n')
+
+
+def _DumpSentenceEmbedding(sentence, vocab):
+ """Predict next words using the given prefix words.
+
+ Args:
+ sentence: Sentence words.
+ vocab: Vocabulary. Contains max word chard id length and converts between
+ words and ids.
+ """
+ targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+ weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
+
+ sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
+
+ if sentence.find('') != 0:
+ sentence = ' ' + sentence
+
+ word_ids = [vocab.word_to_id(w) for w in sentence.split()]
+ char_ids = [vocab.word_to_char_ids(w) for w in sentence.split()]
+
+ inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+ char_ids_inputs = np.zeros(
+ [BATCH_SIZE, NUM_TIMESTEPS, vocab.max_word_length], np.int32)
+ for i in xrange(len(word_ids)):
+ inputs[0, 0] = word_ids[i]
+ char_ids_inputs[0, 0, :] = char_ids[i]
+
+ # Add 'lstm/lstm_0/control_dependency' if you want to dump previous layer
+ # LSTM.
+ lstm_emb = sess.run(t['lstm/lstm_1/control_dependency'],
+ feed_dict={t['char_inputs_in']: char_ids_inputs,
+ t['inputs_in']: inputs,
+ t['targets_in']: targets,
+ t['target_weights_in']: weights})
+
+ fname = os.path.join(FLAGS.save_dir, 'lstm_emb_step_%d.npy' % i)
+ with tf.gfile.Open(fname, mode='w') as f:
+ np.save(f, lstm_emb)
+ sys.stderr.write('LSTM embedding step %d file saved\n' % i)
+
+
+def main(unused_argv):
+ vocab = data_utils.CharsVocabulary(FLAGS.vocab_file, MAX_WORD_LEN)
+
+ if FLAGS.mode == 'eval':
+ dataset = data_utils.LM1BDataset(FLAGS.input_data, vocab)
+ _EvalModel(dataset)
+ elif FLAGS.mode == 'sample':
+ _SampleModel(FLAGS.prefix, vocab)
+ elif FLAGS.mode == 'dump_emb':
+ _DumpEmb(vocab)
+ elif FLAGS.mode == 'dump_lstm_emb':
+ _DumpSentenceEmbedding(FLAGS.sentence, vocab)
+ else:
+ raise Exception('Mode not supported.')
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/lm_commonsense/README.md b/models/research/lm_commonsense/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..78c8f53ca226f09c4b185490d6966f98bf584889
--- /dev/null
+++ b/models/research/lm_commonsense/README.md
@@ -0,0 +1,170 @@
+
+
+
+
+# A Simple Method for Commonsense Reasoning
+
+This repository contains code to reproduce results from [*A Simple Method for Commonsense Reasoning*](https://arxiv.org/abs/1806.02847).
+
+Authors and contact:
+
+* Trieu H. Trinh (thtrieu@google.com, github: thtrieu)
+* Quoc V. Le (qvl@google.com)
+
+## TL;DR
+
+Commonsense reasoning is a long-standing challenge for deep learning. For example,
+it is difficult to use neural networks to tackle the Winograd Schema dataset - a difficult subset of Pronoun Disambiguation problems. In this work, we use language models to score substitued sentences to decide the correct reference of the ambiguous pronoun (see Figure below for an example).
+
+
+
+This simple unsupervised method achieves new state-of-the-art (*as of June 1st, 2018*) results on both benchmark PDP-60 and WSC-273 (See Table below), without using rule-based reasoning nor expensive annotated knowledge bases.
+
+| Commonsense-reasoning test | Previous best result | Ours |
+| ----------------------------|:----------------------:|:-----:|
+| Pronoun Disambiguation | 66.7% | 70% |
+| Winograd Schema Challenge | 52.8% | 63.7% |
+
+
+
+## Citation
+
+If you use our released models below in your publication, please cite the original paper:
+
+@article{TBD}
+
+
+## Requirements
+* Python >=2.6
+* Tensorflow >= v1.4
+* Numpy >= 1.12.1
+
+## Details of this release
+
+The open-sourced components include:
+
+* Test sets from Pronoun Disambiguation Problem (PDP-60) and Winograd Schema Challenges (WSC-273).
+* Tensorflow metagraph and checkpoints of 14 language models (See Appendix A in the paper).
+* A vocabulary file.
+* Code to reproduce results from the original paper.
+
+## How to run
+
+### 1. Download data files
+
+Download all files from the [Google Cloud Storage of this project](https://console.cloud.google.com/storage/browser/commonsense-reasoning/). The easiest way is to install and use `gsutil cp` command-line tool (See [install gsutil](https://cloud.google.com/storage/docs/gsutil_install)).
+
+
+```shell
+# Download everything from the project gs://commonsense-reasoning
+$ gsutil cp -R gs://commonsense-reasoning/* .
+Copying gs://commonsense-reasoning/reproduce/vocab.txt...
+Copying gs://commonsense-reasoning/reproduce/commonsense_test/pdp60.json...
+Copying gs://commonsense-reasoning/reproduce/commonsense_test/wsc273.json...
+
+...(omitted)
+```
+
+All downloaded content should be in `./reproduce/`. This includes two tests `pdp60.json` and `wsc273.json`, a vocabulary file `vocab.txt` and checkpoints for all 14 language models, each includes three files (`.data`, `.index` and `.meta`). All checkpoint names start with `ckpt-best` since they are saved at the best perplexity on a hold-out text corpus.
+
+```shell
+# Check for the content
+$ ls reproduce/*
+reproduce/vocab.txt
+
+reproduce/commonsense_test:
+pdp60.json wsc273.json
+
+reproduce/lm01:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm02:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm03:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm04:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm05:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm06:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm07:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm08:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm09:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm10:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm11:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm12:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm13:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+
+reproduce/lm14:
+ckpt-best.data-00000-of-00001 ckpt-best.index ckpt-best.meta
+```
+
+### 2. Run evaluation code
+
+To reproduce results from the paper, simply run `eval.py` script.
+
+```shell
+$ python eval.py --data_dir=reproduce
+
+Restored from ./reproduce/lm01
+Reset RNN states.
+Processing patch (1, 1) / (2, 4)
+Probs for
+[['Then' 'Dad' 'figured' ..., 'man' "'s" 'board-bill']
+ ['Then' 'Dad' 'figured' ..., 'man' "'s" 'board-bill']
+ ['Always' 'before' ',' ..., 'now' ',' 'for']
+ ...,
+ ['Mark' 'was' 'close' ..., 'promising' 'him' ',']
+ ['Mark' 'was' 'close' ..., 'promising' 'him' ',']
+ ['Mark' 'was' 'close' ..., 'promising' 'him' ',']]
+=
+[[ 1.64250596e-05 1.77780055e-06 4.14267970e-06 ..., 1.87315454e-03
+ 1.57723188e-01 6.31845817e-02]
+ [ 1.64250596e-05 1.77780055e-06 4.14267970e-06 ..., 1.87315454e-03
+ 1.57723188e-01 6.31845817e-02]
+ [ 1.28243030e-07 3.80435935e-03 1.12383246e-01 ..., 9.67682712e-03
+ 2.17407525e-01 1.08243264e-01]
+ ...,
+ [ 1.15557734e-04 2.92792241e-03 3.46455898e-04 ..., 2.72328052e-05
+ 3.37066874e-02 7.89367408e-02]
+ [ 1.15557734e-04 2.92792241e-03 3.46455898e-04 ..., 2.72328052e-05
+ 3.37066874e-02 7.89367408e-02]
+ [ 1.15557734e-04 2.92792241e-03 3.46455898e-04 ..., 2.72328052e-05
+ 3.37066874e-02 7.89367408e-02]]
+Processing patch (1, 2) / (2, 4)
+
+...(omitted)
+
+Accuracy of 1 LM(s) on pdp60 = 0.6
+
+...(omitted)
+
+Accuracy of 5 LM(s) on pdp60 = 0.7
+
+...(omitted)
+
+Accuracy of 10 LM(s) on wsc273 = 0.615
+
+...(omitted)
+
+Accuracy of 14 LM(s) on wsc273 = 0.637
+```
diff --git a/models/research/lm_commonsense/eval.py b/models/research/lm_commonsense/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5b7ff98b50a5af4e066d3d9f82c1acae81c3e93
--- /dev/null
+++ b/models/research/lm_commonsense/eval.py
@@ -0,0 +1,190 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import pickle as pkl
+import numpy as np
+import tensorflow as tf
+import utils
+
+tf.app.flags.DEFINE_string(
+ 'data_dir', 'reproduce',
+ 'Path to directory containing data and model checkpoints.')
+
+
+FLAGS = tf.app.flags.FLAGS
+
+
+class EnsembleLM(object):
+ """Ensemble of language models."""
+
+ def __init__(self, test_data_name='wsc273'):
+ vocab_file = os.path.join(FLAGS.data_dir, 'vocab.txt')
+ self.vocab = utils.CharsVocabulary(vocab_file, 50)
+ assert test_data_name in ['pdp60', 'wsc273'], (
+ 'Test data must be pdp60 or wsc273, got {}'.format(test_data_name))
+ self.test_data_name = test_data_name
+
+ test_data = utils.parse_commonsense_reasoning_test(test_data_name)
+ self.question_ids, self.sentences, self.labels = test_data
+ self.all_probs = [] # aggregate single-model prediction here.
+
+ def add_single_model(self, model_name='lm1'):
+ """Add a single model into the current ensemble."""
+ # Create single LM
+ single_lm = SingleRecurrentLanguageModel(self.vocab, model_name)
+
+ # Add the single LM prediction.
+ probs = single_lm.assign_probs(self.sentences, self.test_data_name)
+ self.all_probs.append(probs)
+ print('Done adding {}'.format(model_name))
+
+ def evaluate(self):
+ """Evaluate the current ensemble."""
+ # Attach word probabilities and correctness label to each substitution
+ ensembled_probs = sum(self.all_probs) / len(self.all_probs)
+ scorings = []
+ for i, sentence in enumerate(self.sentences):
+ correctness = self.labels[i]
+ word_probs = ensembled_probs[i, :len(sentence)]
+ joint_prob = np.prod(word_probs, dtype=np.float64)
+
+ scorings.append(dict(
+ correctness=correctness,
+ sentence=sentence,
+ joint_prob=joint_prob,
+ word_probs=word_probs))
+ scoring_mode = 'full' if self.test_data_name == 'pdp60' else 'partial'
+ return utils.compare_substitutions(
+ self.question_ids, scorings, scoring_mode)
+
+
+class SingleRecurrentLanguageModel(object):
+ """Single Recurrent Language Model."""
+
+ def __init__(self, vocab, model_name='lm01'):
+ self.vocab = vocab
+ self.log_dir = os.path.join(FLAGS.data_dir, model_name)
+
+ def reset(self):
+ self.sess.run(self.tensors['states_init'])
+
+ def _score(self, word_patch):
+ """Score a matrix of shape (batch_size, num_timesteps+1) str tokens."""
+ word_ids = np.array(
+ [[self.vocab.word_to_id(word) for word in row]
+ for row in word_patch])
+ char_ids = np.array(
+ [[self.vocab.word_to_char_ids(word) for word in row]
+ for row in word_patch])
+ print('Probs for \n{}\n='.format(np.array(word_patch)[:, 1:]))
+
+ input_ids, target_ids = word_ids[:, :-1], word_ids[:, 1:]
+ input_char_ids = char_ids[:, :-1, :]
+
+ softmax = self.sess.run(self.tensors['softmax_out'], feed_dict={
+ self.tensors['inputs_in']: input_ids,
+ self.tensors['char_inputs_in']: input_char_ids
+ })
+
+ batch_size, num_timesteps = self.shape
+ softmax = softmax.reshape((num_timesteps, batch_size, -1))
+ softmax = np.transpose(softmax, [1, 0, 2])
+ probs = np.array([[softmax[row, col, target_ids[row, col]]
+ for col in range(num_timesteps)]
+ for row in range(batch_size)])
+ print(probs)
+ return probs
+
+ def _score_patches(self, word_patches):
+ """Score a 2D matrix of word_patches and stitch results together."""
+ batch_size, num_timesteps = self.shape
+ nrow, ncol = len(word_patches), len(word_patches[0])
+ max_len = num_timesteps * ncol
+ probs = np.zeros([0, max_len]) # accumulate results into this.
+
+ # Loop through the 2D matrix of word_patches and score each.
+ for i, row in enumerate(word_patches):
+ print('Reset RNN states.')
+ self.reset() # reset states before processing each row.
+ row_probs = np.zeros([batch_size, 0])
+ for j, word_patch in enumerate(row):
+ print('Processing patch '
+ '({}, {}) / ({}, {})'.format(i+1, j+1, nrow, ncol))
+ patch_probs = (self._score(word_patch) if word_patch else
+ np.zeros([batch_size, num_timesteps]))
+ row_probs = np.concatenate([row_probs, patch_probs], 1)
+ probs = np.concatenate([probs, row_probs], 0)
+ return probs
+
+ def assign_probs(self, sentences, test_data_name='wsc273'):
+ """Return prediction accuracy using this LM for a test."""
+
+ probs_cache = os.path.join(self.log_dir, '{}.probs'.format(test_data_name))
+ if os.path.exists(probs_cache):
+ print('Reading cached result from {}'.format(probs_cache))
+ with tf.gfile.Open(probs_cache, 'r') as f:
+ probs = pkl.load(f)
+ else:
+ tf.reset_default_graph()
+ self.sess = tf.Session()
+ # Build the graph.
+ saver = tf.train.import_meta_graph(
+ os.path.join(self.log_dir, 'ckpt-best.meta'))
+ saver.restore(self.sess, os.path.join(self.log_dir, 'ckpt-best'))
+ print('Restored from {}'.format(self.log_dir))
+ graph = tf.get_default_graph()
+ self.tensors = dict(
+ inputs_in=graph.get_tensor_by_name('test_inputs_in:0'),
+ char_inputs_in=graph.get_tensor_by_name('test_char_inputs_in:0'),
+ softmax_out=graph.get_tensor_by_name('SotaRNN_1/softmax_out:0'),
+ states_init=graph.get_operation_by_name('SotaRNN_1/states_init'))
+ self.shape = self.tensors['inputs_in'].shape.as_list()
+
+ # Cut sentences into patches of shape processable by the LM.
+ batch_size, num_timesteps = self.shape
+ word_patches = utils.cut_to_patches(sentences, batch_size, num_timesteps)
+ probs = self._score_patches(word_patches)
+
+ # Cache the probs since they are expensive to evaluate
+ with tf.gfile.Open(probs_cache, 'w') as f:
+ pkl.dump(probs, f)
+ return probs
+
+
+def evaluate_ensemble(test_data_name, number_of_lms):
+ ensemble = EnsembleLM(test_data_name)
+ model_list = ['lm{:02d}'.format(i+1) for i in range(number_of_lms)]
+ for model_name in model_list:
+ ensemble.add_single_model(model_name)
+ accuracy = ensemble.evaluate()
+ print('Accuracy of {} LM(s) on {} = {}'.format(
+ number_of_lms, test_data_name, accuracy))
+
+
+def main(_):
+ evaluate_ensemble('pdp60', 1) # 60%
+ evaluate_ensemble('pdp60', 5) # 70%
+ evaluate_ensemble('wsc273', 10) # 61.5%
+ evaluate_ensemble('wsc273', 14) # 63.7%
+
+
+if __name__ == '__main__':
+ tf.app.run(main)
diff --git a/models/research/lm_commonsense/method.jpg b/models/research/lm_commonsense/method.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ee8a5506fccca3cbb67f7bda0ccef78303cb228b
Binary files /dev/null and b/models/research/lm_commonsense/method.jpg differ
diff --git a/models/research/lm_commonsense/utils.py b/models/research/lm_commonsense/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d75f2b0fb72716860ea6d438e6b8ca2732d13c84
--- /dev/null
+++ b/models/research/lm_commonsense/utils.py
@@ -0,0 +1,368 @@
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import numpy as np
+import tensorflow as tf
+
+FLAGS = tf.flags.FLAGS
+
+
+class Vocabulary(object):
+ """Class that holds a vocabulary for the dataset."""
+
+ def __init__(self, filename):
+
+ self._id_to_word = []
+ self._word_to_id = {}
+ self._unk = -1
+ self._bos = -1
+ self._eos = -1
+
+ with tf.gfile.Open(filename) as f:
+ idx = 0
+ for line in f:
+ word_name = line.strip()
+ if word_name == '':
+ self._bos = idx
+ elif word_name == '':
+ self._eos = idx
+ elif word_name == '':
+ self._unk = idx
+ if word_name == '!!!MAXTERMID':
+ continue
+
+ self._id_to_word.append(word_name)
+ self._word_to_id[word_name] = idx
+ idx += 1
+
+ @property
+ def bos(self):
+ return self._bos
+
+ @property
+ def eos(self):
+ return self._eos
+
+ @property
+ def unk(self):
+ return self._unk
+
+ @property
+ def size(self):
+ return len(self._id_to_word)
+
+ def word_to_id(self, word):
+ if word in self._word_to_id:
+ return self._word_to_id[word]
+ else:
+ if word.lower() in self._word_to_id:
+ return self._word_to_id[word.lower()]
+ return self.unk
+
+ def id_to_word(self, cur_id):
+ if cur_id < self.size:
+ return self._id_to_word[int(cur_id)]
+ return ''
+
+ def decode(self, cur_ids):
+ return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids])
+
+ def encode(self, sentence):
+ word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
+ return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)
+
+
+class CharsVocabulary(Vocabulary):
+ """Vocabulary containing character-level information."""
+
+ def __init__(self, filename, max_word_length):
+ super(CharsVocabulary, self).__init__(filename)
+
+ self._max_word_length = max_word_length
+ chars_set = set()
+
+ for word in self._id_to_word:
+ chars_set |= set(word)
+
+ free_ids = []
+ for i in range(256):
+ if chr(i) in chars_set:
+ continue
+ free_ids.append(chr(i))
+
+ if len(free_ids) < 5:
+ raise ValueError('Not enough free char ids: %d' % len(free_ids))
+
+ self.bos_char = free_ids[0] #
+ self.eos_char = free_ids[1] #
+ self.bow_char = free_ids[2] #
+ self.eow_char = free_ids[3] #
+ self.pad_char = free_ids[4] #
+
+ chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char,
+ self.pad_char}
+
+ self._char_set = chars_set
+ num_words = len(self._id_to_word)
+
+ self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32)
+
+ self.bos_chars = self._convert_word_to_char_ids(self.bos_char)
+ self.eos_chars = self._convert_word_to_char_ids(self.eos_char)
+
+ for i, word in enumerate(self._id_to_word):
+ if i == self.bos:
+ self._word_char_ids[i] = self.bos_chars
+ elif i == self.eos:
+ self._word_char_ids[i] = self.eos_chars
+ else:
+ self._word_char_ids[i] = self._convert_word_to_char_ids(word)
+
+ @property
+ def max_word_length(self):
+ return self._max_word_length
+
+ def _convert_word_to_char_ids(self, word):
+ code = np.zeros([self.max_word_length], dtype=np.int32)
+ code[:] = ord(self.pad_char)
+
+ if len(word) > self.max_word_length - 2:
+ word = word[:self.max_word_length-2]
+ cur_word = self.bow_char + word + self.eow_char
+ for j in range(len(cur_word)):
+ code[j] = ord(cur_word[j])
+ return code
+
+ def word_to_char_ids(self, word):
+ if word in self._word_to_id:
+ return self._word_char_ids[self._word_to_id[word]]
+ else:
+ return self._convert_word_to_char_ids(word)
+
+ def encode_chars(self, sentence):
+ chars_ids = [self.word_to_char_ids(cur_word)
+ for cur_word in sentence.split()]
+ return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])
+
+
+_SPECIAL_CHAR_MAP = {
+ '\xe2\x80\x98': '\'',
+ '\xe2\x80\x99': '\'',
+ '\xe2\x80\x9c': '"',
+ '\xe2\x80\x9d': '"',
+ '\xe2\x80\x93': '-',
+ '\xe2\x80\x94': '-',
+ '\xe2\x88\x92': '-',
+ '\xce\x84': '\'',
+ '\xc2\xb4': '\'',
+ '`': '\''
+}
+
+_START_SPECIAL_CHARS = ['.', ',', '?', '!', ';', ':', '[', ']', '\'', '+', '/',
+ '\xc2\xa3', '$', '~', '*', '%', '{', '}', '#', '&', '-',
+ '"', '(', ')', '='] + list(_SPECIAL_CHAR_MAP.keys())
+_SPECIAL_CHARS = _START_SPECIAL_CHARS + [
+ '\'s', '\'m', '\'t', '\'re', '\'d', '\'ve', '\'ll']
+
+
+def tokenize(sentence):
+ """Tokenize a sentence."""
+ sentence = str(sentence)
+ words = sentence.strip().split()
+ tokenized = [] # return this
+
+ for word in words:
+ if word.lower() in ['mr.', 'ms.']:
+ tokenized.append(word)
+ continue
+
+ # Split special chars at the start of word
+ will_split = True
+ while will_split:
+ will_split = False
+ for char in _START_SPECIAL_CHARS:
+ if word.startswith(char):
+ tokenized.append(char)
+ word = word[len(char):]
+ will_split = True
+
+ # Split special chars at the end of word
+ special_end_tokens = []
+ will_split = True
+ while will_split:
+ will_split = False
+ for char in _SPECIAL_CHARS:
+ if word.endswith(char):
+ special_end_tokens = [char] + special_end_tokens
+ word = word[:-len(char)]
+ will_split = True
+
+ if word:
+ tokenized.append(word)
+ tokenized += special_end_tokens
+
+ # Add necessary end of sentence token.
+ if tokenized[-1] not in ['.', '!', '?']:
+ tokenized += ['.']
+ return tokenized
+
+
+def parse_commonsense_reasoning_test(test_data_name):
+ """Read JSON test data."""
+ with tf.gfile.Open(os.path.join(
+ FLAGS.data_dir, 'commonsense_test',
+ '{}.json'.format(test_data_name)), 'r') as f:
+ data = json.load(f)
+
+ question_ids = [d['question_id'] for d in data]
+ sentences = [tokenize(d['substitution']) for d in data]
+ labels = [d['correctness'] for d in data]
+
+ return question_ids, sentences, labels
+
+
+PAD = ''
+
+
+def cut_to_patches(sentences, batch_size, num_timesteps):
+ """Cut sentences into patches of shape (batch_size, num_timesteps).
+
+ Args:
+ sentences: a list of sentences, each sentence is a list of str token.
+ batch_size: batch size
+ num_timesteps: number of backprop step
+
+ Returns:
+ patches: A 2D matrix,
+ each entry is a matrix of shape (batch_size, num_timesteps).
+ """
+ preprocessed = [['']+sentence+[''] for sentence in sentences]
+ max_len = max([len(sent) for sent in preprocessed])
+
+ # Pad to shape [height, width]
+ # where height is a multiple of batch_size
+ # and width is a multiple of num_timesteps
+ nrow = int(np.ceil(len(preprocessed) * 1.0 / batch_size))
+ ncol = int(np.ceil(max_len * 1.0 / num_timesteps))
+ height, width = nrow * batch_size, ncol * num_timesteps + 1
+ preprocessed = [sent + [PAD] * (width - len(sent)) for sent in preprocessed]
+ preprocessed += [[PAD] * width] * (height - len(preprocessed))
+
+ # Cut preprocessed into patches of shape [batch_size, num_timesteps]
+ patches = []
+ for row in range(nrow):
+ patches.append([])
+ for col in range(ncol):
+ patch = [sent[col * num_timesteps:
+ (col+1) * num_timesteps + 1]
+ for sent in preprocessed[row * batch_size:
+ (row+1) * batch_size]]
+ if np.all(np.array(patch)[:, 1:] == PAD):
+ patch = None # no need to process this patch.
+ patches[-1].append(patch)
+ return patches
+
+
+def _substitution_mask(sent1, sent2):
+ """Binary mask identifying substituted part in two sentences.
+
+ Example sentence and their mask:
+ First sentence = "I like the cat 's color"
+ 0 0 0 1 0 0
+ Second sentence = "I like the yellow dog 's color"
+ 0 0 0 1 1 0 0
+
+ Args:
+ sent1: first sentence
+ sent2: second sentence
+
+ Returns:
+ mask1: mask for first sentence
+ mask2: mask for second sentence
+ """
+ mask1_start, mask2_start = [], []
+ while sent1[0] == sent2[0]:
+ sent1 = sent1[1:]
+ sent2 = sent2[1:]
+ mask1_start.append(0.)
+ mask2_start.append(0.)
+
+ mask1_end, mask2_end = [], []
+ while sent1[-1] == sent2[-1]:
+ if (len(sent1) == 1) or (len(sent2) == 1):
+ break
+ sent1 = sent1[:-1]
+ sent2 = sent2[:-1]
+ mask1_end = [0.] + mask1_end
+ mask2_end = [0.] + mask2_end
+
+ assert sent1 or sent2, 'Two sentences are identical.'
+ return (mask1_start + [1.] * len(sent1) + mask1_end,
+ mask2_start + [1.] * len(sent2) + mask2_end)
+
+
+def _convert_to_partial(scoring1, scoring2):
+ """Convert full scoring into partial scoring."""
+ mask1, mask2 = _substitution_mask(
+ scoring1['sentence'], scoring2['sentence'])
+
+ def _partial_score(scoring, mask):
+ word_probs = [max(_) for _ in zip(scoring['word_probs'], mask)]
+ scoring.update(word_probs=word_probs,
+ joint_prob=np.prod(word_probs))
+
+ _partial_score(scoring1, mask1)
+ _partial_score(scoring2, mask2)
+
+
+def compare_substitutions(question_ids, scorings, mode='full'):
+ """Return accuracy by comparing two consecutive scorings."""
+ prediction_correctness = []
+ # Compare two consecutive substitutions
+ for i in range(len(scorings) // 2):
+ scoring1, scoring2 = scorings[2*i: 2*i+2]
+ if mode == 'partial': # fix joint prob into partial prob
+ _convert_to_partial(scoring1, scoring2)
+
+ prediction_correctness.append(
+ (scoring2['joint_prob'] > scoring1['joint_prob']) ==
+ scoring2['correctness'])
+
+ # Two consecutive substitutions always belong to the same question
+ question_ids = [qid for i, qid in enumerate(question_ids) if i % 2 == 0]
+ assert len(question_ids) == len(prediction_correctness)
+ num_questions = len(set(question_ids))
+
+ # Question is correctly answered only if
+ # all predictions of the same question_id is correct
+ num_correct_answer = 0
+ previous_qid = None
+ correctly_answered = False
+ for predict, qid in zip(prediction_correctness, question_ids):
+ if qid != previous_qid:
+ previous_qid = qid
+ num_correct_answer += int(correctly_answered)
+ correctly_answered = True
+ correctly_answered = correctly_answered and predict
+ num_correct_answer += int(correctly_answered)
+
+ return num_correct_answer / num_questions
diff --git a/models/research/lstm_object_detection/README.md b/models/research/lstm_object_detection/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a696ba3df306768cfa28223ad957ef564667c7dd
--- /dev/null
+++ b/models/research/lstm_object_detection/README.md
@@ -0,0 +1,40 @@
+# Tensorflow Mobile Video Object Detection
+
+Tensorflow mobile video object detection implementation proposed in the
+following papers:
+
+
+
+
+
+```
+"Mobile Video Object Detection with Temporally-Aware Feature Maps",
+Liu, Mason and Zhu, Menglong, CVPR 2018.
+```
+\[[link](http://openaccess.thecvf.com/content_cvpr_2018/papers/Liu_Mobile_Video_Object_CVPR_2018_paper.pdf)\]\[[bibtex](
+https://scholar.googleusercontent.com/scholar.bib?q=info:hq5rcMUUXysJ:scholar.google.com/&output=citation&scisig=AAGBfm0AAAAAXLdwXcU5g_wiMQ40EvbHQ9kTyvfUxffh&scisf=4&ct=citation&cd=-1&hl=en)\]
+
+
+
+
+
+
+```
+"Looking Fast and Slow: Memory-Guided Mobile Video Object Detection",
+Liu, Mason and Zhu, Menglong and White, Marie and Li, Yinxiao and Kalenichenko, Dmitry
+```
+\[[link](https://arxiv.org/abs/1903.10172)\]\[[bibtex](
+https://scholar.googleusercontent.com/scholar.bib?q=info:rLqvkztmWYgJ:scholar.google.com/&output=citation&scisig=AAGBfm0AAAAAXLdwNf-LJlm2M1ymQHbq2wYA995MHpJu&scisf=4&ct=citation&cd=-1&hl=en)\]
+
+
+## Maintainers
+* masonliuw@gmail.com
+* yinxiao@google.com
+* menglong@google.com
+* yongzhe@google.com
+* lzyuan@google.com
+
+
+## Table of Contents
+
+ * Exporting a trained model
diff --git a/models/research/lstm_object_detection/__init__.py b/models/research/lstm_object_detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/lstm_object_detection/builders/__init__.py b/models/research/lstm_object_detection/builders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/lstm_object_detection/builders/graph_rewriter_builder.py b/models/research/lstm_object_detection/builders/graph_rewriter_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..accced2f0fccec190894348d5518bd991332fc71
--- /dev/null
+++ b/models/research/lstm_object_detection/builders/graph_rewriter_builder.py
@@ -0,0 +1,147 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Custom version for quantized training and evaluation functions.
+
+The main difference between this and the third_party graph_rewriter_builder.py
+is that this version uses experimental_create_training_graph which allows the
+customization of freeze_bn_delay.
+"""
+
+import re
+import tensorflow.compat.v1 as tf
+from tensorflow.contrib import layers as contrib_layers
+from tensorflow.contrib import quantize as contrib_quantize
+from tensorflow.contrib.quantize.python import common
+from tensorflow.contrib.quantize.python import input_to_ops
+from tensorflow.contrib.quantize.python import quant_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+def build(graph_rewriter_config,
+ quant_overrides_config=None,
+ is_training=True,
+ is_export=False):
+ """Returns a function that modifies default graph based on options.
+
+ Args:
+ graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
+ quant_overrides_config: quant_overrides_pb2.QuantOverrides proto.
+ is_training: whether in training or eval mode.
+ is_export: whether exporting the graph.
+ """
+ def graph_rewrite_fn():
+ """Function to quantize weights and activation of the default graph."""
+ if (graph_rewriter_config.quantization.weight_bits != 8 or
+ graph_rewriter_config.quantization.activation_bits != 8):
+ raise ValueError('Only 8bit quantization is supported')
+
+ graph = tf.get_default_graph()
+
+ # Insert custom quant ops.
+ if quant_overrides_config is not None:
+ input_to_ops_map = input_to_ops.InputToOps(graph)
+ for q in quant_overrides_config.quant_configs:
+ producer = graph.get_operation_by_name(q.op_name)
+ if producer is None:
+ raise ValueError('Op name does not exist in graph.')
+ context = _get_context_from_op(producer)
+ consumers = input_to_ops_map.ConsumerOperations(producer)
+ if q.fixed_range:
+ _insert_fixed_quant_op(
+ context,
+ q.quant_op_name,
+ producer,
+ consumers,
+ init_min=q.min,
+ init_max=q.max,
+ quant_delay=q.delay if is_training else 0)
+ else:
+ raise ValueError('Learned ranges are not yet supported.')
+
+ # Quantize the graph by inserting quantize ops for weights and activations
+ if is_training:
+ contrib_quantize.experimental_create_training_graph(
+ input_graph=graph,
+ quant_delay=graph_rewriter_config.quantization.delay,
+ freeze_bn_delay=graph_rewriter_config.quantization.delay)
+ else:
+ contrib_quantize.experimental_create_eval_graph(
+ input_graph=graph,
+ quant_delay=graph_rewriter_config.quantization.delay
+ if not is_export else 0)
+
+ contrib_layers.summarize_collection('quant_vars')
+
+ return graph_rewrite_fn
+
+
+def _get_context_from_op(op):
+ """Gets the root context name from the op name."""
+ context_re = re.search(r'^(.*)/([^/]+)', op.name)
+ if context_re:
+ return context_re.group(1)
+ return ''
+
+
+def _insert_fixed_quant_op(context,
+ name,
+ producer,
+ consumers,
+ init_min=-6.0,
+ init_max=6.0,
+ quant_delay=None):
+ """Adds a fake quant op with fixed ranges.
+
+ Args:
+ context: The parent scope of the op to be quantized.
+ name: The name of the fake quant op.
+ producer: The producer op to be quantized.
+ consumers: The consumer ops to the producer op.
+ init_min: The minimum range for the fake quant op.
+ init_max: The maximum range for the fake quant op.
+ quant_delay: Number of steps to wait before activating the fake quant op.
+
+ Raises:
+ ValueError: When producer operation is not directly connected to the
+ consumer operation.
+ """
+ name_prefix = name if not context else context + '/' + name
+ inputs = producer.outputs[0]
+ quant = quant_ops.FixedQuantize(
+ inputs, init_min=init_min, init_max=init_max, scope=name_prefix)
+
+ if quant_delay and quant_delay > 0:
+ activate_quant = math_ops.greater_equal(
+ common.CreateOrGetQuantizationStep(),
+ quant_delay,
+ name=name_prefix + '/activate_quant')
+ quant = control_flow_ops.cond(
+ activate_quant,
+ lambda: quant,
+ lambda: inputs,
+ name=name_prefix + '/delayed_quant')
+
+ if consumers:
+ tensors_modified_count = common.RerouteTensor(
+ quant, inputs, can_modify=consumers)
+ # Some operations can have multiple output tensors going to the same
+ # consumer. Since consumers is a set, we need to ensure that
+ # tensors_modified_count is greater than or equal to the length of the set
+ # of consumers.
+ if tensors_modified_count < len(consumers):
+ raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
+ [consumer.name for consumer in consumers]))
diff --git a/models/research/lstm_object_detection/builders/graph_rewriter_builder_test.py b/models/research/lstm_object_detection/builders/graph_rewriter_builder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e06a9f5a3d729fe122bc00e74e2d158b3d06482e
--- /dev/null
+++ b/models/research/lstm_object_detection/builders/graph_rewriter_builder_test.py
@@ -0,0 +1,117 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for graph_rewriter_builder."""
+import mock
+import tensorflow.compat.v1 as tf
+from tensorflow.contrib import layers as contrib_layers
+from tensorflow.contrib import quantize as contrib_quantize
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from lstm_object_detection.builders import graph_rewriter_builder
+from lstm_object_detection.protos import quant_overrides_pb2
+from object_detection.protos import graph_rewriter_pb2
+
+
+class QuantizationBuilderTest(tf.test.TestCase):
+
+ def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
+ with mock.patch.object(
+ contrib_quantize,
+ 'experimental_create_training_graph') as mock_quant_fn:
+ with mock.patch.object(contrib_layers,
+ 'summarize_collection') as mock_summarize_col:
+ graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
+ graph_rewriter_proto.quantization.delay = 10
+ graph_rewriter_proto.quantization.weight_bits = 8
+ graph_rewriter_proto.quantization.activation_bits = 8
+ graph_rewrite_fn = graph_rewriter_builder.build(
+ graph_rewriter_proto, is_training=True)
+ graph_rewrite_fn()
+ _, kwargs = mock_quant_fn.call_args
+ self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
+ self.assertEqual(kwargs['quant_delay'], 10)
+ mock_summarize_col.assert_called_with('quant_vars')
+
+ def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
+ with mock.patch.object(contrib_quantize,
+ 'experimental_create_eval_graph') as mock_quant_fn:
+ with mock.patch.object(contrib_layers,
+ 'summarize_collection') as mock_summarize_col:
+ graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
+ graph_rewriter_proto.quantization.delay = 10
+ graph_rewrite_fn = graph_rewriter_builder.build(
+ graph_rewriter_proto, is_training=False)
+ graph_rewrite_fn()
+ _, kwargs = mock_quant_fn.call_args
+ self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
+ mock_summarize_col.assert_called_with('quant_vars')
+
+ def testQuantizationBuilderAddsQuantOverride(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ self._buildGraph()
+
+ quant_overrides_proto = quant_overrides_pb2.QuantOverrides()
+ quant_config = quant_overrides_proto.quant_configs.add()
+ quant_config.op_name = 'test_graph/add_ab'
+ quant_config.quant_op_name = 'act_quant'
+ quant_config.fixed_range = True
+ quant_config.min = 0
+ quant_config.max = 6
+ quant_config.delay = 100
+
+ graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
+ graph_rewriter_proto.quantization.delay = 10
+ graph_rewriter_proto.quantization.weight_bits = 8
+ graph_rewriter_proto.quantization.activation_bits = 8
+
+ graph_rewrite_fn = graph_rewriter_builder.build(
+ graph_rewriter_proto,
+ quant_overrides_config=quant_overrides_proto,
+ is_training=True)
+ graph_rewrite_fn()
+
+ act_quant_found = False
+ quant_delay_found = False
+ for op in graph.get_operations():
+ if (quant_config.quant_op_name in op.name and
+ op.type == 'FakeQuantWithMinMaxArgs'):
+ act_quant_found = True
+ min_val = op.get_attr('min')
+ max_val = op.get_attr('max')
+ self.assertEqual(min_val, quant_config.min)
+ self.assertEqual(max_val, quant_config.max)
+ if ('activate_quant' in op.name and
+ quant_config.quant_op_name in op.name and op.type == 'Const'):
+ tensor = op.get_attr('value')
+ if tensor.int64_val[0] == quant_config.delay:
+ quant_delay_found = True
+
+ self.assertTrue(act_quant_found)
+ self.assertTrue(quant_delay_found)
+
+ def _buildGraph(self, scope='test_graph'):
+ with ops.name_scope(scope):
+ a = tf.constant(10, dtype=dtypes.float32, name='input_a')
+ b = tf.constant(20, dtype=dtypes.float32, name='input_b')
+ ab = tf.add(a, b, name='add_ab')
+ c = tf.constant(30, dtype=dtypes.float32, name='input_c')
+ abc = tf.multiply(ab, c, name='mul_ab_c')
+ return abc
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/configs/lstm_ssd_interleaved_mobilenet_v2_imagenet.config b/models/research/lstm_object_detection/configs/lstm_ssd_interleaved_mobilenet_v2_imagenet.config
new file mode 100644
index 0000000000000000000000000000000000000000..536d7d5327114efa159475433f051c627043e64f
--- /dev/null
+++ b/models/research/lstm_object_detection/configs/lstm_ssd_interleaved_mobilenet_v2_imagenet.config
@@ -0,0 +1,239 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# For training on Imagenet Video with LSTM Interleaved Mobilenet V2
+
+[lstm_object_detection.protos.lstm_model] {
+ train_unroll_length: 4
+ eval_unroll_length: 4
+ lstm_state_depth: 320
+ depth_multipliers: 1.4
+ depth_multipliers: 0.35
+ pre_bottleneck: true
+ low_res: true
+ train_interleave_method: 'RANDOM_SKIP_SMALL'
+ eval_interleave_method: 'SKIP3'
+}
+model {
+ ssd {
+ num_classes: 30 # Num of class for imagenet vid dataset.
+ box_coder {
+ faster_rcnn_box_coder {
+ y_scale: 10.0
+ x_scale: 10.0
+ height_scale: 5.0
+ width_scale: 5.0
+ }
+ }
+ matcher {
+ argmax_matcher {
+ matched_threshold: 0.5
+ unmatched_threshold: 0.5
+ ignore_thresholds: false
+ negatives_lower_than_unmatched: true
+ force_match_for_each_row: true
+ }
+ }
+ similarity_calculator {
+ iou_similarity {
+ }
+ }
+ anchor_generator {
+ ssd_anchor_generator {
+ num_layers: 5
+ min_scale: 0.2
+ max_scale: 0.95
+ aspect_ratios: 1.0
+ aspect_ratios: 2.0
+ aspect_ratios: 0.5
+ aspect_ratios: 3.0
+ aspect_ratios: 0.3333
+ }
+ }
+ image_resizer {
+ fixed_shape_resizer {
+ height: 320
+ width: 320
+ }
+ }
+ box_predictor {
+ convolutional_box_predictor {
+ min_depth: 0
+ max_depth: 0
+ num_layers_before_predictor: 3
+ use_dropout: false
+ dropout_keep_probability: 0.8
+ kernel_size: 3
+ box_code_size: 4
+ apply_sigmoid_to_scores: false
+ use_depthwise: true
+ conv_hyperparams {
+ activation: RELU_6,
+ regularizer {
+ l2_regularizer {
+ weight: 0.00004
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ stddev: 0.03
+ mean: 0.0
+ }
+ }
+ batch_norm {
+ train: true,
+ scale: true,
+ center: true,
+ decay: 0.9997,
+ epsilon: 0.001,
+ }
+ }
+ }
+ }
+ feature_extractor {
+ type: 'lstm_ssd_interleaved_mobilenet_v2'
+ conv_hyperparams {
+ activation: RELU_6,
+ regularizer {
+ l2_regularizer {
+ weight: 0.00004
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ stddev: 0.03
+ mean: 0.0
+ }
+ }
+ batch_norm {
+ train: true,
+ scale: true,
+ center: true,
+ decay: 0.9997,
+ epsilon: 0.001,
+ }
+ }
+ }
+ loss {
+ classification_loss {
+ weighted_sigmoid {
+ }
+ }
+ localization_loss {
+ weighted_smooth_l1 {
+ }
+ }
+ hard_example_miner {
+ num_hard_examples: 3000
+ iou_threshold: 0.99
+ loss_type: CLASSIFICATION
+ max_negatives_per_positive: 3
+ min_negatives_per_image: 0
+ }
+ classification_weight: 1.0
+ localization_weight: 4.0
+ }
+ normalize_loss_by_num_matches: true
+ post_processing {
+ batch_non_max_suppression {
+ score_threshold: -20.0
+ iou_threshold: 0.5
+ max_detections_per_class: 100
+ max_total_detections: 100
+ }
+ score_converter: SIGMOID
+ }
+ }
+}
+
+train_config: {
+ batch_size: 8
+ optimizer {
+ use_moving_average: false
+ rms_prop_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.002
+ decay_steps: 200000
+ decay_factor: 0.95
+ }
+ }
+ momentum_optimizer_value: 0.9
+ decay: 0.9
+ epsilon: 1.0
+ }
+ }
+ gradient_clipping_by_norm: 10.0
+ batch_queue_capacity: 12
+ prefetch_queue_capacity: 4
+}
+
+train_input_reader: {
+ shuffle_buffer_size: 32
+ queue_capacity: 12
+ prefetch_size: 12
+ min_after_dequeue: 4
+ label_map_path: "path/to/label_map"
+ external_input_reader {
+ [lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
+ tf_record_video_input_reader: {
+ input_path: '/data/lstm_detection/tfrecords/test.tfrecord'
+ data_type: TF_SEQUENCE_EXAMPLE
+ video_length: 4
+ }
+ }
+ }
+}
+
+eval_config: {
+ metrics_set: "coco_evaluation_all_frames"
+ use_moving_averages: true
+ min_score_threshold: 0.5
+ max_num_boxes_to_visualize: 300
+ visualize_groundtruth_boxes: true
+ groundtruth_box_visualization_color: "red"
+}
+
+eval_input_reader {
+ label_map_path: "path/to/label_map"
+ shuffle: true
+ num_epochs: 1
+ num_parallel_batches: 1
+ num_readers: 1
+ external_input_reader {
+ [lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
+ tf_record_video_input_reader: {
+ input_path: "path/to/sequence_example/data"
+ data_type: TF_SEQUENCE_EXAMPLE
+ video_length: 10
+ }
+ }
+ }
+}
+
+eval_input_reader: {
+ label_map_path: "path/to/label_map"
+ external_input_reader {
+ [lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
+ tf_record_video_input_reader: {
+ input_path: "path/to/sequence_example/data"
+ data_type: TF_SEQUENCE_EXAMPLE
+ video_length: 4
+ }
+ }
+ }
+ shuffle: true
+ num_readers: 1
+}
diff --git a/models/research/lstm_object_detection/configs/lstm_ssd_mobilenet_v1_imagenet.config b/models/research/lstm_object_detection/configs/lstm_ssd_mobilenet_v1_imagenet.config
new file mode 100644
index 0000000000000000000000000000000000000000..cb357ec17eeb80795d48a5aea50f98f3934ff1ad
--- /dev/null
+++ b/models/research/lstm_object_detection/configs/lstm_ssd_mobilenet_v1_imagenet.config
@@ -0,0 +1,232 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# For training on Imagenet Video with LSTM Mobilenet V1
+
+[lstm_object_detection.protos.lstm_model] {
+ train_unroll_length: 4
+ eval_unroll_length: 4
+}
+
+model {
+ ssd {
+ num_classes: 30 # Num of class for imagenet vid dataset.
+ box_coder {
+ faster_rcnn_box_coder {
+ y_scale: 10.0
+ x_scale: 10.0
+ height_scale: 5.0
+ width_scale: 5.0
+ }
+ }
+ matcher {
+ argmax_matcher {
+ matched_threshold: 0.5
+ unmatched_threshold: 0.5
+ ignore_thresholds: false
+ negatives_lower_than_unmatched: true
+ force_match_for_each_row: true
+ }
+ }
+ similarity_calculator {
+ iou_similarity {
+ }
+ }
+ anchor_generator {
+ ssd_anchor_generator {
+ num_layers: 5
+ min_scale: 0.2
+ max_scale: 0.95
+ aspect_ratios: 1.0
+ aspect_ratios: 2.0
+ aspect_ratios: 0.5
+ aspect_ratios: 3.0
+ aspect_ratios: 0.3333
+ }
+ }
+ image_resizer {
+ fixed_shape_resizer {
+ height: 256
+ width: 256
+ }
+ }
+ box_predictor {
+ convolutional_box_predictor {
+ min_depth: 0
+ max_depth: 0
+ num_layers_before_predictor: 3
+ use_dropout: false
+ dropout_keep_probability: 0.8
+ kernel_size: 3
+ box_code_size: 4
+ apply_sigmoid_to_scores: false
+ use_depthwise: true
+ conv_hyperparams {
+ activation: RELU_6,
+ regularizer {
+ l2_regularizer {
+ weight: 0.00004
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ stddev: 0.03
+ mean: 0.0
+ }
+ }
+ batch_norm {
+ train: true,
+ scale: true,
+ center: true,
+ decay: 0.9997,
+ epsilon: 0.001,
+ }
+ }
+ }
+ }
+ feature_extractor {
+ type: 'lstm_mobilenet_v1'
+ min_depth: 16
+ depth_multiplier: 1.0
+ use_depthwise: true
+ conv_hyperparams {
+ activation: RELU_6,
+ regularizer {
+ l2_regularizer {
+ weight: 0.00004
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ stddev: 0.03
+ mean: 0.0
+ }
+ }
+ batch_norm {
+ train: true,
+ scale: true,
+ center: true,
+ decay: 0.9997,
+ epsilon: 0.001,
+ }
+ }
+ }
+ loss {
+ classification_loss {
+ weighted_sigmoid {
+ }
+ }
+ localization_loss {
+ weighted_smooth_l1 {
+ }
+ }
+ hard_example_miner {
+ num_hard_examples: 3000
+ iou_threshold: 0.99
+ loss_type: CLASSIFICATION
+ max_negatives_per_positive: 3
+ min_negatives_per_image: 0
+ }
+ classification_weight: 1.0
+ localization_weight: 4.0
+ }
+ normalize_loss_by_num_matches: true
+ post_processing {
+ batch_non_max_suppression {
+ score_threshold: -20.0
+ iou_threshold: 0.5
+ max_detections_per_class: 100
+ max_total_detections: 100
+ }
+ score_converter: SIGMOID
+ }
+ }
+}
+
+train_config: {
+ batch_size: 8
+ data_augmentation_options {
+ random_horizontal_flip {
+ }
+ }
+ data_augmentation_options {
+ ssd_random_crop {
+ }
+ }
+ optimizer {
+ use_moving_average: false
+ rms_prop_optimizer: {
+ learning_rate: {
+ exponential_decay_learning_rate {
+ initial_learning_rate: 0.002
+ decay_steps: 200000
+ decay_factor: 0.95
+ }
+ }
+ momentum_optimizer_value: 0.9
+ decay: 0.9
+ epsilon: 1.0
+ }
+ }
+
+ from_detection_checkpoint: true
+ gradient_clipping_by_norm: 10.0
+ batch_queue_capacity: 12
+ prefetch_queue_capacity: 4
+ fine_tune_checkpoint: "/path/to/checkpoint/"
+ fine_tune_checkpoint_type: "detection"
+}
+
+
+train_input_reader: {
+ shuffle_buffer_size: 32
+ queue_capacity: 12
+ prefetch_size: 12
+ min_after_dequeue: 4
+ label_map_path: "path/to/label_map"
+ external_input_reader {
+ [lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
+ tf_record_video_input_reader: {
+ input_path: "path/to/sequence_example/data"
+ data_type: TF_SEQUENCE_EXAMPLE
+ video_length: 4
+ }
+ }
+ }
+}
+
+eval_config: {
+ metrics_set: "coco_evaluation_all_frames"
+ use_moving_averages: true
+ min_score_threshold: 0.5
+ max_num_boxes_to_visualize: 300
+ visualize_groundtruth_boxes: true
+ groundtruth_box_visualization_color: "red"
+}
+
+eval_input_reader: {
+ label_map_path: "path/to/label_map"
+ external_input_reader {
+ [lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
+ tf_record_video_input_reader: {
+ input_path: "path/to/sequence_example/data"
+ data_type: TF_SEQUENCE_EXAMPLE
+ video_length: 4
+ }
+ }
+ }
+ shuffle: true
+ num_readers: 1
+}
diff --git a/models/research/lstm_object_detection/eval.py b/models/research/lstm_object_detection/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..aac25c1182bd354b710a7bb83c7bd68365f14fed
--- /dev/null
+++ b/models/research/lstm_object_detection/eval.py
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""Evaluation executable for detection models.
+
+This executable is used to evaluate DetectionModels. Example usage:
+ ./eval \
+ --logtostderr \
+ --checkpoint_dir=path/to/checkpoint_dir \
+ --eval_dir=path/to/eval_dir \
+ --pipeline_config_path=pipeline_config.pbtxt
+"""
+
+import functools
+import os
+import tensorflow.compat.v1 as tf
+from google.protobuf import text_format
+from lstm_object_detection import evaluator
+from lstm_object_detection import model_builder
+from lstm_object_detection.inputs import seq_dataset_builder
+from lstm_object_detection.utils import config_util
+from object_detection.utils import label_map_util
+
+tf.logging.set_verbosity(tf.logging.INFO)
+flags = tf.app.flags
+flags.DEFINE_boolean('eval_training_data', False,
+ 'If training data should be evaluated for this job.')
+flags.DEFINE_string('checkpoint_dir', '',
+ 'Directory containing checkpoints to evaluate, typically '
+ 'set to `train_dir` used in the training job.')
+flags.DEFINE_string('eval_dir', '', 'Directory to write eval summaries to.')
+flags.DEFINE_string('pipeline_config_path', '',
+ 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
+ 'file. If provided, other configs are ignored')
+flags.DEFINE_boolean('run_once', False, 'Option to only run a single pass of '
+ 'evaluation. Overrides the `max_evals` parameter in the '
+ 'provided config.')
+FLAGS = flags.FLAGS
+
+
+def main(unused_argv):
+ assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
+ assert FLAGS.eval_dir, '`eval_dir` is missing.'
+ if FLAGS.pipeline_config_path:
+ configs = config_util.get_configs_from_pipeline_file(
+ FLAGS.pipeline_config_path)
+ else:
+ configs = config_util.get_configs_from_multiple_files(
+ model_config_path=FLAGS.model_config_path,
+ eval_config_path=FLAGS.eval_config_path,
+ eval_input_config_path=FLAGS.input_config_path)
+
+ pipeline_proto = config_util.create_pipeline_proto_from_configs(configs)
+ config_text = text_format.MessageToString(pipeline_proto)
+ tf.gfile.MakeDirs(FLAGS.eval_dir)
+ with tf.gfile.Open(os.path.join(FLAGS.eval_dir, 'pipeline.config'),
+ 'wb') as f:
+ f.write(config_text)
+
+ model_config = configs['model']
+ lstm_config = configs['lstm_model']
+ eval_config = configs['eval_config']
+ input_config = configs['eval_input_config']
+
+ if FLAGS.eval_training_data:
+ input_config.external_input_reader.CopyFrom(
+ configs['train_input_config'].external_input_reader)
+ lstm_config.eval_unroll_length = lstm_config.train_unroll_length
+
+ model_fn = functools.partial(
+ model_builder.build,
+ model_config=model_config,
+ lstm_config=lstm_config,
+ is_training=False)
+
+ def get_next(config, model_config, lstm_config, unroll_length):
+ return seq_dataset_builder.build(config, model_config, lstm_config,
+ unroll_length)
+
+ create_input_dict_fn = functools.partial(get_next, input_config, model_config,
+ lstm_config,
+ lstm_config.eval_unroll_length)
+
+ label_map = label_map_util.load_labelmap(input_config.label_map_path)
+ max_num_classes = max([item.id for item in label_map.item])
+ categories = label_map_util.convert_label_map_to_categories(
+ label_map, max_num_classes)
+
+ if FLAGS.run_once:
+ eval_config.max_evals = 1
+
+ evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories,
+ FLAGS.checkpoint_dir, FLAGS.eval_dir)
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/lstm_object_detection/evaluator.py b/models/research/lstm_object_detection/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ed3e476e8e9bfd9c0d4cfe71925ccb7ff5f6b07
--- /dev/null
+++ b/models/research/lstm_object_detection/evaluator.py
@@ -0,0 +1,337 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Detection model evaluator.
+
+This file provides a generic evaluation method that can be used to evaluate a
+DetectionModel.
+
+"""
+
+import tensorflow.compat.v1 as tf
+from tensorflow.contrib import tfprof as contrib_tfprof
+from lstm_object_detection.metrics import coco_evaluation_all_frames
+from object_detection import eval_util
+from object_detection.core import prefetcher
+from object_detection.core import standard_fields as fields
+from object_detection.metrics import coco_evaluation
+from object_detection.utils import object_detection_evaluation
+
+
+# A dictionary of metric names to classes that implement the metric. The classes
+# in the dictionary must implement
+# utils.object_detection_evaluation.DetectionEvaluator interface.
+EVAL_METRICS_CLASS_DICT = {
+ 'pascal_voc_detection_metrics':
+ object_detection_evaluation.PascalDetectionEvaluator,
+ 'weighted_pascal_voc_detection_metrics':
+ object_detection_evaluation.WeightedPascalDetectionEvaluator,
+ 'pascal_voc_instance_segmentation_metrics':
+ object_detection_evaluation.PascalInstanceSegmentationEvaluator,
+ 'weighted_pascal_voc_instance_segmentation_metrics':
+ object_detection_evaluation.WeightedPascalInstanceSegmentationEvaluator,
+ 'open_images_detection_metrics':
+ object_detection_evaluation.OpenImagesDetectionEvaluator,
+ 'coco_detection_metrics':
+ coco_evaluation.CocoDetectionEvaluator,
+ 'coco_mask_metrics':
+ coco_evaluation.CocoMaskEvaluator,
+ 'coco_evaluation_all_frames':
+ coco_evaluation_all_frames.CocoEvaluationAllFrames,
+}
+
+EVAL_DEFAULT_METRIC = 'pascal_voc_detection_metrics'
+
+
+def _create_detection_op(model, input_dict, batch):
+ """Create detection ops.
+
+ Args:
+ model: model to perform predictions with.
+ input_dict: A dict holds input data.
+ batch: batch size for evaluation.
+
+ Returns:
+ Detection tensor ops.
+ """
+ video_tensor = tf.stack(list(input_dict[fields.InputDataFields.image]))
+ preprocessed_video, true_image_shapes = model.preprocess(
+ tf.to_float(video_tensor))
+ if batch is not None:
+ prediction_dict = model.predict(preprocessed_video, true_image_shapes,
+ batch)
+ else:
+ prediction_dict = model.predict(preprocessed_video, true_image_shapes)
+
+ return model.postprocess(prediction_dict, true_image_shapes)
+
+
+def _extract_prediction_tensors(model,
+ create_input_dict_fn,
+ ignore_groundtruth=False):
+ """Restores the model in a tensorflow session.
+
+ Args:
+ model: model to perform predictions with.
+ create_input_dict_fn: function to create input tensor dictionaries.
+ ignore_groundtruth: whether groundtruth should be ignored.
+
+
+ Returns:
+ tensor_dict: A tensor dictionary with evaluations.
+ """
+ input_dict = create_input_dict_fn()
+ batch = None
+ if 'batch' in input_dict:
+ batch = input_dict.pop('batch')
+ else:
+ prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
+ input_dict = prefetch_queue.dequeue()
+ # consistent format for images and videos
+ for key, value in input_dict.iteritems():
+ input_dict[key] = (value,)
+
+ detections = _create_detection_op(model, input_dict, batch)
+
+ # Print out anaylsis of the model.
+ contrib_tfprof.model_analyzer.print_model_analysis(
+ tf.get_default_graph(),
+ tfprof_options=contrib_tfprof.model_analyzer
+ .TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
+ contrib_tfprof.model_analyzer.print_model_analysis(
+ tf.get_default_graph(),
+ tfprof_options=contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
+
+ num_frames = len(input_dict[fields.InputDataFields.image])
+ ret = []
+ for i in range(num_frames):
+ original_image = tf.expand_dims(input_dict[fields.InputDataFields.image][i],
+ 0)
+ groundtruth = None
+ if not ignore_groundtruth:
+ groundtruth = {
+ fields.InputDataFields.groundtruth_boxes:
+ input_dict[fields.InputDataFields.groundtruth_boxes][i],
+ fields.InputDataFields.groundtruth_classes:
+ input_dict[fields.InputDataFields.groundtruth_classes][i],
+ }
+ optional_keys = (
+ fields.InputDataFields.groundtruth_area,
+ fields.InputDataFields.groundtruth_is_crowd,
+ fields.InputDataFields.groundtruth_difficult,
+ fields.InputDataFields.groundtruth_group_of,
+ )
+ for opt_key in optional_keys:
+ if opt_key in input_dict:
+ groundtruth[opt_key] = input_dict[opt_key][i]
+ if fields.DetectionResultFields.detection_masks in detections:
+ groundtruth[fields.InputDataFields.groundtruth_instance_masks] = (
+ input_dict[fields.InputDataFields.groundtruth_instance_masks][i])
+
+ detections_frame = {
+ key: tf.expand_dims(value[i], 0)
+ for key, value in detections.iteritems()
+ }
+
+ source_id = (
+ batch.key[0] if batch is not None else
+ input_dict[fields.InputDataFields.source_id][i])
+ ret.append(
+ eval_util.result_dict_for_single_example(
+ original_image,
+ source_id,
+ detections_frame,
+ groundtruth,
+ class_agnostic=(fields.DetectionResultFields.detection_classes
+ not in detections),
+ scale_to_absolute=True))
+ return ret
+
+
+def get_evaluators(eval_config, categories):
+ """Returns the evaluator class according to eval_config, valid for categories.
+
+ Args:
+ eval_config: evaluation configurations.
+ categories: a list of categories to evaluate.
+ Returns:
+ An list of instances of DetectionEvaluator.
+
+ Raises:
+ ValueError: if metric is not in the metric class dictionary.
+ """
+ eval_metric_fn_keys = eval_config.metrics_set
+ if not eval_metric_fn_keys:
+ eval_metric_fn_keys = [EVAL_DEFAULT_METRIC]
+ evaluators_list = []
+ for eval_metric_fn_key in eval_metric_fn_keys:
+ if eval_metric_fn_key not in EVAL_METRICS_CLASS_DICT:
+ raise ValueError('Metric not found: {}'.format(eval_metric_fn_key))
+ else:
+ evaluators_list.append(
+ EVAL_METRICS_CLASS_DICT[eval_metric_fn_key](categories=categories))
+ return evaluators_list
+
+
+def evaluate(create_input_dict_fn,
+ create_model_fn,
+ eval_config,
+ categories,
+ checkpoint_dir,
+ eval_dir,
+ graph_hook_fn=None):
+ """Evaluation function for detection models.
+
+ Args:
+ create_input_dict_fn: a function to create a tensor input dictionary.
+ create_model_fn: a function that creates a DetectionModel.
+ eval_config: a eval_pb2.EvalConfig protobuf.
+ categories: a list of category dictionaries. Each dict in the list should
+ have an integer 'id' field and string 'name' field.
+ checkpoint_dir: directory to load the checkpoints to evaluate from.
+ eval_dir: directory to write evaluation metrics summary to.
+ graph_hook_fn: Optional function that is called after the training graph is
+ completely built. This is helpful to perform additional changes to the
+ training graph such as optimizing batchnorm. The function should modify
+ the default graph.
+
+ Returns:
+ metrics: A dictionary containing metric names and values from the latest
+ run.
+ """
+
+ model = create_model_fn()
+
+ if eval_config.ignore_groundtruth and not eval_config.export_path:
+ tf.logging.fatal('If ignore_groundtruth=True then an export_path is '
+ 'required. Aborting!!!')
+
+ tensor_dicts = _extract_prediction_tensors(
+ model=model,
+ create_input_dict_fn=create_input_dict_fn,
+ ignore_groundtruth=eval_config.ignore_groundtruth)
+
+ def _process_batch(tensor_dicts,
+ sess,
+ batch_index,
+ counters,
+ losses_dict=None):
+ """Evaluates tensors in tensor_dicts, visualizing the first K examples.
+
+ This function calls sess.run on tensor_dicts, evaluating the original_image
+ tensor only on the first K examples and visualizing detections overlaid
+ on this original_image.
+
+ Args:
+ tensor_dicts: a dictionary of tensors
+ sess: tensorflow session
+ batch_index: the index of the batch amongst all batches in the run.
+ counters: a dictionary holding 'success' and 'skipped' fields which can
+ be updated to keep track of number of successful and failed runs,
+ respectively. If these fields are not updated, then the success/skipped
+ counter values shown at the end of evaluation will be incorrect.
+ losses_dict: Optional dictonary of scalar loss tensors. Necessary only
+ for matching function signiture in third_party eval_util.py.
+
+ Returns:
+ result_dict: a dictionary of numpy arrays
+ result_losses_dict: a dictionary of scalar losses. This is empty if input
+ losses_dict is None. Necessary only for matching function signiture in
+ third_party eval_util.py.
+ """
+ if batch_index % 10 == 0:
+ tf.logging.info('Running eval ops batch %d', batch_index)
+ if not losses_dict:
+ losses_dict = {}
+ try:
+ result_dicts, result_losses_dict = sess.run([tensor_dicts, losses_dict])
+ counters['success'] += 1
+ except tf.errors.InvalidArgumentError:
+ tf.logging.info('Skipping image')
+ counters['skipped'] += 1
+ return {}
+ num_images = len(tensor_dicts)
+ for i in range(num_images):
+ result_dict = result_dicts[i]
+ global_step = tf.train.global_step(sess, tf.train.get_global_step())
+ tag = 'image-%d' % (batch_index * num_images + i)
+ if batch_index < eval_config.num_visualizations / num_images:
+ eval_util.visualize_detection_results(
+ result_dict,
+ tag,
+ global_step,
+ categories=categories,
+ summary_dir=eval_dir,
+ export_dir=eval_config.visualization_export_dir,
+ show_groundtruth=eval_config.visualize_groundtruth_boxes,
+ groundtruth_box_visualization_color=eval_config.
+ groundtruth_box_visualization_color,
+ min_score_thresh=eval_config.min_score_threshold,
+ max_num_predictions=eval_config.max_num_boxes_to_visualize,
+ skip_scores=eval_config.skip_scores,
+ skip_labels=eval_config.skip_labels,
+ keep_image_id_for_visualization_export=eval_config.
+ keep_image_id_for_visualization_export)
+ if num_images > 1:
+ return result_dicts, result_losses_dict
+ else:
+ return result_dicts[0], result_losses_dict
+
+ variables_to_restore = tf.global_variables()
+ global_step = tf.train.get_or_create_global_step()
+ variables_to_restore.append(global_step)
+
+ if graph_hook_fn:
+ graph_hook_fn()
+
+ if eval_config.use_moving_averages:
+ variable_averages = tf.train.ExponentialMovingAverage(0.0)
+ variables_to_restore = variable_averages.variables_to_restore()
+ for key in variables_to_restore.keys():
+ if 'moving_mean' in key:
+ variables_to_restore[key.replace(
+ 'moving_mean', 'moving_mean/ExponentialMovingAverage')] = (
+ variables_to_restore[key])
+ del variables_to_restore[key]
+ if 'moving_variance' in key:
+ variables_to_restore[key.replace(
+ 'moving_variance', 'moving_variance/ExponentialMovingAverage')] = (
+ variables_to_restore[key])
+ del variables_to_restore[key]
+
+ saver = tf.train.Saver(variables_to_restore)
+
+ def _restore_latest_checkpoint(sess):
+ latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
+ saver.restore(sess, latest_checkpoint)
+
+ metrics = eval_util.repeated_checkpoint_run(
+ tensor_dict=tensor_dicts,
+ summary_dir=eval_dir,
+ evaluators=get_evaluators(eval_config, categories),
+ batch_processor=_process_batch,
+ checkpoint_dirs=[checkpoint_dir],
+ variables_to_restore=None,
+ restore_fn=_restore_latest_checkpoint,
+ num_batches=eval_config.num_examples,
+ eval_interval_secs=eval_config.eval_interval_secs,
+ max_number_of_evaluations=(1 if eval_config.ignore_groundtruth else
+ eval_config.max_evals
+ if eval_config.max_evals else None),
+ master=eval_config.eval_master,
+ save_graph=eval_config.save_graph,
+ save_graph_dir=(eval_dir if eval_config.save_graph else ''))
+
+ return metrics
diff --git a/models/research/lstm_object_detection/export_tflite_lstd_graph.py b/models/research/lstm_object_detection/export_tflite_lstd_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e933fb480d04aefa66ec6c4c8ec38f91dee9cb6
--- /dev/null
+++ b/models/research/lstm_object_detection/export_tflite_lstd_graph.py
@@ -0,0 +1,138 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Exports an LSTM detection model to use with tf-lite.
+
+Outputs file:
+* A tflite compatible frozen graph - $output_directory/tflite_graph.pb
+
+The exported graph has the following input and output nodes.
+
+Inputs:
+'input_video_tensor': a float32 tensor of shape
+[unroll_length, height, width, 3] containing the normalized input image.
+Note that the height and width must be compatible with the height and
+width configured in the fixed_shape_image resizer options in the pipeline
+config proto.
+
+Outputs:
+If add_postprocessing_op is true: frozen graph adds a
+ TFLite_Detection_PostProcess custom op node has four outputs:
+ detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
+ locations
+ detection_classes: a float32 tensor of shape [1, num_boxes]
+ with class indices
+ detection_scores: a float32 tensor of shape [1, num_boxes]
+ with class scores
+ num_boxes: a float32 tensor of size 1 containing the number of detected boxes
+else:
+ the graph has three outputs:
+ 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
+ containing the encoded box predictions.
+ 'raw_outputs/class_predictions': a float32 tensor of shape
+ [1, num_anchors, num_classes] containing the class scores for each anchor
+ after applying score conversion.
+ 'anchors': a float32 constant tensor of shape [num_anchors, 4]
+ containing the anchor boxes.
+
+Example Usage:
+--------------
+python lstm_object_detection/export_tflite_lstd_graph.py \
+ --pipeline_config_path path/to/lstm_pipeline.config \
+ --trained_checkpoint_prefix path/to/model.ckpt \
+ --output_directory path/to/exported_model_directory
+
+The expected output would be in the directory
+path/to/exported_model_directory (which is created if it does not exist)
+with contents:
+ - tflite_graph.pbtxt
+ - tflite_graph.pb
+Config overrides (see the `config_override` flag) are text protobufs
+(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
+certain fields in the provided pipeline_config_path. These are useful for
+making small changes to the inference graph that differ from the training or
+eval config.
+
+Example Usage (in which we change the NMS iou_threshold to be 0.5 and
+NMS score_threshold to be 0.0):
+python lstm_object_detection/export_tflite_lstd_graph.py \
+ --pipeline_config_path path/to/lstm_pipeline.config \
+ --trained_checkpoint_prefix path/to/model.ckpt \
+ --output_directory path/to/exported_model_directory
+ --config_override " \
+ model{ \
+ ssd{ \
+ post_processing { \
+ batch_non_max_suppression { \
+ score_threshold: 0.0 \
+ iou_threshold: 0.5 \
+ } \
+ } \
+ } \
+ } \
+ "
+"""
+
+import tensorflow.compat.v1 as tf
+
+from lstm_object_detection import export_tflite_lstd_graph_lib
+from lstm_object_detection.utils import config_util
+
+flags = tf.app.flags
+flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
+flags.DEFINE_string(
+ 'pipeline_config_path', None,
+ 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
+ 'file.')
+flags.DEFINE_string('trained_checkpoint_prefix', None, 'Checkpoint prefix.')
+flags.DEFINE_integer('max_detections', 10,
+ 'Maximum number of detections (boxes) to show.')
+flags.DEFINE_integer('max_classes_per_detection', 1,
+ 'Maximum number of classes to output per detection box.')
+flags.DEFINE_integer(
+ 'detections_per_class', 100,
+ 'Number of anchors used per class in Regular Non-Max-Suppression.')
+flags.DEFINE_bool('add_postprocessing_op', True,
+ 'Add TFLite custom op for postprocessing to the graph.')
+flags.DEFINE_bool(
+ 'use_regular_nms', False,
+ 'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.')
+flags.DEFINE_string(
+ 'config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig '
+ 'text proto to override pipeline_config_path.')
+
+FLAGS = flags.FLAGS
+
+
+def main(argv):
+ del argv # Unused.
+ flags.mark_flag_as_required('output_directory')
+ flags.mark_flag_as_required('pipeline_config_path')
+ flags.mark_flag_as_required('trained_checkpoint_prefix')
+
+ pipeline_config = config_util.get_configs_from_pipeline_file(
+ FLAGS.pipeline_config_path)
+
+ export_tflite_lstd_graph_lib.export_tflite_graph(
+ pipeline_config,
+ FLAGS.trained_checkpoint_prefix,
+ FLAGS.output_directory,
+ FLAGS.add_postprocessing_op,
+ FLAGS.max_detections,
+ FLAGS.max_classes_per_detection,
+ use_regular_nms=FLAGS.use_regular_nms)
+
+
+if __name__ == '__main__':
+ tf.app.run(main)
diff --git a/models/research/lstm_object_detection/export_tflite_lstd_graph_lib.py b/models/research/lstm_object_detection/export_tflite_lstd_graph_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..e066f11b45f2bd4608b08656040abba2632b4aa2
--- /dev/null
+++ b/models/research/lstm_object_detection/export_tflite_lstd_graph_lib.py
@@ -0,0 +1,327 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Exports detection models to use with tf-lite.
+
+See export_tflite_lstd_graph.py for usage.
+"""
+import os
+import tempfile
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.tools.graph_transforms import TransformGraph
+from lstm_object_detection import model_builder
+from object_detection import exporter
+from object_detection.builders import graph_rewriter_builder
+from object_detection.builders import post_processing_builder
+from object_detection.core import box_list
+
+_DEFAULT_NUM_CHANNELS = 3
+_DEFAULT_NUM_COORD_BOX = 4
+
+
+def get_const_center_size_encoded_anchors(anchors):
+ """Exports center-size encoded anchors as a constant tensor.
+
+ Args:
+ anchors: a float32 tensor of shape [num_anchors, 4] containing the anchor
+ boxes
+
+ Returns:
+ encoded_anchors: a float32 constant tensor of shape [num_anchors, 4]
+ containing the anchor boxes.
+ """
+ anchor_boxlist = box_list.BoxList(anchors)
+ y, x, h, w = anchor_boxlist.get_center_coordinates_and_sizes()
+ num_anchors = y.get_shape().as_list()
+
+ with tf.Session() as sess:
+ y_out, x_out, h_out, w_out = sess.run([y, x, h, w])
+ encoded_anchors = tf.constant(
+ np.transpose(np.stack((y_out, x_out, h_out, w_out))),
+ dtype=tf.float32,
+ shape=[num_anchors[0], _DEFAULT_NUM_COORD_BOX],
+ name='anchors')
+ return encoded_anchors
+
+
+def append_postprocessing_op(frozen_graph_def,
+ max_detections,
+ max_classes_per_detection,
+ nms_score_threshold,
+ nms_iou_threshold,
+ num_classes,
+ scale_values,
+ detections_per_class=100,
+ use_regular_nms=False):
+ """Appends postprocessing custom op.
+
+ Args:
+ frozen_graph_def: Frozen GraphDef for SSD model after freezing the
+ checkpoint
+ max_detections: Maximum number of detections (boxes) to show
+ max_classes_per_detection: Number of classes to display per detection
+ nms_score_threshold: Score threshold used in Non-maximal suppression in
+ post-processing
+ nms_iou_threshold: Intersection-over-union threshold used in Non-maximal
+ suppression in post-processing
+ num_classes: number of classes in SSD detector
+ scale_values: scale values is a dict with following key-value pairs
+ {y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
+ centersize boxes
+ detections_per_class: In regular NonMaxSuppression, number of anchors used
+ for NonMaxSuppression per class
+ use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
+ Fast NMS.
+
+ Returns:
+ transformed_graph_def: Frozen GraphDef with postprocessing custom op
+ appended
+ TFLite_Detection_PostProcess custom op node has four outputs:
+ detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
+ locations
+ detection_classes: a float32 tensor of shape [1, num_boxes]
+ with class indices
+ detection_scores: a float32 tensor of shape [1, num_boxes]
+ with class scores
+ num_boxes: a float32 tensor of size 1 containing the number of detected
+ boxes
+ """
+ new_output = frozen_graph_def.node.add()
+ new_output.op = 'TFLite_Detection_PostProcess'
+ new_output.name = 'TFLite_Detection_PostProcess'
+ new_output.attr['_output_quantized'].CopyFrom(
+ attr_value_pb2.AttrValue(b=True))
+ new_output.attr['_output_types'].list.type.extend([
+ types_pb2.DT_FLOAT, types_pb2.DT_FLOAT, types_pb2.DT_FLOAT,
+ types_pb2.DT_FLOAT
+ ])
+ new_output.attr['_support_output_type_float_in_quantized_op'].CopyFrom(
+ attr_value_pb2.AttrValue(b=True))
+ new_output.attr['max_detections'].CopyFrom(
+ attr_value_pb2.AttrValue(i=max_detections))
+ new_output.attr['max_classes_per_detection'].CopyFrom(
+ attr_value_pb2.AttrValue(i=max_classes_per_detection))
+ new_output.attr['nms_score_threshold'].CopyFrom(
+ attr_value_pb2.AttrValue(f=nms_score_threshold.pop()))
+ new_output.attr['nms_iou_threshold'].CopyFrom(
+ attr_value_pb2.AttrValue(f=nms_iou_threshold.pop()))
+ new_output.attr['num_classes'].CopyFrom(
+ attr_value_pb2.AttrValue(i=num_classes))
+
+ new_output.attr['y_scale'].CopyFrom(
+ attr_value_pb2.AttrValue(f=scale_values['y_scale'].pop()))
+ new_output.attr['x_scale'].CopyFrom(
+ attr_value_pb2.AttrValue(f=scale_values['x_scale'].pop()))
+ new_output.attr['h_scale'].CopyFrom(
+ attr_value_pb2.AttrValue(f=scale_values['h_scale'].pop()))
+ new_output.attr['w_scale'].CopyFrom(
+ attr_value_pb2.AttrValue(f=scale_values['w_scale'].pop()))
+ new_output.attr['detections_per_class'].CopyFrom(
+ attr_value_pb2.AttrValue(i=detections_per_class))
+ new_output.attr['use_regular_nms'].CopyFrom(
+ attr_value_pb2.AttrValue(b=use_regular_nms))
+
+ new_output.input.extend(
+ ['raw_outputs/box_encodings', 'raw_outputs/class_predictions', 'anchors'])
+ # Transform the graph to append new postprocessing op
+ input_names = []
+ output_names = ['TFLite_Detection_PostProcess']
+ transforms = ['strip_unused_nodes']
+ transformed_graph_def = TransformGraph(frozen_graph_def, input_names,
+ output_names, transforms)
+ return transformed_graph_def
+
+
+def export_tflite_graph(pipeline_config,
+ trained_checkpoint_prefix,
+ output_dir,
+ add_postprocessing_op,
+ max_detections,
+ max_classes_per_detection,
+ detections_per_class=100,
+ use_regular_nms=False,
+ binary_graph_name='tflite_graph.pb',
+ txt_graph_name='tflite_graph.pbtxt'):
+ """Exports a tflite compatible graph and anchors for ssd detection model.
+
+ Anchors are written to a tensor and tflite compatible graph
+ is written to output_dir/tflite_graph.pb.
+
+ Args:
+ pipeline_config: Dictionary of configuration objects. Keys are `model`,
+ `train_config`, `train_input_config`, `eval_config`, `eval_input_config`,
+ `lstm_model`. Value are the corresponding config objects.
+ trained_checkpoint_prefix: a file prefix for the checkpoint containing the
+ trained parameters of the SSD model.
+ output_dir: A directory to write the tflite graph and anchor file to.
+ add_postprocessing_op: If add_postprocessing_op is true: frozen graph adds a
+ TFLite_Detection_PostProcess custom op
+ max_detections: Maximum number of detections (boxes) to show
+ max_classes_per_detection: Number of classes to display per detection
+ detections_per_class: In regular NonMaxSuppression, number of anchors used
+ for NonMaxSuppression per class
+ use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
+ Fast NMS.
+ binary_graph_name: Name of the exported graph file in binary format.
+ txt_graph_name: Name of the exported graph file in text format.
+
+ Raises:
+ ValueError: if the pipeline config contains models other than ssd or uses an
+ fixed_shape_resizer and provides a shape as well.
+ """
+ model_config = pipeline_config['model']
+ lstm_config = pipeline_config['lstm_model']
+ eval_config = pipeline_config['eval_config']
+ tf.gfile.MakeDirs(output_dir)
+ if model_config.WhichOneof('model') != 'ssd':
+ raise ValueError('Only ssd models are supported in tflite. '
+ 'Found {} in config'.format(
+ model_config.WhichOneof('model')))
+
+ num_classes = model_config.ssd.num_classes
+ nms_score_threshold = {
+ model_config.ssd.post_processing.batch_non_max_suppression.score_threshold
+ }
+ nms_iou_threshold = {
+ model_config.ssd.post_processing.batch_non_max_suppression.iou_threshold
+ }
+ scale_values = {}
+ scale_values['y_scale'] = {
+ model_config.ssd.box_coder.faster_rcnn_box_coder.y_scale
+ }
+ scale_values['x_scale'] = {
+ model_config.ssd.box_coder.faster_rcnn_box_coder.x_scale
+ }
+ scale_values['h_scale'] = {
+ model_config.ssd.box_coder.faster_rcnn_box_coder.height_scale
+ }
+ scale_values['w_scale'] = {
+ model_config.ssd.box_coder.faster_rcnn_box_coder.width_scale
+ }
+
+ image_resizer_config = model_config.ssd.image_resizer
+ image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
+ num_channels = _DEFAULT_NUM_CHANNELS
+ if image_resizer == 'fixed_shape_resizer':
+ height = image_resizer_config.fixed_shape_resizer.height
+ width = image_resizer_config.fixed_shape_resizer.width
+ if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
+ num_channels = 1
+
+ shape = [lstm_config.eval_unroll_length, height, width, num_channels]
+ else:
+ raise ValueError(
+ 'Only fixed_shape_resizer'
+ 'is supported with tflite. Found {}'.format(
+ image_resizer_config.WhichOneof('image_resizer_oneof')))
+
+ video_tensor = tf.placeholder(
+ tf.float32, shape=shape, name='input_video_tensor')
+
+ detection_model = model_builder.build(
+ model_config, lstm_config, is_training=False)
+ preprocessed_video, true_image_shapes = detection_model.preprocess(
+ tf.to_float(video_tensor))
+ predicted_tensors = detection_model.predict(preprocessed_video,
+ true_image_shapes)
+ # predicted_tensors = detection_model.postprocess(predicted_tensors,
+ # true_image_shapes)
+ # The score conversion occurs before the post-processing custom op
+ _, score_conversion_fn = post_processing_builder.build(
+ model_config.ssd.post_processing)
+ class_predictions = score_conversion_fn(
+ predicted_tensors['class_predictions_with_background'])
+
+ with tf.name_scope('raw_outputs'):
+ # 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
+ # containing the encoded box predictions. Note that these are raw
+ # predictions and no Non-Max suppression is applied on them and
+ # no decode center size boxes is applied to them.
+ tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
+ # 'raw_outputs/class_predictions': a float32 tensor of shape
+ # [1, num_anchors, num_classes] containing the class scores for each anchor
+ # after applying score conversion.
+ tf.identity(class_predictions, name='class_predictions')
+ # 'anchors': a float32 tensor of shape
+ # [4, num_anchors] containing the anchors as a constant node.
+ tf.identity(
+ get_const_center_size_encoded_anchors(predicted_tensors['anchors']),
+ name='anchors')
+
+ # Add global step to the graph, so we know the training step number when we
+ # evaluate the model.
+ tf.train.get_or_create_global_step()
+
+ # graph rewriter
+ is_quantized = ('graph_rewriter' in pipeline_config)
+ if is_quantized:
+ graph_rewriter_config = pipeline_config['graph_rewriter']
+ graph_rewriter_fn = graph_rewriter_builder.build(
+ graph_rewriter_config, is_training=False, is_export=True)
+ graph_rewriter_fn()
+
+ if model_config.ssd.feature_extractor.HasField('fpn'):
+ exporter.rewrite_nn_resize_op(is_quantized)
+
+ # freeze the graph
+ saver_kwargs = {}
+ if eval_config.use_moving_averages:
+ saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
+ moving_average_checkpoint = tempfile.NamedTemporaryFile()
+ exporter.replace_variable_values_with_moving_averages(
+ tf.get_default_graph(), trained_checkpoint_prefix,
+ moving_average_checkpoint.name)
+ checkpoint_to_use = moving_average_checkpoint.name
+ else:
+ checkpoint_to_use = trained_checkpoint_prefix
+
+ saver = tf.train.Saver(**saver_kwargs)
+ input_saver_def = saver.as_saver_def()
+ frozen_graph_def = exporter.freeze_graph_with_def_protos(
+ input_graph_def=tf.get_default_graph().as_graph_def(),
+ input_saver_def=input_saver_def,
+ input_checkpoint=checkpoint_to_use,
+ output_node_names=','.join([
+ 'raw_outputs/box_encodings', 'raw_outputs/class_predictions',
+ 'anchors'
+ ]),
+ restore_op_name='save/restore_all',
+ filename_tensor_name='save/Const:0',
+ clear_devices=True,
+ output_graph='',
+ initializer_nodes='')
+
+ # Add new operation to do post processing in a custom op (TF Lite only)
+
+ if add_postprocessing_op:
+ transformed_graph_def = append_postprocessing_op(
+ frozen_graph_def, max_detections, max_classes_per_detection,
+ nms_score_threshold, nms_iou_threshold, num_classes, scale_values,
+ detections_per_class, use_regular_nms)
+ else:
+ # Return frozen without adding post-processing custom op
+ transformed_graph_def = frozen_graph_def
+
+ binary_graph = os.path.join(output_dir, binary_graph_name)
+ with tf.gfile.GFile(binary_graph, 'wb') as f:
+ f.write(transformed_graph_def.SerializeToString())
+ txt_graph = os.path.join(output_dir, txt_graph_name)
+ with tf.gfile.GFile(txt_graph, 'w') as f:
+ f.write(str(transformed_graph_def))
diff --git a/models/research/lstm_object_detection/export_tflite_lstd_model.py b/models/research/lstm_object_detection/export_tflite_lstd_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..58c674728b5b0e274ae112d66abe3ff72f63b86e
--- /dev/null
+++ b/models/research/lstm_object_detection/export_tflite_lstd_model.py
@@ -0,0 +1,65 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Export a LSTD model in tflite format."""
+
+import os
+from absl import flags
+import tensorflow.compat.v1 as tf
+
+from lstm_object_detection.utils import config_util
+
+flags.DEFINE_string('export_path', None, 'Path to export model.')
+flags.DEFINE_string('frozen_graph_path', None, 'Path to frozen graph.')
+flags.DEFINE_string(
+ 'pipeline_config_path', '',
+ 'Path to a pipeline_pb2.TrainEvalPipelineConfig config file.')
+
+FLAGS = flags.FLAGS
+
+
+def main(_):
+ flags.mark_flag_as_required('export_path')
+ flags.mark_flag_as_required('frozen_graph_path')
+ flags.mark_flag_as_required('pipeline_config_path')
+
+ configs = config_util.get_configs_from_pipeline_file(
+ FLAGS.pipeline_config_path)
+ lstm_config = configs['lstm_model']
+
+ input_arrays = ['input_video_tensor']
+ output_arrays = [
+ 'TFLite_Detection_PostProcess',
+ 'TFLite_Detection_PostProcess:1',
+ 'TFLite_Detection_PostProcess:2',
+ 'TFLite_Detection_PostProcess:3',
+ ]
+ input_shapes = {
+ 'input_video_tensor': [lstm_config.eval_unroll_length, 320, 320, 3],
+ }
+
+ converter = tf.lite.TFLiteConverter.from_frozen_graph(
+ FLAGS.frozen_graph_path,
+ input_arrays,
+ output_arrays,
+ input_shapes=input_shapes)
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ ofilename = os.path.join(FLAGS.export_path)
+ open(ofilename, 'wb').write(tflite_model)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/lstm_object_detection/g3doc/Interleaved_Intro.png b/models/research/lstm_object_detection/g3doc/Interleaved_Intro.png
new file mode 100644
index 0000000000000000000000000000000000000000..2b829c997bc75e807c0982b1d71334966452b122
Binary files /dev/null and b/models/research/lstm_object_detection/g3doc/Interleaved_Intro.png differ
diff --git a/models/research/lstm_object_detection/g3doc/exporting_models.md b/models/research/lstm_object_detection/g3doc/exporting_models.md
new file mode 100644
index 0000000000000000000000000000000000000000..7d501d97efdfb8d259e867164aa04f275b56a036
--- /dev/null
+++ b/models/research/lstm_object_detection/g3doc/exporting_models.md
@@ -0,0 +1,49 @@
+# Exporting a tflite model from a checkpoint
+
+Starting from a trained model checkpoint, creating a tflite model requires 2
+steps:
+
+* exporting a tflite frozen graph from a checkpoint
+* exporting a tflite model from a frozen graph
+
+## Exporting a tflite frozen graph from a checkpoint
+
+With a candidate checkpoint to export, run the following command from
+tensorflow/models/research:
+
+```bash
+# from tensorflow/models/research
+PIPELINE_CONFIG_PATH={path to pipeline config}
+TRAINED_CKPT_PREFIX=/{path to model.ckpt}
+EXPORT_DIR={path to folder that will be used for export}
+python lstm_object_detection/export_tflite_lstd_graph.py \
+ --pipeline_config_path ${PIPELINE_CONFIG_PATH} \
+ --trained_checkpoint_prefix ${TRAINED_CKPT_PREFIX} \
+ --output_directory ${EXPORT_DIR} \
+ --add_preprocessing_op
+```
+
+After export, you should see the directory ${EXPORT_DIR} containing the
+following files:
+
+* `tflite_graph.pb`
+* `tflite_graph.pbtxt`
+
+## Exporting a tflite model from a frozen graph
+
+We then take the exported tflite-compatable tflite model, and convert it to a
+TFLite FlatBuffer file by running the following:
+
+```bash
+# from tensorflow/models/research
+FROZEN_GRAPH_PATH={path to exported tflite_graph.pb}
+EXPORT_PATH={path to filename that will be used for export}
+PIPELINE_CONFIG_PATH={path to pipeline config}
+python lstm_object_detection/export_tflite_lstd_model.py \
+ --export_path ${EXPORT_PATH} \
+ --frozen_graph_path ${FROZEN_GRAPH_PATH} \
+ --pipeline_config_path ${PIPELINE_CONFIG_PATH}
+```
+
+After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
+model to be used by an application.
diff --git a/models/research/lstm_object_detection/g3doc/lstm_ssd_intro.png b/models/research/lstm_object_detection/g3doc/lstm_ssd_intro.png
new file mode 100644
index 0000000000000000000000000000000000000000..fa62eb533b9190bcf05094d12781808dc85f1107
Binary files /dev/null and b/models/research/lstm_object_detection/g3doc/lstm_ssd_intro.png differ
diff --git a/models/research/lstm_object_detection/inputs/__init__.py b/models/research/lstm_object_detection/inputs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/lstm_object_detection/inputs/seq_dataset_builder.py b/models/research/lstm_object_detection/inputs/seq_dataset_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..55e24820f60d24d14db64f2aea21e462ee278ff2
--- /dev/null
+++ b/models/research/lstm_object_detection/inputs/seq_dataset_builder.py
@@ -0,0 +1,242 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""tf.data.Dataset builder.
+
+Creates data sources for DetectionModels from an InputReader config. See
+input_reader.proto for options.
+
+Note: If users wishes to also use their own InputReaders with the Object
+Detection configuration framework, they should define their own builder function
+that wraps the build function.
+"""
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+
+from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
+from lstm_object_detection.inputs import tf_sequence_example_decoder
+from lstm_object_detection.protos import input_reader_google_pb2
+from object_detection.core import preprocessor
+from object_detection.core import preprocessor_cache
+from object_detection.core import standard_fields as fields
+from object_detection.protos import input_reader_pb2
+from object_detection.utils import ops as util_ops
+
+parallel_reader = slim.parallel_reader
+# TODO(yinxiao): Make the following variable into configurable proto.
+# Padding size for the labeled objects in each frame. Here we assume each
+# frame has a total number of objects less than _PADDING_SIZE.
+_PADDING_SIZE = 30
+
+
+def _build_training_batch_dict(batch_sequences_with_states, unroll_length,
+ batch_size):
+ """Builds training batch samples.
+
+ Args:
+ batch_sequences_with_states: A batch_sequences_with_states object.
+ unroll_length: Unrolled length for LSTM training.
+ batch_size: Batch size for queue outputs.
+
+ Returns:
+ A dictionary of tensors based on items in input_reader_config.
+ """
+ seq_tensors_dict = {
+ fields.InputDataFields.image: [],
+ fields.InputDataFields.groundtruth_boxes: [],
+ fields.InputDataFields.groundtruth_classes: [],
+ 'batch': batch_sequences_with_states,
+ }
+ for i in range(unroll_length):
+ for j in range(batch_size):
+ filtered_dict = util_ops.filter_groundtruth_with_nan_box_coordinates({
+ fields.InputDataFields.groundtruth_boxes: (
+ batch_sequences_with_states.sequences['groundtruth_boxes'][j][i]),
+ fields.InputDataFields.groundtruth_classes: (
+ batch_sequences_with_states.sequences['groundtruth_classes'][j][i]
+ ),
+ })
+ filtered_dict = util_ops.retain_groundtruth_with_positive_classes(
+ filtered_dict)
+ seq_tensors_dict[fields.InputDataFields.image].append(
+ batch_sequences_with_states.sequences['image'][j][i])
+ seq_tensors_dict[fields.InputDataFields.groundtruth_boxes].append(
+ filtered_dict[fields.InputDataFields.groundtruth_boxes])
+ seq_tensors_dict[fields.InputDataFields.groundtruth_classes].append(
+ filtered_dict[fields.InputDataFields.groundtruth_classes])
+ seq_tensors_dict[fields.InputDataFields.image] = tuple(
+ seq_tensors_dict[fields.InputDataFields.image])
+ seq_tensors_dict[fields.InputDataFields.groundtruth_boxes] = tuple(
+ seq_tensors_dict[fields.InputDataFields.groundtruth_boxes])
+ seq_tensors_dict[fields.InputDataFields.groundtruth_classes] = tuple(
+ seq_tensors_dict[fields.InputDataFields.groundtruth_classes])
+
+ return seq_tensors_dict
+
+
+def build(input_reader_config,
+ model_config,
+ lstm_config,
+ unroll_length,
+ data_augmentation_options=None,
+ batch_size=1):
+ """Builds a tensor dictionary based on the InputReader config.
+
+ Args:
+ input_reader_config: An input_reader_builder.InputReader object.
+ model_config: A model.proto object containing the config for the desired
+ DetectionModel.
+ lstm_config: LSTM specific configs.
+ unroll_length: Unrolled length for LSTM training.
+ data_augmentation_options: A list of tuples, where each tuple contains a
+ data augmentation function and a dictionary containing arguments and their
+ values (see preprocessor.py).
+ batch_size: Batch size for queue outputs.
+
+ Returns:
+ A dictionary of tensors based on items in the input_reader_config.
+
+ Raises:
+ ValueError: On invalid input reader proto.
+ ValueError: If no input paths are specified.
+ """
+ if not isinstance(input_reader_config, input_reader_pb2.InputReader):
+ raise ValueError('input_reader_config not of type '
+ 'input_reader_pb2.InputReader.')
+
+ external_reader_config = input_reader_config.external_input_reader
+ external_input_reader_config = external_reader_config.Extensions[
+ input_reader_google_pb2.GoogleInputReader.google_input_reader]
+ input_reader_type = external_input_reader_config.WhichOneof('input_reader')
+
+ if input_reader_type == 'tf_record_video_input_reader':
+ config = external_input_reader_config.tf_record_video_input_reader
+ reader_type_class = tf.TFRecordReader
+ else:
+ raise ValueError(
+ 'Unsupported reader in input_reader_config: %s' % input_reader_type)
+
+ if not config.input_path:
+ raise ValueError('At least one input path must be specified in '
+ '`input_reader_config`.')
+ key, value = parallel_reader.parallel_read(
+ config.input_path[:], # Convert `RepeatedScalarContainer` to list.
+ reader_class=reader_type_class,
+ num_epochs=(input_reader_config.num_epochs
+ if input_reader_config.num_epochs else None),
+ num_readers=input_reader_config.num_readers,
+ shuffle=input_reader_config.shuffle,
+ dtypes=[tf.string, tf.string],
+ capacity=input_reader_config.queue_capacity,
+ min_after_dequeue=input_reader_config.min_after_dequeue)
+
+ # TODO(yinxiao): Add loading instance mask option.
+ decoder = tf_sequence_example_decoder.TFSequenceExampleDecoder()
+
+ keys_to_decode = [
+ fields.InputDataFields.image, fields.InputDataFields.groundtruth_boxes,
+ fields.InputDataFields.groundtruth_classes
+ ]
+ tensor_dict = decoder.decode(value, items=keys_to_decode)
+
+ tensor_dict['image'].set_shape([None, None, None, 3])
+ tensor_dict['groundtruth_boxes'].set_shape([None, None, 4])
+
+ height = model_config.ssd.image_resizer.fixed_shape_resizer.height
+ width = model_config.ssd.image_resizer.fixed_shape_resizer.width
+
+ # If data augmentation is specified in the config file, the preprocessor
+ # will be called here to augment the data as specified. Most common
+ # augmentations include horizontal flip and cropping.
+ if data_augmentation_options:
+ images_pre = tf.split(tensor_dict['image'], config.video_length, axis=0)
+ bboxes_pre = tf.split(
+ tensor_dict['groundtruth_boxes'], config.video_length, axis=0)
+ labels_pre = tf.split(
+ tensor_dict['groundtruth_classes'], config.video_length, axis=0)
+ images_proc, bboxes_proc, labels_proc = [], [], []
+ cache = preprocessor_cache.PreprocessorCache()
+
+ for i, _ in enumerate(images_pre):
+ image_dict = {
+ fields.InputDataFields.image:
+ images_pre[i],
+ fields.InputDataFields.groundtruth_boxes:
+ tf.squeeze(bboxes_pre[i], axis=0),
+ fields.InputDataFields.groundtruth_classes:
+ tf.squeeze(labels_pre[i], axis=0),
+ }
+ image_dict = preprocessor.preprocess(
+ image_dict,
+ data_augmentation_options,
+ func_arg_map=preprocessor.get_default_func_arg_map(),
+ preprocess_vars_cache=cache)
+ # Pads detection count to _PADDING_SIZE.
+ image_dict[fields.InputDataFields.groundtruth_boxes] = tf.pad(
+ image_dict[fields.InputDataFields.groundtruth_boxes],
+ [[0, _PADDING_SIZE], [0, 0]])
+ image_dict[fields.InputDataFields.groundtruth_boxes] = tf.slice(
+ image_dict[fields.InputDataFields.groundtruth_boxes], [0, 0],
+ [_PADDING_SIZE, -1])
+ image_dict[fields.InputDataFields.groundtruth_classes] = tf.pad(
+ image_dict[fields.InputDataFields.groundtruth_classes],
+ [[0, _PADDING_SIZE]])
+ image_dict[fields.InputDataFields.groundtruth_classes] = tf.slice(
+ image_dict[fields.InputDataFields.groundtruth_classes], [0],
+ [_PADDING_SIZE])
+ images_proc.append(image_dict[fields.InputDataFields.image])
+ bboxes_proc.append(image_dict[fields.InputDataFields.groundtruth_boxes])
+ labels_proc.append(image_dict[fields.InputDataFields.groundtruth_classes])
+ tensor_dict['image'] = tf.concat(images_proc, axis=0)
+ tensor_dict['groundtruth_boxes'] = tf.stack(bboxes_proc, axis=0)
+ tensor_dict['groundtruth_classes'] = tf.stack(labels_proc, axis=0)
+ else:
+ # Pads detection count to _PADDING_SIZE per frame.
+ tensor_dict['groundtruth_boxes'] = tf.pad(
+ tensor_dict['groundtruth_boxes'], [[0, 0], [0, _PADDING_SIZE], [0, 0]])
+ tensor_dict['groundtruth_boxes'] = tf.slice(
+ tensor_dict['groundtruth_boxes'], [0, 0, 0], [-1, _PADDING_SIZE, -1])
+ tensor_dict['groundtruth_classes'] = tf.pad(
+ tensor_dict['groundtruth_classes'], [[0, 0], [0, _PADDING_SIZE]])
+ tensor_dict['groundtruth_classes'] = tf.slice(
+ tensor_dict['groundtruth_classes'], [0, 0], [-1, _PADDING_SIZE])
+
+ tensor_dict['image'], _ = preprocessor.resize_image(
+ tensor_dict['image'], new_height=height, new_width=width)
+
+ num_steps = config.video_length / unroll_length
+
+ init_states = {
+ 'lstm_state_c':
+ tf.zeros([height / 32, width / 32, lstm_config.lstm_state_depth]),
+ 'lstm_state_h':
+ tf.zeros([height / 32, width / 32, lstm_config.lstm_state_depth]),
+ 'lstm_state_step':
+ tf.constant(num_steps, shape=[]),
+ }
+
+ batch = sqss.batch_sequences_with_states(
+ input_key=key,
+ input_sequences=tensor_dict,
+ input_context={},
+ input_length=None,
+ initial_states=init_states,
+ num_unroll=unroll_length,
+ batch_size=batch_size,
+ num_threads=batch_size,
+ make_keys_unique=True,
+ capacity=batch_size * batch_size)
+
+ return _build_training_batch_dict(batch, unroll_length, batch_size)
diff --git a/models/research/lstm_object_detection/inputs/seq_dataset_builder_test.py b/models/research/lstm_object_detection/inputs/seq_dataset_builder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b894d24f71fea1c5c372ec0ead9141af6d5ef6f
--- /dev/null
+++ b/models/research/lstm_object_detection/inputs/seq_dataset_builder_test.py
@@ -0,0 +1,282 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for dataset_builder."""
+
+import os
+import numpy as np
+import tensorflow.compat.v1 as tf
+
+from google.protobuf import text_format
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from lstm_object_detection.inputs import seq_dataset_builder
+from lstm_object_detection.protos import pipeline_pb2 as internal_pipeline_pb2
+from object_detection.builders import preprocessor_builder
+from object_detection.core import standard_fields as fields
+from object_detection.protos import input_reader_pb2
+from object_detection.protos import pipeline_pb2
+from object_detection.protos import preprocessor_pb2
+
+
+class DatasetBuilderTest(tf.test.TestCase):
+
+ def _create_tf_record(self):
+ path = os.path.join(self.get_temp_dir(), 'tfrecord')
+ writer = tf.python_io.TFRecordWriter(path)
+
+ image_tensor = np.random.randint(255, size=(16, 16, 3)).astype(np.uint8)
+ with self.test_session():
+ encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
+
+ sequence_example = example_pb2.SequenceExample(
+ context=feature_pb2.Features(
+ feature={
+ 'image/format':
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=['jpeg'.encode('utf-8')])),
+ 'image/height':
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[16])),
+ 'image/width':
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[16])),
+ }),
+ feature_lists=feature_pb2.FeatureLists(
+ feature_list={
+ 'image/encoded':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[encoded_jpeg])),
+ ]),
+ 'image/object/bbox/xmin':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=[0.0])),
+ ]),
+ 'image/object/bbox/xmax':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=[1.0]))
+ ]),
+ 'image/object/bbox/ymin':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=[0.0])),
+ ]),
+ 'image/object/bbox/ymax':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=[1.0]))
+ ]),
+ 'image/object/class/label':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[2]))
+ ]),
+ }))
+
+ writer.write(sequence_example.SerializeToString())
+ writer.close()
+
+ return path
+
+ def _get_model_configs_from_proto(self):
+ """Creates a model text proto for testing.
+
+ Returns:
+ A dictionary of model configs.
+ """
+
+ model_text_proto = """
+ [lstm_object_detection.protos.lstm_model] {
+ train_unroll_length: 4
+ eval_unroll_length: 4
+ }
+ model {
+ ssd {
+ feature_extractor {
+ type: 'lstm_mobilenet_v1_fpn'
+ conv_hyperparams {
+ regularizer {
+ l2_regularizer {
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ }
+ }
+ }
+ }
+ negative_class_weight: 2.0
+ box_coder {
+ faster_rcnn_box_coder {
+ }
+ }
+ matcher {
+ argmax_matcher {
+ }
+ }
+ similarity_calculator {
+ iou_similarity {
+ }
+ }
+ anchor_generator {
+ ssd_anchor_generator {
+ aspect_ratios: 1.0
+ }
+ }
+ image_resizer {
+ fixed_shape_resizer {
+ height: 32
+ width: 32
+ }
+ }
+ box_predictor {
+ convolutional_box_predictor {
+ conv_hyperparams {
+ regularizer {
+ l2_regularizer {
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ }
+ }
+ }
+ }
+ }
+ normalize_loc_loss_by_codesize: true
+ loss {
+ classification_loss {
+ weighted_softmax {
+ }
+ }
+ localization_loss {
+ weighted_smooth_l1 {
+ }
+ }
+ }
+ }
+ }"""
+
+ pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
+ text_format.Merge(model_text_proto, pipeline_config)
+ configs = {}
+ configs['model'] = pipeline_config.model
+ configs['lstm_model'] = pipeline_config.Extensions[
+ internal_pipeline_pb2.lstm_model]
+
+ return configs
+
+ def _get_data_augmentation_preprocessor_proto(self):
+ preprocessor_text_proto = """
+ random_horizontal_flip {
+ }
+ """
+ preprocessor_proto = preprocessor_pb2.PreprocessingStep()
+ text_format.Merge(preprocessor_text_proto, preprocessor_proto)
+ return preprocessor_proto
+
+ def _create_training_dict(self, tensor_dict):
+ image_dict = {}
+ all_dict = {}
+ all_dict['batch'] = tensor_dict.pop('batch')
+ for i, _ in enumerate(tensor_dict[fields.InputDataFields.image]):
+ for key, val in tensor_dict.items():
+ image_dict[key] = val[i]
+
+ image_dict[fields.InputDataFields.image] = tf.to_float(
+ tf.expand_dims(image_dict[fields.InputDataFields.image], 0))
+ suffix = str(i)
+ for key, val in image_dict.items():
+ all_dict[key + suffix] = val
+ return all_dict
+
+ def _get_input_proto(self, input_reader):
+ return """
+ external_input_reader {
+ [lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
+ %s: {
+ input_path: '{0}'
+ data_type: TF_SEQUENCE_EXAMPLE
+ video_length: 4
+ }
+ }
+ }
+ """ % input_reader
+
+ def test_video_input_reader(self):
+ input_reader_proto = input_reader_pb2.InputReader()
+ text_format.Merge(
+ self._get_input_proto('tf_record_video_input_reader'),
+ input_reader_proto)
+
+ configs = self._get_model_configs_from_proto()
+ tensor_dict = seq_dataset_builder.build(
+ input_reader_proto,
+ configs['model'],
+ configs['lstm_model'],
+ unroll_length=1)
+
+ all_dict = self._create_training_dict(tensor_dict)
+
+ self.assertEqual((1, 32, 32, 3), all_dict['image0'].shape)
+ self.assertEqual(4, all_dict['groundtruth_boxes0'].shape[1])
+
+ def test_build_with_data_augmentation(self):
+ input_reader_proto = input_reader_pb2.InputReader()
+ text_format.Merge(
+ self._get_input_proto('tf_record_video_input_reader'),
+ input_reader_proto)
+
+ configs = self._get_model_configs_from_proto()
+ data_augmentation_options = [
+ preprocessor_builder.build(
+ self._get_data_augmentation_preprocessor_proto())
+ ]
+ tensor_dict = seq_dataset_builder.build(
+ input_reader_proto,
+ configs['model'],
+ configs['lstm_model'],
+ unroll_length=1,
+ data_augmentation_options=data_augmentation_options)
+
+ all_dict = self._create_training_dict(tensor_dict)
+ self.assertEqual((1, 32, 32, 3), all_dict['image0'].shape)
+ self.assertEqual(4, all_dict['groundtruth_boxes0'].shape[1])
+
+ def test_raises_error_without_input_paths(self):
+ input_reader_text_proto = """
+ shuffle: false
+ num_readers: 1
+ load_instance_masks: true
+ """
+ input_reader_proto = input_reader_pb2.InputReader()
+ text_format.Merge(input_reader_text_proto, input_reader_proto)
+
+ configs = self._get_model_configs_from_proto()
+ with self.assertRaises(ValueError):
+ _ = seq_dataset_builder.build(
+ input_reader_proto,
+ configs['model'],
+ configs['lstm_model'],
+ unroll_length=1)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/inputs/tf_sequence_example_decoder.py b/models/research/lstm_object_detection/inputs/tf_sequence_example_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..def945b3f07d5c0ef35c454c495405971e04574a
--- /dev/null
+++ b/models/research/lstm_object_detection/inputs/tf_sequence_example_decoder.py
@@ -0,0 +1,263 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tensorflow Sequence Example proto decoder.
+
+A decoder to decode string tensors containing serialized
+tensorflow.SequenceExample protos.
+"""
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+from object_detection.core import data_decoder
+from object_detection.core import standard_fields as fields
+
+tfexample_decoder = slim.tfexample_decoder
+
+
+class BoundingBoxSequence(tfexample_decoder.ItemHandler):
+ """An ItemHandler that concatenates SparseTensors to Bounding Boxes.
+ """
+
+ def __init__(self, keys=None, prefix=None, return_dense=True,
+ default_value=-1.0):
+ """Initialize the bounding box handler.
+
+ Args:
+ keys: A list of four key names representing the ymin, xmin, ymax, xmax
+ in the Example or SequenceExample.
+ prefix: An optional prefix for each of the bounding box keys in the
+ Example or SequenceExample. If provided, `prefix` is prepended to each
+ key in `keys`.
+ return_dense: if True, returns a dense tensor; if False, returns as
+ sparse tensor.
+ default_value: The value used when the `tensor_key` is not found in a
+ particular `TFExample`.
+
+ Raises:
+ ValueError: if keys is not `None` and also not a list of exactly 4 keys
+ """
+ if keys is None:
+ keys = ['ymin', 'xmin', 'ymax', 'xmax']
+ elif len(keys) != 4:
+ raise ValueError('BoundingBoxSequence expects 4 keys but got {}'.format(
+ len(keys)))
+ self._prefix = prefix
+ self._keys = keys
+ self._full_keys = [prefix + k for k in keys]
+ self._return_dense = return_dense
+ self._default_value = default_value
+ super(BoundingBoxSequence, self).__init__(self._full_keys)
+
+ def tensors_to_item(self, keys_to_tensors):
+ """Maps the given dictionary of tensors to a concatenated list of bboxes.
+
+ Args:
+ keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
+
+ Returns:
+ [time, num_boxes, 4] tensor of bounding box coordinates, in order
+ [y_min, x_min, y_max, x_max]. Whether the tensor is a SparseTensor
+ or a dense Tensor is determined by the return_dense parameter. Empty
+ positions in the sparse tensor are filled with -1.0 values.
+ """
+ sides = []
+ for key in self._full_keys:
+ value = keys_to_tensors[key]
+ expanded_dims = tf.concat(
+ [tf.to_int64(tf.shape(value)),
+ tf.constant([1], dtype=tf.int64)], 0)
+ side = tf.sparse_reshape(value, expanded_dims)
+ sides.append(side)
+ bounding_boxes = tf.sparse_concat(2, sides)
+ if self._return_dense:
+ bounding_boxes = tf.sparse_tensor_to_dense(
+ bounding_boxes, default_value=self._default_value)
+ return bounding_boxes
+
+
+class TFSequenceExampleDecoder(data_decoder.DataDecoder):
+ """Tensorflow Sequence Example proto decoder."""
+
+ def __init__(self):
+ """Constructor sets keys_to_features and items_to_handlers."""
+ self.keys_to_context_features = {
+ 'image/format':
+ tf.FixedLenFeature((), tf.string, default_value='jpeg'),
+ 'image/filename':
+ tf.FixedLenFeature((), tf.string, default_value=''),
+ 'image/key/sha256':
+ tf.FixedLenFeature((), tf.string, default_value=''),
+ 'image/source_id':
+ tf.FixedLenFeature((), tf.string, default_value=''),
+ 'image/height':
+ tf.FixedLenFeature((), tf.int64, 1),
+ 'image/width':
+ tf.FixedLenFeature((), tf.int64, 1),
+ }
+ self.keys_to_features = {
+ 'image/encoded': tf.FixedLenSequenceFeature((), tf.string),
+ 'bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
+ 'bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
+ 'bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
+ 'bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
+ 'bbox/label/index': tf.VarLenFeature(dtype=tf.int64),
+ 'bbox/label/string': tf.VarLenFeature(tf.string),
+ 'area': tf.VarLenFeature(tf.float32),
+ 'is_crowd': tf.VarLenFeature(tf.int64),
+ 'difficult': tf.VarLenFeature(tf.int64),
+ 'group_of': tf.VarLenFeature(tf.int64),
+ }
+ self.items_to_handlers = {
+ fields.InputDataFields.image:
+ tfexample_decoder.Image(
+ image_key='image/encoded',
+ format_key='image/format',
+ channels=3,
+ repeated=True),
+ fields.InputDataFields.source_id: (
+ tfexample_decoder.Tensor('image/source_id')),
+ fields.InputDataFields.key: (
+ tfexample_decoder.Tensor('image/key/sha256')),
+ fields.InputDataFields.filename: (
+ tfexample_decoder.Tensor('image/filename')),
+ # Object boxes and classes.
+ fields.InputDataFields.groundtruth_boxes:
+ BoundingBoxSequence(prefix='bbox/'),
+ fields.InputDataFields.groundtruth_classes: (
+ tfexample_decoder.Tensor('bbox/label/index')),
+ fields.InputDataFields.groundtruth_area:
+ tfexample_decoder.Tensor('area'),
+ fields.InputDataFields.groundtruth_is_crowd: (
+ tfexample_decoder.Tensor('is_crowd')),
+ fields.InputDataFields.groundtruth_difficult: (
+ tfexample_decoder.Tensor('difficult')),
+ fields.InputDataFields.groundtruth_group_of: (
+ tfexample_decoder.Tensor('group_of'))
+ }
+
+ def decode(self, tf_seq_example_string_tensor, items=None):
+ """Decodes serialized tf.SequenceExample and returns a tensor dictionary.
+
+ Args:
+ tf_seq_example_string_tensor: A string tensor holding a serialized
+ tensorflow example proto.
+ items: The list of items to decode. These must be a subset of the item
+ keys in self._items_to_handlers. If `items` is left as None, then all
+ of the items in self._items_to_handlers are decoded.
+
+ Returns:
+ A dictionary of the following tensors.
+ fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, seq]
+ containing image(s).
+ fields.InputDataFields.source_id - string tensor containing original
+ image id.
+ fields.InputDataFields.key - string tensor with unique sha256 hash key.
+ fields.InputDataFields.filename - string tensor with original dataset
+ filename.
+ fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
+ [None, 4] containing box corners.
+ fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
+ [None] containing classes for the boxes.
+ fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
+ [None] containing object mask area in pixel squared.
+ fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
+ [None] indicating if the boxes enclose a crowd.
+ fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
+ [None] indicating if the boxes represent `difficult` instances.
+ """
+ serialized_example = tf.reshape(tf_seq_example_string_tensor, shape=[])
+ decoder = TFSequenceExampleDecoderHelper(self.keys_to_context_features,
+ self.keys_to_features,
+ self.items_to_handlers)
+ if not items:
+ items = decoder.list_items()
+ tensors = decoder.decode(serialized_example, items=items)
+ tensor_dict = dict(zip(items, tensors))
+
+ return tensor_dict
+
+
+class TFSequenceExampleDecoderHelper(data_decoder.DataDecoder):
+ """A decoder helper class for TensorFlow SequenceExamples.
+
+ To perform this decoding operation, a SequenceExampleDecoder is given a list
+ of ItemHandlers. Each ItemHandler indicates the set of features.
+ """
+
+ def __init__(self, keys_to_context_features, keys_to_sequence_features,
+ items_to_handlers):
+ """Constructs the decoder.
+
+ Args:
+ keys_to_context_features: A dictionary from TF-SequenceExample context
+ keys to either tf.VarLenFeature or tf.FixedLenFeature instances.
+ See tensorflow's parsing_ops.py.
+ keys_to_sequence_features: A dictionary from TF-SequenceExample sequence
+ keys to either tf.VarLenFeature or tf.FixedLenSequenceFeature instances.
+ items_to_handlers: A dictionary from items (strings) to ItemHandler
+ instances. Note that the ItemHandler's are provided the keys that they
+ use to return the final item Tensors.
+ Raises:
+ ValueError: If the same key is present for context features and sequence
+ features.
+ """
+ unique_keys = set()
+ unique_keys.update(keys_to_context_features)
+ unique_keys.update(keys_to_sequence_features)
+ if len(unique_keys) != (
+ len(keys_to_context_features) + len(keys_to_sequence_features)):
+ # This situation is ambiguous in the decoder's keys_to_tensors variable.
+ raise ValueError('Context and sequence keys are not unique. \n'
+ ' Context keys: %s \n Sequence keys: %s' %
+ (list(keys_to_context_features.keys()),
+ list(keys_to_sequence_features.keys())))
+ self._keys_to_context_features = keys_to_context_features
+ self._keys_to_sequence_features = keys_to_sequence_features
+ self._items_to_handlers = items_to_handlers
+
+ def list_items(self):
+ """Returns keys of items."""
+ return self._items_to_handlers.keys()
+
+ def decode(self, serialized_example, items=None):
+ """Decodes the given serialized TF-SequenceExample.
+
+ Args:
+ serialized_example: A serialized TF-SequenceExample tensor.
+ items: The list of items to decode. These must be a subset of the item
+ keys in self._items_to_handlers. If `items` is left as None, then all
+ of the items in self._items_to_handlers are decoded.
+ Returns:
+ The decoded items, a list of tensor.
+ """
+ context, feature_list = tf.parse_single_sequence_example(
+ serialized_example, self._keys_to_context_features,
+ self._keys_to_sequence_features)
+ # Reshape non-sparse elements just once:
+ for k in self._keys_to_context_features:
+ v = self._keys_to_context_features[k]
+ if isinstance(v, tf.FixedLenFeature):
+ context[k] = tf.reshape(context[k], v.shape)
+ if not items:
+ items = self._items_to_handlers.keys()
+ outputs = []
+ for item in items:
+ handler = self._items_to_handlers[item]
+ keys_to_tensors = {
+ key: context[key] if key in context else feature_list[key]
+ for key in handler.keys
+ }
+ outputs.append(handler.tensors_to_item(keys_to_tensors))
+ return outputs
diff --git a/models/research/lstm_object_detection/inputs/tf_sequence_example_decoder_test.py b/models/research/lstm_object_detection/inputs/tf_sequence_example_decoder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbbb8d3c7443dabcfc0df08638e2a381eca2cc31
--- /dev/null
+++ b/models/research/lstm_object_detection/inputs/tf_sequence_example_decoder_test.py
@@ -0,0 +1,113 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for lstm_object_detection.tf_sequence_example_decoder."""
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import parsing_ops
+from lstm_object_detection.inputs import tf_sequence_example_decoder
+from object_detection.core import standard_fields as fields
+
+
+class TFSequenceExampleDecoderTest(tf.test.TestCase):
+ """Tests for sequence example decoder."""
+
+ def _EncodeImage(self, image_tensor, encoding_type='jpeg'):
+ with self.test_session():
+ if encoding_type == 'jpeg':
+ image_encoded = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
+ else:
+ raise ValueError('Invalid encoding type.')
+ return image_encoded
+
+ def _DecodeImage(self, image_encoded, encoding_type='jpeg'):
+ with self.test_session():
+ if encoding_type == 'jpeg':
+ image_decoded = tf.image.decode_jpeg(tf.constant(image_encoded)).eval()
+ else:
+ raise ValueError('Invalid encoding type.')
+ return image_decoded
+
+ def testDecodeJpegImageAndBoundingBox(self):
+ """Test if the decoder can correctly decode the image and bounding box.
+
+ A set of random images (represented as an image tensor) is first decoded as
+ the groundtrue image. Meanwhile, the image tensor will be encoded and pass
+ through the sequence example, and then decoded as images. The groundtruth
+ image and the decoded image are expected to be equal. Similar tests are
+ also applied to labels such as bounding box.
+ """
+ image_tensor = np.random.randint(256, size=(256, 256, 3)).astype(np.uint8)
+ encoded_jpeg = self._EncodeImage(image_tensor)
+ decoded_jpeg = self._DecodeImage(encoded_jpeg)
+
+ sequence_example = example_pb2.SequenceExample(
+ feature_lists=feature_pb2.FeatureLists(
+ feature_list={
+ 'image/encoded':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[encoded_jpeg])),
+ ]),
+ 'bbox/xmin':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=[0.0])),
+ ]),
+ 'bbox/xmax':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=[1.0]))
+ ]),
+ 'bbox/ymin':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=[0.0])),
+ ]),
+ 'bbox/ymax':
+ feature_pb2.FeatureList(feature=[
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=[1.0]))
+ ]),
+ })).SerializeToString()
+
+ example_decoder = tf_sequence_example_decoder.TFSequenceExampleDecoder()
+ tensor_dict = example_decoder.decode(tf.convert_to_tensor(sequence_example))
+
+ # Test tensor dict image dimension.
+ self.assertAllEqual(
+ (tensor_dict[fields.InputDataFields.image].get_shape().as_list()),
+ [None, None, None, 3])
+ with self.test_session() as sess:
+ tensor_dict[fields.InputDataFields.image] = tf.squeeze(
+ tensor_dict[fields.InputDataFields.image])
+ tensor_dict[fields.InputDataFields.groundtruth_boxes] = tf.squeeze(
+ tensor_dict[fields.InputDataFields.groundtruth_boxes])
+ tensor_dict = sess.run(tensor_dict)
+
+ # Test decoded image.
+ self.assertAllEqual(decoded_jpeg, tensor_dict[fields.InputDataFields.image])
+ # Test decoded bounding box.
+ self.assertAllEqual([0.0, 0.0, 1.0, 1.0],
+ tensor_dict[fields.InputDataFields.groundtruth_boxes])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/lstm/__init__.py b/models/research/lstm_object_detection/lstm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/lstm_object_detection/lstm/lstm_cells.py b/models/research/lstm_object_detection/lstm/lstm_cells.py
new file mode 100644
index 0000000000000000000000000000000000000000..a553073d978b4b61e6f550fa65e2a2ccc7bfe92d
--- /dev/null
+++ b/models/research/lstm_object_detection/lstm/lstm_cells.py
@@ -0,0 +1,734 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BottleneckConvLSTMCell implementation."""
+import functools
+
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+
+from tensorflow.contrib import rnn as contrib_rnn
+from tensorflow.contrib.framework.python.ops import variables as contrib_variables
+import lstm_object_detection.lstm.utils as lstm_utils
+
+
+class BottleneckConvLSTMCell(contrib_rnn.RNNCell):
+ """Basic LSTM recurrent network cell using separable convolutions.
+
+ The implementation is based on:
+ Mobile Video Object Detection with Temporally-Aware Feature Maps
+ https://arxiv.org/abs/1711.06368.
+
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
+ reduce the scale of forgetting in the beginning of the training.
+
+ This LSTM first projects inputs to the size of the output before doing gate
+ computations. This saves params unless the input is less than a third of the
+ state size channel-wise.
+ """
+
+ def __init__(self,
+ filter_size,
+ output_size,
+ num_units,
+ forget_bias=1.0,
+ activation=tf.tanh,
+ flatten_state=False,
+ clip_state=False,
+ output_bottleneck=False,
+ pre_bottleneck=False,
+ visualize_gates=False):
+ """Initializes the basic LSTM cell.
+
+ Args:
+ filter_size: collection, conv filter size.
+ output_size: collection, the width/height dimensions of the cell/output.
+ num_units: int, The number of channels in the LSTM cell.
+ forget_bias: float, The bias added to forget gates (see above).
+ activation: Activation function of the inner states.
+ flatten_state: if True, state tensor will be flattened and stored as a 2-d
+ tensor. Use for exporting the model to tfmini.
+ clip_state: if True, clip state between [-6, 6].
+ output_bottleneck: if True, the cell bottleneck will be concatenated to
+ the cell output.
+ pre_bottleneck: if True, cell assumes that bottlenecking was performing
+ before the function was called.
+ visualize_gates: if True, add histogram summaries of all gates and outputs
+ to tensorboard.
+ """
+ self._filter_size = list(filter_size)
+ self._output_size = list(output_size)
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+ self._activation = activation
+ self._viz_gates = visualize_gates
+ self._flatten_state = flatten_state
+ self._clip_state = clip_state
+ self._output_bottleneck = output_bottleneck
+ self._pre_bottleneck = pre_bottleneck
+ self._param_count = self._num_units
+ for dim in self._output_size:
+ self._param_count *= dim
+
+ @property
+ def state_size(self):
+ return contrib_rnn.LSTMStateTuple(self._output_size + [self._num_units],
+ self._output_size + [self._num_units])
+
+ @property
+ def state_size_flat(self):
+ return contrib_rnn.LSTMStateTuple([self._param_count], [self._param_count])
+
+ @property
+ def output_size(self):
+ return self._output_size + [self._num_units]
+
+ def __call__(self, inputs, state, scope=None):
+ """Long short-term memory cell (LSTM) with bottlenecking.
+
+ Args:
+ inputs: Input tensor at the current timestep.
+ state: Tuple of tensors, the state and output at the previous timestep.
+ scope: Optional scope.
+
+ Returns:
+ A tuple where the first element is the LSTM output and the second is
+ a LSTMStateTuple of the state at the current timestep.
+ """
+ scope = scope or 'conv_lstm_cell'
+ with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
+ c, h = state
+
+ # unflatten state if necessary
+ if self._flatten_state:
+ c = tf.reshape(c, [-1] + self.output_size)
+ h = tf.reshape(h, [-1] + self.output_size)
+
+ # summary of input passed into cell
+ if self._viz_gates:
+ slim.summaries.add_histogram_summary(inputs, 'cell_input')
+ if self._pre_bottleneck:
+ bottleneck = inputs
+ else:
+ bottleneck = slim.separable_conv2d(
+ tf.concat([inputs, h], 3),
+ self._num_units,
+ self._filter_size,
+ depth_multiplier=1,
+ activation_fn=self._activation,
+ normalizer_fn=None,
+ scope='bottleneck')
+
+ if self._viz_gates:
+ slim.summaries.add_histogram_summary(bottleneck, 'bottleneck')
+
+ concat = slim.separable_conv2d(
+ bottleneck,
+ 4 * self._num_units,
+ self._filter_size,
+ depth_multiplier=1,
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='gates')
+
+ i, j, f, o = tf.split(concat, 4, 3)
+
+ new_c = (
+ c * tf.sigmoid(f + self._forget_bias) +
+ tf.sigmoid(i) * self._activation(j))
+ if self._clip_state:
+ new_c = tf.clip_by_value(new_c, -6, 6)
+ new_h = self._activation(new_c) * tf.sigmoid(o)
+ # summary of cell output and new state
+ if self._viz_gates:
+ slim.summaries.add_histogram_summary(new_h, 'cell_output')
+ slim.summaries.add_histogram_summary(new_c, 'cell_state')
+
+ output = new_h
+ if self._output_bottleneck:
+ output = tf.concat([new_h, bottleneck], axis=3)
+
+ # reflatten state to store it
+ if self._flatten_state:
+ new_c = tf.reshape(new_c, [-1, self._param_count])
+ new_h = tf.reshape(new_h, [-1, self._param_count])
+
+ return output, contrib_rnn.LSTMStateTuple(new_c, new_h)
+
+ def init_state(self, state_name, batch_size, dtype, learned_state=False):
+ """Creates an initial state compatible with this cell.
+
+ Args:
+ state_name: name of the state tensor
+ batch_size: model batch size
+ dtype: dtype for the tensor values i.e. tf.float32
+ learned_state: whether the initial state should be learnable. If false,
+ the initial state is set to all 0's
+
+ Returns:
+ The created initial state.
+ """
+ state_size = (
+ self.state_size_flat if self._flatten_state else self.state_size)
+ # list of 2 zero tensors or variables tensors, depending on if
+ # learned_state is true
+ # pylint: disable=g-long-ternary,g-complex-comprehension
+ ret_flat = [(contrib_variables.model_variable(
+ state_name + str(i),
+ shape=s,
+ dtype=dtype,
+ initializer=tf.truncated_normal_initializer(stddev=0.03))
+ if learned_state else tf.zeros(
+ [batch_size] + s, dtype=dtype, name=state_name))
+ for i, s in enumerate(state_size)]
+
+ # duplicates initial state across the batch axis if it's learned
+ if learned_state:
+ ret_flat = [
+ tf.stack([tensor
+ for i in range(int(batch_size))])
+ for tensor in ret_flat
+ ]
+ for s, r in zip(state_size, ret_flat):
+ r.set_shape([None] + s)
+ return tf.nest.pack_sequence_as(structure=[1, 1], flat_sequence=ret_flat)
+
+ def pre_bottleneck(self, inputs, state, input_index):
+ """Apply pre-bottleneck projection to inputs.
+
+ Pre-bottleneck operation maps features of different channels into the same
+ dimension. The purpose of this op is to share the features from both large
+ and small models in the same LSTM cell.
+
+ Args:
+ inputs: 4D Tensor with shape [batch_size x width x height x input_size].
+ state: 4D Tensor with shape [batch_size x width x height x state_size].
+ input_index: integer index indicating which base features the inputs
+ correspoding to.
+
+ Returns:
+ inputs: pre-bottlenecked inputs.
+ Raises:
+ ValueError: If pre_bottleneck is not set or inputs is not rank 4.
+ """
+ # Sometimes state is a tuple, in which case it cannot be modified, e.g.
+ # during training, tf.contrib.training.SequenceQueueingStateSaver
+ # returns the state as a tuple. This should not be an issue since we
+ # only need to modify state[1] during export, when state should be a
+ # list.
+ if len(inputs.shape) != 4:
+ raise ValueError('Expect rank 4 feature tensor.')
+ if not self._flatten_state and len(state.shape) != 4:
+ raise ValueError('Expect rank 4 state tensor.')
+ if self._flatten_state and len(state.shape) != 2:
+ raise ValueError('Expect rank 2 state tensor when flatten_state is set.')
+
+ with tf.name_scope(None):
+ state = tf.identity(state, name='raw_inputs/init_lstm_h')
+ if self._flatten_state:
+ batch_size = inputs.shape[0]
+ height = inputs.shape[1]
+ width = inputs.shape[2]
+ state = tf.reshape(state, [batch_size, height, width, -1])
+ with tf.variable_scope('conv_lstm_cell', reuse=tf.AUTO_REUSE):
+ scope_name = 'bottleneck_%d' % input_index
+ inputs = slim.separable_conv2d(
+ tf.concat([inputs, state], 3),
+ self.output_size[-1],
+ self._filter_size,
+ depth_multiplier=1,
+ activation_fn=tf.nn.relu6,
+ normalizer_fn=None,
+ scope=scope_name)
+ # For exporting inference graph, we only mark the first timestep.
+ with tf.name_scope(None):
+ inputs = tf.identity(
+ inputs, name='raw_outputs/base_endpoint_%d' % (input_index + 1))
+ return inputs
+
+
+class GroupedConvLSTMCell(contrib_rnn.RNNCell):
+ """Basic LSTM recurrent network cell using separable convolutions.
+
+ The implementation is based on: https://arxiv.org/abs/1903.10172.
+
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
+ reduce the scale of forgetting in the beginning of the training.
+
+ This LSTM first projects inputs to the size of the output before doing gate
+ computations. This saves params unless the input is less than a third of the
+ state size channel-wise. Computation of bottlenecks and gates is divided
+ into independent groups for further savings.
+ """
+
+ def __init__(self,
+ filter_size,
+ output_size,
+ num_units,
+ is_training,
+ forget_bias=1.0,
+ activation=tf.tanh,
+ use_batch_norm=False,
+ flatten_state=False,
+ groups=4,
+ clip_state=False,
+ scale_state=False,
+ output_bottleneck=False,
+ pre_bottleneck=False,
+ is_quantized=False,
+ visualize_gates=False,
+ conv_op_overrides=None):
+ """Initialize the basic LSTM cell.
+
+ Args:
+ filter_size: collection, conv filter size
+ output_size: collection, the width/height dimensions of the cell/output
+ num_units: int, The number of channels in the LSTM cell.
+ is_training: Whether the LSTM is in training mode.
+ forget_bias: float, The bias added to forget gates (see above).
+ activation: Activation function of the inner states.
+ use_batch_norm: if True, use batch norm after convolution
+ flatten_state: if True, state tensor will be flattened and stored as a 2-d
+ tensor. Use for exporting the model to tfmini
+ groups: Number of groups to split the state into. Must evenly divide
+ num_units.
+ clip_state: if True, clips state between [-6, 6].
+ scale_state: if True, scales state so that all values are under 6 at all
+ times.
+ output_bottleneck: if True, the cell bottleneck will be concatenated to
+ the cell output.
+ pre_bottleneck: if True, cell assumes that bottlenecking was performing
+ before the function was called.
+ is_quantized: if True, the model is in quantize mode, which requires
+ quantization friendly concat and separable_conv2d ops.
+ visualize_gates: if True, add histogram summaries of all gates and outputs
+ to tensorboard
+ conv_op_overrides: A list of convolutional operations that override the
+ 'bottleneck' and 'convolution' layers before lstm gates. If None, the
+ original implementation of seperable_conv will be used. The length of
+ the list should be two.
+
+ Raises:
+ ValueError: when both clip_state and scale_state are enabled.
+ """
+ if clip_state and scale_state:
+ raise ValueError('clip_state and scale_state cannot both be enabled.')
+
+ self._filter_size = list(filter_size)
+ self._output_size = list(output_size)
+ self._num_units = num_units
+ self._is_training = is_training
+ self._forget_bias = forget_bias
+ self._activation = activation
+ self._use_batch_norm = use_batch_norm
+ self._viz_gates = visualize_gates
+ self._flatten_state = flatten_state
+ self._param_count = self._num_units
+ self._groups = groups
+ self._scale_state = scale_state
+ self._clip_state = clip_state
+ self._output_bottleneck = output_bottleneck
+ self._pre_bottleneck = pre_bottleneck
+ self._is_quantized = is_quantized
+ for dim in self._output_size:
+ self._param_count *= dim
+ self._conv_op_overrides = conv_op_overrides
+ if self._conv_op_overrides and len(self._conv_op_overrides) != 2:
+ raise ValueError('Bottleneck and Convolutional layer should be overriden'
+ 'together')
+
+ @property
+ def state_size(self):
+ return contrib_rnn.LSTMStateTuple(self._output_size + [self._num_units],
+ self._output_size + [self._num_units])
+
+ @property
+ def state_size_flat(self):
+ return contrib_rnn.LSTMStateTuple([self._param_count], [self._param_count])
+
+ @property
+ def output_size(self):
+ return self._output_size + [self._num_units]
+
+ @property
+ def filter_size(self):
+ return self._filter_size
+
+ @property
+ def num_groups(self):
+ return self._groups
+
+ def __call__(self, inputs, state, scope=None):
+ """Long short-term memory cell (LSTM) with bottlenecking.
+
+ Includes logic for quantization-aware training. Note that all concats and
+ activations use fixed ranges unless stated otherwise.
+
+ Args:
+ inputs: Input tensor at the current timestep.
+ state: Tuple of tensors, the state at the previous timestep.
+ scope: Optional scope.
+
+ Returns:
+ A tuple where the first element is the LSTM output and the second is
+ a LSTMStateTuple of the state at the current timestep.
+ """
+ scope = scope or 'conv_lstm_cell'
+ with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
+ c, h = state
+
+ # Set nodes to be under raw_inputs/ name scope for tfmini export.
+ with tf.name_scope(None):
+ c = tf.identity(c, name='raw_inputs/init_lstm_c')
+ # When pre_bottleneck is enabled, input h handle is in rnn_decoder.py
+ if not self._pre_bottleneck:
+ h = tf.identity(h, name='raw_inputs/init_lstm_h')
+
+ # unflatten state if necessary
+ if self._flatten_state:
+ c = tf.reshape(c, [-1] + self.output_size)
+ h = tf.reshape(h, [-1] + self.output_size)
+
+ c_list = tf.split(c, self._groups, axis=3)
+ if self._pre_bottleneck:
+ inputs_list = tf.split(inputs, self._groups, axis=3)
+ else:
+ h_list = tf.split(h, self._groups, axis=3)
+ out_bottleneck = []
+ out_c = []
+ out_h = []
+ # summary of input passed into cell
+ if self._viz_gates:
+ slim.summaries.add_histogram_summary(inputs, 'cell_input')
+
+ for k in range(self._groups):
+ if self._pre_bottleneck:
+ bottleneck = inputs_list[k]
+ else:
+ if self._conv_op_overrides:
+ bottleneck_fn = self._conv_op_overrides[0]
+ else:
+ bottleneck_fn = functools.partial(
+ lstm_utils.quantizable_separable_conv2d,
+ kernel_size=self._filter_size,
+ activation_fn=self._activation)
+ if self._use_batch_norm:
+ b_x = bottleneck_fn(
+ inputs=inputs,
+ num_outputs=self._num_units // self._groups,
+ is_quantized=self._is_quantized,
+ depth_multiplier=1,
+ normalizer_fn=None,
+ scope='bottleneck_%d_x' % k)
+ b_h = bottleneck_fn(
+ inputs=h_list[k],
+ num_outputs=self._num_units // self._groups,
+ is_quantized=self._is_quantized,
+ depth_multiplier=1,
+ normalizer_fn=None,
+ scope='bottleneck_%d_h' % k)
+ b_x = slim.batch_norm(
+ b_x,
+ scale=True,
+ is_training=self._is_training,
+ scope='BatchNorm_%d_X' % k)
+ b_h = slim.batch_norm(
+ b_h,
+ scale=True,
+ is_training=self._is_training,
+ scope='BatchNorm_%d_H' % k)
+ bottleneck = b_x + b_h
+ else:
+ # All concats use fixed quantization ranges to prevent rescaling
+ # at inference. Both |inputs| and |h_list| are tensors resulting
+ # from Relu6 operations so we fix the ranges to [0, 6].
+ bottleneck_concat = lstm_utils.quantizable_concat(
+ [inputs, h_list[k]],
+ axis=3,
+ is_training=False,
+ is_quantized=self._is_quantized,
+ scope='bottleneck_%d/quantized_concat' % k)
+ bottleneck = bottleneck_fn(
+ inputs=bottleneck_concat,
+ num_outputs=self._num_units // self._groups,
+ is_quantized=self._is_quantized,
+ depth_multiplier=1,
+ normalizer_fn=None,
+ scope='bottleneck_%d' % k)
+
+ if self._conv_op_overrides:
+ conv_fn = self._conv_op_overrides[1]
+ else:
+ conv_fn = functools.partial(
+ lstm_utils.quantizable_separable_conv2d,
+ kernel_size=self._filter_size,
+ activation_fn=None)
+ concat = conv_fn(
+ inputs=bottleneck,
+ num_outputs=4 * self._num_units // self._groups,
+ is_quantized=self._is_quantized,
+ depth_multiplier=1,
+ normalizer_fn=None,
+ scope='concat_conv_%d' % k)
+
+ # Since there is no activation in the previous separable conv, we
+ # quantize here. A starting range of [-6, 6] is used because the
+ # tensors are input to a Sigmoid function that saturates at these
+ # ranges.
+ concat = lstm_utils.quantize_op(
+ concat,
+ is_training=self._is_training,
+ default_min=-6,
+ default_max=6,
+ is_quantized=self._is_quantized,
+ scope='gates_%d/act_quant' % k)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = tf.split(concat, 4, 3)
+
+ f_add = f + self._forget_bias
+ f_add = lstm_utils.quantize_op(
+ f_add,
+ is_training=self._is_training,
+ default_min=-6,
+ default_max=6,
+ is_quantized=self._is_quantized,
+ scope='forget_gate_%d/add_quant' % k)
+ f_act = tf.sigmoid(f_add)
+
+ a = c_list[k] * f_act
+ a = lstm_utils.quantize_op(
+ a,
+ is_training=self._is_training,
+ is_quantized=self._is_quantized,
+ scope='forget_gate_%d/mul_quant' % k)
+
+ i_act = tf.sigmoid(i)
+
+ j_act = self._activation(j)
+ # The quantization range is fixed for the relu6 to ensure that zero
+ # is exactly representable.
+ j_act = lstm_utils.fixed_quantize_op(
+ j_act,
+ fixed_min=0.0,
+ fixed_max=6.0,
+ is_quantized=self._is_quantized,
+ scope='new_input_%d/act_quant' % k)
+
+ b = i_act * j_act
+ b = lstm_utils.quantize_op(
+ b,
+ is_training=self._is_training,
+ is_quantized=self._is_quantized,
+ scope='input_gate_%d/mul_quant' % k)
+
+ new_c = a + b
+ # The quantization range is fixed to [0, 6] due to an optimization in
+ # TFLite. The order of operations is as fllows:
+ # Add -> FakeQuant -> Relu6 -> FakeQuant -> Concat.
+ # The fakequant ranges to the concat must be fixed to ensure all inputs
+ # to the concat have the same range, removing the need for rescaling.
+ # The quantization ranges input to the relu6 are propagated to its
+ # output. Any mismatch between these two ranges will cause an error.
+ new_c = lstm_utils.fixed_quantize_op(
+ new_c,
+ fixed_min=0.0,
+ fixed_max=6.0,
+ is_quantized=self._is_quantized,
+ scope='new_c_%d/add_quant' % k)
+
+ if not self._is_quantized:
+ if self._scale_state:
+ normalizer = tf.maximum(1.0,
+ tf.reduce_max(new_c, axis=(1, 2, 3)) / 6)
+ new_c /= tf.reshape(normalizer, [tf.shape(new_c)[0], 1, 1, 1])
+ elif self._clip_state:
+ new_c = tf.clip_by_value(new_c, -6, 6)
+
+ new_c_act = self._activation(new_c)
+ # The quantization range is fixed for the relu6 to ensure that zero
+ # is exactly representable.
+ new_c_act = lstm_utils.fixed_quantize_op(
+ new_c_act,
+ fixed_min=0.0,
+ fixed_max=6.0,
+ is_quantized=self._is_quantized,
+ scope='new_c_%d/act_quant' % k)
+
+ o_act = tf.sigmoid(o)
+
+ new_h = new_c_act * o_act
+ # The quantization range is fixed since it is input to a concat.
+ # A range of [0, 6] is used since |new_h| is a product of ranges [0, 6]
+ # and [0, 1].
+ new_h_act = lstm_utils.fixed_quantize_op(
+ new_h,
+ fixed_min=0.0,
+ fixed_max=6.0,
+ is_quantized=self._is_quantized,
+ scope='new_h_%d/act_quant' % k)
+
+ out_bottleneck.append(bottleneck)
+ out_c.append(new_c_act)
+ out_h.append(new_h_act)
+
+ # Since all inputs to the below concats are already quantized, we can use
+ # a regular concat operation.
+ new_c = tf.concat(out_c, axis=3)
+ new_h = tf.concat(out_h, axis=3)
+
+ # |bottleneck| is input to a concat with |new_h|. We must use
+ # quantizable_concat() with a fixed range that matches |new_h|.
+ bottleneck = lstm_utils.quantizable_concat(
+ out_bottleneck,
+ axis=3,
+ is_training=False,
+ is_quantized=self._is_quantized,
+ scope='out_bottleneck/quantized_concat')
+
+ # summary of cell output and new state
+ if self._viz_gates:
+ slim.summaries.add_histogram_summary(new_h, 'cell_output')
+ slim.summaries.add_histogram_summary(new_c, 'cell_state')
+
+ output = new_h
+ if self._output_bottleneck:
+ output = lstm_utils.quantizable_concat(
+ [new_h, bottleneck],
+ axis=3,
+ is_training=False,
+ is_quantized=self._is_quantized,
+ scope='new_output/quantized_concat')
+
+ # reflatten state to store it
+ if self._flatten_state:
+ new_c = tf.reshape(new_c, [-1, self._param_count], name='lstm_c')
+ new_h = tf.reshape(new_h, [-1, self._param_count], name='lstm_h')
+
+ # Set nodes to be under raw_outputs/ name scope for tfmini export.
+ with tf.name_scope(None):
+ new_c = tf.identity(new_c, name='raw_outputs/lstm_c')
+ new_h = tf.identity(new_h, name='raw_outputs/lstm_h')
+ states_and_output = contrib_rnn.LSTMStateTuple(new_c, new_h)
+
+ return output, states_and_output
+
+ def init_state(self, state_name, batch_size, dtype, learned_state=False):
+ """Creates an initial state compatible with this cell.
+
+ Args:
+ state_name: name of the state tensor
+ batch_size: model batch size
+ dtype: dtype for the tensor values i.e. tf.float32
+ learned_state: whether the initial state should be learnable. If false,
+ the initial state is set to all 0's
+
+ Returns:
+ ret: the created initial state
+ """
+ state_size = (
+ self.state_size_flat if self._flatten_state else self.state_size)
+ # list of 2 zero tensors or variables tensors,
+ # depending on if learned_state is true
+ # pylint: disable=g-long-ternary,g-complex-comprehension
+ ret_flat = [(contrib_variables.model_variable(
+ state_name + str(i),
+ shape=s,
+ dtype=dtype,
+ initializer=tf.truncated_normal_initializer(stddev=0.03))
+ if learned_state else tf.zeros(
+ [batch_size] + s, dtype=dtype, name=state_name))
+ for i, s in enumerate(state_size)]
+
+ # duplicates initial state across the batch axis if it's learned
+ if learned_state:
+ ret_flat = [tf.stack([tensor for i in range(int(batch_size))])
+ for tensor in ret_flat]
+ for s, r in zip(state_size, ret_flat):
+ r = tf.reshape(r, [-1] + s)
+ ret = tf.nest.pack_sequence_as(structure=[1, 1], flat_sequence=ret_flat)
+ return ret
+
+ def pre_bottleneck(self, inputs, state, input_index):
+ """Apply pre-bottleneck projection to inputs.
+
+ Pre-bottleneck operation maps features of different channels into the same
+ dimension. The purpose of this op is to share the features from both large
+ and small models in the same LSTM cell.
+
+ Args:
+ inputs: 4D Tensor with shape [batch_size x width x height x input_size].
+ state: 4D Tensor with shape [batch_size x width x height x state_size].
+ input_index: integer index indicating which base features the inputs
+ correspoding to.
+
+ Returns:
+ inputs: pre-bottlenecked inputs.
+ Raises:
+ ValueError: If pre_bottleneck is not set or inputs is not rank 4.
+ """
+ # Sometimes state is a tuple, in which case it cannot be modified, e.g.
+ # during training, tf.contrib.training.SequenceQueueingStateSaver
+ # returns the state as a tuple. This should not be an issue since we
+ # only need to modify state[1] during export, when state should be a
+ # list.
+ if not self._pre_bottleneck:
+ raise ValueError('Only applied when pre_bottleneck is set to true.')
+ if len(inputs.shape) != 4:
+ raise ValueError('Expect a rank 4 feature tensor.')
+ if not self._flatten_state and len(state.shape) != 4:
+ raise ValueError('Expect rank 4 state tensor.')
+ if self._flatten_state and len(state.shape) != 2:
+ raise ValueError('Expect rank 2 state tensor when flatten_state is set.')
+
+ with tf.name_scope(None):
+ state = tf.identity(
+ state, name='raw_inputs/init_lstm_h_%d' % (input_index + 1))
+ if self._flatten_state:
+ batch_size = inputs.shape[0]
+ height = inputs.shape[1]
+ width = inputs.shape[2]
+ state = tf.reshape(state, [batch_size, height, width, -1])
+ with tf.variable_scope('conv_lstm_cell', reuse=tf.AUTO_REUSE):
+ state_split = tf.split(state, self._groups, axis=3)
+ with tf.variable_scope('bottleneck_%d' % input_index):
+ bottleneck_out = []
+ for k in range(self._groups):
+ with tf.variable_scope('group_%d' % k):
+ bottleneck_out.append(
+ lstm_utils.quantizable_separable_conv2d(
+ lstm_utils.quantizable_concat(
+ [inputs, state_split[k]],
+ axis=3,
+ is_training=self._is_training,
+ is_quantized=self._is_quantized,
+ scope='quantized_concat'),
+ self.output_size[-1] / self._groups,
+ self._filter_size,
+ is_quantized=self._is_quantized,
+ depth_multiplier=1,
+ activation_fn=tf.nn.relu6,
+ normalizer_fn=None,
+ scope='project'))
+ inputs = lstm_utils.quantizable_concat(
+ bottleneck_out,
+ axis=3,
+ is_training=self._is_training,
+ is_quantized=self._is_quantized,
+ scope='bottleneck_out/quantized_concat')
+ # For exporting inference graph, we only mark the first timestep.
+ with tf.name_scope(None):
+ inputs = tf.identity(
+ inputs, name='raw_outputs/base_endpoint_%d' % (input_index + 1))
+ return inputs
diff --git a/models/research/lstm_object_detection/lstm/lstm_cells_test.py b/models/research/lstm_object_detection/lstm/lstm_cells_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b296310194dde2a10249c0af266d50ff762ec745
--- /dev/null
+++ b/models/research/lstm_object_detection/lstm/lstm_cells_test.py
@@ -0,0 +1,412 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lstm_object_detection.lstm.lstm_cells."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+
+from lstm_object_detection.lstm import lstm_cells
+
+
+class BottleneckConvLstmCellsTest(tf.test.TestCase):
+
+ def test_run_lstm_cell(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 15
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = False
+
+ inputs = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
+ cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units)
+ init_state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ output, state_tuple = cell(inputs, init_state)
+ self.assertAllEqual([4, 10, 10, 15], output.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], state_tuple[0].shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], state_tuple[1].shape.as_list())
+
+ def test_run_lstm_cell_with_flattened_state(self):
+ filter_size = [3, 3]
+ output_dim = 10
+ output_size = [output_dim] * 2
+ num_units = 15
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = False
+
+ inputs = tf.zeros([batch_size, output_dim, output_dim, 3], dtype=tf.float32)
+ cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ flatten_state=True)
+ init_state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ output, state_tuple = cell(inputs, init_state)
+ self.assertAllEqual([4, 10, 10, 15], output.shape.as_list())
+ self.assertAllEqual([4, 1500], state_tuple[0].shape.as_list())
+ self.assertAllEqual([4, 1500], state_tuple[1].shape.as_list())
+
+ def test_run_lstm_cell_with_output_bottleneck(self):
+ filter_size = [3, 3]
+ output_dim = 10
+ output_size = [output_dim] * 2
+ num_units = 15
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = False
+
+ inputs = tf.zeros([batch_size, output_dim, output_dim, 3], dtype=tf.float32)
+ cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ output_bottleneck=True)
+ init_state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ output, state_tuple = cell(inputs, init_state)
+ self.assertAllEqual([4, 10, 10, 30], output.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], state_tuple[0].shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], state_tuple[1].shape.as_list())
+
+ def test_get_init_state(self):
+ filter_size = [3, 3]
+ output_dim = 10
+ output_size = [output_dim] * 2
+ num_units = 15
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = False
+
+ cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units)
+ init_c, init_h = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+
+ self.assertEqual(tf.float32, init_c.dtype)
+ self.assertEqual(tf.float32, init_h.dtype)
+ with self.test_session() as sess:
+ init_c_res, init_h_res = sess.run([init_c, init_h])
+ self.assertAllClose(np.zeros((4, 10, 10, 15)), init_c_res)
+ self.assertAllClose(np.zeros((4, 10, 10, 15)), init_h_res)
+
+ def test_get_init_learned_state(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 15
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = True
+
+ cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units)
+ init_c, init_h = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+
+ self.assertEqual(tf.float32, init_c.dtype)
+ self.assertEqual(tf.float32, init_h.dtype)
+ self.assertAllEqual([4, 10, 10, 15], init_c.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], init_h.shape.as_list())
+
+ def test_unroll(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 15
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ unroll = 10
+ learned_state = False
+
+ inputs = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
+ cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units)
+ state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ for step in range(unroll):
+ output, state = cell(inputs, state)
+ self.assertAllEqual([4, 10, 10, 15], output.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], state[0].shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], state[1].shape.as_list())
+
+ def test_prebottleneck(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 15
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ unroll = 10
+ learned_state = False
+
+ inputs_large = tf.zeros([4, 10, 10, 5], dtype=tf.float32)
+ inputs_small = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
+ cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ pre_bottleneck=True)
+ state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ for step in range(unroll):
+ if step % 2 == 0:
+ inputs = cell.pre_bottleneck(inputs_large, state[1], 0)
+ else:
+ inputs = cell.pre_bottleneck(inputs_small, state[1], 1)
+ output, state = cell(inputs, state)
+ self.assertAllEqual([4, 10, 10, 15], output.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], state[0].shape.as_list())
+ self.assertAllEqual([4, 10, 10, 15], state[1].shape.as_list())
+
+ def test_flatten_state(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 15
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ unroll = 10
+ learned_state = False
+
+ inputs_large = tf.zeros([4, 10, 10, 5], dtype=tf.float32)
+ inputs_small = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
+ cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ pre_bottleneck=True,
+ flatten_state=True)
+ state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ for step in range(unroll):
+ if step % 2 == 0:
+ inputs = cell.pre_bottleneck(inputs_large, state[1], 0)
+ else:
+ inputs = cell.pre_bottleneck(inputs_small, state[1], 1)
+ output, state = cell(inputs, state)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output_result, state_result = sess.run([output, state])
+ self.assertAllEqual((4, 10, 10, 15), output_result.shape)
+ self.assertAllEqual((4, 10*10*15), state_result[0].shape)
+ self.assertAllEqual((4, 10*10*15), state_result[1].shape)
+
+
+class GroupedConvLstmCellsTest(tf.test.TestCase):
+
+ def test_run_lstm_cell(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 16
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = False
+
+ inputs = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
+ cell = lstm_cells.GroupedConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ is_training=True)
+ init_state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ output, state_tuple = cell(inputs, init_state)
+ self.assertAllEqual([4, 10, 10, 16], output.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], state_tuple[0].shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], state_tuple[1].shape.as_list())
+
+ def test_run_lstm_cell_with_output_bottleneck(self):
+ filter_size = [3, 3]
+ output_dim = 10
+ output_size = [output_dim] * 2
+ num_units = 16
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = False
+
+ inputs = tf.zeros([batch_size, output_dim, output_dim, 3], dtype=tf.float32)
+ cell = lstm_cells.GroupedConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ is_training=True,
+ output_bottleneck=True)
+ init_state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ output, state_tuple = cell(inputs, init_state)
+ self.assertAllEqual([4, 10, 10, 32], output.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], state_tuple[0].shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], state_tuple[1].shape.as_list())
+
+ def test_get_init_state(self):
+ filter_size = [3, 3]
+ output_dim = 10
+ output_size = [output_dim] * 2
+ num_units = 16
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = False
+
+ cell = lstm_cells.GroupedConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ is_training=True)
+ init_c, init_h = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+
+ self.assertEqual(tf.float32, init_c.dtype)
+ self.assertEqual(tf.float32, init_h.dtype)
+ with self.test_session() as sess:
+ init_c_res, init_h_res = sess.run([init_c, init_h])
+ self.assertAllClose(np.zeros((4, 10, 10, 16)), init_c_res)
+ self.assertAllClose(np.zeros((4, 10, 10, 16)), init_h_res)
+
+ def test_get_init_learned_state(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 16
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ learned_state = True
+
+ cell = lstm_cells.GroupedConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ is_training=True)
+ init_c, init_h = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+
+ self.assertEqual(tf.float32, init_c.dtype)
+ self.assertEqual(tf.float32, init_h.dtype)
+ self.assertAllEqual([4, 10, 10, 16], init_c.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], init_h.shape.as_list())
+
+ def test_unroll(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 16
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ unroll = 10
+ learned_state = False
+
+ inputs = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
+ cell = lstm_cells.GroupedConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ is_training=True)
+ state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ for step in range(unroll):
+ output, state = cell(inputs, state)
+ self.assertAllEqual([4, 10, 10, 16], output.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], state[0].shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], state[1].shape.as_list())
+
+ def test_prebottleneck(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 16
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ unroll = 10
+ learned_state = False
+
+ inputs_large = tf.zeros([4, 10, 10, 5], dtype=tf.float32)
+ inputs_small = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
+ cell = lstm_cells.GroupedConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ is_training=True,
+ pre_bottleneck=True)
+ state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ for step in range(unroll):
+ if step % 2 == 0:
+ inputs = cell.pre_bottleneck(inputs_large, state[1], 0)
+ else:
+ inputs = cell.pre_bottleneck(inputs_small, state[1], 1)
+ output, state = cell(inputs, state)
+ self.assertAllEqual([4, 10, 10, 16], output.shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], state[0].shape.as_list())
+ self.assertAllEqual([4, 10, 10, 16], state[1].shape.as_list())
+
+ def test_flatten_state(self):
+ filter_size = [3, 3]
+ output_size = [10, 10]
+ num_units = 16
+ state_name = 'lstm_state'
+ batch_size = 4
+ dtype = tf.float32
+ unroll = 10
+ learned_state = False
+
+ inputs_large = tf.zeros([4, 10, 10, 5], dtype=tf.float32)
+ inputs_small = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
+ cell = lstm_cells.GroupedConvLSTMCell(
+ filter_size=filter_size,
+ output_size=output_size,
+ num_units=num_units,
+ is_training=True,
+ pre_bottleneck=True,
+ flatten_state=True)
+ state = cell.init_state(
+ state_name, batch_size, dtype, learned_state)
+ for step in range(unroll):
+ if step % 2 == 0:
+ inputs = cell.pre_bottleneck(inputs_large, state[1], 0)
+ else:
+ inputs = cell.pre_bottleneck(inputs_small, state[1], 1)
+ output, state = cell(inputs, state)
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ output_result, state_result = sess.run([output, state])
+ self.assertAllEqual((4, 10, 10, 16), output_result.shape)
+ self.assertAllEqual((4, 10*10*16), state_result[0].shape)
+ self.assertAllEqual((4, 10*10*16), state_result[1].shape)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/lstm/rnn_decoder.py b/models/research/lstm_object_detection/lstm/rnn_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..185ca130396fa8687ba9359f91366b64d16d0255
--- /dev/null
+++ b/models/research/lstm_object_detection/lstm/rnn_decoder.py
@@ -0,0 +1,269 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Custom RNN decoder."""
+
+import tensorflow.compat.v1 as tf
+import lstm_object_detection.lstm.utils as lstm_utils
+
+
+class _NoVariableScope(object):
+
+ def __enter__(self):
+ return
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ return False
+
+
+def rnn_decoder(decoder_inputs,
+ initial_state,
+ cell,
+ loop_function=None,
+ scope=None):
+ """RNN decoder for the LSTM-SSD model.
+
+ This decoder returns a list of all states, rather than only the final state.
+ Args:
+ decoder_inputs: A list of 4D Tensors with shape [batch_size x input_size].
+ initial_state: 2D Tensor with shape [batch_size x cell.state_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ loop_function: If not None, this function will be applied to the i-th output
+ in order to generate the i+1-st input, and decoder_inputs will be ignored,
+ except for the first element ("GO" symbol). This can be used for decoding,
+ but also for training to emulate http://arxiv.org/abs/1506.03099.
+ Signature -- loop_function(prev, i) = next
+ * prev is a 2D Tensor of shape [batch_size x output_size],
+ * i is an integer, the step number (when advanced control is needed),
+ * next is a 2D Tensor of shape [batch_size x input_size].
+ scope: optional VariableScope for the created subgraph.
+ Returns:
+ A tuple of the form (outputs, state), where:
+ outputs: A list of the same length as decoder_inputs of 4D Tensors with
+ shape [batch_size x output_size] containing generated outputs.
+ states: A list of the same length as decoder_inputs of the state of each
+ cell at each time-step. It is a 2D Tensor of shape
+ [batch_size x cell.state_size].
+ """
+ with tf.variable_scope(scope) if scope else _NoVariableScope():
+ state_tuple = initial_state
+ outputs = []
+ states = []
+ prev = None
+ for local_step, decoder_input in enumerate(decoder_inputs):
+ if loop_function is not None and prev is not None:
+ with tf.variable_scope('loop_function', reuse=True):
+ decoder_input = loop_function(prev, local_step)
+ output, state_tuple = cell(decoder_input, state_tuple)
+ outputs.append(output)
+ states.append(state_tuple)
+ if loop_function is not None:
+ prev = output
+ return outputs, states
+
+def multi_input_rnn_decoder(decoder_inputs,
+ initial_state,
+ cell,
+ sequence_step,
+ selection_strategy='RANDOM',
+ is_training=None,
+ is_quantized=False,
+ preprocess_fn_list=None,
+ pre_bottleneck=False,
+ flatten_state=False,
+ scope=None):
+ """RNN decoder for the Interleaved LSTM-SSD model.
+
+ This decoder takes multiple sequences of inputs and selects the input to feed
+ to the rnn at each timestep using its selection_strategy, which can be random,
+ learned, or deterministic.
+ This decoder returns a list of all states, rather than only the final state.
+ Args:
+ decoder_inputs: A list of lists of 2D Tensors [batch_size x input_size].
+ initial_state: 2D Tensor with shape [batch_size x cell.state_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ sequence_step: Tensor [batch_size] of the step number of the first elements
+ in the sequence.
+ selection_strategy: Method for picking the decoder_input to use at each
+ timestep. Must be 'RANDOM', 'SKIPX' for integer X, where X is the number
+ of times to use the second input before using the first.
+ is_training: boolean, whether the network is training. When using learned
+ selection, attempts exploration if training.
+ is_quantized: flag to enable/disable quantization mode.
+ preprocess_fn_list: List of functions accepting two tensor arguments: one
+ timestep of decoder_inputs and the lstm state. If not None,
+ decoder_inputs[i] will be updated with preprocess_fn[i] at the start of
+ each timestep.
+ pre_bottleneck: if True, use separate bottleneck weights for each sequence.
+ Useful when input sequences have differing numbers of channels. Final
+ bottlenecks will have the same dimension.
+ flatten_state: Whether the LSTM state is flattened.
+ scope: optional VariableScope for the created subgraph.
+ Returns:
+ A tuple of the form (outputs, state), where:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x output_size] containing generated outputs.
+ states: A list of the same length as decoder_inputs of the state of each
+ cell at each time-step. It is a 2D Tensor of shape
+ [batch_size x cell.state_size].
+ Raises:
+ ValueError: If selection_strategy is not recognized or unexpected unroll
+ length.
+ """
+ if flatten_state and len(decoder_inputs[0]) > 1:
+ raise ValueError('In export mode, unroll length should not be more than 1')
+ with tf.variable_scope(scope) if scope else _NoVariableScope():
+ state_tuple = initial_state
+ outputs = []
+ states = []
+ batch_size = decoder_inputs[0][0].shape[0].value
+ num_sequences = len(decoder_inputs)
+ sequence_length = len(decoder_inputs[0])
+
+ for local_step in range(sequence_length):
+ for sequence_index in range(num_sequences):
+ if preprocess_fn_list is not None:
+ decoder_inputs[sequence_index][local_step] = (
+ preprocess_fn_list[sequence_index](
+ decoder_inputs[sequence_index][local_step], state_tuple[0]))
+ if pre_bottleneck:
+ decoder_inputs[sequence_index][local_step] = cell.pre_bottleneck(
+ inputs=decoder_inputs[sequence_index][local_step],
+ state=state_tuple[1],
+ input_index=sequence_index)
+
+ action = generate_action(selection_strategy, local_step, sequence_step,
+ [batch_size, 1, 1, 1])
+ inputs, _ = (
+ select_inputs(decoder_inputs, action, local_step, is_training,
+ is_quantized))
+ # Mark base network endpoints under raw_inputs/
+ with tf.name_scope(None):
+ inputs = tf.identity(inputs, 'raw_inputs/base_endpoint')
+ output, state_tuple_out = cell(inputs, state_tuple)
+ state_tuple = select_state(state_tuple, state_tuple_out, action)
+
+ outputs.append(output)
+ states.append(state_tuple)
+ return outputs, states
+
+
+def generate_action(selection_strategy, local_step, sequence_step,
+ action_shape):
+ """Generate current (binary) action based on selection strategy.
+
+ Args:
+ selection_strategy: Method for picking the decoder_input to use at each
+ timestep. Must be 'RANDOM', 'SKIPX' for integer X, where X is the number
+ of times to use the second input before using the first.
+ local_step: Tensor [batch_size] of the step number within the current
+ unrolled batch.
+ sequence_step: Tensor [batch_size] of the step number of the first elements
+ in the sequence.
+ action_shape: The shape of action tensor to be generated.
+
+ Returns:
+ A tensor of shape action_shape, each element is an individual action.
+
+ Raises:
+ ValueError: if selection_strategy is not supported or if 'SKIP' is not
+ followed by numerics.
+ """
+ if selection_strategy.startswith('RANDOM'):
+ action = tf.random.uniform(action_shape, maxval=2, dtype=tf.int32)
+ action = tf.minimum(action, 1)
+
+ # First step always runs large network.
+ if local_step == 0 and sequence_step is not None:
+ action *= tf.minimum(
+ tf.reshape(tf.cast(sequence_step, tf.int32), action_shape), 1)
+ elif selection_strategy.startswith('SKIP'):
+ inter_count = int(selection_strategy[4:])
+ if local_step % (inter_count + 1) == 0:
+ action = tf.zeros(action_shape)
+ else:
+ action = tf.ones(action_shape)
+ else:
+ raise ValueError('Selection strategy %s not recognized' %
+ selection_strategy)
+ return tf.cast(action, tf.int32)
+
+
+def select_inputs(decoder_inputs, action, local_step, is_training, is_quantized,
+ get_alt_inputs=False):
+ """Selects sequence from decoder_inputs based on 1D actions.
+
+ Given multiple input batches, creates a single output batch by
+ selecting from the action[i]-ith input for the i-th batch element.
+
+ Args:
+ decoder_inputs: A 2-D list of tensor inputs.
+ action: A tensor of shape [batch_size]. Each element corresponds to an index
+ of decoder_inputs to choose.
+ local_step: The current timestep.
+ is_training: boolean, whether the network is training. When using learned
+ selection, attempts exploration if training.
+ is_quantized: flag to enable/disable quantization mode.
+ get_alt_inputs: Whether the non-chosen inputs should also be returned.
+
+ Returns:
+ The constructed output. Also outputs the elements that were not chosen
+ if get_alt_inputs is True, otherwise None.
+
+ Raises:
+ ValueError: if the decoder inputs contains other than two sequences.
+ """
+ num_seqs = len(decoder_inputs)
+ if not num_seqs == 2:
+ raise ValueError('Currently only supports two sets of inputs.')
+ stacked_inputs = tf.stack(
+ [decoder_inputs[seq_index][local_step] for seq_index in range(num_seqs)],
+ axis=-1)
+ action_index = tf.one_hot(action, num_seqs)
+ selected_inputs = (
+ lstm_utils.quantize_op(stacked_inputs * action_index, is_training,
+ is_quantized, scope='quant_selected_inputs'))
+ inputs = tf.reduce_sum(selected_inputs, axis=-1)
+ inputs_alt = None
+ # Only works for 2 models.
+ if get_alt_inputs:
+ # Reverse of action_index.
+ action_index_alt = tf.one_hot(action, num_seqs, on_value=0.0, off_value=1.0)
+ selected_inputs = (
+ lstm_utils.quantize_op(stacked_inputs * action_index_alt, is_training,
+ is_quantized, scope='quant_selected_inputs_alt'))
+ inputs_alt = tf.reduce_sum(selected_inputs, axis=-1)
+ return inputs, inputs_alt
+
+def select_state(previous_state, new_state, action):
+ """Select state given action.
+
+ Currently only supports binary action. If action is 0, it means the state is
+ generated from the large model, and thus we will update the state. Otherwise,
+ if the action is 1, it means the state is generated from the small model, and
+ in interleaved model, we skip this state update.
+
+ Args:
+ previous_state: A state tuple representing state from previous step.
+ new_state: A state tuple representing newly computed state.
+ action: A tensor the same shape as state.
+
+ Returns:
+ A state tuple selected based on the given action.
+ """
+ action = tf.cast(action, tf.float32)
+ state_c = previous_state[0] * action + new_state[0] * (1 - action)
+ state_h = previous_state[1] * action + new_state[1] * (1 - action)
+ return (state_c, state_h)
diff --git a/models/research/lstm_object_detection/lstm/rnn_decoder_test.py b/models/research/lstm_object_detection/lstm/rnn_decoder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..480694f6fde57332b2f72357d5d6903ec7a12f87
--- /dev/null
+++ b/models/research/lstm_object_detection/lstm/rnn_decoder_test.py
@@ -0,0 +1,306 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for lstm_object_detection.lstm.rnn_decoder."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+
+from tensorflow.contrib import layers as contrib_layers
+from tensorflow.contrib import rnn as contrib_rnn
+from lstm_object_detection.lstm import rnn_decoder
+
+
+class MockRnnCell(contrib_rnn.RNNCell):
+
+ def __init__(self, input_size, num_units):
+ self._input_size = input_size
+ self._num_units = num_units
+ self._filter_size = [3, 3]
+
+ def __call__(self, inputs, state_tuple):
+ outputs = tf.concat([inputs, state_tuple[0]], axis=3)
+ new_state_tuple = (tf.multiply(state_tuple[0], 2), state_tuple[1])
+ return outputs, new_state_tuple
+
+ def state_size(self):
+ return self._num_units
+
+ def output_size(self):
+ return self._input_size + self._num_units
+
+ def pre_bottleneck(self, inputs, state, input_index):
+ with tf.variable_scope('bottleneck_%d' % input_index, reuse=tf.AUTO_REUSE):
+ inputs = contrib_layers.separable_conv2d(
+ tf.concat([inputs, state], 3),
+ self._input_size,
+ self._filter_size,
+ depth_multiplier=1,
+ activation_fn=tf.nn.relu6,
+ normalizer_fn=None)
+ return inputs
+
+
+class RnnDecoderTest(tf.test.TestCase):
+
+ def test_rnn_decoder_single_unroll(self):
+ batch_size = 2
+ num_unroll = 1
+ num_units = 64
+ width = 8
+ height = 10
+ input_channels = 128
+
+ initial_state = tf.random_normal((batch_size, width, height, num_units))
+ inputs = tf.random_normal([batch_size, width, height, input_channels])
+
+ rnn_cell = MockRnnCell(input_channels, num_units)
+ outputs, states = rnn_decoder.rnn_decoder(
+ decoder_inputs=[inputs] * num_unroll,
+ initial_state=(initial_state, initial_state),
+ cell=rnn_cell)
+
+ self.assertEqual(len(outputs), num_unroll)
+ self.assertEqual(len(states), num_unroll)
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ results = sess.run((outputs, states, inputs, initial_state))
+ outputs_results = results[0]
+ states_results = results[1]
+ inputs_results = results[2]
+ initial_states_results = results[3]
+ self.assertEqual(outputs_results[0].shape,
+ (batch_size, width, height, input_channels + num_units))
+ self.assertAllEqual(
+ outputs_results[0],
+ np.concatenate((inputs_results, initial_states_results), axis=3))
+ self.assertEqual(states_results[0][0].shape,
+ (batch_size, width, height, num_units))
+ self.assertEqual(states_results[0][1].shape,
+ (batch_size, width, height, num_units))
+ self.assertAllEqual(states_results[0][0],
+ np.multiply(initial_states_results, 2.0))
+ self.assertAllEqual(states_results[0][1], initial_states_results)
+
+ def test_rnn_decoder_multiple_unroll(self):
+ batch_size = 2
+ num_unroll = 3
+ num_units = 64
+ width = 8
+ height = 10
+ input_channels = 128
+
+ initial_state = tf.random_normal((batch_size, width, height, num_units))
+ inputs = tf.random_normal([batch_size, width, height, input_channels])
+
+ rnn_cell = MockRnnCell(input_channels, num_units)
+ outputs, states = rnn_decoder.rnn_decoder(
+ decoder_inputs=[inputs] * num_unroll,
+ initial_state=(initial_state, initial_state),
+ cell=rnn_cell)
+
+ self.assertEqual(len(outputs), num_unroll)
+ self.assertEqual(len(states), num_unroll)
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ results = sess.run((outputs, states, inputs, initial_state))
+ outputs_results = results[0]
+ states_results = results[1]
+ inputs_results = results[2]
+ initial_states_results = results[3]
+ for i in range(num_unroll):
+ previous_state = ([initial_states_results, initial_states_results]
+ if i == 0 else states_results[i - 1])
+ self.assertEqual(
+ outputs_results[i].shape,
+ (batch_size, width, height, input_channels + num_units))
+ self.assertAllEqual(
+ outputs_results[i],
+ np.concatenate((inputs_results, previous_state[0]), axis=3))
+ self.assertEqual(states_results[i][0].shape,
+ (batch_size, width, height, num_units))
+ self.assertEqual(states_results[i][1].shape,
+ (batch_size, width, height, num_units))
+ self.assertAllEqual(states_results[i][0],
+ np.multiply(previous_state[0], 2.0))
+ self.assertAllEqual(states_results[i][1], previous_state[1])
+
+
+class MultiInputRnnDecoderTest(tf.test.TestCase):
+
+ def test_rnn_decoder_single_unroll(self):
+ batch_size = 2
+ num_unroll = 1
+ num_units = 12
+ width = 8
+ height = 10
+ input_channels_large = 24
+ input_channels_small = 12
+ bottleneck_channels = 20
+
+ initial_state_c = tf.random_normal((batch_size, width, height, num_units))
+ initial_state_h = tf.random_normal((batch_size, width, height, num_units))
+ initial_state = (initial_state_c, initial_state_h)
+ inputs_large = tf.random_normal(
+ [batch_size, width, height, input_channels_large])
+ inputs_small = tf.random_normal(
+ [batch_size, width, height, input_channels_small])
+
+ rnn_cell = MockRnnCell(bottleneck_channels, num_units)
+ outputs, states = rnn_decoder.multi_input_rnn_decoder(
+ decoder_inputs=[[inputs_large] * num_unroll,
+ [inputs_small] * num_unroll],
+ initial_state=initial_state,
+ cell=rnn_cell,
+ sequence_step=tf.zeros([batch_size]),
+ pre_bottleneck=True)
+
+ self.assertEqual(len(outputs), num_unroll)
+ self.assertEqual(len(states), num_unroll)
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ results = sess.run(
+ (outputs, states, inputs_large, inputs_small, initial_state))
+ outputs_results = results[0]
+ states_results = results[1]
+ initial_states_results = results[4]
+ self.assertEqual(
+ outputs_results[0].shape,
+ (batch_size, width, height, bottleneck_channels + num_units))
+ self.assertEqual(states_results[0][0].shape,
+ (batch_size, width, height, num_units))
+ self.assertEqual(states_results[0][1].shape,
+ (batch_size, width, height, num_units))
+ # The first step should always update state.
+ self.assertAllEqual(states_results[0][0],
+ np.multiply(initial_states_results[0], 2))
+ self.assertAllEqual(states_results[0][1], initial_states_results[1])
+
+ def test_rnn_decoder_multiple_unroll(self):
+ batch_size = 2
+ num_unroll = 3
+ num_units = 12
+ width = 8
+ height = 10
+ input_channels_large = 24
+ input_channels_small = 12
+ bottleneck_channels = 20
+
+ initial_state_c = tf.random_normal((batch_size, width, height, num_units))
+ initial_state_h = tf.random_normal((batch_size, width, height, num_units))
+ initial_state = (initial_state_c, initial_state_h)
+ inputs_large = tf.random_normal(
+ [batch_size, width, height, input_channels_large])
+ inputs_small = tf.random_normal(
+ [batch_size, width, height, input_channels_small])
+
+ rnn_cell = MockRnnCell(bottleneck_channels, num_units)
+ outputs, states = rnn_decoder.multi_input_rnn_decoder(
+ decoder_inputs=[[inputs_large] * num_unroll,
+ [inputs_small] * num_unroll],
+ initial_state=initial_state,
+ cell=rnn_cell,
+ sequence_step=tf.zeros([batch_size]),
+ pre_bottleneck=True)
+
+ self.assertEqual(len(outputs), num_unroll)
+ self.assertEqual(len(states), num_unroll)
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ results = sess.run(
+ (outputs, states, inputs_large, inputs_small, initial_state))
+ outputs_results = results[0]
+ states_results = results[1]
+ initial_states_results = results[4]
+
+ # The first step should always update state.
+ self.assertAllEqual(states_results[0][0],
+ np.multiply(initial_states_results[0], 2))
+ self.assertAllEqual(states_results[0][1], initial_states_results[1])
+ for i in range(num_unroll):
+ self.assertEqual(
+ outputs_results[i].shape,
+ (batch_size, width, height, bottleneck_channels + num_units))
+ self.assertEqual(states_results[i][0].shape,
+ (batch_size, width, height, num_units))
+ self.assertEqual(states_results[i][1].shape,
+ (batch_size, width, height, num_units))
+
+ def test_rnn_decoder_multiple_unroll_with_skip(self):
+ batch_size = 2
+ num_unroll = 5
+ num_units = 12
+ width = 8
+ height = 10
+ input_channels_large = 24
+ input_channels_small = 12
+ bottleneck_channels = 20
+ skip = 2
+
+ initial_state_c = tf.random_normal((batch_size, width, height, num_units))
+ initial_state_h = tf.random_normal((batch_size, width, height, num_units))
+ initial_state = (initial_state_c, initial_state_h)
+ inputs_large = tf.random_normal(
+ [batch_size, width, height, input_channels_large])
+ inputs_small = tf.random_normal(
+ [batch_size, width, height, input_channels_small])
+
+ rnn_cell = MockRnnCell(bottleneck_channels, num_units)
+ outputs, states = rnn_decoder.multi_input_rnn_decoder(
+ decoder_inputs=[[inputs_large] * num_unroll,
+ [inputs_small] * num_unroll],
+ initial_state=initial_state,
+ cell=rnn_cell,
+ sequence_step=tf.zeros([batch_size]),
+ pre_bottleneck=True,
+ selection_strategy='SKIP%d' % skip)
+
+ self.assertEqual(len(outputs), num_unroll)
+ self.assertEqual(len(states), num_unroll)
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ results = sess.run(
+ (outputs, states, inputs_large, inputs_small, initial_state))
+ outputs_results = results[0]
+ states_results = results[1]
+ initial_states_results = results[4]
+
+ for i in range(num_unroll):
+ self.assertEqual(
+ outputs_results[i].shape,
+ (batch_size, width, height, bottleneck_channels + num_units))
+ self.assertEqual(states_results[i][0].shape,
+ (batch_size, width, height, num_units))
+ self.assertEqual(states_results[i][1].shape,
+ (batch_size, width, height, num_units))
+
+ previous_state = (
+ initial_states_results if i == 0 else states_results[i - 1])
+ # State only updates during key frames
+ if i % (skip + 1) == 0:
+ self.assertAllEqual(states_results[i][0],
+ np.multiply(previous_state[0], 2))
+ self.assertAllEqual(states_results[i][1], previous_state[1])
+ else:
+ self.assertAllEqual(states_results[i][0], previous_state[0])
+ self.assertAllEqual(states_results[i][1], previous_state[1])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/lstm/utils.py b/models/research/lstm_object_detection/lstm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c87db4bb208ece5102df327e5487fbffb2fe2ce
--- /dev/null
+++ b/models/research/lstm_object_detection/lstm/utils.py
@@ -0,0 +1,257 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Quantization related ops for LSTM."""
+
+from __future__ import absolute_import
+from __future__ import division
+
+import tensorflow.compat.v1 as tf
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib import layers as contrib_layers
+from tensorflow.python.training import moving_averages
+
+
+def _quant_var(
+ name,
+ initializer_val,
+ vars_collection=tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
+):
+ """Create an var for storing the min/max quantization range."""
+ return contrib_framework.model_variable(
+ name,
+ shape=[],
+ initializer=tf.constant_initializer(initializer_val),
+ collections=[vars_collection],
+ trainable=False)
+
+
+def quantizable_concat(inputs,
+ axis,
+ is_training,
+ is_quantized=True,
+ default_min=0,
+ default_max=6,
+ ema_decay=0.999,
+ scope='quantized_concat'):
+ """Concat replacement with quantization option.
+
+ Allows concat inputs to share the same min max ranges,
+ from experimental/gazelle/synthetic/model/tpu/utils.py.
+
+ Args:
+ inputs: list of tensors to concatenate.
+ axis: dimension along which to concatenate.
+ is_training: true if the graph is a training graph.
+ is_quantized: flag to enable/disable quantization.
+ default_min: default min value for fake quant op.
+ default_max: default max value for fake quant op.
+ ema_decay: the moving average decay for the quantization variables.
+ scope: Optional scope for variable_scope.
+
+ Returns:
+ Tensor resulting from concatenation of input tensors
+ """
+ if is_quantized:
+ with tf.variable_scope(scope):
+ tf.logging.info('inputs: {}'.format(inputs))
+ for t in inputs:
+ tf.logging.info(t)
+
+ min_var = _quant_var('min', default_min)
+ max_var = _quant_var('max', default_max)
+ if not is_training:
+ # If we are building an eval graph just use the values in the variables.
+ quant_inputs = [
+ tf.fake_quant_with_min_max_vars(t, min_var, max_var) for t in inputs
+ ]
+ tf.logging.info('min_val: {}'.format(min_var))
+ tf.logging.info('max_val: {}'.format(max_var))
+ else:
+ concat_tensors = tf.concat(inputs, axis=axis)
+ tf.logging.info('concat_tensors: {}'.format(concat_tensors))
+ # TFLite requires that 0.0 is always in the [min; max] range.
+ range_min = tf.minimum(
+ tf.reduce_min(concat_tensors), 0.0, name='SafeQuantRangeMin')
+ range_max = tf.maximum(
+ tf.reduce_max(concat_tensors), 0.0, name='SafeQuantRangeMax')
+ # Otherwise we need to keep track of the moving averages of the min and
+ # of the elements of the input tensor max.
+ min_val = moving_averages.assign_moving_average(
+ min_var,
+ range_min,
+ ema_decay,
+ name='AssignMinEma')
+ max_val = moving_averages.assign_moving_average(
+ max_var,
+ range_max,
+ ema_decay,
+ name='AssignMaxEma')
+ tf.logging.info('min_val: {}'.format(min_val))
+ tf.logging.info('max_val: {}'.format(max_val))
+ quant_inputs = [
+ tf.fake_quant_with_min_max_vars(t, min_val, max_val) for t in inputs
+ ]
+ tf.logging.info('quant_inputs: {}'.format(quant_inputs))
+ outputs = tf.concat(quant_inputs, axis=axis)
+ tf.logging.info('outputs: {}'.format(outputs))
+ else:
+ outputs = tf.concat(inputs, axis=axis)
+ return outputs
+
+
+def quantizable_separable_conv2d(inputs,
+ num_outputs,
+ kernel_size,
+ is_quantized=True,
+ depth_multiplier=1,
+ stride=1,
+ activation_fn=tf.nn.relu6,
+ normalizer_fn=None,
+ weights_initializer=None,
+ pointwise_initializer=None,
+ scope=None):
+ """Quantization friendly backward compatible separable conv2d.
+
+ This op has the same API is separable_conv2d. The main difference is that an
+ additional BiasAdd is manually inserted after the depthwise conv, such that
+ the depthwise bias will not have name conflict with pointwise bias. The
+ motivation of this op is that quantization script need BiasAdd in order to
+ recognize the op, in which a native call to separable_conv2d do not create
+ for the depthwise conv.
+
+ Args:
+ inputs: A tensor of size [batch_size, height, width, channels].
+ num_outputs: The number of pointwise convolution output filters. If is
+ None, then we skip the pointwise convolution stage.
+ kernel_size: A list of length 2: [kernel_height, kernel_width] of the
+ filters. Can be an int if both values are the same.
+ is_quantized: flag to enable/disable quantization.
+ depth_multiplier: The number of depthwise convolution output channels for
+ each input channel. The total number of depthwise convolution output
+ channels will be equal to num_filters_in * depth_multiplier.
+ stride: A list of length 2: [stride_height, stride_width], specifying the
+ depthwise convolution stride. Can be an int if both strides are the same.
+ activation_fn: Activation function. The default value is a ReLU function.
+ Explicitly set it to None to skip it and maintain a linear activation.
+ normalizer_fn: Normalization function to use instead of biases.
+ weights_initializer: An initializer for the depthwise weights.
+ pointwise_initializer: An initializer for the pointwise weights.
+ scope: Optional scope for variable_scope.
+
+ Returns:
+ Tensor resulting from concatenation of input tensors
+ """
+ if is_quantized:
+ outputs = contrib_layers.separable_conv2d(
+ inputs,
+ None,
+ kernel_size,
+ depth_multiplier=depth_multiplier,
+ stride=1,
+ activation_fn=None,
+ normalizer_fn=None,
+ biases_initializer=None,
+ weights_initializer=weights_initializer,
+ pointwise_initializer=None,
+ scope=scope)
+ outputs = contrib_layers.bias_add(
+ outputs, trainable=True, scope='%s_bias' % scope)
+ outputs = contrib_layers.conv2d(
+ outputs,
+ num_outputs, [1, 1],
+ activation_fn=activation_fn,
+ stride=stride,
+ normalizer_fn=normalizer_fn,
+ weights_initializer=pointwise_initializer,
+ scope=scope)
+ else:
+ outputs = contrib_layers.separable_conv2d(
+ inputs,
+ num_outputs,
+ kernel_size,
+ depth_multiplier=depth_multiplier,
+ stride=stride,
+ activation_fn=activation_fn,
+ normalizer_fn=normalizer_fn,
+ weights_initializer=weights_initializer,
+ pointwise_initializer=pointwise_initializer,
+ scope=scope)
+ return outputs
+
+
+def quantize_op(inputs,
+ is_training=True,
+ is_quantized=True,
+ default_min=0,
+ default_max=6,
+ ema_decay=0.999,
+ scope='quant'):
+ """Inserts a fake quantization op after inputs.
+
+ Args:
+ inputs: A tensor of size [batch_size, height, width, channels].
+ is_training: true if the graph is a training graph.
+ is_quantized: flag to enable/disable quantization.
+ default_min: default min value for fake quant op.
+ default_max: default max value for fake quant op.
+ ema_decay: the moving average decay for the quantization variables.
+ scope: Optional scope for variable_scope.
+
+ Returns:
+ Tensor resulting from quantizing the input tensors.
+ """
+ if not is_quantized:
+ return inputs
+
+ with tf.variable_scope(scope):
+ min_var = _quant_var('min', default_min)
+ max_var = _quant_var('max', default_max)
+ if not is_training:
+ # Just use variables in the checkpoint.
+ return tf.fake_quant_with_min_max_vars(inputs, min_var, max_var)
+
+ # While training, collect EMAs of ranges seen, store in min_var, max_var.
+ # TFLite requires that 0.0 is always in the [min; max] range.
+ range_min = tf.minimum(tf.reduce_min(inputs), 0.0, 'SafeQuantRangeMin')
+ # We set the lower_bound of max_range to prevent range collapse.
+ range_max = tf.maximum(tf.reduce_max(inputs), 1e-5, 'SafeQuantRangeMax')
+ min_val = moving_averages.assign_moving_average(
+ min_var, range_min, ema_decay, name='AssignMinEma')
+ max_val = moving_averages.assign_moving_average(
+ max_var, range_max, ema_decay, name='AssignMaxEma')
+ return tf.fake_quant_with_min_max_vars(inputs, min_val, max_val)
+
+
+def fixed_quantize_op(inputs, is_quantized=True,
+ fixed_min=0.0, fixed_max=6.0, scope='quant'):
+ """Inserts a fake quantization op with fixed range after inputs.
+
+ Args:
+ inputs: A tensor of size [batch_size, height, width, channels].
+ is_quantized: flag to enable/disable quantization.
+ fixed_min: fixed min value for fake quant op.
+ fixed_max: fixed max value for fake quant op.
+ scope: Optional scope for variable_scope.
+
+ Returns:
+ Tensor resulting from quantizing the input tensors.
+ """
+ if not is_quantized:
+ return inputs
+
+ with tf.variable_scope(scope):
+ # Just use fixed quantization range.
+ return tf.fake_quant_with_min_max_args(inputs, fixed_min, fixed_max)
diff --git a/models/research/lstm_object_detection/lstm/utils_test.py b/models/research/lstm_object_detection/lstm/utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5f5bc75db8f7e7be44fc15898598e5179e51236
--- /dev/null
+++ b/models/research/lstm_object_detection/lstm/utils_test.py
@@ -0,0 +1,149 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lstm_object_detection.lstm.utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.compat.v1 as tf
+from lstm_object_detection.lstm import utils
+
+
+class QuantizableUtilsTest(tf.test.TestCase):
+
+ def test_quantizable_concat_is_training(self):
+ inputs_1 = tf.zeros([4, 10, 10, 1], dtype=tf.float32)
+ inputs_2 = tf.ones([4, 10, 10, 2], dtype=tf.float32)
+ concat_in_train = utils.quantizable_concat([inputs_1, inputs_2],
+ axis=3,
+ is_training=True)
+ self.assertAllEqual([4, 10, 10, 3], concat_in_train.shape.as_list())
+ self._check_min_max_ema(tf.get_default_graph())
+ self._check_min_max_vars(tf.get_default_graph())
+
+ def test_quantizable_concat_inference(self):
+ inputs_1 = tf.zeros([4, 10, 10, 1], dtype=tf.float32)
+ inputs_2 = tf.ones([4, 10, 10, 2], dtype=tf.float32)
+ concat_in_train = utils.quantizable_concat([inputs_1, inputs_2],
+ axis=3,
+ is_training=False)
+ self.assertAllEqual([4, 10, 10, 3], concat_in_train.shape.as_list())
+ self._check_no_min_max_ema(tf.get_default_graph())
+ self._check_min_max_vars(tf.get_default_graph())
+
+ def test_quantizable_concat_not_quantized_is_training(self):
+ inputs_1 = tf.zeros([4, 10, 10, 1], dtype=tf.float32)
+ inputs_2 = tf.ones([4, 10, 10, 2], dtype=tf.float32)
+ concat_in_train = utils.quantizable_concat([inputs_1, inputs_2],
+ axis=3,
+ is_training=True,
+ is_quantized=False)
+ self.assertAllEqual([4, 10, 10, 3], concat_in_train.shape.as_list())
+ self._check_no_min_max_ema(tf.get_default_graph())
+ self._check_no_min_max_vars(tf.get_default_graph())
+
+ def test_quantizable_concat_not_quantized_inference(self):
+ inputs_1 = tf.zeros([4, 10, 10, 1], dtype=tf.float32)
+ inputs_2 = tf.ones([4, 10, 10, 2], dtype=tf.float32)
+ concat_in_train = utils.quantizable_concat([inputs_1, inputs_2],
+ axis=3,
+ is_training=False,
+ is_quantized=False)
+ self.assertAllEqual([4, 10, 10, 3], concat_in_train.shape.as_list())
+ self._check_no_min_max_ema(tf.get_default_graph())
+ self._check_no_min_max_vars(tf.get_default_graph())
+
+ def test_quantize_op_is_training(self):
+ inputs = tf.zeros([4, 10, 10, 128], dtype=tf.float32)
+ outputs = utils.quantize_op(inputs)
+ self.assertAllEqual(inputs.shape.as_list(), outputs.shape.as_list())
+ self._check_min_max_ema(tf.get_default_graph())
+ self._check_min_max_vars(tf.get_default_graph())
+
+ def test_quantize_op_inference(self):
+ inputs = tf.zeros([4, 10, 10, 128], dtype=tf.float32)
+ outputs = utils.quantize_op(inputs, is_training=False)
+ self.assertAllEqual(inputs.shape.as_list(), outputs.shape.as_list())
+ self._check_no_min_max_ema(tf.get_default_graph())
+ self._check_min_max_vars(tf.get_default_graph())
+
+ def test_fixed_quantize_op(self):
+ inputs = tf.zeros([4, 10, 10, 128], dtype=tf.float32)
+ outputs = utils.fixed_quantize_op(inputs)
+ self.assertAllEqual(inputs.shape.as_list(), outputs.shape.as_list())
+ self._check_no_min_max_ema(tf.get_default_graph())
+ self._check_no_min_max_vars(tf.get_default_graph())
+
+ def _check_min_max_vars(self, graph):
+ op_types = [op.type for op in graph.get_operations()]
+ self.assertTrue(
+ any('FakeQuantWithMinMaxVars' in op_type for op_type in op_types))
+
+ def _check_min_max_ema(self, graph):
+ op_names = [op.name for op in graph.get_operations()]
+ self.assertTrue(any('AssignMinEma' in name for name in op_names))
+ self.assertTrue(any('AssignMaxEma' in name for name in op_names))
+ self.assertTrue(any('SafeQuantRangeMin' in name for name in op_names))
+ self.assertTrue(any('SafeQuantRangeMax' in name for name in op_names))
+
+ def _check_no_min_max_vars(self, graph):
+ op_types = [op.type for op in graph.get_operations()]
+ self.assertFalse(
+ any('FakeQuantWithMinMaxVars' in op_type for op_type in op_types))
+
+ def _check_no_min_max_ema(self, graph):
+ op_names = [op.name for op in graph.get_operations()]
+ self.assertFalse(any('AssignMinEma' in name for name in op_names))
+ self.assertFalse(any('AssignMaxEma' in name for name in op_names))
+ self.assertFalse(any('SafeQuantRangeMin' in name for name in op_names))
+ self.assertFalse(any('SafeQuantRangeMax' in name for name in op_names))
+
+
+class QuantizableSeparableConv2dTest(tf.test.TestCase):
+
+ def test_quantizable_separable_conv2d(self):
+ inputs = tf.zeros([4, 10, 10, 128], dtype=tf.float32)
+ num_outputs = 64
+ kernel_size = [3, 3]
+ scope = 'QuantSeparable'
+ outputs = utils.quantizable_separable_conv2d(
+ inputs, num_outputs, kernel_size, scope=scope)
+ self.assertAllEqual([4, 10, 10, num_outputs], outputs.shape.as_list())
+ self._check_depthwise_bias_add(tf.get_default_graph(), scope)
+
+ def test_quantizable_separable_conv2d_not_quantized(self):
+ inputs = tf.zeros([4, 10, 10, 128], dtype=tf.float32)
+ num_outputs = 64
+ kernel_size = [3, 3]
+ scope = 'QuantSeparable'
+ outputs = utils.quantizable_separable_conv2d(
+ inputs, num_outputs, kernel_size, is_quantized=False, scope=scope)
+ self.assertAllEqual([4, 10, 10, num_outputs], outputs.shape.as_list())
+ self._check_no_depthwise_bias_add(tf.get_default_graph(), scope)
+
+ def _check_depthwise_bias_add(self, graph, scope):
+ op_names = [op.name for op in graph.get_operations()]
+ self.assertTrue(
+ any('%s_bias/BiasAdd' % scope in name for name in op_names))
+
+ def _check_no_depthwise_bias_add(self, graph, scope):
+ op_names = [op.name for op in graph.get_operations()]
+ self.assertFalse(
+ any('%s_bias/BiasAdd' % scope in name for name in op_names))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/meta_architectures/__init__.py b/models/research/lstm_object_detection/meta_architectures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/lstm_object_detection/meta_architectures/lstm_ssd_meta_arch.py b/models/research/lstm_object_detection/meta_architectures/lstm_ssd_meta_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..22edc97ee348df8a4a4ce8b885a4df6a6b891072
--- /dev/null
+++ b/models/research/lstm_object_detection/meta_architectures/lstm_ssd_meta_arch.py
@@ -0,0 +1,463 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""LSTM SSD Meta-architecture definition.
+
+General tensorflow implementation of convolutional Multibox/SSD detection
+models with LSTM states, for use on video data. This implementation supports
+both regular LSTM-SSD and interleaved LSTM-SSD framework.
+
+See https://arxiv.org/abs/1711.06368 and https://arxiv.org/abs/1903.10172
+for details.
+"""
+import abc
+import re
+import tensorflow.compat.v1 as tf
+
+from object_detection.core import box_list_ops
+from object_detection.core import matcher
+from object_detection.core import standard_fields as fields
+from object_detection.meta_architectures import ssd_meta_arch
+from object_detection.utils import ops
+from object_detection.utils import shape_utils
+
+
+class LSTMSSDMetaArch(ssd_meta_arch.SSDMetaArch):
+ """LSTM Meta-architecture definition."""
+
+ def __init__(self,
+ is_training,
+ anchor_generator,
+ box_predictor,
+ box_coder,
+ feature_extractor,
+ encode_background_as_zeros,
+ image_resizer_fn,
+ non_max_suppression_fn,
+ score_conversion_fn,
+ classification_loss,
+ localization_loss,
+ classification_loss_weight,
+ localization_loss_weight,
+ normalize_loss_by_num_matches,
+ hard_example_miner,
+ unroll_length,
+ target_assigner_instance,
+ add_summaries=True):
+ super(LSTMSSDMetaArch, self).__init__(
+ is_training=is_training,
+ anchor_generator=anchor_generator,
+ box_predictor=box_predictor,
+ box_coder=box_coder,
+ feature_extractor=feature_extractor,
+ encode_background_as_zeros=encode_background_as_zeros,
+ image_resizer_fn=image_resizer_fn,
+ non_max_suppression_fn=non_max_suppression_fn,
+ score_conversion_fn=score_conversion_fn,
+ classification_loss=classification_loss,
+ localization_loss=localization_loss,
+ classification_loss_weight=classification_loss_weight,
+ localization_loss_weight=localization_loss_weight,
+ normalize_loss_by_num_matches=normalize_loss_by_num_matches,
+ hard_example_miner=hard_example_miner,
+ target_assigner_instance=target_assigner_instance,
+ add_summaries=add_summaries)
+ self._unroll_length = unroll_length
+
+ @property
+ def unroll_length(self):
+ return self._unroll_length
+
+ @unroll_length.setter
+ def unroll_length(self, unroll_length):
+ self._unroll_length = unroll_length
+
+ def predict(self, preprocessed_inputs, true_image_shapes, states=None,
+ state_name='lstm_state', feature_scope=None):
+ with tf.variable_scope(self._extract_features_scope,
+ values=[preprocessed_inputs], reuse=tf.AUTO_REUSE):
+ feature_maps = self._feature_extractor.extract_features(
+ preprocessed_inputs, states, state_name,
+ unroll_length=self._unroll_length, scope=feature_scope)
+ feature_map_spatial_dims = self._get_feature_map_spatial_dims(feature_maps)
+ image_shape = shape_utils.combined_static_and_dynamic_shape(
+ preprocessed_inputs)
+ self._batch_size = preprocessed_inputs.shape[0].value / self._unroll_length
+ self._states = states
+ anchors = self._anchor_generator.generate(feature_map_spatial_dims,
+ im_height=image_shape[1],
+ im_width=image_shape[2])
+ with tf.variable_scope('MultipleGridAnchorGenerator', reuse=tf.AUTO_REUSE):
+ self._anchors = box_list_ops.concatenate(anchors)
+ prediction_dict = self._box_predictor.predict(
+ feature_maps, self._anchor_generator.num_anchors_per_location())
+ with tf.variable_scope('Loss', reuse=tf.AUTO_REUSE):
+ box_encodings = tf.concat(prediction_dict['box_encodings'], axis=1)
+ if box_encodings.shape.ndims == 4 and box_encodings.shape[2] == 1:
+ box_encodings = tf.squeeze(box_encodings, axis=2)
+ class_predictions_with_background = tf.concat(
+ prediction_dict['class_predictions_with_background'], axis=1)
+ predictions_dict = {
+ 'preprocessed_inputs': preprocessed_inputs,
+ 'box_encodings': box_encodings,
+ 'class_predictions_with_background': class_predictions_with_background,
+ 'feature_maps': feature_maps,
+ 'anchors': self._anchors.get(),
+ 'states_and_outputs': self._feature_extractor.states_and_outputs,
+ }
+ # In cases such as exporting the model, the states is always zero. Thus the
+ # step should be ignored.
+ if states is not None:
+ predictions_dict['step'] = self._feature_extractor.step
+ return predictions_dict
+
+ def loss(self, prediction_dict, true_image_shapes, scope=None):
+ """Computes scalar loss tensors with respect to provided groundtruth.
+
+ Calling this function requires that groundtruth tensors have been
+ provided via the provide_groundtruth function.
+
+ Args:
+ prediction_dict: a dictionary holding prediction tensors with
+ 1) box_encodings: 3-D float tensor of shape [batch_size, num_anchors,
+ box_code_dimension] containing predicted boxes.
+ 2) class_predictions_with_background: 3-D float tensor of shape
+ [batch_size, num_anchors, num_classes+1] containing class predictions
+ (logits) for each of the anchors. Note that this tensor *includes*
+ background class predictions.
+ true_image_shapes: int32 tensor of shape [batch, 3] where each row is
+ of the form [height, width, channels] indicating the shapes
+ of true images in the resized images, as resized images can be padded
+ with zeros.
+ scope: Optional scope name.
+
+ Returns:
+ a dictionary mapping loss keys (`localization_loss` and
+ `classification_loss`) to scalar tensors representing corresponding loss
+ values.
+ """
+ with tf.name_scope(scope, 'Loss', prediction_dict.values()):
+ keypoints = None
+ if self.groundtruth_has_field(fields.BoxListFields.keypoints):
+ keypoints = self.groundtruth_lists(fields.BoxListFields.keypoints)
+ weights = None
+ if self.groundtruth_has_field(fields.BoxListFields.weights):
+ weights = self.groundtruth_lists(fields.BoxListFields.weights)
+ (batch_cls_targets, batch_cls_weights, batch_reg_targets,
+ batch_reg_weights, batch_match) = self._assign_targets(
+ self.groundtruth_lists(fields.BoxListFields.boxes),
+ self.groundtruth_lists(fields.BoxListFields.classes),
+ keypoints, weights)
+ match_list = [matcher.Match(match) for match in tf.unstack(batch_match)]
+ if self._add_summaries:
+ self._summarize_target_assignment(
+ self.groundtruth_lists(fields.BoxListFields.boxes), match_list)
+ location_losses = self._localization_loss(
+ prediction_dict['box_encodings'],
+ batch_reg_targets,
+ ignore_nan_targets=True,
+ weights=batch_reg_weights)
+ cls_losses = ops.reduce_sum_trailing_dimensions(
+ self._classification_loss(
+ prediction_dict['class_predictions_with_background'],
+ batch_cls_targets,
+ weights=batch_cls_weights),
+ ndims=2)
+
+ if self._hard_example_miner:
+ (loc_loss_list, cls_loss_list) = self._apply_hard_mining(
+ location_losses, cls_losses, prediction_dict, match_list)
+ localization_loss = tf.reduce_sum(tf.stack(loc_loss_list))
+ classification_loss = tf.reduce_sum(tf.stack(cls_loss_list))
+
+ if self._add_summaries:
+ self._hard_example_miner.summarize()
+ else:
+ if self._add_summaries:
+ class_ids = tf.argmax(batch_cls_targets, axis=2)
+ flattened_class_ids = tf.reshape(class_ids, [-1])
+ flattened_classification_losses = tf.reshape(cls_losses, [-1])
+ self._summarize_anchor_classification_loss(
+ flattened_class_ids, flattened_classification_losses)
+ localization_loss = tf.reduce_sum(location_losses)
+ classification_loss = tf.reduce_sum(cls_losses)
+
+ # Optionally normalize by number of positive matches
+ normalizer = tf.constant(1.0, dtype=tf.float32)
+ if self._normalize_loss_by_num_matches:
+ normalizer = tf.maximum(tf.to_float(tf.reduce_sum(batch_reg_weights)),
+ 1.0)
+
+ with tf.name_scope('localization_loss'):
+ localization_loss_normalizer = normalizer
+ if self._normalize_loc_loss_by_codesize:
+ localization_loss_normalizer *= self._box_coder.code_size
+ localization_loss = ((self._localization_loss_weight / (
+ localization_loss_normalizer)) * localization_loss)
+ with tf.name_scope('classification_loss'):
+ classification_loss = ((self._classification_loss_weight / normalizer) *
+ classification_loss)
+
+ loss_dict = {
+ 'localization_loss': localization_loss,
+ 'classification_loss': classification_loss
+ }
+ return loss_dict
+
+ def restore_map(self, fine_tune_checkpoint_type='lstm'):
+ """Returns a map of variables to load from a foreign checkpoint.
+
+ See parent class for details.
+
+ Args:
+ fine_tune_checkpoint_type: the type of checkpoint to restore from, either
+ SSD/LSTM detection checkpoint (with compatible variable names)
+ classification checkpoint for initialization prior to training.
+ Available options: `classification`, `detection`, `interleaved`,
+ and `lstm`.
+
+ Returns:
+ A dict mapping variable names (to load from a checkpoint) to variables in
+ the model graph.
+ Raises:
+ ValueError: if fine_tune_checkpoint_type is not among
+ `classification`/`detection`/`interleaved`/`lstm`.
+ """
+ if fine_tune_checkpoint_type not in [
+ 'classification', 'detection', 'interleaved', 'lstm',
+ 'interleaved_pretrain'
+ ]:
+ raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
+ fine_tune_checkpoint_type))
+
+ self._restored_networks += 1
+ base_network_scope = self.get_base_network_scope()
+ if base_network_scope:
+ scope_to_replace = '{0}_{1}'.format(base_network_scope,
+ self._restored_networks)
+
+ interleaved_model = False
+ for variable in tf.global_variables():
+ if scope_to_replace in variable.op.name:
+ interleaved_model = True
+ break
+
+ variables_to_restore = {}
+ for variable in tf.global_variables():
+ var_name = variable.op.name
+ if 'global_step' in var_name:
+ continue
+
+ # Remove FeatureExtractor prefix for classification checkpoints.
+ if (fine_tune_checkpoint_type == 'classification' or
+ fine_tune_checkpoint_type == 'interleaved_pretrain'):
+ var_name = (
+ re.split('^' + self._extract_features_scope + '/', var_name)[-1])
+
+ # When loading from single frame detection checkpoints, we need to
+ # remap FeatureMaps variable names.
+ if ('FeatureMaps' in var_name and
+ fine_tune_checkpoint_type == 'detection'):
+ var_name = var_name.replace('FeatureMaps',
+ self.get_base_network_scope())
+
+ # Load interleaved checkpoint specifically.
+ if interleaved_model: # Interleaved LSTD.
+ if 'interleaved' in fine_tune_checkpoint_type:
+ variables_to_restore[var_name] = variable
+ else:
+ # Restore non-base layers from the first checkpoint only.
+ if self._restored_networks == 1:
+ if base_network_scope + '_' not in var_name: # LSTM and FeatureMap
+ variables_to_restore[var_name] = variable
+ if scope_to_replace in var_name:
+ var_name = var_name.replace(scope_to_replace, base_network_scope)
+ variables_to_restore[var_name] = variable
+ else:
+ # Restore from the first model of interleaved checkpoints
+ if 'interleaved' in fine_tune_checkpoint_type:
+ var_name = var_name.replace(self.get_base_network_scope(),
+ self.get_base_network_scope() + '_1', 1)
+
+ variables_to_restore[var_name] = variable
+
+ return variables_to_restore
+
+ def get_base_network_scope(self):
+ """Returns the variable scope of the base network.
+
+ Returns:
+ The variable scope of the feature extractor base network, e.g. MobilenetV1
+ """
+ return self._feature_extractor.get_base_network_scope()
+
+
+class LSTMSSDFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
+ """LSTM SSD Meta-architecture Feature Extractor definition."""
+
+ __metaclass__ = abc.ABCMeta
+
+ @property
+ def clip_state(self):
+ return self._clip_state
+
+ @clip_state.setter
+ def clip_state(self, clip_state):
+ self._clip_state = clip_state
+
+ @property
+ def depth_multipliers(self):
+ return self._depth_multipliers
+
+ @depth_multipliers.setter
+ def depth_multipliers(self, depth_multipliers):
+ self._depth_multipliers = depth_multipliers
+
+ @property
+ def lstm_state_depth(self):
+ return self._lstm_state_depth
+
+ @lstm_state_depth.setter
+ def lstm_state_depth(self, lstm_state_depth):
+ self._lstm_state_depth = lstm_state_depth
+
+ @property
+ def is_quantized(self):
+ return self._is_quantized
+
+ @is_quantized.setter
+ def is_quantized(self, is_quantized):
+ self._is_quantized = is_quantized
+
+ @property
+ def interleaved(self):
+ return False
+
+ @property
+ def states_and_outputs(self):
+ """LSTM states and outputs.
+
+ This variable includes both LSTM states {C_t} and outputs {h_t}.
+
+ Returns:
+ states_and_outputs: A list of 4-D float tensors, including the lstm state
+ and output at each timestep.
+ """
+ return self._states_out
+
+ @property
+ def step(self):
+ return self._step
+
+ def preprocess(self, resized_inputs):
+ """SSD preprocessing.
+
+ Maps pixel values to the range [-1, 1].
+
+ Args:
+ resized_inputs: a [batch, height, width, channels] float tensor
+ representing a batch of images.
+
+ Returns:
+ preprocessed_inputs: a [batch, height, width, channels] float tensor
+ representing a batch of images.
+ """
+ return (2.0 / 255.0) * resized_inputs - 1.0
+
+ def get_base_network_scope(self):
+ """Returns the variable scope of the base network.
+
+ Returns:
+ The variable scope of the base network, e.g. MobilenetV1
+ """
+ return self._base_network_scope
+
+ @abc.abstractmethod
+ def create_lstm_cell(self, batch_size, output_size, state_saver, state_name):
+ """Create the LSTM cell, and initialize state if necessary.
+
+ Args:
+ batch_size: input batch size.
+ output_size: output size of the lstm cell, [width, height].
+ state_saver: a state saver object with methods `state` and `save_state`.
+ state_name: string, the name to use with the state_saver.
+ Returns:
+ lstm_cell: the lstm cell unit.
+ init_state: initial state representations.
+ step: the step
+ """
+ pass
+
+
+class LSTMSSDInterleavedFeatureExtractor(LSTMSSDFeatureExtractor):
+ """LSTM SSD Meta-architecture Interleaved Feature Extractor definition."""
+
+ __metaclass__ = abc.ABCMeta
+
+ @property
+ def pre_bottleneck(self):
+ return self._pre_bottleneck
+
+ @pre_bottleneck.setter
+ def pre_bottleneck(self, pre_bottleneck):
+ self._pre_bottleneck = pre_bottleneck
+
+ @property
+ def low_res(self):
+ return self._low_res
+
+ @low_res.setter
+ def low_res(self, low_res):
+ self._low_res = low_res
+
+ @property
+ def interleaved(self):
+ return True
+
+ @property
+ def interleave_method(self):
+ return self._interleave_method
+
+ @interleave_method.setter
+ def interleave_method(self, interleave_method):
+ self._interleave_method = interleave_method
+
+ @abc.abstractmethod
+ def extract_base_features_large(self, preprocessed_inputs):
+ """Extract the large base model features.
+
+ Args:
+ preprocessed_inputs: preprocessed input images of shape:
+ [batch, width, height, depth].
+
+ Returns:
+ net: the last feature map created from the base feature extractor.
+ end_points: a dictionary of feature maps created.
+ """
+ pass
+
+ @abc.abstractmethod
+ def extract_base_features_small(self, preprocessed_inputs):
+ """Extract the small base model features.
+
+ Args:
+ preprocessed_inputs: preprocessed input images of shape:
+ [batch, width, height, depth].
+
+ Returns:
+ net: the last feature map created from the base feature extractor.
+ end_points: a dictionary of feature maps created.
+ """
+ pass
diff --git a/models/research/lstm_object_detection/meta_architectures/lstm_ssd_meta_arch_test.py b/models/research/lstm_object_detection/meta_architectures/lstm_ssd_meta_arch_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..03e8a1274603806c19bc36ad09022c9b4d6ca91b
--- /dev/null
+++ b/models/research/lstm_object_detection/meta_architectures/lstm_ssd_meta_arch_test.py
@@ -0,0 +1,320 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for meta_architectures.lstm_ssd_meta_arch."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+
+from lstm_object_detection.lstm import lstm_cells
+from lstm_object_detection.meta_architectures import lstm_ssd_meta_arch
+from object_detection.core import anchor_generator
+from object_detection.core import box_list
+from object_detection.core import losses
+from object_detection.core import post_processing
+from object_detection.core import region_similarity_calculator as sim_calc
+from object_detection.core import standard_fields as fields
+from object_detection.core import target_assigner
+from object_detection.models import feature_map_generators
+from object_detection.utils import test_case
+from object_detection.utils import test_utils
+
+
+MAX_TOTAL_NUM_BOXES = 5
+NUM_CLASSES = 1
+
+
+class FakeLSTMFeatureExtractor(
+ lstm_ssd_meta_arch.LSTMSSDFeatureExtractor):
+
+ def __init__(self):
+ super(FakeLSTMFeatureExtractor, self).__init__(
+ is_training=True,
+ depth_multiplier=1.0,
+ min_depth=0,
+ pad_to_multiple=1,
+ conv_hyperparams_fn=self.scope_fn)
+ self._lstm_state_depth = 256
+
+ def scope_fn(self):
+ with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu6) as sc:
+ return sc
+
+ def create_lstm_cell(self):
+ pass
+
+ def extract_features(self, preprocessed_inputs, state_saver=None,
+ state_name='lstm_state', unroll_length=5, scope=None):
+ with tf.variable_scope('mock_model'):
+ net = slim.conv2d(inputs=preprocessed_inputs, num_outputs=32,
+ kernel_size=1, scope='layer1')
+ image_features = {'last_layer': net}
+
+ self._states_out = {}
+ feature_map_layout = {
+ 'from_layer': ['last_layer'],
+ 'layer_depth': [-1],
+ 'use_explicit_padding': self._use_explicit_padding,
+ 'use_depthwise': self._use_depthwise,
+ }
+ feature_maps = feature_map_generators.multi_resolution_feature_maps(
+ feature_map_layout=feature_map_layout,
+ depth_multiplier=(self._depth_multiplier),
+ min_depth=self._min_depth,
+ insert_1x1_conv=True,
+ image_features=image_features)
+ return list(feature_maps.values())
+
+
+class FakeLSTMInterleavedFeatureExtractor(
+ lstm_ssd_meta_arch.LSTMSSDInterleavedFeatureExtractor):
+
+ def __init__(self):
+ super(FakeLSTMInterleavedFeatureExtractor, self).__init__(
+ is_training=True,
+ depth_multiplier=1.0,
+ min_depth=0,
+ pad_to_multiple=1,
+ conv_hyperparams_fn=self.scope_fn)
+ self._lstm_state_depth = 256
+
+ def scope_fn(self):
+ with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu6) as sc:
+ return sc
+
+ def create_lstm_cell(self):
+ pass
+
+ def extract_base_features_large(self, preprocessed_inputs):
+ with tf.variable_scope('base_large'):
+ net = slim.conv2d(inputs=preprocessed_inputs, num_outputs=32,
+ kernel_size=1, scope='layer1')
+ return net
+
+ def extract_base_features_small(self, preprocessed_inputs):
+ with tf.variable_scope('base_small'):
+ net = slim.conv2d(inputs=preprocessed_inputs, num_outputs=32,
+ kernel_size=1, scope='layer1')
+ return net
+
+ def extract_features(self, preprocessed_inputs, state_saver=None,
+ state_name='lstm_state', unroll_length=5, scope=None):
+ with tf.variable_scope('mock_model'):
+ net_large = self.extract_base_features_large(preprocessed_inputs)
+ net_small = self.extract_base_features_small(preprocessed_inputs)
+ net = slim.conv2d(
+ inputs=tf.concat([net_large, net_small], axis=3),
+ num_outputs=32,
+ kernel_size=1,
+ scope='layer1')
+ image_features = {'last_layer': net}
+
+ self._states_out = {}
+ feature_map_layout = {
+ 'from_layer': ['last_layer'],
+ 'layer_depth': [-1],
+ 'use_explicit_padding': self._use_explicit_padding,
+ 'use_depthwise': self._use_depthwise,
+ }
+ feature_maps = feature_map_generators.multi_resolution_feature_maps(
+ feature_map_layout=feature_map_layout,
+ depth_multiplier=(self._depth_multiplier),
+ min_depth=self._min_depth,
+ insert_1x1_conv=True,
+ image_features=image_features)
+ return list(feature_maps.values())
+
+
+class MockAnchorGenerator2x2(anchor_generator.AnchorGenerator):
+ """Sets up a simple 2x2 anchor grid on the unit square."""
+
+ def name_scope(self):
+ return 'MockAnchorGenerator'
+
+ def num_anchors_per_location(self):
+ return [1]
+
+ def _generate(self, feature_map_shape_list, im_height, im_width):
+ return [box_list.BoxList(
+ tf.constant([[0, 0, .5, .5],
+ [0, .5, .5, 1],
+ [.5, 0, 1, .5],
+ [1., 1., 1.5, 1.5] # Anchor that is outside clip_window.
+ ], tf.float32))]
+
+ def num_anchors(self):
+ return 4
+
+
+class LSTMSSDMetaArchTest(test_case.TestCase):
+
+ def _create_model(self,
+ interleaved=False,
+ apply_hard_mining=True,
+ normalize_loc_loss_by_codesize=False,
+ add_background_class=True,
+ random_example_sampling=False,
+ use_expected_classification_loss_under_sampling=False,
+ min_num_negative_samples=1,
+ desired_negative_sampling_ratio=3,
+ unroll_length=1):
+ num_classes = NUM_CLASSES
+ is_training = False
+ mock_anchor_generator = MockAnchorGenerator2x2()
+ mock_box_predictor = test_utils.MockBoxPredictor(is_training, num_classes)
+ mock_box_coder = test_utils.MockBoxCoder()
+ if interleaved:
+ fake_feature_extractor = FakeLSTMInterleavedFeatureExtractor()
+ else:
+ fake_feature_extractor = FakeLSTMFeatureExtractor()
+ mock_matcher = test_utils.MockMatcher()
+ region_similarity_calculator = sim_calc.IouSimilarity()
+ encode_background_as_zeros = False
+ def image_resizer_fn(image):
+ return [tf.identity(image), tf.shape(image)]
+
+ classification_loss = losses.WeightedSigmoidClassificationLoss()
+ localization_loss = losses.WeightedSmoothL1LocalizationLoss()
+ non_max_suppression_fn = functools.partial(
+ post_processing.batch_multiclass_non_max_suppression,
+ score_thresh=-20.0,
+ iou_thresh=1.0,
+ max_size_per_class=5,
+ max_total_size=MAX_TOTAL_NUM_BOXES)
+ classification_loss_weight = 1.0
+ localization_loss_weight = 1.0
+ negative_class_weight = 1.0
+ normalize_loss_by_num_matches = False
+
+ hard_example_miner = None
+ if apply_hard_mining:
+ # This hard example miner is expected to be a no-op.
+ hard_example_miner = losses.HardExampleMiner(
+ num_hard_examples=None,
+ iou_threshold=1.0)
+
+ target_assigner_instance = target_assigner.TargetAssigner(
+ region_similarity_calculator,
+ mock_matcher,
+ mock_box_coder,
+ negative_class_weight=negative_class_weight)
+
+ code_size = 4
+ model = lstm_ssd_meta_arch.LSTMSSDMetaArch(
+ is_training=is_training,
+ anchor_generator=mock_anchor_generator,
+ box_predictor=mock_box_predictor,
+ box_coder=mock_box_coder,
+ feature_extractor=fake_feature_extractor,
+ encode_background_as_zeros=encode_background_as_zeros,
+ image_resizer_fn=image_resizer_fn,
+ non_max_suppression_fn=non_max_suppression_fn,
+ score_conversion_fn=tf.identity,
+ classification_loss=classification_loss,
+ localization_loss=localization_loss,
+ classification_loss_weight=classification_loss_weight,
+ localization_loss_weight=localization_loss_weight,
+ normalize_loss_by_num_matches=normalize_loss_by_num_matches,
+ hard_example_miner=hard_example_miner,
+ unroll_length=unroll_length,
+ target_assigner_instance=target_assigner_instance,
+ add_summaries=False)
+ return model, num_classes, mock_anchor_generator.num_anchors(), code_size
+
+ def _get_value_for_matching_key(self, dictionary, suffix):
+ for key in dictionary.keys():
+ if key.endswith(suffix):
+ return dictionary[key]
+ raise ValueError('key not found {}'.format(suffix))
+
+ def test_predict_returns_correct_items_and_sizes(self):
+ batch_size = 3
+ height = width = 2
+ num_unroll = 1
+
+ graph = tf.Graph()
+ with graph.as_default():
+ model, num_classes, num_anchors, code_size = self._create_model()
+ preprocessed_images = tf.random_uniform(
+ [batch_size * num_unroll, height, width, 3],
+ minval=-1.,
+ maxval=1.)
+ true_image_shapes = tf.tile(
+ [[height, width, 3]], [batch_size, 1])
+ prediction_dict = model.predict(preprocessed_images, true_image_shapes)
+
+
+ self.assertIn('preprocessed_inputs', prediction_dict)
+ self.assertIn('box_encodings', prediction_dict)
+ self.assertIn('class_predictions_with_background', prediction_dict)
+ self.assertIn('feature_maps', prediction_dict)
+ self.assertIn('anchors', prediction_dict)
+ self.assertAllEqual(
+ [batch_size * num_unroll, height, width, 3],
+ prediction_dict['preprocessed_inputs'].shape.as_list())
+ self.assertAllEqual(
+ [batch_size * num_unroll, num_anchors, code_size],
+ prediction_dict['box_encodings'].shape.as_list())
+ self.assertAllEqual(
+ [batch_size * num_unroll, num_anchors, num_classes + 1],
+ prediction_dict['class_predictions_with_background'].shape.as_list())
+ self.assertAllEqual(
+ [num_anchors, code_size],
+ prediction_dict['anchors'].shape.as_list())
+
+ def test_interleaved_predict_returns_correct_items_and_sizes(self):
+ batch_size = 3
+ height = width = 2
+ num_unroll = 1
+
+ graph = tf.Graph()
+ with graph.as_default():
+ model, num_classes, num_anchors, code_size = self._create_model(
+ interleaved=True)
+ preprocessed_images = tf.random_uniform(
+ [batch_size * num_unroll, height, width, 3],
+ minval=-1.,
+ maxval=1.)
+ true_image_shapes = tf.tile(
+ [[height, width, 3]], [batch_size, 1])
+ prediction_dict = model.predict(preprocessed_images, true_image_shapes)
+
+ self.assertIn('preprocessed_inputs', prediction_dict)
+ self.assertIn('box_encodings', prediction_dict)
+ self.assertIn('class_predictions_with_background', prediction_dict)
+ self.assertIn('feature_maps', prediction_dict)
+ self.assertIn('anchors', prediction_dict)
+ self.assertAllEqual(
+ [batch_size * num_unroll, height, width, 3],
+ prediction_dict['preprocessed_inputs'].shape.as_list())
+ self.assertAllEqual(
+ [batch_size * num_unroll, num_anchors, code_size],
+ prediction_dict['box_encodings'].shape.as_list())
+ self.assertAllEqual(
+ [batch_size * num_unroll, num_anchors, num_classes + 1],
+ prediction_dict['class_predictions_with_background'].shape.as_list())
+ self.assertAllEqual(
+ [num_anchors, code_size],
+ prediction_dict['anchors'].shape.as_list())
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/metrics/__init__.py b/models/research/lstm_object_detection/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/lstm_object_detection/metrics/coco_evaluation_all_frames.py b/models/research/lstm_object_detection/metrics/coco_evaluation_all_frames.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e6d336cbf71ecfdf5f438b6f74e078db1a6fb17
--- /dev/null
+++ b/models/research/lstm_object_detection/metrics/coco_evaluation_all_frames.py
@@ -0,0 +1,124 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Class for evaluating video object detections with COCO metrics."""
+
+import tensorflow.compat.v1 as tf
+
+from object_detection.core import standard_fields
+from object_detection.metrics import coco_evaluation
+from object_detection.metrics import coco_tools
+
+
+class CocoEvaluationAllFrames(coco_evaluation.CocoDetectionEvaluator):
+ """Class to evaluate COCO detection metrics for frame sequences.
+
+ The class overrides two functions: add_single_ground_truth_image_info and
+ add_single_detected_image_info.
+
+ For the evaluation of sequence video detection, by iterating through the
+ entire groundtruth_dict, all the frames in the unrolled frames in one LSTM
+ training sample are considered. Therefore, both groundtruth and detection
+ results of all frames are added for the evaluation. This is used when all the
+ frames are labeled in the video object detection training job.
+ """
+
+ def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
+ """Add groundtruth results of all frames to the eval pipeline.
+
+ This method overrides the function defined in the base class.
+
+ Args:
+ image_id: A unique string/integer identifier for the image.
+ groundtruth_dict: A list of dictionary containing -
+ InputDataFields.groundtruth_boxes: float32 numpy array of shape
+ [num_boxes, 4] containing `num_boxes` groundtruth boxes of the format
+ [ymin, xmin, ymax, xmax] in absolute image coordinates.
+ InputDataFields.groundtruth_classes: integer numpy array of shape
+ [num_boxes] containing 1-indexed groundtruth classes for the boxes.
+ InputDataFields.groundtruth_is_crowd (optional): integer numpy array of
+ shape [num_boxes] containing iscrowd flag for groundtruth boxes.
+ """
+ for idx, gt in enumerate(groundtruth_dict):
+ if not gt:
+ continue
+
+ image_frame_id = '{}_{}'.format(image_id, idx)
+ if image_frame_id in self._image_ids:
+ tf.logging.warning(
+ 'Ignoring ground truth with image id %s since it was '
+ 'previously added', image_frame_id)
+ continue
+
+ self._groundtruth_list.extend(
+ coco_tools.ExportSingleImageGroundtruthToCoco(
+ image_id=image_frame_id,
+ next_annotation_id=self._annotation_id,
+ category_id_set=self._category_id_set,
+ groundtruth_boxes=gt[
+ standard_fields.InputDataFields.groundtruth_boxes],
+ groundtruth_classes=gt[
+ standard_fields.InputDataFields.groundtruth_classes]))
+ self._annotation_id += (
+ gt[standard_fields.InputDataFields.groundtruth_boxes].shape[0])
+
+ # Boolean to indicate whether a detection has been added for this image.
+ self._image_ids[image_frame_id] = False
+
+ def add_single_detected_image_info(self, image_id, detections_dict):
+ """Add detection results of all frames to the eval pipeline.
+
+ This method overrides the function defined in the base class.
+
+ Args:
+ image_id: A unique string/integer identifier for the image.
+ detections_dict: A list of dictionary containing -
+ DetectionResultFields.detection_boxes: float32 numpy array of shape
+ [num_boxes, 4] containing `num_boxes` detection boxes of the format
+ [ymin, xmin, ymax, xmax] in absolute image coordinates.
+ DetectionResultFields.detection_scores: float32 numpy array of shape
+ [num_boxes] containing detection scores for the boxes.
+ DetectionResultFields.detection_classes: integer numpy array of shape
+ [num_boxes] containing 1-indexed detection classes for the boxes.
+
+ Raises:
+ ValueError: If groundtruth for the image_id is not available.
+ """
+ for idx, det in enumerate(detections_dict):
+ if not det:
+ continue
+
+ image_frame_id = '{}_{}'.format(image_id, idx)
+ if image_frame_id not in self._image_ids:
+ raise ValueError(
+ 'Missing groundtruth for image-frame id: {}'.format(image_frame_id))
+
+ if self._image_ids[image_frame_id]:
+ tf.logging.warning(
+ 'Ignoring detection with image id %s since it was '
+ 'previously added', image_frame_id)
+ continue
+
+ self._detection_boxes_list.extend(
+ coco_tools.ExportSingleImageDetectionBoxesToCoco(
+ image_id=image_frame_id,
+ category_id_set=self._category_id_set,
+ detection_boxes=det[
+ standard_fields.DetectionResultFields.detection_boxes],
+ detection_scores=det[
+ standard_fields.DetectionResultFields.detection_scores],
+ detection_classes=det[
+ standard_fields.DetectionResultFields.detection_classes]))
+ self._image_ids[image_frame_id] = True
diff --git a/models/research/lstm_object_detection/metrics/coco_evaluation_all_frames_test.py b/models/research/lstm_object_detection/metrics/coco_evaluation_all_frames_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c1e7b7546b037d974bde9e3dadef94d7535235b
--- /dev/null
+++ b/models/research/lstm_object_detection/metrics/coco_evaluation_all_frames_test.py
@@ -0,0 +1,156 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for video_object_detection.metrics.coco_video_evaluation."""
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+from lstm_object_detection.metrics import coco_evaluation_all_frames
+from object_detection.core import standard_fields
+
+
+class CocoEvaluationAllFramesTest(tf.test.TestCase):
+
+ def testGroundtruthAndDetectionsDisagreeOnAllFrames(self):
+ """Tests that mAP is calculated on several different frame results."""
+ category_list = [{'id': 0, 'name': 'dog'}, {'id': 1, 'name': 'cat'}]
+ video_evaluator = coco_evaluation_all_frames.CocoEvaluationAllFrames(
+ category_list)
+ video_evaluator.add_single_ground_truth_image_info(
+ image_id='image1',
+ groundtruth_dict=[{
+ standard_fields.InputDataFields.groundtruth_boxes:
+ np.array([[50., 50., 200., 200.]]),
+ standard_fields.InputDataFields.groundtruth_classes:
+ np.array([1])
+ }, {
+ standard_fields.InputDataFields.groundtruth_boxes:
+ np.array([[50., 50., 100., 100.]]),
+ standard_fields.InputDataFields.groundtruth_classes:
+ np.array([1])
+ }])
+ video_evaluator.add_single_detected_image_info(
+ image_id='image1',
+ # A different groundtruth box on the frame other than the last one.
+ detections_dict=[{
+ standard_fields.DetectionResultFields.detection_boxes:
+ np.array([[100., 100., 200., 200.]]),
+ standard_fields.DetectionResultFields.detection_scores:
+ np.array([.8]),
+ standard_fields.DetectionResultFields.detection_classes:
+ np.array([1])
+ }, {
+ standard_fields.DetectionResultFields.detection_boxes:
+ np.array([[50., 50., 100., 100.]]),
+ standard_fields.DetectionResultFields.detection_scores:
+ np.array([.8]),
+ standard_fields.DetectionResultFields.detection_classes:
+ np.array([1])
+ }])
+
+ metrics = video_evaluator.evaluate()
+ self.assertNotEqual(metrics['DetectionBoxes_Precision/mAP'], 1.0)
+
+ def testGroundtruthAndDetections(self):
+ """Tests that mAP is calculated correctly on GT and Detections."""
+ category_list = [{'id': 0, 'name': 'dog'}, {'id': 1, 'name': 'cat'}]
+ video_evaluator = coco_evaluation_all_frames.CocoEvaluationAllFrames(
+ category_list)
+ video_evaluator.add_single_ground_truth_image_info(
+ image_id='image1',
+ groundtruth_dict=[{
+ standard_fields.InputDataFields.groundtruth_boxes:
+ np.array([[100., 100., 200., 200.]]),
+ standard_fields.InputDataFields.groundtruth_classes:
+ np.array([1])
+ }])
+ video_evaluator.add_single_ground_truth_image_info(
+ image_id='image2',
+ groundtruth_dict=[{
+ standard_fields.InputDataFields.groundtruth_boxes:
+ np.array([[50., 50., 100., 100.]]),
+ standard_fields.InputDataFields.groundtruth_classes:
+ np.array([1])
+ }])
+ video_evaluator.add_single_ground_truth_image_info(
+ image_id='image3',
+ groundtruth_dict=[{
+ standard_fields.InputDataFields.groundtruth_boxes:
+ np.array([[50., 100., 100., 120.]]),
+ standard_fields.InputDataFields.groundtruth_classes:
+ np.array([1])
+ }])
+ video_evaluator.add_single_detected_image_info(
+ image_id='image1',
+ detections_dict=[{
+ standard_fields.DetectionResultFields.detection_boxes:
+ np.array([[100., 100., 200., 200.]]),
+ standard_fields.DetectionResultFields.detection_scores:
+ np.array([.8]),
+ standard_fields.DetectionResultFields.detection_classes:
+ np.array([1])
+ }])
+ video_evaluator.add_single_detected_image_info(
+ image_id='image2',
+ detections_dict=[{
+ standard_fields.DetectionResultFields.detection_boxes:
+ np.array([[50., 50., 100., 100.]]),
+ standard_fields.DetectionResultFields.detection_scores:
+ np.array([.8]),
+ standard_fields.DetectionResultFields.detection_classes:
+ np.array([1])
+ }])
+ video_evaluator.add_single_detected_image_info(
+ image_id='image3',
+ detections_dict=[{
+ standard_fields.DetectionResultFields.detection_boxes:
+ np.array([[50., 100., 100., 120.]]),
+ standard_fields.DetectionResultFields.detection_scores:
+ np.array([.8]),
+ standard_fields.DetectionResultFields.detection_classes:
+ np.array([1])
+ }])
+ metrics = video_evaluator.evaluate()
+ self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP'], 1.0)
+
+ def testMissingDetectionResults(self):
+ """Tests if groundtrue is missing, raises ValueError."""
+ category_list = [{'id': 0, 'name': 'dog'}]
+ video_evaluator = coco_evaluation_all_frames.CocoEvaluationAllFrames(
+ category_list)
+ video_evaluator.add_single_ground_truth_image_info(
+ image_id='image1',
+ groundtruth_dict=[{
+ standard_fields.InputDataFields.groundtruth_boxes:
+ np.array([[100., 100., 200., 200.]]),
+ standard_fields.InputDataFields.groundtruth_classes:
+ np.array([1])
+ }])
+ with self.assertRaisesRegexp(ValueError,
+ r'Missing groundtruth for image-frame id:.*'):
+ video_evaluator.add_single_detected_image_info(
+ image_id='image3',
+ detections_dict=[{
+ standard_fields.DetectionResultFields.detection_boxes:
+ np.array([[100., 100., 200., 200.]]),
+ standard_fields.DetectionResultFields.detection_scores:
+ np.array([.8]),
+ standard_fields.DetectionResultFields.detection_classes:
+ np.array([1])
+ }])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/model_builder.py b/models/research/lstm_object_detection/model_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d622558cf75f6664f9a1b075e3ed690caf457f68
--- /dev/null
+++ b/models/research/lstm_object_detection/model_builder.py
@@ -0,0 +1,192 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A function to build a DetectionModel from configuration."""
+from lstm_object_detection.meta_architectures import lstm_ssd_meta_arch
+from lstm_object_detection.models import lstm_ssd_interleaved_mobilenet_v2_feature_extractor
+from lstm_object_detection.models import lstm_ssd_mobilenet_v1_feature_extractor
+from object_detection.builders import anchor_generator_builder
+from object_detection.builders import box_coder_builder
+from object_detection.builders import box_predictor_builder
+from object_detection.builders import hyperparams_builder
+from object_detection.builders import image_resizer_builder
+from object_detection.builders import losses_builder
+from object_detection.builders import matcher_builder
+from object_detection.builders import model_builder
+from object_detection.builders import post_processing_builder
+from object_detection.builders import region_similarity_calculator_builder as sim_calc
+from object_detection.core import target_assigner
+
+model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP.update({
+ 'lstm_ssd_mobilenet_v1':
+ lstm_ssd_mobilenet_v1_feature_extractor
+ .LSTMSSDMobileNetV1FeatureExtractor,
+ 'lstm_ssd_interleaved_mobilenet_v2':
+ lstm_ssd_interleaved_mobilenet_v2_feature_extractor
+ .LSTMSSDInterleavedMobilenetV2FeatureExtractor,
+})
+SSD_FEATURE_EXTRACTOR_CLASS_MAP = model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP
+
+
+def build(model_config, lstm_config, is_training):
+ """Builds a DetectionModel based on the model config.
+
+ Args:
+ model_config: A model.proto object containing the config for the desired
+ DetectionModel.
+ lstm_config: LstmModel config proto that specifies LSTM train/eval configs.
+ is_training: True if this model is being built for training purposes.
+
+ Returns:
+ DetectionModel based on the config.
+
+ Raises:
+ ValueError: On invalid meta architecture or model.
+ """
+ return _build_lstm_model(model_config.ssd, lstm_config, is_training)
+
+
+def _build_lstm_feature_extractor(feature_extractor_config,
+ is_training,
+ lstm_config,
+ reuse_weights=None):
+ """Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
+
+ Args:
+ feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
+ is_training: True if this feature extractor is being built for training.
+ lstm_config: LSTM-SSD specific configs.
+ reuse_weights: If the feature extractor should reuse weights.
+
+ Returns:
+ ssd_meta_arch.SSDFeatureExtractor based on config.
+
+ Raises:
+ ValueError: On invalid feature extractor type.
+ """
+
+ feature_type = feature_extractor_config.type
+ depth_multiplier = feature_extractor_config.depth_multiplier
+ min_depth = feature_extractor_config.min_depth
+ pad_to_multiple = feature_extractor_config.pad_to_multiple
+ use_explicit_padding = feature_extractor_config.use_explicit_padding
+ use_depthwise = feature_extractor_config.use_depthwise
+ conv_hyperparams = hyperparams_builder.build(
+ feature_extractor_config.conv_hyperparams, is_training)
+ override_base_feature_extractor_hyperparams = (
+ feature_extractor_config.override_base_feature_extractor_hyperparams)
+
+ if feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP:
+ raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))
+
+ feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
+ feature_extractor = feature_extractor_class(
+ is_training, depth_multiplier, min_depth, pad_to_multiple,
+ conv_hyperparams, reuse_weights, use_explicit_padding, use_depthwise,
+ override_base_feature_extractor_hyperparams)
+
+ # Extra configs for LSTM-SSD.
+ feature_extractor.lstm_state_depth = lstm_config.lstm_state_depth
+ feature_extractor.flatten_state = lstm_config.flatten_state
+ feature_extractor.clip_state = lstm_config.clip_state
+ feature_extractor.scale_state = lstm_config.scale_state
+ feature_extractor.is_quantized = lstm_config.is_quantized
+ feature_extractor.low_res = lstm_config.low_res
+ # Extra configs for interleaved LSTM-SSD.
+ if 'interleaved' in feature_extractor_config.type:
+ feature_extractor.pre_bottleneck = lstm_config.pre_bottleneck
+ feature_extractor.depth_multipliers = lstm_config.depth_multipliers
+ if is_training:
+ feature_extractor.interleave_method = lstm_config.train_interleave_method
+ else:
+ feature_extractor.interleave_method = lstm_config.eval_interleave_method
+ return feature_extractor
+
+
+def _build_lstm_model(ssd_config, lstm_config, is_training):
+ """Builds an LSTM detection model based on the model config.
+
+ Args:
+ ssd_config: A ssd.proto object containing the config for the desired
+ LSTMSSDMetaArch.
+ lstm_config: LstmModel config proto that specifies LSTM train/eval configs.
+ is_training: True if this model is being built for training purposes.
+
+ Returns:
+ LSTMSSDMetaArch based on the config.
+ Raises:
+ ValueError: If ssd_config.type is not recognized (i.e. not registered in
+ model_class_map), or if lstm_config.interleave_strategy is not recognized.
+ ValueError: If unroll_length is not specified in the config file.
+ """
+ feature_extractor = _build_lstm_feature_extractor(
+ ssd_config.feature_extractor, is_training, lstm_config)
+
+ box_coder = box_coder_builder.build(ssd_config.box_coder)
+ matcher = matcher_builder.build(ssd_config.matcher)
+ region_similarity_calculator = sim_calc.build(
+ ssd_config.similarity_calculator)
+
+ num_classes = ssd_config.num_classes
+ ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build,
+ ssd_config.box_predictor,
+ is_training, num_classes)
+ anchor_generator = anchor_generator_builder.build(ssd_config.anchor_generator)
+ image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
+ non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
+ ssd_config.post_processing)
+ (classification_loss, localization_loss, classification_weight,
+ localization_weight, miner, _, _) = losses_builder.build(ssd_config.loss)
+
+ normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
+ encode_background_as_zeros = ssd_config.encode_background_as_zeros
+ negative_class_weight = ssd_config.negative_class_weight
+
+ # Extra configs for lstm unroll length.
+ unroll_length = None
+ if 'lstm' in ssd_config.feature_extractor.type:
+ if is_training:
+ unroll_length = lstm_config.train_unroll_length
+ else:
+ unroll_length = lstm_config.eval_unroll_length
+ if unroll_length is None:
+ raise ValueError('No unroll length found in the config file')
+
+ target_assigner_instance = target_assigner.TargetAssigner(
+ region_similarity_calculator,
+ matcher,
+ box_coder,
+ negative_class_weight=negative_class_weight)
+
+ lstm_model = lstm_ssd_meta_arch.LSTMSSDMetaArch(
+ is_training=is_training,
+ anchor_generator=anchor_generator,
+ box_predictor=ssd_box_predictor,
+ box_coder=box_coder,
+ feature_extractor=feature_extractor,
+ encode_background_as_zeros=encode_background_as_zeros,
+ image_resizer_fn=image_resizer_fn,
+ non_max_suppression_fn=non_max_suppression_fn,
+ score_conversion_fn=score_conversion_fn,
+ classification_loss=classification_loss,
+ localization_loss=localization_loss,
+ classification_loss_weight=classification_weight,
+ localization_loss_weight=localization_weight,
+ normalize_loss_by_num_matches=normalize_loss_by_num_matches,
+ hard_example_miner=miner,
+ unroll_length=unroll_length,
+ target_assigner_instance=target_assigner_instance)
+
+ return lstm_model
diff --git a/models/research/lstm_object_detection/model_builder_test.py b/models/research/lstm_object_detection/model_builder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d64b537cdc4044d5302845c53a1a3e4ac700f39
--- /dev/null
+++ b/models/research/lstm_object_detection/model_builder_test.py
@@ -0,0 +1,302 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for lstm_object_detection.tensorflow.model_builder."""
+
+import tensorflow.compat.v1 as tf
+from google.protobuf import text_format
+from lstm_object_detection import model_builder
+from lstm_object_detection.meta_architectures import lstm_ssd_meta_arch
+from lstm_object_detection.protos import pipeline_pb2 as internal_pipeline_pb2
+from object_detection.protos import pipeline_pb2
+
+
+class ModelBuilderTest(tf.test.TestCase):
+
+ def create_train_model(self, model_config, lstm_config):
+ """Builds a DetectionModel based on the model config.
+
+ Args:
+ model_config: A model.proto object containing the config for the desired
+ DetectionModel.
+ lstm_config: LstmModel config proto that specifies LSTM train/eval
+ configs.
+
+ Returns:
+ DetectionModel based on the config.
+ """
+ return model_builder.build(model_config, lstm_config, is_training=True)
+
+ def create_eval_model(self, model_config, lstm_config):
+ """Builds a DetectionModel based on the model config.
+
+ Args:
+ model_config: A model.proto object containing the config for the desired
+ DetectionModel.
+ lstm_config: LstmModel config proto that specifies LSTM train/eval
+ configs.
+
+ Returns:
+ DetectionModel based on the config.
+ """
+ return model_builder.build(model_config, lstm_config, is_training=False)
+
+ def get_model_configs_from_proto(self):
+ """Creates a model text proto for testing.
+
+ Returns:
+ A dictionary of model configs.
+ """
+
+ model_text_proto = """
+ [lstm_object_detection.protos.lstm_model] {
+ train_unroll_length: 4
+ eval_unroll_length: 4
+ }
+ model {
+ ssd {
+ feature_extractor {
+ type: 'lstm_ssd_mobilenet_v1'
+ conv_hyperparams {
+ regularizer {
+ l2_regularizer {
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ }
+ }
+ }
+ }
+ negative_class_weight: 2.0
+ box_coder {
+ faster_rcnn_box_coder {
+ }
+ }
+ matcher {
+ argmax_matcher {
+ }
+ }
+ similarity_calculator {
+ iou_similarity {
+ }
+ }
+ anchor_generator {
+ ssd_anchor_generator {
+ aspect_ratios: 1.0
+ }
+ }
+ image_resizer {
+ fixed_shape_resizer {
+ height: 320
+ width: 320
+ }
+ }
+ box_predictor {
+ convolutional_box_predictor {
+ conv_hyperparams {
+ regularizer {
+ l2_regularizer {
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ }
+ }
+ }
+ }
+ }
+ normalize_loc_loss_by_codesize: true
+ loss {
+ classification_loss {
+ weighted_softmax {
+ }
+ }
+ localization_loss {
+ weighted_smooth_l1 {
+ }
+ }
+ }
+ }
+ }"""
+
+ pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
+ text_format.Merge(model_text_proto, pipeline_config)
+
+ configs = {}
+ configs['model'] = pipeline_config.model
+ configs['lstm_model'] = pipeline_config.Extensions[
+ internal_pipeline_pb2.lstm_model]
+
+ return configs
+
+ def get_interleaved_model_configs_from_proto(self):
+ """Creates an interleaved model text proto for testing.
+
+ Returns:
+ A dictionary of model configs.
+ """
+
+ model_text_proto = """
+ [lstm_object_detection.protos.lstm_model] {
+ train_unroll_length: 4
+ eval_unroll_length: 10
+ lstm_state_depth: 320
+ depth_multipliers: 1.4
+ depth_multipliers: 0.35
+ pre_bottleneck: true
+ low_res: true
+ train_interleave_method: 'RANDOM_SKIP_SMALL'
+ eval_interleave_method: 'SKIP3'
+ }
+ model {
+ ssd {
+ feature_extractor {
+ type: 'lstm_ssd_interleaved_mobilenet_v2'
+ conv_hyperparams {
+ regularizer {
+ l2_regularizer {
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ }
+ }
+ }
+ }
+ negative_class_weight: 2.0
+ box_coder {
+ faster_rcnn_box_coder {
+ }
+ }
+ matcher {
+ argmax_matcher {
+ }
+ }
+ similarity_calculator {
+ iou_similarity {
+ }
+ }
+ anchor_generator {
+ ssd_anchor_generator {
+ aspect_ratios: 1.0
+ }
+ }
+ image_resizer {
+ fixed_shape_resizer {
+ height: 320
+ width: 320
+ }
+ }
+ box_predictor {
+ convolutional_box_predictor {
+ conv_hyperparams {
+ regularizer {
+ l2_regularizer {
+ }
+ }
+ initializer {
+ truncated_normal_initializer {
+ }
+ }
+ }
+ }
+ }
+ normalize_loc_loss_by_codesize: true
+ loss {
+ classification_loss {
+ weighted_softmax {
+ }
+ }
+ localization_loss {
+ weighted_smooth_l1 {
+ }
+ }
+ }
+ }
+ }"""
+
+ pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
+ text_format.Merge(model_text_proto, pipeline_config)
+
+ configs = {}
+ configs['model'] = pipeline_config.model
+ configs['lstm_model'] = pipeline_config.Extensions[
+ internal_pipeline_pb2.lstm_model]
+
+ return configs
+
+ def test_model_creation_from_valid_configs(self):
+ configs = self.get_model_configs_from_proto()
+ # Test model properties.
+ self.assertEqual(configs['model'].ssd.negative_class_weight, 2.0)
+ self.assertTrue(configs['model'].ssd.normalize_loc_loss_by_codesize)
+ self.assertEqual(configs['model'].ssd.feature_extractor.type,
+ 'lstm_ssd_mobilenet_v1')
+
+ model = self.create_train_model(configs['model'], configs['lstm_model'])
+ # Test architechture type.
+ self.assertIsInstance(model, lstm_ssd_meta_arch.LSTMSSDMetaArch)
+ # Test LSTM unroll length.
+ self.assertEqual(model.unroll_length, 4)
+
+ model = self.create_eval_model(configs['model'], configs['lstm_model'])
+ # Test architechture type.
+ self.assertIsInstance(model, lstm_ssd_meta_arch.LSTMSSDMetaArch)
+ # Test LSTM configs.
+ self.assertEqual(model.unroll_length, 4)
+
+ def test_interleaved_model_creation_from_valid_configs(self):
+ configs = self.get_interleaved_model_configs_from_proto()
+ # Test model properties.
+ self.assertEqual(configs['model'].ssd.negative_class_weight, 2.0)
+ self.assertTrue(configs['model'].ssd.normalize_loc_loss_by_codesize)
+ self.assertEqual(configs['model'].ssd.feature_extractor.type,
+ 'lstm_ssd_interleaved_mobilenet_v2')
+
+ model = self.create_train_model(configs['model'], configs['lstm_model'])
+ # Test architechture type.
+ self.assertIsInstance(model, lstm_ssd_meta_arch.LSTMSSDMetaArch)
+ # Test LSTM configs.
+ self.assertEqual(model.unroll_length, 4)
+ self.assertEqual(model._feature_extractor.lstm_state_depth, 320)
+ self.assertAllClose(model._feature_extractor.depth_multipliers, (1.4, 0.35))
+ self.assertTrue(model._feature_extractor.pre_bottleneck)
+ self.assertTrue(model._feature_extractor.low_res)
+ self.assertEqual(model._feature_extractor.interleave_method,
+ 'RANDOM_SKIP_SMALL')
+
+ model = self.create_eval_model(configs['model'], configs['lstm_model'])
+ # Test architechture type.
+ self.assertIsInstance(model, lstm_ssd_meta_arch.LSTMSSDMetaArch)
+ # Test LSTM configs.
+ self.assertEqual(model.unroll_length, 10)
+ self.assertEqual(model._feature_extractor.lstm_state_depth, 320)
+ self.assertAllClose(model._feature_extractor.depth_multipliers, (1.4, 0.35))
+ self.assertTrue(model._feature_extractor.pre_bottleneck)
+ self.assertTrue(model._feature_extractor.low_res)
+ self.assertEqual(model._feature_extractor.interleave_method, 'SKIP3')
+
+ def test_model_creation_from_invalid_configs(self):
+ configs = self.get_model_configs_from_proto()
+ # Test model build failure with wrong input configs.
+ with self.assertRaises(AttributeError):
+ _ = self.create_train_model(configs['model'], configs['model'])
+ with self.assertRaises(AttributeError):
+ _ = self.create_eval_model(configs['model'], configs['model'])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/models/__init__.py b/models/research/lstm_object_detection/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/lstm_object_detection/models/lstm_ssd_interleaved_mobilenet_v2_feature_extractor.py b/models/research/lstm_object_detection/models/lstm_ssd_interleaved_mobilenet_v2_feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a2d4bd0bdceb39801b46b864f512273ae10f8bc
--- /dev/null
+++ b/models/research/lstm_object_detection/models/lstm_ssd_interleaved_mobilenet_v2_feature_extractor.py
@@ -0,0 +1,298 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""LSTDInterleavedFeatureExtractor which interleaves multiple MobileNet V2."""
+
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+
+from tensorflow.python.framework import ops as tf_ops
+from lstm_object_detection.lstm import lstm_cells
+from lstm_object_detection.lstm import rnn_decoder
+from lstm_object_detection.meta_architectures import lstm_ssd_meta_arch
+from lstm_object_detection.models import mobilenet_defs
+from object_detection.models import feature_map_generators
+from object_detection.utils import ops
+from object_detection.utils import shape_utils
+from nets.mobilenet import mobilenet
+from nets.mobilenet import mobilenet_v2
+
+
+class LSTMSSDInterleavedMobilenetV2FeatureExtractor(
+ lstm_ssd_meta_arch.LSTMSSDInterleavedFeatureExtractor):
+ """LSTM-SSD Interleaved Feature Extractor using MobilenetV2 features."""
+
+ def __init__(self,
+ is_training,
+ depth_multiplier,
+ min_depth,
+ pad_to_multiple,
+ conv_hyperparams_fn,
+ reuse_weights=None,
+ use_explicit_padding=False,
+ use_depthwise=True,
+ override_base_feature_extractor_hyperparams=False):
+ """Interleaved Feature Extractor for LSTD Models with MobileNet v2.
+
+ Args:
+ is_training: whether the network is in training mode.
+ depth_multiplier: float depth multiplier for feature extractor.
+ min_depth: minimum feature extractor depth.
+ pad_to_multiple: the nearest multiple to zero pad the input height and
+ width dimensions to.
+ conv_hyperparams_fn: A function to construct tf slim arg_scope for conv2d
+ and separable_conv2d ops in the layers that are added on top of the
+ base feature extractor.
+ reuse_weights: Whether to reuse variables. Default is None.
+ use_explicit_padding: Whether to use explicit padding when extracting
+ features. Default is False.
+ use_depthwise: Whether to use depthwise convolutions. Default is True.
+ override_base_feature_extractor_hyperparams: Whether to override
+ hyperparameters of the base feature extractor with the one from
+ `conv_hyperparams_fn`.
+ """
+ super(LSTMSSDInterleavedMobilenetV2FeatureExtractor, self).__init__(
+ is_training=is_training,
+ depth_multiplier=depth_multiplier,
+ min_depth=min_depth,
+ pad_to_multiple=pad_to_multiple,
+ conv_hyperparams_fn=conv_hyperparams_fn,
+ reuse_weights=reuse_weights,
+ use_explicit_padding=use_explicit_padding,
+ use_depthwise=use_depthwise,
+ override_base_feature_extractor_hyperparams=
+ override_base_feature_extractor_hyperparams)
+ # RANDOM_SKIP_SMALL means the training policy is random and the small model
+ # does not update state during training.
+ if self._is_training:
+ self._interleave_method = 'RANDOM_SKIP_SMALL'
+ else:
+ self._interleave_method = 'SKIP9'
+
+ self._flatten_state = False
+ self._scale_state = False
+ self._clip_state = True
+ self._pre_bottleneck = True
+ self._feature_map_layout = {
+ 'from_layer': ['layer_19', '', '', '', ''],
+ 'layer_depth': [-1, 256, 256, 256, 256],
+ 'use_depthwise': self._use_depthwise,
+ 'use_explicit_padding': self._use_explicit_padding,
+ }
+ self._low_res = True
+ self._base_network_scope = 'MobilenetV2'
+
+ def extract_base_features_large(self, preprocessed_inputs):
+ """Extract the large base model features.
+
+ Variables are created under the scope of /MobilenetV2_1/
+
+ Args:
+ preprocessed_inputs: preprocessed input images of shape:
+ [batch, width, height, depth].
+
+ Returns:
+ net: the last feature map created from the base feature extractor.
+ end_points: a dictionary of feature maps created.
+ """
+ scope_name = self._base_network_scope + '_1'
+ with tf.variable_scope(scope_name, reuse=self._reuse_weights) as base_scope:
+ net, end_points = mobilenet_v2.mobilenet_base(
+ preprocessed_inputs,
+ depth_multiplier=self._depth_multipliers[0],
+ conv_defs=mobilenet_defs.mobilenet_v2_lite_def(
+ is_quantized=self._is_quantized),
+ use_explicit_padding=self._use_explicit_padding,
+ scope=base_scope)
+ return net, end_points
+
+ def extract_base_features_small(self, preprocessed_inputs):
+ """Extract the small base model features.
+
+ Variables are created under the scope of /MobilenetV2_2/
+
+ Args:
+ preprocessed_inputs: preprocessed input images of shape:
+ [batch, width, height, depth].
+
+ Returns:
+ net: the last feature map created from the base feature extractor.
+ end_points: a dictionary of feature maps created.
+ """
+ scope_name = self._base_network_scope + '_2'
+ with tf.variable_scope(scope_name, reuse=self._reuse_weights) as base_scope:
+ if self._low_res:
+ height_small = preprocessed_inputs.get_shape().as_list()[1] // 2
+ width_small = preprocessed_inputs.get_shape().as_list()[2] // 2
+ inputs_small = tf.image.resize_images(preprocessed_inputs,
+ [height_small, width_small])
+ # Create end point handle for tflite deployment.
+ with tf.name_scope(None):
+ inputs_small = tf.identity(
+ inputs_small, name='normalized_input_image_tensor_small')
+ else:
+ inputs_small = preprocessed_inputs
+ net, end_points = mobilenet_v2.mobilenet_base(
+ inputs_small,
+ depth_multiplier=self._depth_multipliers[1],
+ conv_defs=mobilenet_defs.mobilenet_v2_lite_def(
+ is_quantized=self._is_quantized, low_res=self._low_res),
+ use_explicit_padding=self._use_explicit_padding,
+ scope=base_scope)
+ return net, end_points
+
+ def create_lstm_cell(self, batch_size, output_size, state_saver, state_name,
+ dtype=tf.float32):
+ """Create the LSTM cell, and initialize state if necessary.
+
+ Args:
+ batch_size: input batch size.
+ output_size: output size of the lstm cell, [width, height].
+ state_saver: a state saver object with methods `state` and `save_state`.
+ state_name: string, the name to use with the state_saver.
+ dtype: dtype to initialize lstm state.
+
+ Returns:
+ lstm_cell: the lstm cell unit.
+ init_state: initial state representations.
+ step: the step
+ """
+ lstm_cell = lstm_cells.GroupedConvLSTMCell(
+ filter_size=(3, 3),
+ output_size=output_size,
+ num_units=max(self._min_depth, self._lstm_state_depth),
+ is_training=self._is_training,
+ activation=tf.nn.relu6,
+ flatten_state=self._flatten_state,
+ scale_state=self._scale_state,
+ clip_state=self._clip_state,
+ output_bottleneck=True,
+ pre_bottleneck=self._pre_bottleneck,
+ is_quantized=self._is_quantized,
+ visualize_gates=False)
+
+ if state_saver is None:
+ init_state = lstm_cell.init_state('lstm_state', batch_size, dtype)
+ step = None
+ else:
+ step = state_saver.state(state_name + '_step')
+ c = state_saver.state(state_name + '_c')
+ h = state_saver.state(state_name + '_h')
+ c.set_shape([batch_size] + c.get_shape().as_list()[1:])
+ h.set_shape([batch_size] + h.get_shape().as_list()[1:])
+ init_state = (c, h)
+ return lstm_cell, init_state, step
+
+ def extract_features(self, preprocessed_inputs, state_saver=None,
+ state_name='lstm_state', unroll_length=10, scope=None):
+ """Extract features from preprocessed inputs.
+
+ The features include the base network features, lstm features and SSD
+ features, organized in the following name scope:
+
+ /MobilenetV2_1/...
+ /MobilenetV2_2/...
+ /LSTM/...
+ /FeatureMap/...
+
+ Args:
+ preprocessed_inputs: a [batch, height, width, channels] float tensor
+ representing a batch of consecutive frames from video clips.
+ state_saver: A state saver object with methods `state` and `save_state`.
+ state_name: Python string, the name to use with the state_saver.
+ unroll_length: number of steps to unroll the lstm.
+ scope: Scope for the base network of the feature extractor.
+
+ Returns:
+ feature_maps: a list of tensors where the ith tensor has shape
+ [batch, height_i, width_i, depth_i]
+ Raises:
+ ValueError: if interleave_method not recognized or large and small base
+ network output feature maps of different sizes.
+ """
+ preprocessed_inputs = shape_utils.check_min_image_dim(
+ 33, preprocessed_inputs)
+ preprocessed_inputs = ops.pad_to_multiple(
+ preprocessed_inputs, self._pad_to_multiple)
+ batch_size = preprocessed_inputs.shape[0].value // unroll_length
+ batch_axis = 0
+ nets = []
+
+ # Batch processing of mobilenet features.
+ with slim.arg_scope(mobilenet_v2.training_scope(
+ is_training=self._is_training,
+ bn_decay=0.9997)), \
+ slim.arg_scope([mobilenet.depth_multiplier],
+ min_depth=self._min_depth, divisible_by=8):
+ # Big model.
+ net, _ = self.extract_base_features_large(preprocessed_inputs)
+ nets.append(net)
+ large_base_feature_shape = net.shape
+
+ # Small models
+ net, _ = self.extract_base_features_small(preprocessed_inputs)
+ nets.append(net)
+ small_base_feature_shape = net.shape
+ if not (large_base_feature_shape[1] == small_base_feature_shape[1] and
+ large_base_feature_shape[2] == small_base_feature_shape[2]):
+ raise ValueError('Large and Small base network feature map dimension '
+ 'not equal!')
+
+ with slim.arg_scope(self._conv_hyperparams_fn()):
+ with tf.variable_scope('LSTM', reuse=self._reuse_weights):
+ output_size = (large_base_feature_shape[1], large_base_feature_shape[2])
+ lstm_cell, init_state, step = self.create_lstm_cell(
+ batch_size, output_size, state_saver, state_name,
+ dtype=preprocessed_inputs.dtype)
+
+ nets_seq = [
+ tf.split(net, unroll_length, axis=batch_axis) for net in nets
+ ]
+
+ net_seq, states_out = rnn_decoder.multi_input_rnn_decoder(
+ nets_seq,
+ init_state,
+ lstm_cell,
+ step,
+ selection_strategy=self._interleave_method,
+ is_training=self._is_training,
+ is_quantized=self._is_quantized,
+ pre_bottleneck=self._pre_bottleneck,
+ flatten_state=self._flatten_state,
+ scope=None)
+ self._states_out = states_out
+
+ image_features = {}
+ if state_saver is not None:
+ self._step = state_saver.state(state_name + '_step')
+ batcher_ops = [
+ state_saver.save_state(state_name + '_c', states_out[-1][0]),
+ state_saver.save_state(state_name + '_h', states_out[-1][1]),
+ state_saver.save_state(state_name + '_step', self._step + 1)]
+ with tf_ops.control_dependencies(batcher_ops):
+ image_features['layer_19'] = tf.concat(net_seq, 0)
+ else:
+ image_features['layer_19'] = tf.concat(net_seq, 0)
+
+ # SSD layers.
+ with tf.variable_scope('FeatureMap'):
+ feature_maps = feature_map_generators.multi_resolution_feature_maps(
+ feature_map_layout=self._feature_map_layout,
+ depth_multiplier=self._depth_multiplier,
+ min_depth=self._min_depth,
+ insert_1x1_conv=True,
+ image_features=image_features,
+ pool_residual=True)
+ return list(feature_maps.values())
diff --git a/models/research/lstm_object_detection/models/lstm_ssd_interleaved_mobilenet_v2_feature_extractor_test.py b/models/research/lstm_object_detection/models/lstm_ssd_interleaved_mobilenet_v2_feature_extractor_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b285f0e44417a309f54973327c16b55c1169260f
--- /dev/null
+++ b/models/research/lstm_object_detection/models/lstm_ssd_interleaved_mobilenet_v2_feature_extractor_test.py
@@ -0,0 +1,352 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for lstm_ssd_interleaved_mobilenet_v2_feature_extractor."""
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+from tensorflow.contrib import training as contrib_training
+
+from lstm_object_detection.models import lstm_ssd_interleaved_mobilenet_v2_feature_extractor
+from object_detection.models import ssd_feature_extractor_test
+
+
+class LSTMSSDInterleavedMobilenetV2FeatureExtractorTest(
+ ssd_feature_extractor_test.SsdFeatureExtractorTestBase):
+
+ def _create_feature_extractor(self,
+ depth_multiplier,
+ pad_to_multiple,
+ is_quantized=False):
+ """Constructs a new feature extractor.
+
+ Args:
+ depth_multiplier: float depth multiplier for feature extractor
+ pad_to_multiple: the nearest multiple to zero pad the input height and
+ width dimensions to.
+ is_quantized: whether to quantize the graph.
+ Returns:
+ an ssd_meta_arch.SSDFeatureExtractor object.
+ """
+ min_depth = 32
+ def conv_hyperparams_fn():
+ with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm), \
+ slim.arg_scope([slim.batch_norm], is_training=False) as sc:
+ return sc
+ feature_extractor = (
+ lstm_ssd_interleaved_mobilenet_v2_feature_extractor
+ .LSTMSSDInterleavedMobilenetV2FeatureExtractor(False, depth_multiplier,
+ min_depth,
+ pad_to_multiple,
+ conv_hyperparams_fn))
+ feature_extractor.lstm_state_depth = int(320 * depth_multiplier)
+ feature_extractor.depth_multipliers = [
+ depth_multiplier, depth_multiplier / 4.0
+ ]
+ feature_extractor.is_quantized = is_quantized
+ return feature_extractor
+
+ def test_feature_extractor_construct_with_expected_params(self):
+ def conv_hyperparams_fn():
+ with (slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm) and
+ slim.arg_scope([slim.batch_norm], decay=0.97, epsilon=1e-3)) as sc:
+ return sc
+
+ params = {
+ 'is_training': True,
+ 'depth_multiplier': .55,
+ 'min_depth': 9,
+ 'pad_to_multiple': 3,
+ 'conv_hyperparams_fn': conv_hyperparams_fn,
+ 'reuse_weights': False,
+ 'use_explicit_padding': True,
+ 'use_depthwise': False,
+ 'override_base_feature_extractor_hyperparams': True}
+
+ feature_extractor = (
+ lstm_ssd_interleaved_mobilenet_v2_feature_extractor
+ .LSTMSSDInterleavedMobilenetV2FeatureExtractor(**params))
+
+ self.assertEqual(params['is_training'],
+ feature_extractor._is_training)
+ self.assertEqual(params['depth_multiplier'],
+ feature_extractor._depth_multiplier)
+ self.assertEqual(params['min_depth'],
+ feature_extractor._min_depth)
+ self.assertEqual(params['pad_to_multiple'],
+ feature_extractor._pad_to_multiple)
+ self.assertEqual(params['conv_hyperparams_fn'],
+ feature_extractor._conv_hyperparams_fn)
+ self.assertEqual(params['reuse_weights'],
+ feature_extractor._reuse_weights)
+ self.assertEqual(params['use_explicit_padding'],
+ feature_extractor._use_explicit_padding)
+ self.assertEqual(params['use_depthwise'],
+ feature_extractor._use_depthwise)
+ self.assertEqual(params['override_base_feature_extractor_hyperparams'],
+ (feature_extractor.
+ _override_base_feature_extractor_hyperparams))
+
+ def test_extract_features_returns_correct_shapes_128(self):
+ image_height = 128
+ image_width = 128
+ depth_multiplier = 1.0
+ pad_to_multiple = 1
+ expected_feature_map_shape = [(2, 4, 4, 640),
+ (2, 2, 2, 256), (2, 1, 1, 256),
+ (2, 1, 1, 256), (2, 1, 1, 256)]
+ self.check_extract_features_returns_correct_shape(
+ 2, image_height, image_width, depth_multiplier, pad_to_multiple,
+ expected_feature_map_shape)
+
+ def test_extract_features_returns_correct_shapes_unroll10(self):
+ image_height = 128
+ image_width = 128
+ depth_multiplier = 1.0
+ pad_to_multiple = 1
+ expected_feature_map_shape = [(10, 4, 4, 640),
+ (10, 2, 2, 256), (10, 1, 1, 256),
+ (10, 1, 1, 256), (10, 1, 1, 256)]
+ self.check_extract_features_returns_correct_shape(
+ 10, image_height, image_width, depth_multiplier, pad_to_multiple,
+ expected_feature_map_shape, unroll_length=10)
+
+ def test_extract_features_returns_correct_shapes_320(self):
+ image_height = 320
+ image_width = 320
+ depth_multiplier = 1.0
+ pad_to_multiple = 1
+ expected_feature_map_shape = [(2, 10, 10, 640),
+ (2, 5, 5, 256), (2, 3, 3, 256),
+ (2, 2, 2, 256), (2, 1, 1, 256)]
+ self.check_extract_features_returns_correct_shape(
+ 2, image_height, image_width, depth_multiplier, pad_to_multiple,
+ expected_feature_map_shape)
+
+ def test_extract_features_returns_correct_shapes_enforcing_min_depth(self):
+ image_height = 320
+ image_width = 320
+ depth_multiplier = 0.5**12
+ pad_to_multiple = 1
+ expected_feature_map_shape = [(2, 10, 10, 64),
+ (2, 5, 5, 32), (2, 3, 3, 32),
+ (2, 2, 2, 32), (2, 1, 1, 32)]
+ self.check_extract_features_returns_correct_shape(
+ 2, image_height, image_width, depth_multiplier, pad_to_multiple,
+ expected_feature_map_shape)
+
+ def test_extract_features_returns_correct_shapes_with_pad_to_multiple(self):
+ image_height = 299
+ image_width = 299
+ depth_multiplier = 1.0
+ pad_to_multiple = 32
+ expected_feature_map_shape = [(2, 10, 10, 640),
+ (2, 5, 5, 256), (2, 3, 3, 256),
+ (2, 2, 2, 256), (2, 1, 1, 256)]
+ self.check_extract_features_returns_correct_shape(
+ 2, image_height, image_width, depth_multiplier, pad_to_multiple,
+ expected_feature_map_shape)
+
+ def test_preprocess_returns_correct_value_range(self):
+ image_height = 128
+ image_width = 128
+ depth_multiplier = 1
+ pad_to_multiple = 1
+ test_image = np.random.rand(4, image_height, image_width, 3)
+ feature_extractor = self._create_feature_extractor(depth_multiplier,
+ pad_to_multiple)
+ preprocessed_image = feature_extractor.preprocess(test_image)
+ self.assertTrue(np.all(np.less_equal(np.abs(preprocessed_image), 1.0)))
+
+ def test_variables_only_created_in_scope(self):
+ depth_multiplier = 1
+ pad_to_multiple = 1
+ scope_names = ['MobilenetV2', 'LSTM', 'FeatureMap']
+ self.check_feature_extractor_variables_under_scopes(
+ depth_multiplier, pad_to_multiple, scope_names)
+
+ def test_has_fused_batchnorm(self):
+ image_height = 40
+ image_width = 40
+ depth_multiplier = 1
+ pad_to_multiple = 32
+ image_placeholder = tf.placeholder(tf.float32,
+ [1, image_height, image_width, 3])
+ feature_extractor = self._create_feature_extractor(depth_multiplier,
+ pad_to_multiple)
+ preprocessed_image = feature_extractor.preprocess(image_placeholder)
+ _ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)
+ self.assertTrue(any(op.type.startswith('FusedBatchNorm')
+ for op in tf.get_default_graph().get_operations()))
+
+ def test_variables_for_tflite(self):
+ image_height = 40
+ image_width = 40
+ depth_multiplier = 1
+ pad_to_multiple = 32
+ image_placeholder = tf.placeholder(tf.float32,
+ [1, image_height, image_width, 3])
+ feature_extractor = self._create_feature_extractor(depth_multiplier,
+ pad_to_multiple)
+ preprocessed_image = feature_extractor.preprocess(image_placeholder)
+ tflite_unsupported = ['SquaredDifference']
+ _ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)
+ self.assertFalse(any(op.type in tflite_unsupported
+ for op in tf.get_default_graph().get_operations()))
+
+ def test_output_nodes_for_tflite(self):
+ image_height = 64
+ image_width = 64
+ depth_multiplier = 1.0
+ pad_to_multiple = 1
+ image_placeholder = tf.placeholder(tf.float32,
+ [1, image_height, image_width, 3])
+ feature_extractor = self._create_feature_extractor(depth_multiplier,
+ pad_to_multiple)
+ preprocessed_image = feature_extractor.preprocess(image_placeholder)
+ _ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)
+
+ tflite_nodes = [
+ 'raw_inputs/init_lstm_c',
+ 'raw_inputs/init_lstm_h',
+ 'raw_inputs/base_endpoint',
+ 'raw_outputs/lstm_c',
+ 'raw_outputs/lstm_h',
+ 'raw_outputs/base_endpoint_1',
+ 'raw_outputs/base_endpoint_2'
+ ]
+ ops_names = [op.name for op in tf.get_default_graph().get_operations()]
+ for node in tflite_nodes:
+ self.assertTrue(any(node in s for s in ops_names))
+
+ def test_fixed_concat_nodes(self):
+ image_height = 64
+ image_width = 64
+ depth_multiplier = 1.0
+ pad_to_multiple = 1
+ image_placeholder = tf.placeholder(tf.float32,
+ [1, image_height, image_width, 3])
+ feature_extractor = self._create_feature_extractor(
+ depth_multiplier, pad_to_multiple, is_quantized=True)
+ preprocessed_image = feature_extractor.preprocess(image_placeholder)
+ _ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)
+
+ concat_nodes = [
+ 'MobilenetV2_1/expanded_conv_16/project/Relu6',
+ 'MobilenetV2_2/expanded_conv_16/project/Relu6'
+ ]
+ ops_names = [op.name for op in tf.get_default_graph().get_operations()]
+ for node in concat_nodes:
+ self.assertTrue(any(node in s for s in ops_names))
+
+ def test_lstm_states(self):
+ image_height = 256
+ image_width = 256
+ depth_multiplier = 1
+ pad_to_multiple = 1
+ state_channel = 320
+ init_state1 = {
+ 'lstm_state_c': tf.zeros(
+ [image_height // 32, image_width // 32, state_channel]),
+ 'lstm_state_h': tf.zeros(
+ [image_height // 32, image_width // 32, state_channel]),
+ 'lstm_state_step': tf.zeros([1])
+ }
+ init_state2 = {
+ 'lstm_state_c': tf.random_uniform(
+ [image_height // 32, image_width // 32, state_channel]),
+ 'lstm_state_h': tf.random_uniform(
+ [image_height // 32, image_width // 32, state_channel]),
+ 'lstm_state_step': tf.zeros([1])
+ }
+ seq = {'dummy': tf.random_uniform([2, 1, 1, 1])}
+ stateful_reader1 = contrib_training.SequenceQueueingStateSaver(
+ batch_size=1,
+ num_unroll=1,
+ input_length=2,
+ input_key='',
+ input_sequences=seq,
+ input_context={},
+ initial_states=init_state1,
+ capacity=1)
+ stateful_reader2 = contrib_training.SequenceQueueingStateSaver(
+ batch_size=1,
+ num_unroll=1,
+ input_length=2,
+ input_key='',
+ input_sequences=seq,
+ input_context={},
+ initial_states=init_state2,
+ capacity=1)
+ image = tf.random_uniform([1, image_height, image_width, 3])
+ feature_extractor = self._create_feature_extractor(depth_multiplier,
+ pad_to_multiple)
+ with tf.variable_scope('zero_state'):
+ feature_maps1 = feature_extractor.extract_features(
+ image, stateful_reader1.next_batch, unroll_length=1)
+ with tf.variable_scope('random_state'):
+ feature_maps2 = feature_extractor.extract_features(
+ image, stateful_reader2.next_batch, unroll_length=1)
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ sess.run(tf.local_variables_initializer())
+ sess.run(tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS))
+ sess.run([stateful_reader1.prefetch_op, stateful_reader2.prefetch_op])
+ maps1, maps2 = sess.run([feature_maps1, feature_maps2])
+ state = sess.run(stateful_reader1.next_batch.state('lstm_state_c'))
+ # feature maps should be different because states are different
+ self.assertFalse(np.all(np.equal(maps1[0], maps2[0])))
+ # state should no longer be zero after update
+ self.assertTrue(state.any())
+
+ def check_extract_features_returns_correct_shape(
+ self, batch_size, image_height, image_width, depth_multiplier,
+ pad_to_multiple, expected_feature_map_shapes, unroll_length=1):
+ def graph_fn(image_tensor):
+ feature_extractor = self._create_feature_extractor(depth_multiplier,
+ pad_to_multiple)
+ feature_maps = feature_extractor.extract_features(
+ image_tensor, unroll_length=unroll_length)
+ return feature_maps
+
+ image_tensor = np.random.rand(batch_size, image_height, image_width,
+ 3).astype(np.float32)
+ feature_maps = self.execute(graph_fn, [image_tensor])
+ for feature_map, expected_shape in zip(
+ feature_maps, expected_feature_map_shapes):
+ self.assertAllEqual(feature_map.shape, expected_shape)
+
+ def check_feature_extractor_variables_under_scopes(
+ self, depth_multiplier, pad_to_multiple, scope_names):
+ g = tf.Graph()
+ with g.as_default():
+ feature_extractor = self._create_feature_extractor(
+ depth_multiplier, pad_to_multiple)
+ preprocessed_inputs = tf.placeholder(tf.float32, (4, 320, 320, 3))
+ feature_extractor.extract_features(
+ preprocessed_inputs, unroll_length=1)
+ variables = g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
+ for variable in variables:
+ self.assertTrue(
+ any([
+ variable.name.startswith(scope_name)
+ for scope_name in scope_names
+ ]), 'Variable name: ' + variable.name +
+ ' is not under any provided scopes: ' + ','.join(scope_names))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/models/lstm_ssd_mobilenet_v1_feature_extractor.py b/models/research/lstm_object_detection/models/lstm_ssd_mobilenet_v1_feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..cccf740aadd337d29bec56a7fed93fc6937fc123
--- /dev/null
+++ b/models/research/lstm_object_detection/models/lstm_ssd_mobilenet_v1_feature_extractor.py
@@ -0,0 +1,211 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""LSTMSSDFeatureExtractor for MobilenetV1 features."""
+
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+from tensorflow.python.framework import ops as tf_ops
+from lstm_object_detection.lstm import lstm_cells
+from lstm_object_detection.lstm import rnn_decoder
+from lstm_object_detection.meta_architectures import lstm_ssd_meta_arch
+from object_detection.models import feature_map_generators
+from object_detection.utils import context_manager
+from object_detection.utils import ops
+from object_detection.utils import shape_utils
+from nets import mobilenet_v1
+
+
+class LSTMSSDMobileNetV1FeatureExtractor(
+ lstm_ssd_meta_arch.LSTMSSDFeatureExtractor):
+ """LSTM Feature Extractor using MobilenetV1 features."""
+
+ def __init__(self,
+ is_training,
+ depth_multiplier,
+ min_depth,
+ pad_to_multiple,
+ conv_hyperparams_fn,
+ reuse_weights=None,
+ use_explicit_padding=False,
+ use_depthwise=True,
+ override_base_feature_extractor_hyperparams=False,
+ lstm_state_depth=256):
+ """Initializes instance of MobileNetV1 Feature Extractor for LSTMSSD Models.
+
+ Args:
+ is_training: A boolean whether the network is in training mode.
+ depth_multiplier: A float depth multiplier for feature extractor.
+ min_depth: A number representing minimum feature extractor depth.
+ pad_to_multiple: The nearest multiple to zero pad the input height and
+ width dimensions to.
+ conv_hyperparams_fn: A function to construct tf slim arg_scope for conv2d
+ and separable_conv2d ops in the layers that are added on top of the
+ base feature extractor.
+ reuse_weights: Whether to reuse variables. Default is None.
+ use_explicit_padding: Whether to use explicit padding when extracting
+ features. Default is False.
+ use_depthwise: Whether to use depthwise convolutions. Default is True.
+ override_base_feature_extractor_hyperparams: Whether to override
+ hyperparameters of the base feature extractor with the one from
+ `conv_hyperparams_fn`.
+ lstm_state_depth: An integter of the depth of the lstm state.
+ """
+ super(LSTMSSDMobileNetV1FeatureExtractor, self).__init__(
+ is_training=is_training,
+ depth_multiplier=depth_multiplier,
+ min_depth=min_depth,
+ pad_to_multiple=pad_to_multiple,
+ conv_hyperparams_fn=conv_hyperparams_fn,
+ reuse_weights=reuse_weights,
+ use_explicit_padding=use_explicit_padding,
+ use_depthwise=use_depthwise,
+ override_base_feature_extractor_hyperparams=
+ override_base_feature_extractor_hyperparams)
+ self._feature_map_layout = {
+ 'from_layer': ['Conv2d_13_pointwise_lstm', '', '', '', ''],
+ 'layer_depth': [-1, 512, 256, 256, 128],
+ 'use_explicit_padding': self._use_explicit_padding,
+ 'use_depthwise': self._use_depthwise,
+ }
+ self._base_network_scope = 'MobilenetV1'
+ self._lstm_state_depth = lstm_state_depth
+
+ def create_lstm_cell(self, batch_size, output_size, state_saver, state_name,
+ dtype=tf.float32):
+ """Create the LSTM cell, and initialize state if necessary.
+
+ Args:
+ batch_size: input batch size.
+ output_size: output size of the lstm cell, [width, height].
+ state_saver: a state saver object with methods `state` and `save_state`.
+ state_name: string, the name to use with the state_saver.
+ dtype: dtype to initialize lstm state.
+
+ Returns:
+ lstm_cell: the lstm cell unit.
+ init_state: initial state representations.
+ step: the step
+ """
+ lstm_cell = lstm_cells.BottleneckConvLSTMCell(
+ filter_size=(3, 3),
+ output_size=output_size,
+ num_units=max(self._min_depth, self._lstm_state_depth),
+ activation=tf.nn.relu6,
+ visualize_gates=False)
+
+ if state_saver is None:
+ init_state = lstm_cell.init_state(state_name, batch_size, dtype)
+ step = None
+ else:
+ step = state_saver.state(state_name + '_step')
+ c = state_saver.state(state_name + '_c')
+ h = state_saver.state(state_name + '_h')
+ init_state = (c, h)
+ return lstm_cell, init_state, step
+
+ def extract_features(self,
+ preprocessed_inputs,
+ state_saver=None,
+ state_name='lstm_state',
+ unroll_length=5,
+ scope=None):
+ """Extracts features from preprocessed inputs.
+
+ The features include the base network features, lstm features and SSD
+ features, organized in the following name scope:
+
+ /MobilenetV1/...
+ /LSTM/...
+ /FeatureMaps/...
+
+ Args:
+ preprocessed_inputs: A [batch, height, width, channels] float tensor
+ representing a batch of consecutive frames from video clips.
+ state_saver: A state saver object with methods `state` and `save_state`.
+ state_name: A python string for the name to use with the state_saver.
+ unroll_length: The number of steps to unroll the lstm.
+ scope: The scope for the base network of the feature extractor.
+
+ Returns:
+ A list of tensors where the ith tensor has shape [batch, height_i,
+ width_i, depth_i]
+ """
+ preprocessed_inputs = shape_utils.check_min_image_dim(
+ 33, preprocessed_inputs)
+ with slim.arg_scope(
+ mobilenet_v1.mobilenet_v1_arg_scope(is_training=self._is_training)):
+ with (slim.arg_scope(self._conv_hyperparams_fn())
+ if self._override_base_feature_extractor_hyperparams else
+ context_manager.IdentityContextManager()):
+ with slim.arg_scope([slim.batch_norm], fused=False):
+ # Base network.
+ with tf.variable_scope(
+ scope, self._base_network_scope,
+ reuse=self._reuse_weights) as scope:
+ net, image_features = mobilenet_v1.mobilenet_v1_base(
+ ops.pad_to_multiple(preprocessed_inputs, self._pad_to_multiple),
+ final_endpoint='Conv2d_13_pointwise',
+ min_depth=self._min_depth,
+ depth_multiplier=self._depth_multiplier,
+ scope=scope)
+
+ with slim.arg_scope(self._conv_hyperparams_fn()):
+ with slim.arg_scope(
+ [slim.batch_norm], fused=False, is_training=self._is_training):
+ # ConvLSTM layers.
+ batch_size = net.shape[0].value // unroll_length
+ with tf.variable_scope('LSTM', reuse=self._reuse_weights) as lstm_scope:
+ lstm_cell, init_state, _ = self.create_lstm_cell(
+ batch_size,
+ (net.shape[1].value, net.shape[2].value),
+ state_saver,
+ state_name,
+ dtype=preprocessed_inputs.dtype)
+ net_seq = list(tf.split(net, unroll_length))
+
+ # Identities added for inputing state tensors externally.
+ c_ident = tf.identity(init_state[0], name='lstm_state_in_c')
+ h_ident = tf.identity(init_state[1], name='lstm_state_in_h')
+ init_state = (c_ident, h_ident)
+
+ net_seq, states_out = rnn_decoder.rnn_decoder(
+ net_seq, init_state, lstm_cell, scope=lstm_scope)
+ batcher_ops = None
+ self._states_out = states_out
+ if state_saver is not None:
+ self._step = state_saver.state('%s_step' % state_name)
+ batcher_ops = [
+ state_saver.save_state('%s_c' % state_name, states_out[-1][0]),
+ state_saver.save_state('%s_h' % state_name, states_out[-1][1]),
+ state_saver.save_state('%s_step' % state_name, self._step + 1)
+ ]
+ with tf_ops.control_dependencies(batcher_ops):
+ image_features['Conv2d_13_pointwise_lstm'] = tf.concat(net_seq, 0)
+
+ # Identities added for reading output states, to be reused externally.
+ tf.identity(states_out[-1][0], name='lstm_state_out_c')
+ tf.identity(states_out[-1][1], name='lstm_state_out_h')
+
+ # SSD layers.
+ with tf.variable_scope('FeatureMaps', reuse=self._reuse_weights):
+ feature_maps = feature_map_generators.multi_resolution_feature_maps(
+ feature_map_layout=self._feature_map_layout,
+ depth_multiplier=(self._depth_multiplier),
+ min_depth=self._min_depth,
+ insert_1x1_conv=True,
+ image_features=image_features)
+
+ return list(feature_maps.values())
diff --git a/models/research/lstm_object_detection/models/lstm_ssd_mobilenet_v1_feature_extractor_test.py b/models/research/lstm_object_detection/models/lstm_ssd_mobilenet_v1_feature_extractor_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..56ad2745dae558acdb806c8f236d25754799cf49
--- /dev/null
+++ b/models/research/lstm_object_detection/models/lstm_ssd_mobilenet_v1_feature_extractor_test.py
@@ -0,0 +1,179 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for models.lstm_ssd_mobilenet_v1_feature_extractor."""
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+from tensorflow.contrib import training as contrib_training
+
+from lstm_object_detection.models import lstm_ssd_mobilenet_v1_feature_extractor as feature_extractor
+from object_detection.models import ssd_feature_extractor_test
+
+
+class LstmSsdMobilenetV1FeatureExtractorTest(
+ ssd_feature_extractor_test.SsdFeatureExtractorTestBase):
+
+ def _create_feature_extractor(self,
+ depth_multiplier=1.0,
+ pad_to_multiple=1,
+ is_training=True,
+ use_explicit_padding=False):
+ """Constructs a new feature extractor.
+
+ Args:
+ depth_multiplier: A float depth multiplier for feature extractor.
+ pad_to_multiple: The nearest multiple to zero pad the input height and
+ width dimensions to.
+ is_training: A boolean whether the network is in training mode.
+ use_explicit_padding: A boolean whether to use explicit padding.
+
+ Returns:
+ An lstm_ssd_meta_arch.LSTMSSDMobileNetV1FeatureExtractor object.
+ """
+ min_depth = 32
+ extractor = (
+ feature_extractor.LSTMSSDMobileNetV1FeatureExtractor(
+ is_training,
+ depth_multiplier,
+ min_depth,
+ pad_to_multiple,
+ self.conv_hyperparams_fn,
+ use_explicit_padding=use_explicit_padding))
+ extractor.lstm_state_depth = int(256 * depth_multiplier)
+ return extractor
+
+ def test_feature_extractor_construct_with_expected_params(self):
+ def conv_hyperparams_fn():
+ with (slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm) and
+ slim.arg_scope([slim.batch_norm], decay=0.97, epsilon=1e-3)) as sc:
+ return sc
+
+ params = {
+ 'is_training': True,
+ 'depth_multiplier': .55,
+ 'min_depth': 9,
+ 'pad_to_multiple': 3,
+ 'conv_hyperparams_fn': conv_hyperparams_fn,
+ 'reuse_weights': False,
+ 'use_explicit_padding': True,
+ 'use_depthwise': False,
+ 'override_base_feature_extractor_hyperparams': True}
+
+ extractor = (
+ feature_extractor.LSTMSSDMobileNetV1FeatureExtractor(**params))
+
+ self.assertEqual(params['is_training'],
+ extractor._is_training)
+ self.assertEqual(params['depth_multiplier'],
+ extractor._depth_multiplier)
+ self.assertEqual(params['min_depth'],
+ extractor._min_depth)
+ self.assertEqual(params['pad_to_multiple'],
+ extractor._pad_to_multiple)
+ self.assertEqual(params['conv_hyperparams_fn'],
+ extractor._conv_hyperparams_fn)
+ self.assertEqual(params['reuse_weights'],
+ extractor._reuse_weights)
+ self.assertEqual(params['use_explicit_padding'],
+ extractor._use_explicit_padding)
+ self.assertEqual(params['use_depthwise'],
+ extractor._use_depthwise)
+ self.assertEqual(params['override_base_feature_extractor_hyperparams'],
+ (extractor.
+ _override_base_feature_extractor_hyperparams))
+
+ def test_extract_features_returns_correct_shapes_256(self):
+ image_height = 256
+ image_width = 256
+ depth_multiplier = 1.0
+ pad_to_multiple = 1
+ batch_size = 5
+ expected_feature_map_shape = [(batch_size, 8, 8, 256), (batch_size, 4, 4,
+ 512),
+ (batch_size, 2, 2, 256), (batch_size, 1, 1,
+ 256)]
+ self.check_extract_features_returns_correct_shape(
+ batch_size,
+ image_height,
+ image_width,
+ depth_multiplier,
+ pad_to_multiple,
+ expected_feature_map_shape,
+ use_explicit_padding=False)
+ self.check_extract_features_returns_correct_shape(
+ batch_size,
+ image_height,
+ image_width,
+ depth_multiplier,
+ pad_to_multiple,
+ expected_feature_map_shape,
+ use_explicit_padding=True)
+
+ def test_preprocess_returns_correct_value_range(self):
+ test_image = np.random.rand(5, 128, 128, 3)
+ extractor = self._create_feature_extractor()
+ preprocessed_image = extractor.preprocess(test_image)
+ self.assertTrue(np.all(np.less_equal(np.abs(preprocessed_image), 1.0)))
+
+ def test_variables_only_created_in_scope(self):
+ scope_name = 'MobilenetV1'
+ g = tf.Graph()
+ with g.as_default():
+ preprocessed_inputs = tf.placeholder(tf.float32, (5, 256, 256, 3))
+ extractor = self._create_feature_extractor()
+ extractor.extract_features(preprocessed_inputs)
+ variables = g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
+ find_scope = False
+ for variable in variables:
+ if scope_name in variable.name:
+ find_scope = True
+ break
+ self.assertTrue(find_scope)
+
+ def test_lstm_non_zero_state(self):
+ init_state = {
+ 'lstm_state_c': tf.zeros([8, 8, 256]),
+ 'lstm_state_h': tf.zeros([8, 8, 256]),
+ 'lstm_state_step': tf.zeros([1])
+ }
+ seq = {'test': tf.random_uniform([3, 1, 1, 1])}
+ stateful_reader = contrib_training.SequenceQueueingStateSaver(
+ batch_size=1,
+ num_unroll=1,
+ input_length=2,
+ input_key='',
+ input_sequences=seq,
+ input_context={},
+ initial_states=init_state,
+ capacity=1)
+ extractor = self._create_feature_extractor()
+ image = tf.random_uniform([5, 256, 256, 3])
+ with tf.variable_scope('zero_state'):
+ feature_map = extractor.extract_features(
+ image, stateful_reader.next_batch)
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ sess.run([stateful_reader.prefetch_op])
+ _ = sess.run([feature_map])
+ # Update states with the next batch.
+ state = sess.run(stateful_reader.next_batch.state('lstm_state_c'))
+ # State should no longer be zero after update.
+ self.assertTrue(state.any())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/models/mobilenet_defs.py b/models/research/lstm_object_detection/models/mobilenet_defs.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f984240215b818c3e8c9b5481db3319b54ef8fd
--- /dev/null
+++ b/models/research/lstm_object_detection/models/mobilenet_defs.py
@@ -0,0 +1,142 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Definitions for modified MobileNet models used in LSTD."""
+
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
+from nets import mobilenet_v1
+from nets.mobilenet import conv_blocks as mobilenet_convs
+from nets.mobilenet import mobilenet
+
+
+def mobilenet_v1_lite_def(depth_multiplier, low_res=False):
+ """Conv definitions for a lite MobileNet v1 model.
+
+ Args:
+ depth_multiplier: float depth multiplier for MobileNet.
+ low_res: An option of low-res conv input for interleave model.
+
+ Returns:
+ Array of convolutions.
+
+ Raises:
+ ValueError: On invalid channels with provided depth multiplier.
+ """
+ conv = mobilenet_v1.Conv
+ sep_conv = mobilenet_v1.DepthSepConv
+
+ def _find_target_depth(original, depth_multiplier):
+ # Find the target depth such that:
+ # int(target * depth_multiplier) == original
+ pseudo_target = int(original / depth_multiplier)
+ for target in range(pseudo_target - 1, pseudo_target + 2):
+ if int(target * depth_multiplier) == original:
+ return target
+ raise ValueError('Cannot have %d channels with depth multiplier %0.2f' %
+ (original, depth_multiplier))
+
+ return [
+ conv(kernel=[3, 3], stride=2, depth=32),
+ sep_conv(kernel=[3, 3], stride=1, depth=64),
+ sep_conv(kernel=[3, 3], stride=2, depth=128),
+ sep_conv(kernel=[3, 3], stride=1, depth=128),
+ sep_conv(kernel=[3, 3], stride=2, depth=256),
+ sep_conv(kernel=[3, 3], stride=1, depth=256),
+ sep_conv(kernel=[3, 3], stride=2, depth=512),
+ sep_conv(kernel=[3, 3], stride=1, depth=512),
+ sep_conv(kernel=[3, 3], stride=1, depth=512),
+ sep_conv(kernel=[3, 3], stride=1, depth=512),
+ sep_conv(kernel=[3, 3], stride=1, depth=512),
+ sep_conv(kernel=[3, 3], stride=1, depth=512),
+ sep_conv(kernel=[3, 3], stride=1 if low_res else 2, depth=1024),
+ sep_conv(
+ kernel=[3, 3],
+ stride=1,
+ depth=int(_find_target_depth(1024, depth_multiplier)))
+ ]
+
+
+def mobilenet_v2_lite_def(reduced=False, is_quantized=False, low_res=False):
+ """Conv definitions for a lite MobileNet v2 model.
+
+ Args:
+ reduced: Determines the scaling factor for expanded conv. If True, a factor
+ of 6 is used. If False, a factor of 3 is used.
+ is_quantized: Whether the model is trained in quantized mode.
+ low_res: Whether the input to the model is of half resolution.
+
+ Returns:
+ Array of convolutions.
+ """
+ expanded_conv = mobilenet_convs.expanded_conv
+ expand_input = mobilenet_convs.expand_input_by_factor
+ op = mobilenet.op
+ return dict(
+ defaults={
+ # Note: these parameters of batch norm affect the architecture
+ # that's why they are here and not in training_scope.
+ (slim.batch_norm,): {
+ 'center': True,
+ 'scale': True
+ },
+ (slim.conv2d, slim.fully_connected, slim.separable_conv2d): {
+ 'normalizer_fn': slim.batch_norm,
+ 'activation_fn': tf.nn.relu6
+ },
+ (expanded_conv,): {
+ 'expansion_size': expand_input(6),
+ 'split_expansion': 1,
+ 'normalizer_fn': slim.batch_norm,
+ 'residual': True
+ },
+ (slim.conv2d, slim.separable_conv2d): {
+ 'padding': 'SAME'
+ }
+ },
+ spec=[
+ op(slim.conv2d, stride=2, num_outputs=32, kernel_size=[3, 3]),
+ op(expanded_conv,
+ expansion_size=expand_input(1, divisible_by=1),
+ num_outputs=16),
+ op(expanded_conv,
+ expansion_size=(expand_input(3, divisible_by=1)
+ if reduced else expand_input(6)),
+ stride=2,
+ num_outputs=24),
+ op(expanded_conv,
+ expansion_size=(expand_input(3, divisible_by=1)
+ if reduced else expand_input(6)),
+ stride=1,
+ num_outputs=24),
+ op(expanded_conv, stride=2, num_outputs=32),
+ op(expanded_conv, stride=1, num_outputs=32),
+ op(expanded_conv, stride=1, num_outputs=32),
+ op(expanded_conv, stride=2, num_outputs=64),
+ op(expanded_conv, stride=1, num_outputs=64),
+ op(expanded_conv, stride=1, num_outputs=64),
+ op(expanded_conv, stride=1, num_outputs=64),
+ op(expanded_conv, stride=1, num_outputs=96),
+ op(expanded_conv, stride=1, num_outputs=96),
+ op(expanded_conv, stride=1, num_outputs=96),
+ op(expanded_conv, stride=1 if low_res else 2, num_outputs=160),
+ op(expanded_conv, stride=1, num_outputs=160),
+ op(expanded_conv, stride=1, num_outputs=160),
+ op(expanded_conv,
+ stride=1,
+ num_outputs=320,
+ project_activation_fn=(tf.nn.relu6
+ if is_quantized else tf.identity))
+ ],
+ )
diff --git a/models/research/lstm_object_detection/models/mobilenet_defs_test.py b/models/research/lstm_object_detection/models/mobilenet_defs_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b5bda504bb02ac89f55e3acd370862f513a3a3
--- /dev/null
+++ b/models/research/lstm_object_detection/models/mobilenet_defs_test.py
@@ -0,0 +1,136 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lstm_object_detection.models.mobilenet_defs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.compat.v1 as tf
+from lstm_object_detection.models import mobilenet_defs
+from nets import mobilenet_v1
+from nets.mobilenet import mobilenet_v2
+
+
+class MobilenetV1DefsTest(tf.test.TestCase):
+
+ def test_mobilenet_v1_lite_def(self):
+ net, _ = mobilenet_v1.mobilenet_v1_base(
+ tf.placeholder(tf.float32, (10, 320, 320, 3)),
+ final_endpoint='Conv2d_13_pointwise',
+ min_depth=8,
+ depth_multiplier=1.0,
+ conv_defs=mobilenet_defs.mobilenet_v1_lite_def(1.0),
+ use_explicit_padding=True,
+ scope='MobilenetV1')
+ self.assertEqual(net.get_shape().as_list(), [10, 10, 10, 1024])
+
+ def test_mobilenet_v1_lite_def_depthmultiplier_half(self):
+ net, _ = mobilenet_v1.mobilenet_v1_base(
+ tf.placeholder(tf.float32, (10, 320, 320, 3)),
+ final_endpoint='Conv2d_13_pointwise',
+ min_depth=8,
+ depth_multiplier=0.5,
+ conv_defs=mobilenet_defs.mobilenet_v1_lite_def(0.5),
+ use_explicit_padding=True,
+ scope='MobilenetV1')
+ self.assertEqual(net.get_shape().as_list(), [10, 10, 10, 1024])
+
+ def test_mobilenet_v1_lite_def_depthmultiplier_2x(self):
+ net, _ = mobilenet_v1.mobilenet_v1_base(
+ tf.placeholder(tf.float32, (10, 320, 320, 3)),
+ final_endpoint='Conv2d_13_pointwise',
+ min_depth=8,
+ depth_multiplier=2.0,
+ conv_defs=mobilenet_defs.mobilenet_v1_lite_def(2.0),
+ use_explicit_padding=True,
+ scope='MobilenetV1')
+ self.assertEqual(net.get_shape().as_list(), [10, 10, 10, 1024])
+
+ def test_mobilenet_v1_lite_def_low_res(self):
+ net, _ = mobilenet_v1.mobilenet_v1_base(
+ tf.placeholder(tf.float32, (10, 320, 320, 3)),
+ final_endpoint='Conv2d_13_pointwise',
+ min_depth=8,
+ depth_multiplier=1.0,
+ conv_defs=mobilenet_defs.mobilenet_v1_lite_def(1.0, low_res=True),
+ use_explicit_padding=True,
+ scope='MobilenetV1')
+ self.assertEqual(net.get_shape().as_list(), [10, 20, 20, 1024])
+
+
+class MobilenetV2DefsTest(tf.test.TestCase):
+
+ def test_mobilenet_v2_lite_def(self):
+ net, features = mobilenet_v2.mobilenet_base(
+ tf.placeholder(tf.float32, (10, 320, 320, 3)),
+ min_depth=8,
+ depth_multiplier=1.0,
+ conv_defs=mobilenet_defs.mobilenet_v2_lite_def(),
+ use_explicit_padding=True,
+ scope='MobilenetV2')
+ self.assertEqual(net.get_shape().as_list(), [10, 10, 10, 320])
+ self._assert_contains_op('MobilenetV2/expanded_conv_16/project/Identity')
+ self.assertEqual(
+ features['layer_3/expansion_output'].get_shape().as_list(),
+ [10, 160, 160, 96])
+ self.assertEqual(
+ features['layer_4/expansion_output'].get_shape().as_list(),
+ [10, 80, 80, 144])
+
+ def test_mobilenet_v2_lite_def_is_quantized(self):
+ net, _ = mobilenet_v2.mobilenet_base(
+ tf.placeholder(tf.float32, (10, 320, 320, 3)),
+ min_depth=8,
+ depth_multiplier=1.0,
+ conv_defs=mobilenet_defs.mobilenet_v2_lite_def(is_quantized=True),
+ use_explicit_padding=True,
+ scope='MobilenetV2')
+ self.assertEqual(net.get_shape().as_list(), [10, 10, 10, 320])
+ self._assert_contains_op('MobilenetV2/expanded_conv_16/project/Relu6')
+
+ def test_mobilenet_v2_lite_def_low_res(self):
+ net, _ = mobilenet_v2.mobilenet_base(
+ tf.placeholder(tf.float32, (10, 320, 320, 3)),
+ min_depth=8,
+ depth_multiplier=1.0,
+ conv_defs=mobilenet_defs.mobilenet_v2_lite_def(low_res=True),
+ use_explicit_padding=True,
+ scope='MobilenetV2')
+ self.assertEqual(net.get_shape().as_list(), [10, 20, 20, 320])
+
+ def test_mobilenet_v2_lite_def_reduced(self):
+ net, features = mobilenet_v2.mobilenet_base(
+ tf.placeholder(tf.float32, (10, 320, 320, 3)),
+ min_depth=8,
+ depth_multiplier=1.0,
+ conv_defs=mobilenet_defs.mobilenet_v2_lite_def(reduced=True),
+ use_explicit_padding=True,
+ scope='MobilenetV2')
+ self.assertEqual(net.get_shape().as_list(), [10, 10, 10, 320])
+ self.assertEqual(
+ features['layer_3/expansion_output'].get_shape().as_list(),
+ [10, 160, 160, 48])
+ self.assertEqual(
+ features['layer_4/expansion_output'].get_shape().as_list(),
+ [10, 80, 80, 72])
+
+ def _assert_contains_op(self, op_name):
+ op_names = [op.name for op in tf.get_default_graph().get_operations()]
+ self.assertIn(op_name, op_names)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/models/research/lstm_object_detection/protos/__init__.py b/models/research/lstm_object_detection/protos/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/research/lstm_object_detection/protos/input_reader_google.proto b/models/research/lstm_object_detection/protos/input_reader_google.proto
new file mode 100644
index 0000000000000000000000000000000000000000..2c494a62e97321ee9206cebe28cd6601049f3293
--- /dev/null
+++ b/models/research/lstm_object_detection/protos/input_reader_google.proto
@@ -0,0 +1,32 @@
+syntax = "proto2";
+
+package lstm_object_detection.protos;
+
+import "object_detection/protos/input_reader.proto";
+
+message GoogleInputReader {
+ extend object_detection.protos.ExternalInputReader {
+ optional GoogleInputReader google_input_reader = 444;
+ }
+
+ oneof input_reader {
+ TFRecordVideoInputReader tf_record_video_input_reader = 1;
+ }
+}
+
+message TFRecordVideoInputReader {
+ // Path(s) to tfrecords of input data.
+ repeated string input_path = 1;
+
+ enum DataType {
+ UNSPECIFIED = 0;
+ TF_EXAMPLE = 1;
+ TF_SEQUENCE_EXAMPLE = 2;
+ }
+ optional DataType data_type = 2 [default=TF_SEQUENCE_EXAMPLE];
+
+ // Length of the video sequence. All the input video sequence should have the
+ // same length in frames, e.g. 5 frames.
+ optional int32 video_length = 3;
+}
+
diff --git a/models/research/lstm_object_detection/protos/pipeline.proto b/models/research/lstm_object_detection/protos/pipeline.proto
new file mode 100644
index 0000000000000000000000000000000000000000..10dd652554ad38e933acdedf8ce1479f15eed9d7
--- /dev/null
+++ b/models/research/lstm_object_detection/protos/pipeline.proto
@@ -0,0 +1,69 @@
+syntax = "proto2";
+
+package lstm_object_detection.protos;
+
+import "object_detection/protos/pipeline.proto";
+import "lstm_object_detection/protos/quant_overrides.proto";
+
+extend object_detection.protos.TrainEvalPipelineConfig {
+ optional LstmModel lstm_model = 205743444;
+ optional QuantOverrides quant_overrides = 246059837;
+}
+
+// Message for extra fields needed for configuring LSTM model.
+message LstmModel {
+ // Unroll length for training LSTMs.
+ optional int32 train_unroll_length = 1;
+
+ // Unroll length for evaluating LSTMs.
+ optional int32 eval_unroll_length = 2;
+
+ // Depth of the lstm feature map.
+ optional int32 lstm_state_depth = 3 [default = 256];
+
+ // Depth multipliers for multiple feature extractors. Used for interleaved
+ // or ensemble model.
+ repeated float depth_multipliers = 4;
+
+ // Specifies how models are interleaved when multiple feature extractors are
+ // used during training. Must be in ['RANDOM', 'RANDOM_SKIP_SMALL'].
+ optional string train_interleave_method = 5 [default = 'RANDOM'];
+
+ // Specifies how models are interleaved when multiple feature extractors are
+ // used during training. Must be in ['RANDOM', 'RANDOM_SKIP', 'SKIPK'].
+ optional string eval_interleave_method = 6 [default = 'SKIP9'];
+
+ // The stride of the lstm state.
+ optional int32 lstm_state_stride = 7 [default = 32];
+
+ // Whether to flattern LSTM state and output. Note that this is typically
+ // intended only to be modified internally by export_tfmini_lstd_graph_lib
+ // to support flatten state for tfmini/tflite. Do not set this field in
+ // the pipeline config file unless necessary.
+ optional bool flatten_state = 8 [default = false];
+
+ // Whether to apply bottleneck layer before going into LSTM gates. This
+ // allows multiple feature extractors to use separate bottleneck layers
+ // instead of sharing the same one so that different base model output
+ // feature dimensions are not forced to be the same.
+ // For example:
+ // Model 1 outputs feature map f_1 of depth d_1.
+ // Model 2 outputs feature map f_2 of depth d_2.
+ // Pre-bottlenecking allows lstm input to be either:
+ // conv(concat([f_1, h])) or conv(concat([f_2, h])).
+ optional bool pre_bottleneck = 9 [default = false];
+
+ // Normalize LSTM state, default false.
+ optional bool scale_state = 10 [default = false];
+
+ // Clip LSTM state at [0, 6], default true.
+ optional bool clip_state = 11 [default = true];
+
+ // If the model is in quantized training. This field does NOT need to be set
+ // manually. Instead, it will be overridden by configs in graph_rewriter.
+ optional bool is_quantized = 12 [default = false];
+
+ // Downsample input image when using the smaller network in interleaved
+ // models, default false.
+ optional bool low_res = 13 [default = false];
+}
diff --git a/models/research/lstm_object_detection/protos/quant_overrides.proto b/models/research/lstm_object_detection/protos/quant_overrides.proto
new file mode 100644
index 0000000000000000000000000000000000000000..9dc0eaf86e5f507f87b87fe1571b4e3d82991df1
--- /dev/null
+++ b/models/research/lstm_object_detection/protos/quant_overrides.proto
@@ -0,0 +1,40 @@
+syntax = "proto2";
+
+package lstm_object_detection.protos;
+
+// Message to override default quantization behavior.
+message QuantOverrides {
+ repeated QuantConfig quant_configs = 1;
+}
+
+// Parameters to manually create fake quant ops outside of the generic
+// tensorflow/contrib/quantize/python/quantize.py script. This may be
+// used to override default behaviour or quantize ops not already supported.
+message QuantConfig {
+ // The name of the op to add a fake quant op to.
+ required string op_name = 1;
+
+ // The name of the fake quant op.
+ required string quant_op_name = 2;
+
+ // Whether the fake quant op uses fixed ranges. Otherwise, learned moving
+ // average ranges are used.
+ required bool fixed_range = 3 [default = false];
+
+ // The intitial minimum value of the range.
+ optional float min = 4 [default = -6];
+
+ // The initial maximum value of the range.
+ optional float max = 5 [default = 6];
+
+ // Number of steps to delay before quantization takes effect during training.
+ optional int32 delay = 6 [default = 500000];
+
+ // Number of bits to use for quantizing weights.
+ // Only 8 bit is supported for now.
+ optional int32 weight_bits = 7 [default = 8];
+
+ // Number of bits to use for quantizing activations.
+ // Only 8 bit is supported for now.
+ optional int32 activation_bits = 8 [default = 8];
+}
diff --git a/models/research/lstm_object_detection/test_tflite_model.py b/models/research/lstm_object_detection/test_tflite_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8b5e15e210ab6c191911d3c440cef33d936274c
--- /dev/null
+++ b/models/research/lstm_object_detection/test_tflite_model.py
@@ -0,0 +1,53 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test a tflite model using random input data."""
+
+from __future__ import print_function
+from absl import flags
+import numpy as np
+import tensorflow.compat.v1 as tf
+
+flags.DEFINE_string('model_path', None, 'Path to model.')
+FLAGS = flags.FLAGS
+
+
+def main(_):
+
+ flags.mark_flag_as_required('model_path')
+
+ # Load TFLite model and allocate tensors.
+ interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path)
+ interpreter.allocate_tensors()
+
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ print('input_details:', input_details)
+ output_details = interpreter.get_output_details()
+ print('output_details:', output_details)
+
+ # Test model on random input data.
+ input_shape = input_details[0]['shape']
+ # change the following line to feed into your own data.
+ input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
+ interpreter.set_tensor(input_details[0]['index'], input_data)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ print(output_data)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/models/research/lstm_object_detection/tflite/BUILD b/models/research/lstm_object_detection/tflite/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..66068925da4fde7eb99215d907d627e0ff1d3847
--- /dev/null
+++ b/models/research/lstm_object_detection/tflite/BUILD
@@ -0,0 +1,81 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"])
+
+cc_library(
+ name = "mobile_ssd_client",
+ srcs = ["mobile_ssd_client.cc"],
+ hdrs = ["mobile_ssd_client.h"],
+ deps = [
+ "//protos:box_encodings_cc_proto",
+ "//protos:detections_cc_proto",
+ "//protos:labelmap_cc_proto",
+ "//protos:mobile_ssd_client_options_cc_proto",
+ "//utils:conversion_utils",
+ "//utils:ssd_utils",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
+ "@com_google_glog//:glog",
+ "@gemmlowp",
+ ],
+)
+
+config_setting(
+ name = "enable_edgetpu",
+ define_values = {"enable_edgetpu": "true"},
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "mobile_ssd_tflite_client",
+ srcs = ["mobile_ssd_tflite_client.cc"],
+ hdrs = ["mobile_ssd_tflite_client.h"],
+ defines = select({
+ "//conditions:default": [],
+ "enable_edgetpu": ["ENABLE_EDGETPU"],
+ }),
+ deps = [
+ ":mobile_ssd_client",
+ "@com_google_glog//:glog",
+ "@com_google_absl//absl/memory",
+ "@org_tensorflow//tensorflow/lite:arena_planner",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ "//protos:anchor_generation_options_cc_proto",
+ "//utils:file_utils",
+ "//utils:ssd_utils",
+ ] + select({
+ "//conditions:default": [],
+ "enable_edgetpu": [
+ "@libedgetpu//libedgetpu:header",
+ ],
+ }),
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "mobile_lstd_tflite_client",
+ srcs = ["mobile_lstd_tflite_client.cc"],
+ hdrs = ["mobile_lstd_tflite_client.h"],
+ defines = select({
+ "//conditions:default": [],
+ "enable_edgetpu": ["ENABLE_EDGETPU"],
+ }),
+ deps = [
+ ":mobile_ssd_client",
+ ":mobile_ssd_tflite_client",
+ "@com_google_glog//:glog",
+ "@com_google_absl//absl/base:core_headers",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ ] + select({
+ "//conditions:default": [],
+ "enable_edgetpu": [
+ "@libedgetpu//libedgetpu:header",
+ ],
+ }),
+ alwayslink = 1,
+)
diff --git a/models/research/lstm_object_detection/tflite/WORKSPACE b/models/research/lstm_object_detection/tflite/WORKSPACE
new file mode 100644
index 0000000000000000000000000000000000000000..3bce3814f365ec2bcc1122d7dfc8a5ba5f7d3dcb
--- /dev/null
+++ b/models/research/lstm_object_detection/tflite/WORKSPACE
@@ -0,0 +1,133 @@
+workspace(name = "lstm_object_detection")
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
+
+http_archive(
+ name = "bazel_skylib",
+ sha256 = "bbccf674aa441c266df9894182d80de104cabd19be98be002f6d478aaa31574d",
+ strip_prefix = "bazel-skylib-2169ae1c374aab4a09aa90e65efe1a3aad4e279b",
+ urls = ["https://github.com/bazelbuild/bazel-skylib/archive/2169ae1c374aab4a09aa90e65efe1a3aad4e279b.tar.gz"],
+)
+load("@bazel_skylib//lib:versions.bzl", "versions")
+versions.check(minimum_bazel_version = "0.23.0")
+
+# ABSL cpp library.
+http_archive(
+ name = "com_google_absl",
+ urls = [
+ "https://github.com/abseil/abseil-cpp/archive/a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a.tar.gz",
+ ],
+ sha256 = "d437920d1434c766d22e85773b899c77c672b8b4865d5dc2cd61a29fdff3cf03",
+ strip_prefix = "abseil-cpp-a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a",
+)
+
+http_archive(
+ name = "rules_cc",
+ strip_prefix = "rules_cc-master",
+ urls = ["https://github.com/bazelbuild/rules_cc/archive/master.zip"],
+)
+
+# GoogleTest/GoogleMock framework. Used by most unit-tests.
+http_archive(
+ name = "com_google_googletest",
+ urls = ["https://github.com/google/googletest/archive/master.zip"],
+ strip_prefix = "googletest-master",
+)
+
+# gflags needed by glog
+http_archive(
+ name = "com_github_gflags_gflags",
+ sha256 = "6e16c8bc91b1310a44f3965e616383dbda48f83e8c1eaa2370a215057b00cabe",
+ strip_prefix = "gflags-77592648e3f3be87d6c7123eb81cbad75f9aef5a",
+ urls = [
+ "https://mirror.bazel.build/github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz",
+ "https://github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz",
+ ],
+)
+
+# glog
+http_archive(
+ name = "com_google_glog",
+ sha256 = "f28359aeba12f30d73d9e4711ef356dc842886968112162bc73002645139c39c",
+ strip_prefix = "glog-0.4.0",
+ urls = ["https://github.com/google/glog/archive/v0.4.0.tar.gz"],
+)
+
+http_archive(
+ name = "zlib",
+ build_file = "@com_google_protobuf//:third_party/zlib.BUILD",
+ sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
+ strip_prefix = "zlib-1.2.11",
+ urls = ["https://zlib.net/zlib-1.2.11.tar.gz"],
+)
+
+http_archive(
+ name = "gemmlowp",
+ sha256 = "6678b484d929f2d0d3229d8ac4e3b815a950c86bb9f17851471d143f6d4f7834",
+ strip_prefix = "gemmlowp-12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3",
+ urls = [
+ "http://mirror.tensorflow.org/github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
+ "https://github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
+ ],
+)
+
+#-----------------------------------------------------------------------------
+# proto
+#-----------------------------------------------------------------------------
+# proto_library, cc_proto_library and java_proto_library rules implicitly depend
+# on @com_google_protobuf//:proto, @com_google_protobuf//:cc_toolchain and
+# @com_google_protobuf//:java_toolchain, respectively.
+# This statement defines the @com_google_protobuf repo.
+http_archive(
+ name = "com_google_protobuf",
+ strip_prefix = "protobuf-3.8.0",
+ urls = ["https://github.com/google/protobuf/archive/v3.8.0.zip"],
+ sha256 = "1e622ce4b84b88b6d2cdf1db38d1a634fe2392d74f0b7b74ff98f3a51838ee53",
+)
+
+# java_lite_proto_library rules implicitly depend on
+# @com_google_protobuf_javalite//:javalite_toolchain, which is the JavaLite proto
+# runtime (base classes and common utilities).
+http_archive(
+ name = "com_google_protobuf_javalite",
+ strip_prefix = "protobuf-384989534b2246d413dbcd750744faab2607b516",
+ urls = ["https://github.com/google/protobuf/archive/384989534b2246d413dbcd750744faab2607b516.zip"],
+ sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc",
+)
+
+#
+# http_archive(
+# name = "com_google_protobuf",
+# strip_prefix = "protobuf-master",
+# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
+# )
+
+# Needed by TensorFlow
+http_archive(
+ name = "io_bazel_rules_closure",
+ sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
+ strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
+ urls = [
+ "http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
+ "https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
+ ],
+)
+
+
+# TensorFlow r1.14-rc0
+http_archive(
+ name = "org_tensorflow",
+ strip_prefix = "tensorflow-1.14.0-rc0",
+ sha256 = "76404a6157a45e8d7a07e4f5690275256260130145924c2a7c73f6eda2a3de10",
+ urls = ["https://github.com/tensorflow/tensorflow/archive/v1.14.0-rc0.zip"],
+)
+
+load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
+tf_workspace(tf_repo_name = "org_tensorflow")
+
+git_repository(
+ name = "libedgetpu",
+ remote = "sso://coral.googlesource.com/edgetpu-native",
+ commit = "83e47d1bcf22686fae5150ebb99281f6134ef062",
+)
diff --git a/models/research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc b/models/research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
new file mode 100644
index 0000000000000000000000000000000000000000..05a7bbac1b5c8a58c4f10476a2be4fb3a097a463
--- /dev/null
+++ b/models/research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
@@ -0,0 +1,261 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mobile_lstd_tflite_client.h"
+
+#include
+
+namespace lstm_object_detection {
+namespace tflite {
+
+std::unique_ptr MobileLSTDTfLiteClient::Create() {
+ auto client = absl::make_unique();
+ if (!client->InitializeClient(CreateDefaultOptions())) {
+ LOG(ERROR) << "Failed to initialize client";
+ return nullptr;
+ }
+ return client;
+}
+
+protos::ClientOptions MobileLSTDTfLiteClient::CreateDefaultOptions() {
+ const int kMaxDetections = 100;
+ const int kClassesPerDetection = 1;
+ const double kScoreThreshold = -2.0;
+ const double kIouThreshold = 0.5;
+
+ protos::ClientOptions options;
+ options.set_max_detections(kMaxDetections);
+ options.set_max_categories(kClassesPerDetection);
+ options.set_score_threshold(kScoreThreshold);
+ options.set_iou_threshold(kIouThreshold);
+ options.set_agnostic_mode(false);
+ options.set_quantize(false);
+ options.set_num_keypoints(0);
+
+ return options;
+}
+
+std::unique_ptr MobileLSTDTfLiteClient::Create(
+ const protos::ClientOptions& options) {
+ auto client = absl::make_unique();
+ if (!client->InitializeClient(options)) {
+ LOG(ERROR) << "Failed to initialize client";
+ return nullptr;
+ }
+ return client;
+}
+
+bool MobileLSTDTfLiteClient::InitializeInterpreter(
+ const protos::ClientOptions& options) {
+ if (options.prefer_nnapi_delegate()) {
+ LOG(ERROR) << "NNAPI not supported.";
+ return false;
+ } else {
+ interpreter_->UseNNAPI(false);
+ }
+
+#ifdef ENABLE_EDGETPU
+ interpreter_->SetExternalContext(kTfLiteEdgeTpuContext,
+ edge_tpu_context_.get());
+#endif
+
+ // Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
+ // raw_inputs/init_lstm_h
+ if (interpreter_->inputs().size() != 3) {
+ LOG(ERROR) << "Invalid number of interpreter inputs: " <<
+ interpreter_->inputs().size();
+ return false;
+ }
+
+ const std::vector input_tensor_indices = interpreter_->inputs();
+ const TfLiteTensor& input_lstm_c =
+ *interpreter_->tensor(input_tensor_indices[1]);
+ if (input_lstm_c.dims->size != 4) {
+ LOG(ERROR) << "Invalid input lstm_c dimensions: " <<
+ input_lstm_c.dims->size;
+ return false;
+ }
+ if (input_lstm_c.dims->data[0] != 1) {
+ LOG(ERROR) << "Invalid input lstm_c batch size: " <<
+ input_lstm_c.dims->data[0];
+ return false;
+ }
+ lstm_state_width_ = input_lstm_c.dims->data[1];
+ lstm_state_height_ = input_lstm_c.dims->data[2];
+ lstm_state_depth_ = input_lstm_c.dims->data[3];
+ lstm_state_size_ = lstm_state_width_ * lstm_state_height_ * lstm_state_depth_;
+
+ const TfLiteTensor& input_lstm_h =
+ *interpreter_->tensor(input_tensor_indices[2]);
+ if (!ValidateStateTensor(input_lstm_h, "input lstm_h")) {
+ return false;
+ }
+
+ // Outputs are:
+ // TFLite_Detection_PostProcess,
+ // TFLite_Detection_PostProcess:1,
+ // TFLite_Detection_PostProcess:2,
+ // TFLite_Detection_PostProcess:3,
+ // raw_outputs/lstm_c, raw_outputs/lstm_h
+ if (interpreter_->outputs().size() != 6) {
+ LOG(ERROR) << "Invalid number of interpreter outputs: " <<
+ interpreter_->outputs().size();
+ return false;
+ }
+
+ const std::vector output_tensor_indices = interpreter_->outputs();
+ const TfLiteTensor& output_lstm_c =
+ *interpreter_->tensor(output_tensor_indices[4]);
+ if (!ValidateStateTensor(output_lstm_c, "output lstm_c")) {
+ return false;
+ }
+ const TfLiteTensor& output_lstm_h =
+ *interpreter_->tensor(output_tensor_indices[5]);
+ if (!ValidateStateTensor(output_lstm_h, "output lstm_h")) {
+ return false;
+ }
+
+ // Initialize state with all zeroes.
+ lstm_c_data_.resize(lstm_state_size_);
+ lstm_h_data_.resize(lstm_state_size_);
+ lstm_c_data_uint8_.resize(lstm_state_size_);
+ lstm_h_data_uint8_.resize(lstm_state_size_);
+
+ if (interpreter_->AllocateTensors() != kTfLiteOk) {
+ LOG(ERROR) << "Failed to allocate tensors";
+ return false;
+ }
+
+ return true;
+}
+
+bool MobileLSTDTfLiteClient::ValidateStateTensor(const TfLiteTensor& tensor,
+ const std::string& name) {
+ if (tensor.dims->size != 4) {
+ LOG(ERROR) << "Invalid " << name << " dimensions: " << tensor.dims->size;
+ return false;
+ }
+ if (tensor.dims->data[0] != 1) {
+ LOG(ERROR) << "Invalid " << name << " batch size: " << tensor.dims->data[0];
+ return false;
+ }
+ if (tensor.dims->data[1] != lstm_state_width_ ||
+ tensor.dims->data[2] != lstm_state_height_ ||
+ tensor.dims->data[3] != lstm_state_depth_) {
+ LOG(ERROR) << "Invalid " << name << " dimensions: [" <<
+ tensor.dims->data[0] << ", " << tensor.dims->data[1] << ", " <<
+ tensor.dims->data[2] << ", " << tensor.dims->data[3] << "]";
+ return false;
+ }
+ return true;
+}
+
+bool MobileLSTDTfLiteClient::ComputeOutputLayerCount() {
+ // Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions,
+ // raw_outputs/lstm_c, raw_outputs/lstm_h
+ CHECK_EQ(interpreter_->outputs().size(), 4);
+ num_output_layers_ = 1;
+ return true;
+}
+
+bool MobileLSTDTfLiteClient::FloatInference(const uint8_t* input_data) {
+ // Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
+ // raw_inputs/init_lstm_h
+ CHECK(input_data) << "Input data cannot be null.";
+ float* input = interpreter_->typed_input_tensor(0);
+ CHECK(input) << "Input tensor cannot be null.";
+ // Normalize the uint8 input image with mean_value_, std_value_.
+ NormalizeInputImage(input_data, input);
+
+ // Copy input LSTM state into TFLite's input tensors.
+ float* lstm_c_input = interpreter_->typed_input_tensor(1);
+ CHECK(lstm_c_input) << "Input lstm_c tensor cannot be null.";
+ std::copy(lstm_c_data_.begin(), lstm_c_data_.end(), lstm_c_input);
+
+ float* lstm_h_input = interpreter_->typed_input_tensor(2);
+ CHECK(lstm_h_input) << "Input lstm_h tensor cannot be null.";
+ std::copy(lstm_h_data_.begin(), lstm_h_data_.end(), lstm_h_input);
+
+ // Run inference on inputs.
+ CHECK_EQ(interpreter_->Invoke(), kTfLiteOk) << "Invoking interpreter failed.";
+
+ // Copy LSTM state out of TFLite's output tensors.
+ // Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions,
+ // raw_outputs/lstm_c, raw_outputs/lstm_h
+ float* lstm_c_output = interpreter_->typed_output_tensor(2);
+ CHECK(lstm_c_output) << "Output lstm_c tensor cannot be null.";
+ std::copy(lstm_c_output, lstm_c_output + lstm_state_size_,
+ lstm_c_data_.begin());
+
+ float* lstm_h_output = interpreter_->typed_output_tensor(3);
+ CHECK(lstm_h_output) << "Output lstm_h tensor cannot be null.";
+ std::copy(lstm_h_output, lstm_h_output + lstm_state_size_,
+ lstm_h_data_.begin());
+ return true;
+}
+
+bool MobileLSTDTfLiteClient::QuantizedInference(const uint8_t* input_data) {
+ // Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
+ // raw_inputs/init_lstm_h
+ CHECK(input_data) << "Input data cannot be null.";
+ uint8_t* input = interpreter_->typed_input_tensor(0);
+ CHECK(input) << "Input tensor cannot be null.";
+ memcpy(input, input_data, input_size_);
+
+ // Copy input LSTM state into TFLite's input tensors.
+ uint8_t* lstm_c_input = interpreter_->typed_input_tensor(1);
+ CHECK(lstm_c_input) << "Input lstm_c tensor cannot be null.";
+ std::copy(lstm_c_data_uint8_.begin(), lstm_c_data_uint8_.end(), lstm_c_input);
+
+ uint8_t* lstm_h_input = interpreter_->typed_input_tensor(2);
+ CHECK(lstm_h_input) << "Input lstm_h tensor cannot be null.";
+ std::copy(lstm_h_data_uint8_.begin(), lstm_h_data_uint8_.end(), lstm_h_input);
+
+ // Run inference on inputs.
+ CHECK_EQ(interpreter_->Invoke(), kTfLiteOk) << "Invoking interpreter failed.";
+
+ // Copy LSTM state out of TFLite's output tensors.
+ // Outputs are:
+ // TFLite_Detection_PostProcess,
+ // TFLite_Detection_PostProcess:1,
+ // TFLite_Detection_PostProcess:2,
+ // TFLite_Detection_PostProcess:3,
+ // raw_outputs/lstm_c, raw_outputs/lstm_h
+ uint8_t* lstm_c_output = interpreter_->typed_output_tensor(4);
+ CHECK(lstm_c_output) << "Output lstm_c tensor cannot be null.";
+ std::copy(lstm_c_output, lstm_c_output + lstm_state_size_,
+ lstm_c_data_uint8_.begin());
+
+ uint8_t* lstm_h_output = interpreter_->typed_output_tensor(5);
+ CHECK(lstm_h_output) << "Output lstm_h tensor cannot be null.";
+ std::copy(lstm_h_output, lstm_h_output + lstm_state_size_,
+ lstm_h_data_uint8_.begin());
+ return true;
+}
+
+bool MobileLSTDTfLiteClient::Inference(const uint8_t* input_data) {
+ if (input_data == nullptr) {
+ LOG(ERROR) << "input_data cannot be null for inference.";
+ return false;
+ }
+ if (IsQuantizedModel())
+ return QuantizedInference(input_data);
+ else
+ return FloatInference(input_data);
+ return true;
+}
+
+} // namespace tflite
+} // namespace lstm_object_detection
diff --git a/models/research/lstm_object_detection/tflite/mobile_lstd_tflite_client.h b/models/research/lstm_object_detection/tflite/mobile_lstd_tflite_client.h
new file mode 100644
index 0000000000000000000000000000000000000000..e4f16bc945a6725025e285885967637629d0a5fc
--- /dev/null
+++ b/models/research/lstm_object_detection/tflite/mobile_lstd_tflite_client.h
@@ -0,0 +1,74 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_LSTD_TFLITE_CLIENT_H_
+#define TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_LSTD_TFLITE_CLIENT_H_
+
+#include
+#include
+
+#include
+#include "mobile_ssd_client.h"
+#include "mobile_ssd_tflite_client.h"
+
+namespace lstm_object_detection {
+namespace tflite {
+
+// Client for LSTD MobileNet TfLite model.
+class MobileLSTDTfLiteClient : public MobileSSDTfLiteClient {
+ public:
+ MobileLSTDTfLiteClient() = default;
+ // Create with default options.
+ static std::unique_ptr Create();
+ static std::unique_ptr Create(
+ const protos::ClientOptions& options);
+ ~MobileLSTDTfLiteClient() override = default;
+ static protos::ClientOptions CreateDefaultOptions();
+
+ protected:
+ bool InitializeInterpreter(const protos::ClientOptions& options) override;
+ bool ComputeOutputLayerCount() override;
+ bool Inference(const uint8_t* input_data) override;
+
+ private:
+ // MobileLSTDTfLiteClient is neither copyable nor movable.
+ MobileLSTDTfLiteClient(const MobileLSTDTfLiteClient&) = delete;
+ MobileLSTDTfLiteClient& operator=(const MobileLSTDTfLiteClient&) = delete;
+
+ bool ValidateStateTensor(const TfLiteTensor& tensor, const std::string& name);
+
+ // Helper functions used by Inference functions.
+ bool FloatInference(const uint8_t* input_data);
+ bool QuantizedInference(const uint8_t* input_data);
+
+ // LSTM model parameters.
+ int lstm_state_width_ = 0;
+ int lstm_state_height_ = 0;
+ int lstm_state_depth_ = 0;
+ int lstm_state_size_ = 0;
+
+ // LSTM state stored between float inference runs.
+ std::vector lstm_c_data_;
+ std::vector lstm_h_data_;
+
+ // LSTM state stored between uint8 inference runs.
+ std::vector lstm_c_data_uint8_;
+ std::vector lstm_h_data_uint8_;
+};
+
+} // namespace tflite
+} // namespace lstm_object_detection
+
+#endif // TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_LSTD_TFLITE_CLIENT_H_
diff --git a/models/research/lstm_object_detection/tflite/mobile_ssd_client.cc b/models/research/lstm_object_detection/tflite/mobile_ssd_client.cc
new file mode 100644
index 0000000000000000000000000000000000000000..27bf70109e46d2b9612480bb192f01aa3c9bfde1
--- /dev/null
+++ b/models/research/lstm_object_detection/tflite/mobile_ssd_client.cc
@@ -0,0 +1,209 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mobile_ssd_client.h"
+
+#include
+
+#include