EEGViT / eegvit_model.py
fracapuano's picture
Add EEGViT model
65f400e verified
raw
history blame
1.71 kB
import torch
import transformers
from transformers import ViTModel
import torch
from torch import nn
import transformers
from transformers import PreTrainedModel
class EEGViTAutoModel(PreTrainedModel):
config_class = transformers.ViTConfig
def __init__(self, config=None):
if config is None:
config = transformers.ViTConfig()
super().__init__(config)
self.model = EEGViT_pretrained()
class EEGViT_pretrained(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=256,
kernel_size=(1, 36),
stride=(1, 36),
padding=(0,2),
bias=False
)
self.batchnorm1 = nn.BatchNorm2d(256, False)
model_name = "google/vit-base-patch16-224"
config = transformers.ViTConfig.from_pretrained(model_name)
config.update({'num_channels': 256})
config.update({'image_size': (129,14)})
config.update({'patch_size': (8,1)})
model = transformers.ViTForImageClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)
model.vit.embeddings.patch_embeddings.projection = torch.nn.Conv2d(256, 768, kernel_size=(8, 1), stride=(8, 1), padding=(0,0), groups=256)
model.classifier=torch.nn.Sequential(torch.nn.Linear(768,1000,bias=True),
torch.nn.Dropout(p=0.1),
torch.nn.Linear(1000,2,bias=True))
self.ViT = model
def forward(self,x):
x=self.conv1(x)
x=self.batchnorm1(x)
x=self.ViT.forward(x).logits
return x