whackthejacker commited on
Commit
895715d
·
verified ·
1 Parent(s): 2917bca

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +34 -2
model.py CHANGED
@@ -1,9 +1,41 @@
 
1
  import torch.nn as nn
 
2
 
3
  class CodeGenerator(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
4
  def __init__(self, model_name):
 
 
 
 
 
 
 
 
5
  super().__init__()
6
- self.model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
7
 
8
  def forward(self, input_ids):
9
- return self.model(input_ids)[0]
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
  import torch.nn as nn
3
+ from transformers import AutoModelForCausalLM
4
 
5
  class CodeGenerator(nn.Module):
6
+ """
7
+ A PyTorch module that generates code using a pre-trained language model.
8
+
9
+ This class inherits from `nn.Module` and encapsulates a pre-trained language model
10
+ from the Hugging Face Transformers library. The model is used to generate code
11
+ based on the input sequence.
12
+
13
+ Attributes:
14
+ - model (transformers.AutoModelForCausalLM): The pre-trained language model
15
+ used for code generation.
16
+ """
17
  def __init__(self, model_name):
18
+ """
19
+ Initializes a new instance of the `CodeGenerator` class.
20
+
21
+ Parameters:
22
+ - model_name (str): The name of the pre-trained language model to use.
23
+ This should be a valid model name from the Hugging Face
24
+ Transformers library.
25
+ """
26
  super().__init__()
27
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
28
 
29
  def forward(self, input_ids):
30
+ """
31
+ Generates code based on the input sequence.
32
+
33
+ Parameters:
34
+ - input_ids (torch.Tensor): A tensor of token IDs representing the input
35
+ sequence for the language model.
36
+
37
+ Returns:
38
+ torch.Tensor: The output tensor containing the generated code.
39
+ """
40
+ return self.model(input_ids)[0]
41
+ ```