How did I finetune t5 using wikisql dataset?
#2
by
dsivakumar
- opened
Thanks to Shivanandroy, this https://github.com/Shivanandroy/T5-Finetuning-PyTorch/blob/main/notebook/T5_Fine_tuning_with_PyTorch.ipynb, helped to understand and adapt it to text 2 SQL
Only change or major changes is Dataloader, and the wiki dataset I converted it into, just two columns, inputs as 'qurey' and targets as 'sql'
Sample Data
+-------------------------------------------------------------------------------------------+
| source_text | target_text |
|---------------------------------------------+---------------------------------------------|
| What is the season year where the rank is | SELECT tv season WHERE rank EQL 39 |
| 39? | |
|What is the number of season premieres were | SELECT count(season premiere) WHERE viewers |
| 10.17 people watched? | (millions) EQL 10.17 |
+-------------------------------------------------------------------------------------------+
Dataset class
class CSQLSetClass(Dataset):
"""
Using wikiSQL dataset for reading the dataset and
loading it into the dataloader to pass it to the neural network for finetuning the model
"""
def __init__(self, dataframe, tokenizer, source_len, target_len, source_text, target_text):
self.tokenizer = tokenizer
self.data = dataframe
self.source_len = source_len
self.summ_len = target_len
self.target_text = self.data[target_text]
self.source_text = self.data[source_text]
self.data["query"] = "English to SQL: "+self.data["query"]
self.data["sql"] = "<pad>" + self.data["sql"] + "</s>"
def __len__(self):
return len(self.target_text)
def __getitem__(self, index):
source_text = str(self.source_text[index])
target_text = str(self.target_text[index])
#cleaning data so as to ensure data is in string type
source_text = ' '.join(source_text.split())
target_text = ' '.join(target_text.split())
source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
source_ids = source['input_ids'].squeeze()
source_mask = source['attention_mask'].squeeze()
target_ids = target['input_ids'].squeeze()
target_mask = target['attention_mask'].squeeze()
return {
'source_ids': source_ids.to(dtype=torch.long),
'source_mask': source_mask.to(dtype=torch.long),
'target_ids': target_ids.to(dtype=torch.long),
'target_ids_y': target_ids.to(dtype=torch.long)
}
Prediction function takes a plain English question
#Predict function
def get_sql(query,tokenizer,model):
source_text= "English to SQL: "+query
source_text = ' '.join(source_text.split())
source = tokenizer.batch_encode_plus([source_text],max_length= 128, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
source_ids = source['input_ids'] #.squeeze()
source_mask = source['attention_mask']#.squeeze()
generated_ids = model.generate(
input_ids = source_ids.to(dtype=torch.long),
attention_mask = source_mask.to(dtype=torch.long),
max_length=150,
num_beams=2,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True
)
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
return preds