Spaces:
Runtime error
Runtime error
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
from dataclasses import dataclass | |
from typing import List, Optional | |
from utils import ( | |
get_preprocess_function, | |
get_utterance_processing_functions, | |
byt5_decode_batch, | |
consistent, | |
) | |
from utils import ( | |
PROGRAM_SPECIAL_TOKEN, | |
UTTERANCES_SPECIAL_TOKEN, | |
GT_PROGRAM_SPECIAL_TOKEN, | |
) | |
from greenery import parse | |
from greenery.parse import NoMatch | |
import numpy as np | |
import torch | |
class Agent: | |
def __init__( | |
self, | |
model_path: str, | |
gen_config: dict, | |
inference_batch_size: int = 1, | |
device=None, | |
): | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.gen_config = GenerationConfig(**gen_config) | |
self.inference_batch_size = inference_batch_size | |
class ListenerOutput: | |
programs: List[List[str]] | |
idx: Optional[List[List[int]]] = None | |
decoded: Optional[List[List[str]]] = None | |
decoded_scores: Optional[List[List[float]]] = None | |
pruned: Optional[List[List[str]]] = None | |
class Listener(Agent): | |
def __init__( | |
self, | |
model_path, | |
gen_config, | |
inference_batch_size=4, | |
label_pos="suffix", | |
idx: bool = True, | |
program_special_token=PROGRAM_SPECIAL_TOKEN, | |
utterances_special_token=UTTERANCES_SPECIAL_TOKEN, | |
device=None, | |
): | |
super().__init__(model_path, gen_config, inference_batch_size, device) | |
self.label_pos = label_pos | |
self.idx = idx | |
self.program_special_token = program_special_token | |
self.utterances_special_token = utterances_special_token | |
self.utterances_to_string, self.string_to_utterances = ( | |
get_utterance_processing_functions( | |
label_pos, idx, separator=utterances_special_token | |
) | |
) | |
self.device = self.model.device | |
def synthesize(self, context, return_scores=False, enforce_consistency=True): | |
# If context is a list of utterances, convert to string | |
if isinstance(context[0], list): | |
context_str = list(map(self.utterances_to_string, context)) | |
else: | |
context_str = context | |
context_tokens = self.tokenizer( | |
[ | |
( | |
f"{self.utterances_special_token}{c}" | |
if not c.startswith(self.utterances_special_token) | |
else c | |
) | |
for c in context_str | |
], | |
return_tensors="pt", | |
padding=True, | |
).to(self.device) | |
decoder_inputs = self.tokenizer( | |
[self.program_special_token for _ in context], | |
return_tensors="pt", | |
add_special_tokens=False, | |
).to(self.device) | |
outputs = self.model.generate( | |
**context_tokens, | |
decoder_input_ids=decoder_inputs.input_ids, | |
generation_config=self.gen_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
decoded_batch = byt5_decode_batch( | |
outputs.sequences.reshape( | |
(len(context), -1, outputs.sequences.shape[-1]) | |
).tolist(), | |
skip_position_token=True, | |
skip_special_tokens=True, | |
) | |
consistent_programs = [] | |
idxs = [] | |
for decoded, ctx in zip(decoded_batch, context): | |
cp = [] | |
idx = [] | |
for i, p in enumerate(decoded): | |
if enforce_consistency: | |
if consistent(p, ctx): | |
cp.append(p) | |
idx.append(i) | |
else: | |
cp.append(p) | |
idx.append(i) | |
consistent_programs.append(cp) | |
idxs.append(idx) | |
logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1) | |
gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze( | |
-1 | |
) | |
gen_probs.masked_fill_(gen_probs.isinf(), 0) | |
scores = gen_probs.sum(-1) | |
n_decoded = scores.shape[0] | |
n_seq = n_decoded // len(context) | |
scores = scores.reshape((len(context), n_seq)) | |
scores_list = scores.tolist() | |
if return_scores: | |
return ListenerOutput(consistent_programs, idxs, decoded_batch, scores_list) | |
else: | |
return ListenerOutput(consistent_programs) | |