Coverage for starry/_core/ops/exceptions.py : 52%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2import numpy as np
3from theano import gof
4import theano.tensor as tt
6__all__ = ["CheckBoundsOp", "RaiseValueErrorOp", "RaiseValueErrorIfOp"]
9class CheckBoundsOp(tt.Op):
10 """
12 """
14 def __init__(self, lower=-np.inf, upper=np.inf, name=None):
15 self.lower = lower
16 self.upper = upper
17 if name is None:
18 self.name = "parameter"
19 else:
20 self.name = name
22 def make_node(self, *inputs):
23 inputs = [tt.as_tensor_variable(inputs[0])]
24 outputs = [inputs[0].type()]
25 return gof.Apply(self, inputs, outputs)
27 def infer_shape(self, node, shapes):
28 return [shapes[0]]
30 def perform(self, node, inputs, outputs):
31 outputs[0][0] = inputs[0]
32 if np.any((inputs[0] < self.lower) | (inputs[0] > self.upper)):
33 low = np.where((inputs[0] < self.lower))[0]
34 high = np.where((inputs[0] > self.upper))[0]
35 if len(low):
36 value = inputs[0][low[0]]
37 sign = "<"
38 bound = self.lower
39 else:
40 value = inputs[0][high[0]]
41 sign = ">"
42 bound = self.upper
43 raise ValueError(
44 "%s out of bounds: %f %s %f" % (self.name, value, sign, bound)
45 )
48class RaiseValueErrorIfOp(tt.Op):
49 """
51 """
53 def __init__(self, message=None):
54 self.message = message
56 def make_node(self, *inputs):
57 condition = inputs
58 inputs = [tt.as_tensor_variable(condition)]
59 outputs = [tt.TensorType(tt.config.floatX, ())()]
60 return gof.Apply(self, inputs, outputs)
62 def infer_shape(self, node, shapes):
63 return [()]
65 def perform(self, node, inputs, outputs):
66 outputs[0][0] = np.array(0.0)
67 if inputs[0]:
68 raise ValueError(self.message)
70 def grad(self, inputs, gradients):
71 return [inputs[0] * 0.0]
74def RaiseValueErrorOp(msg, shape=()):
75 return tt.zeros(shape) * RaiseValueErrorIfOp(msg)(True)