File size: 729 Bytes
28de1fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from datasets import load_dataset, DatasetDict, Dataset
import pandas as pd
import ast

import genai_SDK
from genai_SDK.Seq2Seq import GPT2

# Load dataset from huggingface
dataset = load_dataset("erwanlc/cocktails_recipe_no_brand")

# Convert to a pandas dataframe
data = [{'title': item['title'], 'raw_ingredients': item['raw_ingredients']} for item in dataset['train']]
df = pd.DataFrame(data)

# Just extract the ingredient names, nothing else
df.raw_ingredients = df.raw_ingredients.apply(lambda x: ', '.join([y[1] for y in ast.literal_eval(x)]))
#display(df.head())

model = GPT2(gpu=0, model_name="distilgpt2")
model.load_data(df=df, batch_size=8)

model.train(num_epochs=2)

print(model.generate_text("Annual Planning"))