Junjie96 commited on
Commit
a72c7d9
·
verified ·
1 Parent(s): 2538f2a

Update src/util.py

Browse files
Files changed (1) hide show
  1. src/util.py +22 -3
src/util.py CHANGED
@@ -1,7 +1,9 @@
1
  import concurrent.futures
2
  import io
3
  import os
 
4
 
 
5
  import oss2
6
  import requests
7
  from PIL import Image
@@ -18,6 +20,19 @@ bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, buck
18
  oss_path = os.getenv("OSS_PATH")
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def download_img_pil(index, img_url):
22
  r = requests.get(img_url, stream=True)
23
  if r.status_code == 200:
@@ -43,16 +58,20 @@ def download_images(img_urls, batch_size):
43
  return imgs_pil
44
 
45
 
46
- def upload_np_2_oss(input_image, name="cache.png"):
47
  assert name.lower().endswith((".png", ".jpg")), name
 
 
48
  imgByteArr = io.BytesIO()
49
  if name.lower().endswith(".png"):
50
- Image.fromarray(input_image).save(imgByteArr, format="PNG")
51
  else:
52
- Image.fromarray(input_image).save(imgByteArr, format="JPEG", quality=95)
53
  imgByteArr = imgByteArr.getvalue()
54
 
 
55
  bucket.put_object(oss_path + "/" + name, imgByteArr)
56
  ret = bucket.sign_url('GET', oss_path + "/" + name, 60 * 60 * 24)
 
57
  del imgByteArr
58
  return ret
 
1
  import concurrent.futures
2
  import io
3
  import os
4
+ import time
5
 
6
+ import cv2
7
  import oss2
8
  import requests
9
  from PIL import Image
 
20
  oss_path = os.getenv("OSS_PATH")
21
 
22
 
23
+ def resize(img, short_side_length=512):
24
+ height, width, _ = img.shape
25
+ aspect_ratio = width / height
26
+ if width > height:
27
+ new_width = short_side_length
28
+ new_height = int(new_width / aspect_ratio)
29
+ else:
30
+ new_height = short_side_length
31
+ new_width = int(new_height * aspect_ratio)
32
+ resized_img = cv2.resize(img, (new_width, new_height))
33
+ return resized_img
34
+
35
+
36
  def download_img_pil(index, img_url):
37
  r = requests.get(img_url, stream=True)
38
  if r.status_code == 200:
 
58
  return imgs_pil
59
 
60
 
61
+ def upload_np_2_oss(input_image, name="cache.jpg"):
62
  assert name.lower().endswith((".png", ".jpg")), name
63
+ if name.lower().endswith(".png"):
64
+ name = name[:-4] + ".jpg"
65
  imgByteArr = io.BytesIO()
66
  if name.lower().endswith(".png"):
67
+ Image.fromarray(resize(input_image)).save(imgByteArr, format="PNG")
68
  else:
69
+ Image.fromarray(resize(input_image)).save(imgByteArr, format="JPEG", quality=95)
70
  imgByteArr = imgByteArr.getvalue()
71
 
72
+ start_time = time.perf_counter()
73
  bucket.put_object(oss_path + "/" + name, imgByteArr)
74
  ret = bucket.sign_url('GET', oss_path + "/" + name, 60 * 60 * 24)
75
+ logger.info(f"upload cost: {time.perf_counter() - start_time} s.")
76
  del imgByteArr
77
  return ret