File size: 2,209 Bytes
d56f8e4 547399c d56f8e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from transformers import pipeline
from pprint import pprint
import re
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
# model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
def skip_special_tokens_and_prettify(text, tokenizer):
recipe_maps = {"<sep>": "--", "<section>": "\n"}
recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))
text = re.sub(
recipe_map_pattern,
lambda m: recipe_maps[m.group()],
re.sub("|".join(tokenizer.all_special_tokens), "", text)
)
data = {"title": "", "ingredients": [], "directions": []}
for section in text.split("\n"):
section = section.strip()
section = section.strip()
if section.startswith("title:"):
data["title"] = section.replace("title:", "").strip()
elif section.startswith("ingredients:"):
data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')]
elif section.startswith("directions:"):
data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')]
else:
pass
return data
def post_generator(output_tensors, tokenizer):
output_tensors = [output_tensors[i]["generated_token_ids"] for i in range(len(output_tensors))]
texts = tokenizer.batch_decode(output_tensors, skip_special_tokens=False)
texts = [skip_special_tokens_and_prettify(text, tokenizer) for text in texts]
return texts
# Example
generate_kwargs = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"do_sample": True,
"top_k": 60,
"top_p": 0.95,
"num_return_sequences": 3
}
# items = "potato, cheese"
# generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
# generated = generator(items, return_tensors=True, return_text=False, **generate_kwargs)
# outputs = post_generator(generated, tokenizer)
# pprint(outputs) |