from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig from dataclasses import dataclass from typing import List, Optional from utils import ( get_preprocess_function, get_utterance_processing_functions, byt5_decode_batch, consistent, ) from utils import ( PROGRAM_SPECIAL_TOKEN, UTTERANCES_SPECIAL_TOKEN, GT_PROGRAM_SPECIAL_TOKEN, ) from greenery import parse from greenery.parse import NoMatch import numpy as np import torch class Agent: def __init__( self, model_path: str, gen_config: dict, inference_batch_size: int = 1, device=None, ): self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.gen_config = GenerationConfig(**gen_config) self.inference_batch_size = inference_batch_size @dataclass class ListenerOutput: programs: List[List[str]] idx: Optional[List[List[int]]] = None decoded: Optional[List[List[str]]] = None decoded_scores: Optional[List[List[float]]] = None pruned: Optional[List[List[str]]] = None class Listener(Agent): def __init__( self, model_path, gen_config, inference_batch_size=4, label_pos="suffix", idx: bool = True, program_special_token=PROGRAM_SPECIAL_TOKEN, utterances_special_token=UTTERANCES_SPECIAL_TOKEN, device=None, ): super().__init__(model_path, gen_config, inference_batch_size, device) self.label_pos = label_pos self.idx = idx self.program_special_token = program_special_token self.utterances_special_token = utterances_special_token self.utterances_to_string, self.string_to_utterances = ( get_utterance_processing_functions( label_pos, idx, separator=utterances_special_token ) ) self.device = self.model.device def synthesize(self, context, return_scores=False, enforce_consistency=True): # If context is a list of utterances, convert to string if isinstance(context[0], list): context_str = list(map(self.utterances_to_string, context)) else: context_str = context context_tokens = self.tokenizer( [ ( f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c ) for c in context_str ], return_tensors="pt", padding=True, ).to(self.device) decoder_inputs = self.tokenizer( [self.program_special_token for _ in context], return_tensors="pt", add_special_tokens=False, ).to(self.device) outputs = self.model.generate( **context_tokens, decoder_input_ids=decoder_inputs.input_ids, generation_config=self.gen_config, return_dict_in_generate=True, output_scores=True, ) decoded_batch = byt5_decode_batch( outputs.sequences.reshape( (len(context), -1, outputs.sequences.shape[-1]) ).tolist(), skip_position_token=True, skip_special_tokens=True, ) consistent_programs = [] idxs = [] for decoded, ctx in zip(decoded_batch, context): cp = [] idx = [] for i, p in enumerate(decoded): if enforce_consistency: if consistent(p, ctx): cp.append(p) idx.append(i) else: cp.append(p) idx.append(i) consistent_programs.append(cp) idxs.append(idx) logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1) gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze( -1 ) gen_probs.masked_fill_(gen_probs.isinf(), 0) scores = gen_probs.sum(-1) n_decoded = scores.shape[0] n_seq = n_decoded // len(context) scores = scores.reshape((len(context), n_seq)) scores_list = scores.tolist() if return_scores: return ListenerOutput(consistent_programs, idxs, decoded_batch, scores_list) else: return ListenerOutput(consistent_programs)