File size: 3,874 Bytes
02f6666
 
 
 
 
 
 
 
 
 
 
 
6b8f356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02f6666
6b8f356
 
 
 
 
02f6666
 
6b8f356
 
02f6666
 
 
6b8f356
 
 
02f6666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b8f356
 
02f6666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b8f356
 
 
 
 
 
 
 
 
 
 
 
 
02f6666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import sys

import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import data
from model import ChessModel


# Experiment parameters:
RUN_CONFIGURATION = {
    "learning_rate": 0.0004,
    "dataset_cap": 100000,
    "epochs": 1000,
    "latent_size": 256,
}

# Logging:
wandb = None
try:
    import wandb
    wandb.init("assembly_ai_hackathon_2022", config=RUN_CONFIGURATION)
except ImportError:
    print("Weights and Biases not found in packages.")


def train():
    learning_rate = RUN_CONFIGURATION["learning_rate"]
    latent_size = RUN_CONFIGURATION["latent_size"]
    data_cap = RUN_CONFIGURATION["dataset_cap"]
    num_epochs = RUN_CONFIGURATION["epochs"]

    device_string = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_string)
    model = ChessModel(latent_size).to(torch.float32).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
    reconstruction_loss_fn = nn.CrossEntropyLoss().to(torch.float32).to(device)
    popularity_loss_fn = nn.L1Loss().to(torch.float32).to(device)
    evaluation_loss_fn = nn.L1Loss().to(torch.float32).to(device)
    data_loader = DataLoader(data.LichessPuzzleDataset(cap_data=data_cap), batch_size=64, num_workers=1)  # 1 to avoid threading madness.
    save_every_nth_epoch = 50
    upload_logs_every_nth_epoch = 1

    for epoch in range(num_epochs):
        model.train()
        total_reconstruction_loss = 0.0
        total_popularity_loss = 0.0
        total_evaluation_loss = 0.0
        total_batch_loss = 0.0
        num_batches = 0
        for batch_idx, (board_vec, popularity, evaluation) in tqdm(enumerate(data_loader)):
            board_vec = board_vec.to(torch.float32).to(device)  # [batch_size x 903]
            popularity = popularity.to(torch.float32).to(device).unsqueeze(1)  # enforce [batch_size, 1]
            evaluation = evaluation.to(torch.float32).to(device).unsqueeze(1)

            _embedding, predicted_popularity, predicted_evaluation, predicted_board_vec = model(board_vec)

            reconstruction_loss = reconstruction_loss_fn(predicted_board_vec, board_vec)
            popularity_loss = popularity_loss_fn(predicted_popularity, popularity)
            evaluation_loss = evaluation_loss_fn(predicted_evaluation, evaluation)
            #total_loss = reconstruction_loss + popularity_loss + evaluation_loss
            total_loss = popularity_loss

            opt.zero_grad()
            total_loss.backward()
            opt.step()

            total_reconstruction_loss += reconstruction_loss.cpu().item()
            total_popularity_loss += popularity_loss.cpu().item()
            total_evaluation_loss += evaluation_loss.cpu().item()
            total_batch_loss += total_loss.cpu().item()
            num_batches += 1
        print(f"Average reconstruction loss: {total_reconstruction_loss/num_batches}")
        print(f"Average popularity loss: {total_popularity_loss/num_batches}")
        print(f"Average evaluation loss: {total_evaluation_loss/num_batches}")
        print(f"Average batch loss: {total_batch_loss/num_batches}")

        if save_every_nth_epoch > 0 and (epoch % save_every_nth_epoch) == 0:
            torch.save(model, f"checkpoints/epoch_{epoch}.pth")

        if wandb:
            wandb.log(
                # For now, just log popularity.
                {"popularity_loss": total_popularity_loss},
                commit=(epoch+1) % upload_logs_every_nth_epoch == 0
            )

    torch.save(model, "checkpoints/final.pth")
    if wandb:
        wandb.finish()


def infer(fen):
    pass


def test():
    pass


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python {sys.argv[0]} --train|infer")
    elif sys.argv[1] == "--train":
        train()
    elif sys.argv[2] == "--infer":
        infer(sys.argv[3])