File size: 1,710 Bytes
65f400e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
import transformers
from transformers import ViTModel
import torch
from torch import nn
import transformers
from transformers import PreTrainedModel
class EEGViTAutoModel(PreTrainedModel):
config_class = transformers.ViTConfig
def __init__(self, config=None):
if config is None:
config = transformers.ViTConfig()
super().__init__(config)
self.model = EEGViT_pretrained()
class EEGViT_pretrained(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=256,
kernel_size=(1, 36),
stride=(1, 36),
padding=(0,2),
bias=False
)
self.batchnorm1 = nn.BatchNorm2d(256, False)
model_name = "google/vit-base-patch16-224"
config = transformers.ViTConfig.from_pretrained(model_name)
config.update({'num_channels': 256})
config.update({'image_size': (129,14)})
config.update({'patch_size': (8,1)})
model = transformers.ViTForImageClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)
model.vit.embeddings.patch_embeddings.projection = torch.nn.Conv2d(256, 768, kernel_size=(8, 1), stride=(8, 1), padding=(0,0), groups=256)
model.classifier=torch.nn.Sequential(torch.nn.Linear(768,1000,bias=True),
torch.nn.Dropout(p=0.1),
torch.nn.Linear(1000,2,bias=True))
self.ViT = model
def forward(self,x):
x=self.conv1(x)
x=self.batchnorm1(x)
x=self.ViT.forward(x).logits
return x
|