JosephCatrambone commited on
Commit
6b8f356
1 Parent(s): 1cf8f2a

Changing a few parameters and training for much longer. Should have better outputs now.

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -33
  2. main.py +42 -6
  3. model.pth +2 -2
.gitattributes CHANGED
@@ -1,34 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tflite filter=lfs diff=lfs merge=lfs -text
29
- *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
- *.xz filter=lfs diff=lfs merge=lfs -text
32
- *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -10,16 +10,39 @@ import data
10
  from model import ChessModel
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def train():
 
 
 
 
 
14
  device_string = "cuda" if torch.cuda.is_available() else "cpu"
15
  device = torch.device(device_string)
16
- model = ChessModel(256).to(torch.float32).to(device)
17
- opt = torch.optim.Adam(model.parameters())
18
  reconstruction_loss_fn = nn.CrossEntropyLoss().to(torch.float32).to(device)
19
  popularity_loss_fn = nn.L1Loss().to(torch.float32).to(device)
20
  evaluation_loss_fn = nn.L1Loss().to(torch.float32).to(device)
21
- data_loader = DataLoader(data.LichessPuzzleDataset(cap_data=65536), batch_size=64, num_workers=1) # 1 to avoid threading madness.
22
- num_epochs = 100
 
23
 
24
  for epoch in range(num_epochs):
25
  model.train()
@@ -38,7 +61,8 @@ def train():
38
  reconstruction_loss = reconstruction_loss_fn(predicted_board_vec, board_vec)
39
  popularity_loss = popularity_loss_fn(predicted_popularity, popularity)
40
  evaluation_loss = evaluation_loss_fn(predicted_evaluation, evaluation)
41
- total_loss = reconstruction_loss + popularity_loss + evaluation_loss
 
42
 
43
  opt.zero_grad()
44
  total_loss.backward()
@@ -54,7 +78,19 @@ def train():
54
  print(f"Average evaluation loss: {total_evaluation_loss/num_batches}")
55
  print(f"Average batch loss: {total_batch_loss/num_batches}")
56
 
57
- torch.save(model, f"checkpoints/epoch_{epoch}.pth")
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  def infer(fen):
 
10
  from model import ChessModel
11
 
12
 
13
+ # Experiment parameters:
14
+ RUN_CONFIGURATION = {
15
+ "learning_rate": 0.0004,
16
+ "dataset_cap": 100000,
17
+ "epochs": 1000,
18
+ "latent_size": 256,
19
+ }
20
+
21
+ # Logging:
22
+ wandb = None
23
+ try:
24
+ import wandb
25
+ wandb.init("assembly_ai_hackathon_2022", config=RUN_CONFIGURATION)
26
+ except ImportError:
27
+ print("Weights and Biases not found in packages.")
28
+
29
+
30
  def train():
31
+ learning_rate = RUN_CONFIGURATION["learning_rate"]
32
+ latent_size = RUN_CONFIGURATION["latent_size"]
33
+ data_cap = RUN_CONFIGURATION["dataset_cap"]
34
+ num_epochs = RUN_CONFIGURATION["epochs"]
35
+
36
  device_string = "cuda" if torch.cuda.is_available() else "cpu"
37
  device = torch.device(device_string)
38
+ model = ChessModel(latent_size).to(torch.float32).to(device)
39
+ opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
40
  reconstruction_loss_fn = nn.CrossEntropyLoss().to(torch.float32).to(device)
41
  popularity_loss_fn = nn.L1Loss().to(torch.float32).to(device)
42
  evaluation_loss_fn = nn.L1Loss().to(torch.float32).to(device)
43
+ data_loader = DataLoader(data.LichessPuzzleDataset(cap_data=data_cap), batch_size=64, num_workers=1) # 1 to avoid threading madness.
44
+ save_every_nth_epoch = 50
45
+ upload_logs_every_nth_epoch = 1
46
 
47
  for epoch in range(num_epochs):
48
  model.train()
 
61
  reconstruction_loss = reconstruction_loss_fn(predicted_board_vec, board_vec)
62
  popularity_loss = popularity_loss_fn(predicted_popularity, popularity)
63
  evaluation_loss = evaluation_loss_fn(predicted_evaluation, evaluation)
64
+ #total_loss = reconstruction_loss + popularity_loss + evaluation_loss
65
+ total_loss = popularity_loss
66
 
67
  opt.zero_grad()
68
  total_loss.backward()
 
78
  print(f"Average evaluation loss: {total_evaluation_loss/num_batches}")
79
  print(f"Average batch loss: {total_batch_loss/num_batches}")
80
 
81
+ if save_every_nth_epoch > 0 and (epoch % save_every_nth_epoch) == 0:
82
+ torch.save(model, f"checkpoints/epoch_{epoch}.pth")
83
+
84
+ if wandb:
85
+ wandb.log(
86
+ # For now, just log popularity.
87
+ {"popularity_loss": total_popularity_loss},
88
+ commit=(epoch+1) % upload_logs_every_nth_epoch == 0
89
+ )
90
+
91
+ torch.save(model, "checkpoints/final.pth")
92
+ if wandb:
93
+ wandb.finish()
94
 
95
 
96
  def infer(fen):
model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:903a3dc9af8a83bc128b0e6581693a0cf8e74dd2127eb669704420463115e18a
3
- size 15268009
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed7dc6d33fb3ac545f78b7be413b0bebd565fcca89e6662ed617a2640d99715b
3
+ size 12118255