acecalisto3 commited on
Commit
1a6b654
1 Parent(s): e26d86d

Create gemmacode/model.py

Browse files
Files changed (1) hide show
  1. germmacode/gemmacode/model.py +9 -0
germmacode/gemmacode/model.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class CodeGenerator(nn.Module):
4
+ def __init__(self, model_name):
5
+ super().__init__()
6
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
7
+
8
+ def forward(self, input_ids):
9
+ return self.model(input_ids)[0]