Spaces:
Runtime error
Runtime error
File size: 4,540 Bytes
2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d 1a8e5ac 2869f1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from dataclasses import dataclass
from typing import List, Optional
from utils import (
get_preprocess_function,
get_utterance_processing_functions,
byt5_decode_batch,
consistent,
)
from utils import (
PROGRAM_SPECIAL_TOKEN,
UTTERANCES_SPECIAL_TOKEN,
GT_PROGRAM_SPECIAL_TOKEN,
)
from greenery import parse
from greenery.parse import NoMatch
import numpy as np
import torch
class Agent:
def __init__(
self,
model_path: str,
gen_config: dict,
inference_batch_size: int = 1,
device=None,
):
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.gen_config = GenerationConfig(**gen_config)
self.inference_batch_size = inference_batch_size
@dataclass
class ListenerOutput:
programs: List[List[str]]
idx: Optional[List[List[int]]] = None
decoded: Optional[List[List[str]]] = None
decoded_scores: Optional[List[List[float]]] = None
pruned: Optional[List[List[str]]] = None
class Listener(Agent):
def __init__(
self,
model_path,
gen_config,
inference_batch_size=4,
label_pos="suffix",
idx: bool = True,
program_special_token=PROGRAM_SPECIAL_TOKEN,
utterances_special_token=UTTERANCES_SPECIAL_TOKEN,
device=None,
):
super().__init__(model_path, gen_config, inference_batch_size, device)
self.label_pos = label_pos
self.idx = idx
self.program_special_token = program_special_token
self.utterances_special_token = utterances_special_token
self.utterances_to_string, self.string_to_utterances = (
get_utterance_processing_functions(
label_pos, idx, separator=utterances_special_token
)
)
self.device = self.model.device
def synthesize(self, context, return_scores=False, enforce_consistency=True):
# If context is a list of utterances, convert to string
if isinstance(context[0], list):
context_str = list(map(self.utterances_to_string, context))
else:
context_str = context
context_tokens = self.tokenizer(
[
(
f"{self.utterances_special_token}{c}"
if not c.startswith(self.utterances_special_token)
else c
)
for c in context_str
],
return_tensors="pt",
padding=True,
).to(self.device)
decoder_inputs = self.tokenizer(
[self.program_special_token for _ in context],
return_tensors="pt",
add_special_tokens=False,
).to(self.device)
outputs = self.model.generate(
**context_tokens,
decoder_input_ids=decoder_inputs.input_ids,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True,
)
decoded_batch = byt5_decode_batch(
outputs.sequences.reshape(
(len(context), -1, outputs.sequences.shape[-1])
).tolist(),
skip_position_token=True,
skip_special_tokens=True,
)
consistent_programs = []
idxs = []
for decoded, ctx in zip(decoded_batch, context):
cp = []
idx = []
for i, p in enumerate(decoded):
if enforce_consistency:
if consistent(p, ctx):
cp.append(p)
idx.append(i)
else:
cp.append(p)
idx.append(i)
consistent_programs.append(cp)
idxs.append(idx)
logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1)
gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(
-1
)
gen_probs.masked_fill_(gen_probs.isinf(), 0)
scores = gen_probs.sum(-1)
n_decoded = scores.shape[0]
n_seq = n_decoded // len(context)
scores = scores.reshape((len(context), n_seq))
scores_list = scores.tolist()
if return_scores:
return ListenerOutput(consistent_programs, idxs, decoded_batch, scores_list)
else:
return ListenerOutput(consistent_programs)
|