Update src/prediction.py
Browse files- src/prediction.py +29 -13
src/prediction.py
CHANGED
@@ -22,6 +22,8 @@ generation_kwargs = {
|
|
22 |
"early_stopping": True,
|
23 |
"num_beams": 5,
|
24 |
"length_penalty": 1.5,
|
|
|
|
|
25 |
}
|
26 |
|
27 |
special_tokens = tokenizer.all_special_tokens
|
@@ -50,7 +52,7 @@ def target_postprocessing(texts, special_tokens):
|
|
50 |
|
51 |
return new_texts
|
52 |
|
53 |
-
def generation_function(texts):
|
54 |
_inputs = texts if isinstance(texts, list) else [texts]
|
55 |
inputs = [prefix + inp for inp in _inputs]
|
56 |
inputs = tokenizer(
|
@@ -58,23 +60,37 @@ def generation_function(texts):
|
|
58 |
max_length=256,
|
59 |
padding="max_length",
|
60 |
truncation=True,
|
61 |
-
return_tensors=
|
62 |
)
|
63 |
|
64 |
input_ids = inputs.input_ids
|
65 |
attention_mask = inputs.attention_mask
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
|
80 |
items = [
|
|
|
22 |
"early_stopping": True,
|
23 |
"num_beams": 5,
|
24 |
"length_penalty": 1.5,
|
25 |
+
"num_return_sequences": 3, # Generate 3 unique sequences
|
26 |
+
"temperature": 0.8
|
27 |
}
|
28 |
|
29 |
special_tokens = tokenizer.all_special_tokens
|
|
|
52 |
|
53 |
return new_texts
|
54 |
|
55 |
+
def generation_function(texts, num_recipes=1):
|
56 |
_inputs = texts if isinstance(texts, list) else [texts]
|
57 |
inputs = [prefix + inp for inp in _inputs]
|
58 |
inputs = tokenizer(
|
|
|
60 |
max_length=256,
|
61 |
padding="max_length",
|
62 |
truncation=True,
|
63 |
+
return_tensors="pt"
|
64 |
)
|
65 |
|
66 |
input_ids = inputs.input_ids
|
67 |
attention_mask = inputs.attention_mask
|
68 |
|
69 |
+
generated_recipes = []
|
70 |
+
while len(generated_recipes) < num_recipes:
|
71 |
+
output_ids = model.generate(
|
72 |
+
input_ids=input_ids,
|
73 |
+
attention_mask=attention_mask,
|
74 |
+
**generation_kwargs
|
75 |
+
)
|
76 |
+
generated = output_ids.detach().cpu().numpy()
|
77 |
+
generated_recipe = target_postprocessing(
|
78 |
+
tokenizer.batch_decode(generated, skip_special_tokens=False),
|
79 |
+
special_tokens
|
80 |
+
)
|
81 |
+
|
82 |
+
# Check if generated_recipe is unique and contains only inputted ingredients
|
83 |
+
unique = True
|
84 |
+
for recipe in generated_recipes:
|
85 |
+
if generated_recipe == recipe or not all(ingredient in generated_recipe[0] for ingredient in texts[0].split(',')):
|
86 |
+
unique = False
|
87 |
+
break
|
88 |
+
|
89 |
+
if unique:
|
90 |
+
generated_recipes.append(generated_recipe)
|
91 |
+
|
92 |
+
return generated_recipes[0] if num_recipes == 1 else generated_recipes
|
93 |
+
|
94 |
|
95 |
|
96 |
items = [
|