Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding=utf-8 | |
import inspect | |
import logging | |
import nltk | |
from typing import Tuple | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
BloomForCausalLM, | |
BloomTokenizerFast, | |
CTRLLMHeadModel, | |
CTRLTokenizer, | |
GenerationMixin, | |
GPT2LMHeadModel, | |
GPT2Tokenizer, | |
GPTJForCausalLM, | |
LlamaForCausalLM, | |
LlamaTokenizer, | |
OpenAIGPTLMHeadModel, | |
OpenAIGPTTokenizer, | |
OPTForCausalLM, | |
TransfoXLLMHeadModel, | |
TransfoXLTokenizer, | |
XLMTokenizer, | |
XLMWithLMHeadModel, | |
XLNetLMHeadModel, | |
XLNetTokenizer, | |
AutoModelForSeq2SeqLM, | |
) | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
from forbidden import FORBIDDEN_NOUN | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop | |
MODEL_CLASSES = { | |
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer), | |
"ctrl": (CTRLLMHeadModel, CTRLTokenizer), | |
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), | |
"xlnet": (XLNetLMHeadModel, XLNetTokenizer), | |
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), | |
"xlm": (XLMWithLMHeadModel, XLMTokenizer), | |
"gptj": (GPTJForCausalLM, AutoTokenizer), | |
"bloom": (BloomForCausalLM, BloomTokenizerFast), | |
"llama": (LlamaForCausalLM, LlamaTokenizer), | |
"opt": (OPTForCausalLM, GPT2Tokenizer), | |
} | |
FORBIDDEN_NOUN = set(FORBIDDEN_NOUN) | |
class Translator: | |
def __init__(self, model_name): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
def translate(self, text): | |
inputs = self.tokenizer(text, return_tensors="pt", padding=True) | |
outputs = self.model.generate(**inputs) | |
translated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return translated_text | |
def __call__(self, text): | |
return self.translate(text) | |
# | |
# Functions to prepare models' input | |
# | |
def prepare_ctrl_input(args, _, tokenizer, prompt_text): | |
if args["temperature"] > 0.7: | |
pass | |
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) | |
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): | |
pass | |
return prompt_text | |
def prepare_xlm_input(args, model, tokenizer, prompt_text): | |
# kwargs = {"language": None, "mask_token_id": None} | |
# Set the language | |
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb | |
if hasattr(model.config, "lang2id") and use_lang_emb: | |
available_languages = model.config.lang2id.keys() | |
if args["xlm_language"] in available_languages: | |
language = args["xlm_language"] | |
else: | |
language = None | |
while language not in available_languages: | |
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") | |
model.config.lang_id = model.config.lang2id[language] | |
# kwargs["language"] = tokenizer.lang2id[language] | |
return prompt_text | |
def prepare_xlnet_input(args, _, tokenizer, prompt_text): | |
prefix = args["prefix"] if args["prefix"] else args["padding_text"] if args["padding_text"] else "" | |
prompt_text = prefix + prompt_text | |
return prompt_text | |
def prepare_transfoxl_input(args, _, tokenizer, prompt_text): | |
prefix = args["prefix"] if args["prefix"] else args["padding_text"] if args["padding_text"] else "" | |
prompt_text = prefix + prompt_text | |
return prompt_text | |
PREPROCESSING_FUNCTIONS = { | |
"ctrl": prepare_ctrl_input, | |
"xlm": prepare_xlm_input, | |
"xlnet": prepare_xlnet_input, | |
"transfo-xl": prepare_transfoxl_input, | |
} | |
def adjust_length_to_model(length, max_sequence_length): | |
if length < 0 and max_sequence_length > 0: | |
length = max_sequence_length | |
elif 0 < max_sequence_length < length: | |
length = max_sequence_length # No generation bigger than model size | |
elif length < 0: | |
length = MAX_LENGTH # avoid infinite loop | |
return length | |
def sparse_model_config(model_config): | |
embedding_size = None | |
if hasattr(model_config, "hidden_size"): | |
embedding_size = model_config.hidden_size | |
elif hasattr(model_config, "n_embed"): | |
embedding_size = model_config.n_embed | |
elif hasattr(model_config, "n_embd"): | |
embedding_size = model_config.n_embd | |
num_head = None | |
if hasattr(model_config, "num_attention_heads"): | |
num_head = model_config.num_attention_heads | |
elif hasattr(model_config, "n_head"): | |
num_head = model_config.n_head | |
if embedding_size is None or num_head is None or num_head == 0: | |
raise ValueError("Check the model config") | |
num_embedding_size_per_head = int(embedding_size / num_head) | |
if hasattr(model_config, "n_layer"): | |
num_layer = model_config.n_layer | |
elif hasattr(model_config, "num_hidden_layers"): | |
num_layer = model_config.num_hidden_layers | |
else: | |
raise ValueError("Number of hidden layers couldn't be determined from the model config") | |
return num_layer, num_head, num_embedding_size_per_head | |
def generate_past_key_values(model, batch_size, seq_len): | |
num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) | |
if model.config.model_type == "bloom": | |
past_key_values = tuple( | |
( | |
torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len) | |
.to(model.dtype) | |
.to(model.device), | |
torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head) | |
.to(model.dtype) | |
.to(model.device), | |
) | |
for _ in range(num_block_layers) | |
) | |
else: | |
past_key_values = tuple( | |
( | |
torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) | |
.to(model.dtype) | |
.to(model.device), | |
torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) | |
.to(model.dtype) | |
.to(model.device), | |
) | |
for _ in range(num_block_layers) | |
) | |
return past_key_values | |
def prepare_jit_inputs(inputs, model, tokenizer): | |
batch_size = len(inputs) | |
dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") | |
dummy_input = dummy_input.to(model.device) | |
if model.config.use_cache: | |
dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1) | |
dummy_input["attention_mask"] = torch.cat( | |
[ | |
torch.zeros(dummy_input["attention_mask"].shape[0], 1) | |
.to(dummy_input["attention_mask"].dtype) | |
.to(model.device), | |
dummy_input["attention_mask"], | |
], | |
-1, | |
) | |
return dummy_input | |
class _ModelFallbackWrapper(GenerationMixin): | |
__slots__ = ("_optimized", "_default") | |
def __init__(self, optimized, default): | |
self._optimized = optimized | |
self._default = default | |
def __call__(self, *args, **kwargs): | |
if kwargs["past_key_values"] is None and self._default.config.use_cache: | |
kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0) | |
kwargs.pop("position_ids", None) | |
for k in list(kwargs.keys()): | |
if kwargs[k] is None or isinstance(kwargs[k], bool): | |
kwargs.pop(k) | |
outputs = self._optimized(**kwargs) | |
lm_logits = outputs[0] | |
past_key_values = outputs[1] | |
fixed_output = CausalLMOutputWithPast( | |
loss=None, | |
logits=lm_logits, | |
past_key_values=past_key_values, | |
hidden_states=None, | |
attentions=None, | |
) | |
return fixed_output | |
def __getattr__(self, item): | |
return getattr(self._default, item) | |
def prepare_inputs_for_generation( | |
self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs | |
): | |
return self._default.prepare_inputs_for_generation( | |
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs | |
) | |
def _reorder_cache( | |
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor | |
) -> Tuple[Tuple[torch.Tensor]]: | |
""" | |
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or | |
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct | |
beam_idx at every generation step. | |
""" | |
return self._default._reorder_cache(past_key_values, beam_idx) | |
def remove_tokens_before_copula(text): | |
sentences = text.split(",") | |
result = [sentences[0]] | |
for sentence in sentences[1:]: | |
tokens = nltk.word_tokenize(sentence) | |
target_indices = [i for i, token in enumerate(tokens) if token.lower() in ["is", "are", "am"]] | |
if target_indices: | |
last_target_index = target_indices[-1] | |
result.append(tokens[last_target_index + 1:]) | |
else: | |
result.append(tokens) | |
all_sentences = [" ".join(sen) for sen in result[1:]] | |
all_sentences.insert(0, result[0]) | |
result_text = ",".join(all_sentences) | |
return result_text | |
def generate_prompt( | |
prompt_text, | |
args, | |
zh_en_translator, | |
nlp, | |
model, | |
tokenizer, | |
distributed_state, | |
): | |
max_seq_length = getattr(model.config, "max_position_embeddings", 0) | |
args["length"] = adjust_length_to_model(args["length"], max_sequence_length=max_seq_length) | |
while(1): | |
prompt_text = zh_en_translator(prompt_text) | |
# only support single input. | |
# Different models need different input formatting and/or extra arguments | |
requires_preprocessing = args["model_type"] in PREPROCESSING_FUNCTIONS.keys() | |
if requires_preprocessing: | |
prepare_input = PREPROCESSING_FUNCTIONS.get(args["model_type"]) | |
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) | |
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: | |
tokenizer_kwargs = {"add_space_before_punct_symbol": True} | |
else: | |
tokenizer_kwargs = {} | |
encoded_prompt = tokenizer.encode( | |
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs | |
) | |
else: | |
prefix = args["prefix"] if args["prefix"] else args["padding_text"] | |
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") | |
encoded_prompt = encoded_prompt.to(distributed_state.device) | |
if encoded_prompt.size()[-1] == 0: | |
input_ids = None | |
else: | |
input_ids = encoded_prompt | |
if args["jit"]: | |
jit_input_texts = ["enable jit"] | |
jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) | |
torch._C._jit_set_texpr_fuser_enabled(False) | |
model.config.return_dict = False | |
if hasattr(model, "forward"): | |
sig = inspect.signature(model.forward) | |
else: | |
sig = inspect.signature(model.__call__) | |
jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) | |
traced_model = torch.jit.trace(model, jit_inputs, strict=False) | |
traced_model = torch.jit.freeze(traced_model.eval()) | |
traced_model(*jit_inputs) | |
traced_model(*jit_inputs) | |
model = _ModelFallbackWrapper(traced_model, model) | |
generated_sequences = [] | |
for generated_sequence_idx in range(args["num_return_sequences"]): | |
repeat_gen_time = 0 | |
while(1): | |
repeat_gen_time = repeat_gen_time + 1 | |
generated_sequence = model.generate( | |
input_ids=input_ids, | |
length_penalty=args["length_penalty"], | |
max_length=args["length"] + len(encoded_prompt[0]), | |
temperature=args["temperature"], | |
top_k=args["k"], | |
top_p=args["p"], | |
repetition_penalty=args["repetition_penalty"], | |
do_sample=True, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
# Remove the n_sequence dimension when returning single sequence | |
if len(generated_sequence.shape) >1: | |
generated_sequence.squeeze_() | |
generated_sequence = generated_sequence.tolist() | |
# Decode text | |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) | |
# Remove all text after the stop token | |
text = text[: text.find(args["stop_token"]) if args["stop_token"] else None] | |
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing | |
total_sequence = ( | |
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] | |
) | |
break | |
total_sequence = remove_tokens_before_copula(total_sequence) | |
generated_sequences.append(total_sequence) | |
return generated_sequences | |
if __name__ == "__main__": | |
generate_prompt() | |