Triton is slower?

#21
by doguaraci - opened

I've been trying the model with the triton on/off with the below code, and triton is almost 3 times slower in my environment (A10G). Do you have any guidance on this?

import time
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer

config = AutoConfig.from_pretrained(
    "replit/replit-code-v1-3b",
    trust_remote_code=True
)

config.attn_config['attn_impl'] = 'triton' # I'm commenting out this to try with 'torch' implementation

model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', config=config, trust_remote_code=True)
model.to(device='cuda:0', dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
x = tokenizer.encode('def hello():\n  print("hello world")\n', return_tensors='pt').to('cuda')

start = time.time()
y = model.generate(x, max_new_tokens=64)
end = time.time()

print(end - start)
doguaraci changed discussion status to closed

Sign up or log in to comment