Spaces:
Running
Running
File size: 3,409 Bytes
b865169 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import numpy as np
class Oracle:
oracle_d = {'exp': np.exp, 'sqrt': np.sqrt, 'pi': np.pi, 'cos': np.cos, 'sin': np.sin, 'tan': np.tan,
'tanh': np.tanh, 'ln': np.log, 'arcsin': np.arcsin}
def __init__(self, nvariables, f=None, form=None, variable_names=None, range_restriction={}, id=None):
"""
nvariables: is the number of variables the function takes in
f: takes in an X of shape (n, nvariables) and returns f(X) of shape (n,)
form: String Def of the function
variable_names: variable names used in form
Range_restrictions: Dictionary of form {variable_index: (low, high)}
"""
self.nvariables = nvariables
if f is None and form is None:
raise ValueError("f and form are both none in Oracle initialization. Specify at least one")
if f is not None and form is not None:
raise ValueError("f and form are both not none, pick only one")
if form is not None and variable_names is None:
raise ValueError("If form is provided then variable_names must also be provided")
if form is not None:
self.form = form
self.variable_names = variable_names
self.use_func = False
self.d = Oracle.oracle_d.copy()
for var_name in variable_names:
self.d[var_name] = None
else:
# f is not None
self.func = f
self.use_func = True
self.ranges = []
for i in range(nvariables):
if i in range_restriction:
self.ranges.append(range_restriction[i])
else:
self.ranges.append(None)
if id is not None:
self.id = id
return
def f(self, X):
"""
X is of shape (n, nvariables)
"""
if self.invalid_input(X):
raise ValueError("Invalid input to Oracle")
if self.use_func:
return self.func(X)
else:
return self.form_f(X)
def form_f(self, X):
"""
Returns the function output using form
"""
for i, var in enumerate(self.variable_names):
self.d[var] = X[:, i]
return eval(self.form, self.d)
def invalid_input(self, X):
"""
Returns true if any of the following are true
X has more or less variables than nvariables
X has a value in a restricted range variable outside said range
"""
if X.shape[1] != self.nvariables:
return True
for i, r in enumerate(self.ranges):
if r is None:
continue
else:
low = r[0]
high = r[1]
low_check = all(low <= X[:, i])
high_check = all(X[:, i] <= high)
if not low_check or not high_check:
return True
def __str__(self):
if self.id:
return str(self.id)
elif self.form:
return str(self.form)
else:
return "<Un named Oracle>"
def from_problem(problem):
"""
Static function to return an oracle when given an instance of class problem.
"""
return Oracle(nvariables=problem.n_vars, f=None, form=problem.form, variable_names=problem.var_names,
id=problem.eq_id) |