Spaces:
Runtime error
Runtime error
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import sys | |
import tarfile | |
import requests | |
from tqdm import tqdm | |
from ppocr.utils.logging import get_logger | |
MODELS_DIR = os.path.expanduser("~/.paddleocr/models/") | |
def download_with_progressbar(url, save_path): | |
logger = get_logger() | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
total_size_in_bytes = int(response.headers.get('content-length', 1)) | |
block_size = 1024 # 1 Kibibyte | |
progress_bar = tqdm( | |
total=total_size_in_bytes, unit='iB', unit_scale=True) | |
with open(save_path, 'wb') as file: | |
for data in response.iter_content(block_size): | |
progress_bar.update(len(data)) | |
file.write(data) | |
progress_bar.close() | |
else: | |
logger.error("Something went wrong while downloading models") | |
sys.exit(0) | |
def maybe_download(model_storage_directory, url): | |
# using custom model | |
tar_file_name_list = ['.pdiparams', '.pdiparams.info', '.pdmodel'] | |
if not os.path.exists( | |
os.path.join(model_storage_directory, 'inference.pdiparams') | |
) or not os.path.exists( | |
os.path.join(model_storage_directory, 'inference.pdmodel')): | |
assert url.endswith('.tar'), 'Only supports tar compressed package' | |
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) | |
print('download {} to {}'.format(url, tmp_path)) | |
os.makedirs(model_storage_directory, exist_ok=True) | |
download_with_progressbar(url, tmp_path) | |
with tarfile.open(tmp_path, 'r') as tarObj: | |
for member in tarObj.getmembers(): | |
filename = None | |
for tar_file_name in tar_file_name_list: | |
if member.name.endswith(tar_file_name): | |
filename = 'inference' + tar_file_name | |
if filename is None: | |
continue | |
file = tarObj.extractfile(member) | |
with open( | |
os.path.join(model_storage_directory, filename), | |
'wb') as f: | |
f.write(file.read()) | |
os.remove(tmp_path) | |
def maybe_download_params(model_path): | |
if os.path.exists(model_path) or not is_link(model_path): | |
return model_path | |
else: | |
url = model_path | |
tmp_path = os.path.join(MODELS_DIR, url.split('/')[-1]) | |
print('download {} to {}'.format(url, tmp_path)) | |
os.makedirs(MODELS_DIR, exist_ok=True) | |
download_with_progressbar(url, tmp_path) | |
return tmp_path | |
def is_link(s): | |
return s is not None and s.startswith('http') | |
def confirm_model_dir_url(model_dir, default_model_dir, default_url): | |
url = default_url | |
if model_dir is None or is_link(model_dir): | |
if is_link(model_dir): | |
url = model_dir | |
file_name = url.split('/')[-1][:-4] | |
model_dir = default_model_dir | |
model_dir = os.path.join(model_dir, file_name) | |
return model_dir, url | |