Spaces:
Runtime error
Runtime error
from typing import Any, Dict, List, Union | |
import numpy as np | |
import logging | |
import os | |
# Create a custom logger | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
class RandomSearch: | |
def random_choice(args: List[Any], n: int = 1): | |
""" | |
pick a random element from a set. | |
Example: | |
>> sampler = RandomSearch.random_choice(1,2,3) | |
>> sampler() | |
2 | |
""" | |
choices = [] | |
for arg in args: | |
choices.append(arg) | |
if n == 1: | |
return lambda: np.random.choice(choices, replace=False) | |
else: | |
return lambda: np.random.choice(choices, n, replace=False) | |
def random_integer(low: Union[int, float], high: Union[int, float]): | |
""" | |
pick a random integer between two bounds | |
Example: | |
>> sampler = RandomSearch.random_integer(1, 10) | |
>> sampler() | |
9 | |
""" | |
return lambda: int(np.random.randint(low, high)) | |
def random_loguniform(low: Union[float, int], high: Union[float, int]): | |
""" | |
pick a random float between two bounds, using loguniform distribution | |
Example: | |
>> sampler = RandomSearch.random_loguniform(1e-5, 1e-2) | |
>> sampler() | |
0.0004 | |
""" | |
return lambda: np.exp(np.random.uniform(np.log(low), np.log(high))) | |
def random_uniform(low: Union[float, int], high: Union[float, int]): | |
""" | |
pick a random float between two bounds, using uniform distribution | |
Example: | |
>> sampler = RandomSearch.random_uniform(0, 1) | |
>> sampler() | |
0.01 | |
""" | |
return lambda: np.random.uniform(low, high) | |
class HyperparameterSearch: | |
def __init__(self, **kwargs): | |
self.search_space = {} | |
self.lambda_ = lambda: 0 | |
for key, val in kwargs.items(): | |
self.search_space[key] = val | |
def parse(self, val: Any): | |
if isinstance(val, (int, np.int)): | |
return int(val) | |
elif isinstance(val, (float, np.float)): | |
return val | |
elif isinstance(val, (np.ndarray, list)): | |
return " ".join(val) | |
elif val is None: | |
return None | |
if isinstance(val, str): | |
return val | |
else: | |
val = val() | |
if isinstance(val, (int, np.int)): | |
return int(val) | |
elif isinstance(val, (np.ndarray, list)): | |
return " ".join(val) | |
else: | |
return val | |
def sample(self) -> Dict: | |
res = {} | |
for key, val in self.search_space.items(): | |
try: | |
res[key] = self.parse(val) | |
except (TypeError, ValueError) as error: | |
logger.error(f"Could not parse key {key} with value {val}. {error}") | |
return res | |
def update_environment(self, sample) -> None: | |
for key, val in sample.items(): | |
os.environ[key] = str(val) | |
SEARCH_SPACE = { | |
"penalty": RandomSearch.random_choice(["l1", "l2"]), | |
"C": RandomSearch.random_uniform(0, 1), | |
"solver": "liblinear", | |
"multi_class": "auto", | |
"tol": RandomSearch.random_loguniform(10e-5, 10e-3), | |
"stopwords": RandomSearch.random_choice([0, 1]), | |
"weight": RandomSearch.random_choice(["hash"]), | |
"ngram_range": RandomSearch.random_choice(["1 2", "2 3", "1 3"]), | |
"random_state": RandomSearch.random_integer(0, 100000) | |
} | |