File size: 3,409 Bytes
b865169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np


class Oracle:
    oracle_d = {'exp': np.exp, 'sqrt': np.sqrt, 'pi': np.pi, 'cos': np.cos, 'sin': np.sin, 'tan': np.tan,
                'tanh': np.tanh, 'ln': np.log, 'arcsin': np.arcsin}

    def __init__(self, nvariables, f=None, form=None, variable_names=None, range_restriction={}, id=None):
        """
        nvariables: is the number of variables the function takes in
        f: takes in an X of shape (n, nvariables) and returns f(X) of shape (n,)
        form: String Def of the function
        variable_names: variable names used in form
        Range_restrictions: Dictionary of form {variable_index: (low, high)}
        """
        self.nvariables = nvariables
        if f is None and form is None:
            raise ValueError("f and form are both none in Oracle initialization. Specify at least one")
        if f is not None and form is not None:
            raise ValueError("f and form are both not none, pick only one")
        if form is not None and variable_names is None:
            raise ValueError("If form is provided then variable_names must also be provided")
        if form is not None:
            self.form = form
            self.variable_names = variable_names
            self.use_func = False
            self.d = Oracle.oracle_d.copy()
            for var_name in variable_names:
                self.d[var_name] = None
        else:
            # f is not None
            self.func = f
            self.use_func = True

        self.ranges = []
        for i in range(nvariables):
            if i in range_restriction:
                self.ranges.append(range_restriction[i])
            else:
                self.ranges.append(None)

        if id is not None:
            self.id = id
        return

    def f(self, X):
        """
        X is of shape (n, nvariables)
        """
        if self.invalid_input(X):
            raise ValueError("Invalid input to Oracle")
        if self.use_func:
            return self.func(X)
        else:
            return self.form_f(X)

    def form_f(self, X):
        """
        Returns the function output using form
        """
        for i, var in enumerate(self.variable_names):
            self.d[var] = X[:, i]
        return eval(self.form, self.d)

    def invalid_input(self, X):
        """
        Returns true if any of the following are true
            X has more or less variables than nvariables
            X has a value in a restricted range variable outside said range
        """
        if X.shape[1] != self.nvariables:
            return True
        for i, r in enumerate(self.ranges):
            if r is None:
                continue
            else:
                low = r[0]
                high = r[1]
                low_check = all(low <= X[:, i])
                high_check = all(X[:, i] <= high)
                if not low_check or not high_check:
                    return True

    def __str__(self):
        if self.id:
            return str(self.id)
        elif self.form:
            return str(self.form)
        else:
            return "<Un named Oracle>"

    def from_problem(problem):
        """
        Static function to return an oracle when given an instance of class problem.
        """
        return Oracle(nvariables=problem.n_vars, f=None, form=problem.form, variable_names=problem.var_names,
                      id=problem.eq_id)