Zai commited on
Commit
d2d5f50
·
1 Parent(s): c4b84ea

poertry inited

Browse files
Files changed (7) hide show
  1. pyproject.toml +14 -0
  2. setup.py +0 -21
  3. version.py +0 -1
  4. yume/config.py +2 -1
  5. yume/dataset.py +4 -5
  6. yume/models.py +100 -68
  7. yume/yume.py +2 -3
pyproject.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "yume"
3
+ version = "0.1.0"
4
+ description = "GPT from scratch trained with Japanese dataset"
5
+ authors = ["Zai <130903099+zaibutcooler@users.noreply.github.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.11"
10
+
11
+
12
+ [build-system]
13
+ requires = ["poetry-core"]
14
+ build-backend = "poetry.core.masonry.api"
setup.py DELETED
@@ -1,21 +0,0 @@
1
- from setuptools import setup, find_packages
2
-
3
- with open("requirements.txt") as f:
4
- requirements = f.read().splitlines()
5
-
6
- setup(
7
- name="yume",
8
- version="0.1",
9
- packages=find_packages(),
10
- install_requires=requirements,
11
- author="Zai",
12
- author_email="zaiyellyintaung@gmail.com",
13
- description="LLM trained with Animanga dataset",
14
- long_description="Inspired by Andrej Karpathy trained with japanese animanga dataset",
15
- url="https://github.com/zaibutcooler/yume",
16
- # classifiers=[
17
- # 'Programming Language :: Python :: 3',
18
- # 'License :: OSI Approved :: MIT License',
19
- # 'Operating System :: OS Independent',
20
- # ],
21
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
version.py DELETED
@@ -1 +0,0 @@
1
- __version__ = "20231117"
 
 
yume/config.py CHANGED
@@ -21,6 +21,7 @@ class Config:
21
  self.bias = bias
22
  self.lr = lr
23
 
 
24
  # Small Yume model (around 100M parameters)
25
  yume_small = Config(
26
  num_epoch=10,
@@ -58,4 +59,4 @@ yume_large = Config(
58
  dropout=0.1,
59
  bias=True,
60
  lr=0.001,
61
- )
 
21
  self.bias = bias
22
  self.lr = lr
23
 
24
+
25
  # Small Yume model (around 100M parameters)
26
  yume_small = Config(
27
  num_epoch=10,
 
59
  dropout=0.1,
60
  bias=True,
61
  lr=0.001,
62
+ )
yume/dataset.py CHANGED
@@ -23,7 +23,6 @@ class Trainset(Dataset):
23
  loaded_dataset = load_dataset(url)
24
  self.texts = loaded_dataset["animanga"]["texts"]
25
  dummy_logger("Successfully loaded the dataset")
26
-
27
 
28
  def _tokenize(self, tiktoken=True):
29
  if tiktoken:
@@ -36,13 +35,13 @@ class Trainset(Dataset):
36
  self.tokenizer = Tokenizer()
37
  self.tokenizer.load_pretrained()
38
  self.tokenizer.encode(self.texts)
39
-
40
  def _prep_bin(self):
41
  pass
42
-
43
  def get_batch(self):
44
  pass
45
-
46
  # from loading to installing in one function
47
  def build_dataset(self):
48
- pass
 
23
  loaded_dataset = load_dataset(url)
24
  self.texts = loaded_dataset["animanga"]["texts"]
25
  dummy_logger("Successfully loaded the dataset")
 
26
 
27
  def _tokenize(self, tiktoken=True):
28
  if tiktoken:
 
35
  self.tokenizer = Tokenizer()
36
  self.tokenizer.load_pretrained()
37
  self.tokenizer.encode(self.texts)
38
+
39
  def _prep_bin(self):
40
  pass
41
+
42
  def get_batch(self):
43
  pass
44
+
45
  # from loading to installing in one function
46
  def build_dataset(self):
47
+ pass
yume/models.py CHANGED
@@ -6,9 +6,10 @@ from .utils import encode, decode
6
  import math
7
  from huggingface_hub import PyTorchModelHubMixin
8
 
 
9
  # took from karpthy's
10
  class LayerNorm(nn.Module):
11
- """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
12
 
13
  def __init__(self, ndim, bias):
14
  super().__init__()
@@ -23,39 +24,49 @@ class LayerNorm(nn.Module):
23
  class SelfAttention(nn.Module, PyTorchModelHubMixin):
24
  def __init__(self, config: Config) -> None:
25
  super().__init__()
26
- self.attn = nn.Linear(config.n_embd,3*config.n_embd,bias=config.bias)
27
- self.proj = nn.Linear(config.n_embd,config.n_embd,bias=config.bias)
28
  self.attn_dropout = nn.Dropout(config.dropout)
29
  self.resid_dropout = nn.Dropout(config.dropout)
30
  self.config = config
31
 
32
- self.flash = hasattr(torch.nn.functional,'scaled_dot_product_attention')
33
  if not self.flash:
34
  print("Using Slow Attention. Use PyTorch >= 2.0")
35
-
36
- self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
37
- .view(1, 1, config.block_size, config.block_size))
38
 
 
 
 
 
 
 
39
 
40
  def forward(self, x):
41
- B,T,C = x.size()
42
- q,k,v = self.attn(x).split(self.config.n_embd,dim=2)
43
  k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
44
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
45
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
46
-
47
  if self.flash:
48
  # efficient attention using Flash Attention CUDA kernels
49
- y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
 
 
 
 
 
 
 
50
  else:
51
  # manual implementation of attention
52
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
53
- att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
54
  att = F.softmax(att, dim=-1)
55
  att = self.attn_dropout(att)
56
  y = att @ v
57
-
58
- y = y.transpose(1, 2).contiguous().view(B, T, C)
59
 
60
  # output projection
61
  y = self.resid_dropout(self.c_proj(y))
@@ -65,11 +76,13 @@ class SelfAttention(nn.Module, PyTorchModelHubMixin):
65
  class MLP(nn.Module, PyTorchModelHubMixin):
66
  def __init__(self, config: Config) -> None:
67
  super().__init__()
68
- self.fully_connected = nn.Linear(config.n_embd,4*config.n_embd,bias=config.bias)
 
 
69
  self.gelu = nn.GELU()
70
- self.proj = nn.Linear(4*config.n_embd,config.n_embd,bias=config.bias)
71
  self.dropout = nn.Dropout(config.dropout)
72
-
73
  def forward(self, x):
74
  x = self.fully_connected(x)
75
  x = self.gelu(x)
@@ -81,17 +94,17 @@ class MLP(nn.Module, PyTorchModelHubMixin):
81
  class Block(nn.Module, PyTorchModelHubMixin):
82
  def __init__(self, config: Config) -> None:
83
  super().__init__()
84
- self.ln_1 = LayerNorm(config.n_embd,bias=config.bias)
85
  self.attn = SelfAttention(config)
86
- self.ln_2 = LayerNorm(config.n_embd,bias=config.bias)
87
  self.mlp = MLP(config)
88
-
89
 
90
  def forward(self, x):
91
- x = x+ self.attn(self.ln_1(x))
92
- x = x+ self.mlp(self.ln_2(x))
93
  return x
94
 
 
95
  class GPT(nn.Module, PyTorchModelHubMixin):
96
  def __init__(self, config: Config):
97
  super().__init__()
@@ -99,17 +112,18 @@ class GPT(nn.Module, PyTorchModelHubMixin):
99
  assert config.block_size is not None
100
  self.config = config
101
  self.device = config.device
102
-
103
- self.transformer= nn.ModuleDict(dict(
104
- wte = nn.Embedding(config.vocab_size,config.n_embd),
105
- wpe = nn.Embedding(config.block_size,config.n_embd),
106
- drop = nn.Dropout(config.dropout),
107
- blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
108
- ln_f = LayerNorm(config.n_embd,config.bias)
109
- ))
110
- self.lm_head = nn.Linear(config.n_embd,config.vocab_size,bias=False)
111
-
112
-
 
113
  def get_num_params(self, non_embedding=True):
114
  """
115
  Return the number of parameters in the model.
@@ -121,28 +135,34 @@ class GPT(nn.Module, PyTorchModelHubMixin):
121
  if non_embedding:
122
  n_params -= self.transformer.wpe.weight.numel()
123
  return n_params
124
-
125
- def forward(self, idx,targets=None):
126
- b,t = x.size()
127
- assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
128
- pos = torch.arange(0, t, dtype=torch.long, device=self.device) # shape (t)
 
 
129
 
130
  tok_emb = self.transformer.wte(idx)
131
  pos_emb = self.transformer.wpe(idx)
132
-
133
- x = self.transformer.drop(tok_emb+pos_emb)
134
-
135
  for block in self.transformer.blocks:
136
  x = block(x)
137
  x = self.transformer.ln_f(x)
138
-
139
  if targets is not None:
140
  # if we are given some desired targets also calculate the loss
141
  logits = self.lm_head(x)
142
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
 
 
143
  else:
144
  # inference-time mini-optimization: only forward the lm_head on the very last position
145
- logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
 
 
146
  loss = None
147
 
148
  return logits, loss
@@ -153,10 +173,12 @@ class GPT(nn.Module, PyTorchModelHubMixin):
153
  # but want to use a smaller block size for some smaller, simpler model
154
  assert block_size <= self.config.block_size
155
  self.config.block_size = block_size
156
- self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
 
 
157
  for block in self.transformer.h:
158
- if hasattr(block.attn, 'bias'):
159
- block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
160
 
161
  def _init_weights(self, module):
162
  if isinstance(module, nn.Linear):
@@ -168,17 +190,21 @@ class GPT(nn.Module, PyTorchModelHubMixin):
168
 
169
  def configure_optimizer(self):
170
  pass
171
-
172
  @torch.no_grad()
173
- def generate(self,idx,max_token,temperature=1.0,top_k=None):
174
-
175
  for _ in range(max_token):
176
- idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:,-self.config.block_size:]
177
- logits,_ = self(idx_cond)
178
-
 
 
 
 
179
  if top_k is not None:
180
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
181
- logits[logits < v[:, [-1]]] = -float('Inf')
182
  # apply softmax to convert logits to (normalized) probabilities
183
  probs = F.softmax(logits, dim=-1)
184
  # sample from the distribution
@@ -198,34 +224,40 @@ class GPT(nn.Module, PyTorchModelHubMixin):
198
  decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
199
  nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
200
  optim_groups = [
201
- {'params': decay_params, 'weight_decay': weight_decay},
202
- {'params': nodecay_params, 'weight_decay': 0.0}
203
  ]
204
  num_decay_params = sum(p.numel() for p in decay_params)
205
  num_nodecay_params = sum(p.numel() for p in nodecay_params)
206
- print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
207
- print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
 
 
 
 
208
  # Create AdamW optimizer and use the fused version if it is available
209
- fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
210
- use_fused = fused_available and device_type == 'cuda'
211
  extra_args = dict(fused=True) if use_fused else dict()
212
- optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
 
 
213
  print(f"using fused AdamW: {use_fused}")
214
 
215
  return optimizer
216
 
217
  def estimate_mfu(self, fwdbwd_per_iter, dt):
218
- """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
219
  # first estimate the number of flops we do per iteration.
220
  # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
221
  N = self.get_num_params()
222
  cfg = self.config
223
- L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
224
- flops_per_token = 6*N + 12*L*H*Q*T
225
  flops_per_fwdbwd = flops_per_token * T
226
  flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
227
  # express our flops throughput as ratio of A100 bfloat16 peak flops
228
- flops_achieved = flops_per_iter * (1.0/dt) # per second
229
- flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
230
  mfu = flops_achieved / flops_promised
231
- return mfu
 
6
  import math
7
  from huggingface_hub import PyTorchModelHubMixin
8
 
9
+
10
  # took from karpthy's
11
  class LayerNorm(nn.Module):
12
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
13
 
14
  def __init__(self, ndim, bias):
15
  super().__init__()
 
24
  class SelfAttention(nn.Module, PyTorchModelHubMixin):
25
  def __init__(self, config: Config) -> None:
26
  super().__init__()
27
+ self.attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
28
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
29
  self.attn_dropout = nn.Dropout(config.dropout)
30
  self.resid_dropout = nn.Dropout(config.dropout)
31
  self.config = config
32
 
33
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
34
  if not self.flash:
35
  print("Using Slow Attention. Use PyTorch >= 2.0")
 
 
 
36
 
37
+ self.register_buffer(
38
+ "bias",
39
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
40
+ 1, 1, config.block_size, config.block_size
41
+ ),
42
+ )
43
 
44
  def forward(self, x):
45
+ B, T, C = x.size()
46
+ q, k, v = self.attn(x).split(self.config.n_embd, dim=2)
47
  k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
48
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
49
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
50
+
51
  if self.flash:
52
  # efficient attention using Flash Attention CUDA kernels
53
+ y = torch.nn.functional.scaled_dot_product_attention(
54
+ q,
55
+ k,
56
+ v,
57
+ attn_mask=None,
58
+ dropout_p=self.dropout if self.training else 0,
59
+ is_causal=True,
60
+ )
61
  else:
62
  # manual implementation of attention
63
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
64
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
65
  att = F.softmax(att, dim=-1)
66
  att = self.attn_dropout(att)
67
  y = att @ v
68
+
69
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
70
 
71
  # output projection
72
  y = self.resid_dropout(self.c_proj(y))
 
76
  class MLP(nn.Module, PyTorchModelHubMixin):
77
  def __init__(self, config: Config) -> None:
78
  super().__init__()
79
+ self.fully_connected = nn.Linear(
80
+ config.n_embd, 4 * config.n_embd, bias=config.bias
81
+ )
82
  self.gelu = nn.GELU()
83
+ self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
84
  self.dropout = nn.Dropout(config.dropout)
85
+
86
  def forward(self, x):
87
  x = self.fully_connected(x)
88
  x = self.gelu(x)
 
94
  class Block(nn.Module, PyTorchModelHubMixin):
95
  def __init__(self, config: Config) -> None:
96
  super().__init__()
97
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
98
  self.attn = SelfAttention(config)
99
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
100
  self.mlp = MLP(config)
 
101
 
102
  def forward(self, x):
103
+ x = x + self.attn(self.ln_1(x))
104
+ x = x + self.mlp(self.ln_2(x))
105
  return x
106
 
107
+
108
  class GPT(nn.Module, PyTorchModelHubMixin):
109
  def __init__(self, config: Config):
110
  super().__init__()
 
112
  assert config.block_size is not None
113
  self.config = config
114
  self.device = config.device
115
+
116
+ self.transformer = nn.ModuleDict(
117
+ dict(
118
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
119
+ wpe=nn.Embedding(config.block_size, config.n_embd),
120
+ drop=nn.Dropout(config.dropout),
121
+ blocks=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
122
+ ln_f=LayerNorm(config.n_embd, config.bias),
123
+ )
124
+ )
125
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
126
+
127
  def get_num_params(self, non_embedding=True):
128
  """
129
  Return the number of parameters in the model.
 
135
  if non_embedding:
136
  n_params -= self.transformer.wpe.weight.numel()
137
  return n_params
138
+
139
+ def forward(self, idx, targets=None):
140
+ b, t = x.size()
141
+ assert (
142
+ t <= self.config.block_size
143
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
144
+ pos = torch.arange(0, t, dtype=torch.long, device=self.device) # shape (t)
145
 
146
  tok_emb = self.transformer.wte(idx)
147
  pos_emb = self.transformer.wpe(idx)
148
+
149
+ x = self.transformer.drop(tok_emb + pos_emb)
150
+
151
  for block in self.transformer.blocks:
152
  x = block(x)
153
  x = self.transformer.ln_f(x)
154
+
155
  if targets is not None:
156
  # if we are given some desired targets also calculate the loss
157
  logits = self.lm_head(x)
158
+ loss = F.cross_entropy(
159
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
160
+ )
161
  else:
162
  # inference-time mini-optimization: only forward the lm_head on the very last position
163
+ logits = self.lm_head(
164
+ x[:, [-1], :]
165
+ ) # note: using list [-1] to preserve the time dim
166
  loss = None
167
 
168
  return logits, loss
 
173
  # but want to use a smaller block size for some smaller, simpler model
174
  assert block_size <= self.config.block_size
175
  self.config.block_size = block_size
176
+ self.transformer.wpe.weight = nn.Parameter(
177
+ self.transformer.wpe.weight[:block_size]
178
+ )
179
  for block in self.transformer.h:
180
+ if hasattr(block.attn, "bias"):
181
+ block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
182
 
183
  def _init_weights(self, module):
184
  if isinstance(module, nn.Linear):
 
190
 
191
  def configure_optimizer(self):
192
  pass
193
+
194
  @torch.no_grad()
195
+ def generate(self, idx, max_token, temperature=1.0, top_k=None):
196
+
197
  for _ in range(max_token):
198
+ idx_cond = (
199
+ idx
200
+ if idx.size(1) <= self.config.block_size
201
+ else idx[:, -self.config.block_size :]
202
+ )
203
+ logits, _ = self(idx_cond)
204
+
205
  if top_k is not None:
206
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
207
+ logits[logits < v[:, [-1]]] = -float("Inf")
208
  # apply softmax to convert logits to (normalized) probabilities
209
  probs = F.softmax(logits, dim=-1)
210
  # sample from the distribution
 
224
  decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
225
  nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
226
  optim_groups = [
227
+ {"params": decay_params, "weight_decay": weight_decay},
228
+ {"params": nodecay_params, "weight_decay": 0.0},
229
  ]
230
  num_decay_params = sum(p.numel() for p in decay_params)
231
  num_nodecay_params = sum(p.numel() for p in nodecay_params)
232
+ print(
233
+ f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
234
+ )
235
+ print(
236
+ f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
237
+ )
238
  # Create AdamW optimizer and use the fused version if it is available
239
+ fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
240
+ use_fused = fused_available and device_type == "cuda"
241
  extra_args = dict(fused=True) if use_fused else dict()
242
+ optimizer = torch.optim.AdamW(
243
+ optim_groups, lr=learning_rate, betas=betas, **extra_args
244
+ )
245
  print(f"using fused AdamW: {use_fused}")
246
 
247
  return optimizer
248
 
249
  def estimate_mfu(self, fwdbwd_per_iter, dt):
250
+ """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS"""
251
  # first estimate the number of flops we do per iteration.
252
  # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
253
  N = self.get_num_params()
254
  cfg = self.config
255
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size
256
+ flops_per_token = 6 * N + 12 * L * H * Q * T
257
  flops_per_fwdbwd = flops_per_token * T
258
  flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
259
  # express our flops throughput as ratio of A100 bfloat16 peak flops
260
+ flops_achieved = flops_per_iter * (1.0 / dt) # per second
261
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
262
  mfu = flops_achieved / flops_promised
263
+ return mfu
yume/yume.py CHANGED
@@ -3,7 +3,7 @@ from torch import nn
3
  import torch.nn.functional as F
4
  from huggingface_hub import login
5
 
6
- from .config import Config,yume_small
7
  from .models import GPT
8
  from .utils import dummy_logger, training_logger
9
  from .dataset import Trainset
@@ -23,13 +23,12 @@ class Yume:
23
  def sample(self):
24
  pass
25
 
26
- def pretrain(self, dataset:Trainset):
27
  lr = self.config.lr
28
  dataset = Trainset()
29
  for epoch in range(self.config.num_epoch):
30
  # real trainset
31
  pass
32
-
33
 
34
  def fine_tune(self):
35
  pass
 
3
  import torch.nn.functional as F
4
  from huggingface_hub import login
5
 
6
+ from .config import Config, yume_small
7
  from .models import GPT
8
  from .utils import dummy_logger, training_logger
9
  from .dataset import Trainset
 
23
  def sample(self):
24
  pass
25
 
26
+ def pretrain(self, dataset: Trainset):
27
  lr = self.config.lr
28
  dataset = Trainset()
29
  for epoch in range(self.config.num_epoch):
30
  # real trainset
31
  pass
 
32
 
33
  def fine_tune(self):
34
  pass