Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2import numpy as np 

3from theano import gof 

4import theano.tensor as tt 

5 

6__all__ = ["CheckBoundsOp", "RaiseValueErrorOp", "RaiseValueErrorIfOp"] 

7 

8 

9class CheckBoundsOp(tt.Op): 

10 """ 

11 

12 """ 

13 

14 def __init__(self, lower=-np.inf, upper=np.inf, name=None): 

15 self.lower = lower 

16 self.upper = upper 

17 if name is None: 

18 self.name = "parameter" 

19 else: 

20 self.name = name 

21 

22 def make_node(self, *inputs): 

23 inputs = [tt.as_tensor_variable(inputs[0])] 

24 outputs = [inputs[0].type()] 

25 return gof.Apply(self, inputs, outputs) 

26 

27 def infer_shape(self, node, shapes): 

28 return [shapes[0]] 

29 

30 def perform(self, node, inputs, outputs): 

31 outputs[0][0] = inputs[0] 

32 if np.any((inputs[0] < self.lower) | (inputs[0] > self.upper)): 

33 low = np.where((inputs[0] < self.lower))[0] 

34 high = np.where((inputs[0] > self.upper))[0] 

35 if len(low): 

36 value = inputs[0][low[0]] 

37 sign = "<" 

38 bound = self.lower 

39 else: 

40 value = inputs[0][high[0]] 

41 sign = ">" 

42 bound = self.upper 

43 raise ValueError( 

44 "%s out of bounds: %f %s %f" % (self.name, value, sign, bound) 

45 ) 

46 

47 

48class RaiseValueErrorIfOp(tt.Op): 

49 """ 

50 

51 """ 

52 

53 def __init__(self, message=None): 

54 self.message = message 

55 

56 def make_node(self, *inputs): 

57 condition = inputs 

58 inputs = [tt.as_tensor_variable(condition)] 

59 outputs = [tt.TensorType(tt.config.floatX, ())()] 

60 return gof.Apply(self, inputs, outputs) 

61 

62 def infer_shape(self, node, shapes): 

63 return [()] 

64 

65 def perform(self, node, inputs, outputs): 

66 outputs[0][0] = np.array(0.0) 

67 if inputs[0]: 

68 raise ValueError(self.message) 

69 

70 def grad(self, inputs, gradients): 

71 return [inputs[0] * 0.0] 

72 

73 

74def RaiseValueErrorOp(msg, shape=()): 

75 return tt.zeros(shape) * RaiseValueErrorIfOp(msg)(True)