sachin commited on
Commit
a8c8fe0
1 Parent(s): 16d5d78

Change where images weres stored

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. src/config.py +2 -2
  3. src/data.py +3 -2
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  .DS_Store
2
  .vscode/
3
  pyrightconfig.json
 
 
1
  .DS_Store
2
  .vscode/
3
  pyrightconfig.json
4
+ *.jpg
src/config.py CHANGED
@@ -4,7 +4,7 @@ import pydantic
4
 
5
  MAX_DOWNLOAD_TIME = 0.2
6
 
7
- IMAGE_DOWNLOAD_PATH = pathlib.Path("/tmp/images")
8
 
9
 
10
  class DataConfig(pydantic.BaseModel):
@@ -30,7 +30,7 @@ class ModelConfig(pydantic.BaseModel):
30
 
31
  class TrainerConfig(pydantic.BaseModel):
32
  epochs: int = 20
33
- batch_size: int = 256
34
  learning_rate: float = 5e-4
35
  accumulate_grad_batches: int = 1
36
  temperature: float = 1.0
 
4
 
5
  MAX_DOWNLOAD_TIME = 0.2
6
 
7
+ IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
8
 
9
 
10
  class DataConfig(pydantic.BaseModel):
 
30
 
31
  class TrainerConfig(pydantic.BaseModel):
32
  epochs: int = 20
33
+ batch_size: int = 64
34
  learning_rate: float = 5e-4
35
  accumulate_grad_batches: int = 1
36
  temperature: float = 1.0
src/data.py CHANGED
@@ -102,7 +102,8 @@ if __name__ == "__main__":
102
  )
103
  train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters)
104
 
 
 
 
105
  for batch in tqdm(train_dl):
106
  continue
107
-
108
- print("hellow")
 
102
  )
103
  train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters)
104
 
105
+ batch = next(iter(train_dl))
106
+ print({k: v.shape for k, v in batch.items()}) # torch.Size([1, 3, 128, 128])
107
+
108
  for batch in tqdm(train_dl):
109
  continue