frutiemax commited on
Commit
a60f6bb
·
1 Parent(s): 9bde8da

Normalize images

Browse files
Files changed (1) hide show
  1. train_model.py +8 -1
train_model.py CHANGED
@@ -31,8 +31,9 @@ def save_and_test(pipeline, epoch):
31
  pipeline.save_pretrained(model_file)
32
 
33
  def transform_images(image):
34
- res = torch.Tensor(SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)
35
  pil_to_tensor = T.PILToTensor()
 
36
 
37
  res_index = 0
38
  scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
@@ -40,7 +41,13 @@ def transform_images(image):
40
 
41
  new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
42
  new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
 
 
 
 
43
  res = pil_to_tensor(new_image)
 
 
44
  return res
45
 
46
  def convert_images(dataset):
 
31
  pipeline.save_pretrained(model_file)
32
 
33
  def transform_images(image):
34
+ res = torch.Tensor((SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE))
35
  pil_to_tensor = T.PILToTensor()
36
+ tensor_to_pil = T.ToPILImage()
37
 
38
  res_index = 0
39
  scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
 
41
 
42
  new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
43
  new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
44
+
45
+ #data = np.array(new_image, dtype=np.float32)
46
+ #data = (data / 128.0 - 1.0)
47
+ #res = torch.from_numpy(data)
48
  res = pil_to_tensor(new_image)
49
+ res.to(dtype=torch.float32)
50
+ res = res / torch.Tensor([128.0]) - torch.Tensor([1.0])
51
  return res
52
 
53
  def convert_images(dataset):