solve memory issue in conv1D attention module of gpt2

#94
by rariwa - opened

I try to use gpt2 model to predict sequence of long vector. However, I got memory issue in this part:
https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py
self.weight = nn.Parameter(torch.empty(nx, nf))
I have huge nx and nf where nx 516224 and nf=3*nx.
anyone has an idea or trick on how to solve the memory issue?

thank you
regards

Sign up or log in to comment