Spaces:
Running
Running
import concurrent.futures | |
import io | |
import os | |
import time | |
import oss2 | |
import requests | |
from PIL import Image | |
from .log import logger | |
# oss | |
access_key_id = os.getenv("ACCESS_KEY_ID") | |
access_key_secret = os.getenv("ACCESS_KEY_SECRET") | |
bucket_name = os.getenv("BUCKET_NAME") | |
endpoint = os.getenv("ENDPOINT") | |
bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name) | |
oss_path = os.getenv("OSS_PATH") | |
def resize(image, short_side_length=512): | |
width, height = image.size | |
ratio = short_side_length / min(width, height) | |
new_width = int(width * ratio) | |
new_height = int(height * ratio) | |
resized_image = image.resize((new_width, new_height)) | |
return resized_image | |
def download_img_pil(index, img_url): | |
r = requests.get(img_url, stream=True) | |
if r.status_code == 200: | |
img = Image.open(io.BytesIO(r.content)) | |
return (index, img) | |
else: | |
logger.error(f"Fail to download: {img_url}") | |
def download_images(img_urls, batch_size): | |
imgs_pil = [None] * batch_size | |
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: | |
to_do = [] | |
for i, url in enumerate(img_urls): | |
future = executor.submit(download_img_pil, i, url) | |
to_do.append(future) | |
for future in concurrent.futures.as_completed(to_do): | |
ret = future.result() | |
index, img_pil = ret | |
imgs_pil[index] = img_pil | |
return imgs_pil | |
def upload_np_2_oss(input_image, name="cache.jpg"): | |
assert name.lower().endswith((".png", ".jpg")), name | |
if name.endswith(".png"): | |
name = name[:-4] + ".jpg" | |
imgByteArr = io.BytesIO() | |
if name.lower().endswith(".png"): | |
Image.fromarray(input_image).save(imgByteArr, format="PNG") | |
else: | |
Image.fromarray(input_image).save(imgByteArr, format="JPEG", quality=95) | |
imgByteArr = imgByteArr.getvalue() | |
start_time = time.perf_counter() | |
bucket.put_object(oss_path + "/" + name, imgByteArr) | |
ret = bucket.sign_url('GET', oss_path + "/" + name, 60 * 60 * 24) | |
logger.info(f"upload cost: {time.perf_counter() - start_time} s.") | |
del imgByteArr | |
return ret | |