Sheng Lei
Add application file
28de1fd
raw
history blame
729 Bytes
from datasets import load_dataset, DatasetDict, Dataset
import pandas as pd
import ast
import genai_SDK
from genai_SDK.Seq2Seq import GPT2
# Load dataset from huggingface
dataset = load_dataset("erwanlc/cocktails_recipe_no_brand")
# Convert to a pandas dataframe
data = [{'title': item['title'], 'raw_ingredients': item['raw_ingredients']} for item in dataset['train']]
df = pd.DataFrame(data)
# Just extract the ingredient names, nothing else
df.raw_ingredients = df.raw_ingredients.apply(lambda x: ', '.join([y[1] for y in ast.literal_eval(x)]))
#display(df.head())
model = GPT2(gpu=0, model_name="distilgpt2")
model.load_data(df=df, batch_size=8)
model.train(num_epochs=2)
print(model.generate_text("Annual Planning"))