Zai commited on
Commit
36cbecb
·
1 Parent(s): bd93ebf

added get_num_params and update test cases

Browse files
.github/workflows/hugging-face.yaml CHANGED
@@ -12,6 +12,7 @@ jobs:
12
  with:
13
  fetch-depth: 0
14
  lfs: true
 
15
  - name: Push to hub
16
  env:
17
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
 
12
  with:
13
  fetch-depth: 0
14
  lfs: true
15
+
16
  - name: Push to hub
17
  env:
18
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
README.md CHANGED
@@ -65,18 +65,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
65
 
66
  Mention any contributors or libraries that you used or were inspired by.
67
 
68
- ---
69
-
70
- title: Yume
71
- emoji: ✨
72
- colorFrom: blue
73
- colorTo: green
74
- sdk: streamlit
75
- sdk_version: 1.29.0
76
- app_file: interface.py
77
- pinned: false
78
- license: openrail
79
- ---
80
 
81
  ## Contact
82
 
 
65
 
66
  Mention any contributors or libraries that you used or were inspired by.
67
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  ## Contact
70
 
examples/__init__.py ADDED
File without changes
sampler.py DELETED
@@ -1,7 +0,0 @@
1
- from yume import Yume
2
-
3
- yume = Yume()
4
-
5
- yume.load_pretrained()
6
-
7
- yume.generate()
 
 
 
 
 
 
 
 
sampling.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .yume import Yume,Config
2
+
3
+ config = Config()
4
+
5
+ yume = Yume(config=config)
6
+
7
+ # Test the quality before loading the pretained
8
+ yume.sample()
9
+
10
+ yume.load_pretrained()
11
+
12
+ yume.sample()
tests/test_datasets.py CHANGED
@@ -1,14 +1,23 @@
1
  import unittest
2
- import yume
3
 
4
 
5
  class TestDatasets(unittest.TestCase):
6
  def test_download(self):
7
- pass
 
 
 
 
8
 
9
  def test_encode(self):
10
- pass
11
-
 
 
 
 
 
12
 
13
  if __name__ == "__main__":
14
  unittest.main()
 
1
  import unittest
2
+ from yume.dataset import Trainset
3
 
4
 
5
  class TestDatasets(unittest.TestCase):
6
  def test_download(self):
7
+ trainset = Trainset()
8
+ trainset._load_dataset()
9
+ assert trainset.texts is not None
10
+ trainset._tokenize()
11
+ assert len(trainset.data) > 1
12
 
13
  def test_encode(self):
14
+ trainset = Trainset()
15
+ dummy_text = "Hello Human World"
16
+ trainset.texts = dummy_text
17
+ trainset._tokenize()
18
+ assert len(trainset.data) > 1
19
+ encoded_text = trainset.tokenizer.encode(dummy_text)
20
+ assert trainset.tokenizer.decode(encoded_text) == dummy_text
21
 
22
  if __name__ == "__main__":
23
  unittest.main()
tests/test_pretrained.py CHANGED
@@ -1,11 +1,19 @@
1
  import unittest
 
2
 
3
 
4
  class TestPretrained(unittest.TestCase):
 
 
 
 
 
5
  def test_download(self):
 
6
  pass
7
 
8
  def test_generation(self):
 
9
  pass
10
 
11
 
 
1
  import unittest
2
+ from yume import Yume,Config
3
 
4
 
5
  class TestPretrained(unittest.TestCase):
6
+ def __init__(self, methodName: str = "runTest") -> None:
7
+ super().__init__(methodName)
8
+ self.config = Config()
9
+ self.yume = Yume(config=self.config)
10
+
11
  def test_download(self):
12
+ self.yume.load_pretrained()
13
  pass
14
 
15
  def test_generation(self):
16
+ self.yume.sample()
17
  pass
18
 
19
 
tests/test_tokenizer.py CHANGED
@@ -1,14 +1,19 @@
1
  import unittest
2
-
3
 
4
  class TestTokenizer(unittest.TestCase):
 
 
 
 
 
5
  def test_encode(self):
6
  pass
7
 
8
  def test_decode(self):
9
  pass
10
 
11
- def test_equal_result(self):
12
  pass
13
 
14
 
 
1
  import unittest
2
+ from yume import Tokenizer
3
 
4
  class TestTokenizer(unittest.TestCase):
5
+ def __init__(self, methodName: str = "runTest") -> None:
6
+ super().__init__(methodName)
7
+ self.tokenizer = Tokenizer()
8
+ self.dummy_text = "馬鹿なこと言わないでよ"
9
+
10
  def test_encode(self):
11
  pass
12
 
13
  def test_decode(self):
14
  pass
15
 
16
+ def test_train_encoder(self):
17
  pass
18
 
19
 
training.py CHANGED
@@ -1,5 +1,21 @@
1
- from .yume import Yume
2
 
3
- yume = Yume()
4
 
5
- yume.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .yume import Yume,Trainset,Config
2
 
3
+ config = Config()
4
 
5
+ dataset = Trainset()
6
+
7
+ dataset._load_dataset()
8
+
9
+ dataset._tokenize(tiktoken=True)
10
+
11
+ yume = Yume(config)
12
+
13
+ assert len(dataset.data) > 0
14
+
15
+ yume.pretrain(dataset.data)
16
+
17
+ yume.sample()
18
+
19
+ #optional
20
+ # yume.huggingface_login("your hf tokens")
21
+ # yume.save_pretrained("yume")
yume/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .yume import Yume
2
+ from .dataset import Trainset
3
+ from.tokenizer import Tokenizer
4
+ from .config import Config
yume/config.py CHANGED
@@ -9,6 +9,7 @@ class Config:
9
  n_embd=768,
10
  dropout=0.0,
11
  bias=True,
 
12
  ) -> None:
13
  self.num_epoch = num_epoch
14
  self.block_sized = 1024
@@ -18,3 +19,4 @@ class Config:
18
  self.n_embdd = 768
19
  self.dropout = 0.0
20
  self.bias = True
 
 
9
  n_embd=768,
10
  dropout=0.0,
11
  bias=True,
12
+ lr=0.001
13
  ) -> None:
14
  self.num_epoch = num_epoch
15
  self.block_sized = 1024
 
19
  self.n_embdd = 768
20
  self.dropout = 0.0
21
  self.bias = True
22
+ self.lr = lr
yume/dataset.py CHANGED
@@ -1,19 +1,39 @@
1
  from torch.utils.data import Dataset
2
  from datasets import load_dataset
3
  from .tokenizer import Tokenizer
 
 
 
4
 
5
 
6
  # TODO setup dataset
7
  class Trainset(Dataset):
8
  def __init__(self, batch_size=48):
9
- self.loaded_data = load_dataset("zaibutcooler/animanga-vault")
10
- self.texts = self.loaded_data["train"]["raw"]
11
- self.data = self.loaded_data["train"]["data"]
12
- self.tokenizer = Tokenizer()
13
- self.tokenizer.load_pretrained()
14
 
15
  def __len__(self):
16
  return len(self.data)
17
-
18
  def __getitem__(self, index):
 
19
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from torch.utils.data import Dataset
2
  from datasets import load_dataset
3
  from .tokenizer import Tokenizer
4
+ from .utils import dummy_logger
5
+
6
+ import tiktoken
7
 
8
 
9
  # TODO setup dataset
10
  class Trainset(Dataset):
11
  def __init__(self, batch_size=48):
12
+ self.texts = None
13
+ self.data = []
 
 
 
14
 
15
  def __len__(self):
16
  return len(self.data)
17
+
18
  def __getitem__(self, index):
19
+ assert len(self.data) > 10
20
  return []
21
+
22
+
23
+ def _load_dataset(self,url="zaibutcooler/animanga-vault"):
24
+ loaded_dataset = load_dataset(url)
25
+ self.texts = self.loaded_data["train"]["raw"]
26
+ self.data = self.loaded_data["train"]["data"]
27
+ dummy_logger("Successfully loaded the dataset")
28
+
29
+ def _tokenize(self,tiktoken=True):
30
+ if tiktoken:
31
+ enc = tiktoken.get_encoding("cl100k_base")
32
+ assert enc.decode(enc.encode("hello world")) == "hello world"
33
+
34
+ enc = tiktoken.encoding_for_model("gpt-4")
35
+ self.tokenizer = enc
36
+ else:
37
+ self.tokenizer = Tokenizer()
38
+ self.tokenizer.load_pretrained()
39
+
yume/yume.py CHANGED
@@ -1,9 +1,10 @@
1
  import torch
2
  from torch import nn
3
  import torch.nn.functional as F
 
 
4
  from .config import Config
5
  from .models import GPT
6
- from huggingface_hub import login
7
  from .utils import dummy_logger, training_logger
8
 
9
 
@@ -15,9 +16,29 @@ class Yume:
15
  self.model = GPT(config=config)
16
  self.config = config
17
 
18
- def train(self):
 
 
 
19
  pass
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def save_pretrained(self, name="yume"):
22
  self.model.save_pretrained(name)
23
  self.model.push_to_hub(name)
@@ -30,7 +51,4 @@ class Yume:
30
  def huggingface_login(self, token):
31
  assert token is not None
32
  login(token=token)
33
- dummy_logger("Logged in successfully")
34
-
35
- def generate(self):
36
- pass
 
1
  import torch
2
  from torch import nn
3
  import torch.nn.functional as F
4
+ from huggingface_hub import login
5
+
6
  from .config import Config
7
  from .models import GPT
 
8
  from .utils import dummy_logger, training_logger
9
 
10
 
 
16
  self.model = GPT(config=config)
17
  self.config = config
18
 
19
+ def generate(self):
20
+ pass
21
+
22
+ def sample(self):
23
  pass
24
 
25
+ def pretrain(self,tokens):
26
+ lr = self.config.lr
27
+ num_epochs = self.config.num_epoch
28
+
29
+
30
+ pass
31
+
32
+ def fine_tune(self):
33
+ pass
34
+
35
+ def get_num_params(self, non_embedding=True):
36
+ n_params = sum(p.numel() for p in self.parameters())
37
+ if non_embedding:
38
+ n_params -= self.transformer.wpe.weight.numel()
39
+ dummy_logger(f"parameter count -> {n_params}")
40
+ return n_params
41
+
42
  def save_pretrained(self, name="yume"):
43
  self.model.save_pretrained(name)
44
  self.model.push_to_hub(name)
 
51
  def huggingface_login(self, token):
52
  assert token is not None
53
  login(token=token)
54
+ dummy_logger("Logged in successfully")