Spaces:
Paused
Paused
# | |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import base64 | |
import datetime | |
import io | |
import json | |
import os | |
import pickle | |
import socket | |
import time | |
import uuid | |
import requests | |
from enum import Enum, IntEnum | |
import importlib | |
from Cryptodome.PublicKey import RSA | |
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 | |
from filelock import FileLock | |
from . import file_utils | |
SERVICE_CONF = "service_conf.yaml" | |
def conf_realpath(conf_name): | |
conf_path = f"conf/{conf_name}" | |
return os.path.join(file_utils.get_project_base_directory(), conf_path) | |
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: | |
local_config = {} | |
local_path = conf_realpath(f'local.{conf_name}') | |
if default is None: | |
default = os.environ.get(key.upper()) | |
if os.path.exists(local_path): | |
local_config = file_utils.load_yaml_conf(local_path) | |
if not isinstance(local_config, dict): | |
raise ValueError(f'Invalid config file: "{local_path}".') | |
if key is not None and key in local_config: | |
return local_config[key] | |
config_path = conf_realpath(conf_name) | |
config = file_utils.load_yaml_conf(config_path) | |
if not isinstance(config, dict): | |
raise ValueError(f'Invalid config file: "{config_path}".') | |
config.update(local_config) | |
return config.get(key, default) if key is not None else config | |
use_deserialize_safe_module = get_base_config( | |
'use_deserialize_safe_module', False) | |
class CoordinationCommunicationProtocol(object): | |
HTTP = "http" | |
GRPC = "grpc" | |
class BaseType: | |
def to_dict(self): | |
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()]) | |
def to_dict_with_type(self): | |
def _dict(obj): | |
module = None | |
if issubclass(obj.__class__, BaseType): | |
data = {} | |
for attr, v in obj.__dict__.items(): | |
k = attr.lstrip("_") | |
data[k] = _dict(v) | |
module = obj.__module__ | |
elif isinstance(obj, (list, tuple)): | |
data = [] | |
for i, vv in enumerate(obj): | |
data.append(_dict(vv)) | |
elif isinstance(obj, dict): | |
data = {} | |
for _k, vv in obj.items(): | |
data[_k] = _dict(vv) | |
else: | |
data = obj | |
return {"type": obj.__class__.__name__, | |
"data": data, "module": module} | |
return _dict(self) | |
class CustomJSONEncoder(json.JSONEncoder): | |
def __init__(self, **kwargs): | |
self._with_type = kwargs.pop("with_type", False) | |
super().__init__(**kwargs) | |
def default(self, obj): | |
if isinstance(obj, datetime.datetime): | |
return obj.strftime('%Y-%m-%d %H:%M:%S') | |
elif isinstance(obj, datetime.date): | |
return obj.strftime('%Y-%m-%d') | |
elif isinstance(obj, datetime.timedelta): | |
return str(obj) | |
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): | |
return obj.value | |
elif isinstance(obj, set): | |
return list(obj) | |
elif issubclass(type(obj), BaseType): | |
if not self._with_type: | |
return obj.to_dict() | |
else: | |
return obj.to_dict_with_type() | |
elif isinstance(obj, type): | |
return obj.__name__ | |
else: | |
return json.JSONEncoder.default(self, obj) | |
def rag_uuid(): | |
return uuid.uuid1().hex | |
def string_to_bytes(string): | |
return string if isinstance( | |
string, bytes) else string.encode(encoding="utf-8") | |
def bytes_to_string(byte): | |
return byte.decode(encoding="utf-8") | |
def json_dumps(src, byte=False, indent=None, with_type=False): | |
dest = json.dumps( | |
src, | |
indent=indent, | |
cls=CustomJSONEncoder, | |
with_type=with_type) | |
if byte: | |
dest = string_to_bytes(dest) | |
return dest | |
def json_loads(src, object_hook=None, object_pairs_hook=None): | |
if isinstance(src, bytes): | |
src = bytes_to_string(src) | |
return json.loads(src, object_hook=object_hook, | |
object_pairs_hook=object_pairs_hook) | |
def current_timestamp(): | |
return int(time.time() * 1000) | |
def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"): | |
if not timestamp: | |
timestamp = time.time() | |
timestamp = int(timestamp) / 1000 | |
time_array = time.localtime(timestamp) | |
str_date = time.strftime(format_string, time_array) | |
return str_date | |
def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"): | |
time_array = time.strptime(time_str, format_string) | |
time_stamp = int(time.mktime(time_array) * 1000) | |
return time_stamp | |
def serialize_b64(src, to_str=False): | |
dest = base64.b64encode(pickle.dumps(src)) | |
if not to_str: | |
return dest | |
else: | |
return bytes_to_string(dest) | |
def deserialize_b64(src): | |
src = base64.b64decode( | |
string_to_bytes(src) if isinstance( | |
src, str) else src) | |
if use_deserialize_safe_module: | |
return restricted_loads(src) | |
return pickle.loads(src) | |
safe_module = { | |
'numpy', | |
'rag_flow' | |
} | |
class RestrictedUnpickler(pickle.Unpickler): | |
def find_class(self, module, name): | |
import importlib | |
if module.split('.')[0] in safe_module: | |
_module = importlib.import_module(module) | |
return getattr(_module, name) | |
# Forbid everything else. | |
raise pickle.UnpicklingError("global '%s.%s' is forbidden" % | |
(module, name)) | |
def restricted_loads(src): | |
"""Helper function analogous to pickle.loads().""" | |
return RestrictedUnpickler(io.BytesIO(src)).load() | |
def get_lan_ip(): | |
if os.name != "nt": | |
import fcntl | |
import struct | |
def get_interface_ip(ifname): | |
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
return socket.inet_ntoa( | |
fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24]) | |
ip = socket.gethostbyname(socket.getfqdn()) | |
if ip.startswith("127.") and os.name != "nt": | |
interfaces = [ | |
"bond1", | |
"eth0", | |
"eth1", | |
"eth2", | |
"wlan0", | |
"wlan1", | |
"wifi0", | |
"ath0", | |
"ath1", | |
"ppp0", | |
] | |
for ifname in interfaces: | |
try: | |
ip = get_interface_ip(ifname) | |
break | |
except IOError as e: | |
pass | |
return ip or '' | |
def from_dict_hook(in_dict: dict): | |
if "type" in in_dict and "data" in in_dict: | |
if in_dict["module"] is None: | |
return in_dict["data"] | |
else: | |
return getattr(importlib.import_module( | |
in_dict["module"]), in_dict["type"])(**in_dict["data"]) | |
else: | |
return in_dict | |
def decrypt_database_password(password): | |
encrypt_password = get_base_config("encrypt_password", False) | |
encrypt_module = get_base_config("encrypt_module", False) | |
private_key = get_base_config("private_key", None) | |
if not password or not encrypt_password: | |
return password | |
if not private_key: | |
raise ValueError("No private key") | |
module_fun = encrypt_module.split("#") | |
pwdecrypt_fun = getattr( | |
importlib.import_module( | |
module_fun[0]), | |
module_fun[1]) | |
return pwdecrypt_fun(private_key, password) | |
def decrypt_database_config( | |
database=None, passwd_key="password", name="database"): | |
if not database: | |
database = get_base_config(name, {}) | |
database[passwd_key] = decrypt_database_password(database[passwd_key]) | |
return database | |
def update_config(key, value, conf_name=SERVICE_CONF): | |
conf_path = conf_realpath(conf_name=conf_name) | |
if not os.path.isabs(conf_path): | |
conf_path = os.path.join( | |
file_utils.get_project_base_directory(), conf_path) | |
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): | |
config = file_utils.load_yaml_conf(conf_path=conf_path) or {} | |
config[key] = value | |
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) | |
def get_uuid(): | |
return uuid.uuid1().hex | |
def datetime_format(date_time: datetime.datetime) -> datetime.datetime: | |
return datetime.datetime(date_time.year, date_time.month, date_time.day, | |
date_time.hour, date_time.minute, date_time.second) | |
def get_format_time() -> datetime.datetime: | |
return datetime_format(datetime.datetime.now()) | |
def str2date(date_time: str): | |
return datetime.datetime.strptime(date_time, '%Y-%m-%d') | |
def elapsed2time(elapsed): | |
seconds = elapsed / 1000 | |
minuter, second = divmod(seconds, 60) | |
hour, minuter = divmod(minuter, 60) | |
return '%02d:%02d:%02d' % (hour, minuter, second) | |
def decrypt(line): | |
file_path = os.path.join( | |
file_utils.get_project_base_directory(), | |
"conf", | |
"private.pem") | |
rsa_key = RSA.importKey(open(file_path).read(), "Welcome") | |
cipher = Cipher_pkcs1_v1_5.new(rsa_key) | |
return cipher.decrypt(base64.b64decode( | |
line), "Fail to decrypt password!").decode('utf-8') | |
def download_img(url): | |
if not url: | |
return "" | |
response = requests.get(url) | |
return "data:" + \ | |
response.headers.get('Content-Type', 'image/jpg') + ";" + \ | |
"base64," + base64.b64encode(response.content).decode("utf-8") | |