Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2import numpy as np 

3import theano 

4from theano import gof 

5import theano.tensor as tt 

6 

7__all__ = ["pTOp"] 

8 

9 

10class pTOp(tt.Op): 

11 def __init__(self, func, deg): 

12 self.func = func 

13 self.deg = deg 

14 self.N = (deg + 1) ** 2 

15 self._grad_op = pTGradientOp(self) 

16 

17 def make_node(self, *inputs): 

18 inputs = [tt.as_tensor_variable(i) for i in inputs] 

19 outputs = [tt.TensorType(inputs[0].dtype, (False, False))()] 

20 return gof.Apply(self, inputs, outputs) 

21 

22 def infer_shape(self, node, shapes): 

23 return [[shapes[0][0], self.N]] 

24 

25 def perform(self, node, inputs, outputs): 

26 outputs[0][0] = self.func(*inputs) 

27 

28 def grad(self, inputs, gradients): 

29 return self._grad_op(*(inputs + gradients)) 

30 

31 

32class pTGradientOp(tt.Op): 

33 def __init__(self, base_op): 

34 self.base_op = base_op 

35 

36 # Pre-compute the gradient factors for x, y, and z 

37 n = 0 

38 self.xf = np.zeros(self.base_op.N, dtype=int) 

39 self.yf = np.zeros(self.base_op.N, dtype=int) 

40 self.zf = np.zeros(self.base_op.N, dtype=int) 

41 for l in range(self.base_op.deg + 1): 

42 for m in range(-l, l + 1): 

43 mu = l - m 

44 nu = l + m 

45 if nu % 2 == 0: 

46 if mu > 0: 

47 self.xf[n] = mu // 2 

48 if nu > 0: 

49 self.yf[n] = nu // 2 

50 else: 

51 if mu > 1: 

52 self.xf[n] = (mu - 1) // 2 

53 if nu > 1: 

54 self.yf[n] = (nu - 1) // 2 

55 self.zf[n] = 1 

56 n += 1 

57 

58 def make_node(self, *inputs): 

59 inputs = [tt.as_tensor_variable(i) for i in inputs] 

60 outputs = [i.type() for i in inputs[:-1]] 

61 return gof.Apply(self, inputs, outputs) 

62 

63 def infer_shape(self, node, shapes): 

64 return shapes[:-1] 

65 

66 def perform(self, node, inputs, outputs): 

67 x, y, z, bpT = inputs 

68 

69 # TODO: When any of the coords are zero, there's a div 

70 # by zero below. This hack fixes the issue. We should 

71 # think of a better way of doing this! 

72 tol = 1e-8 

73 x[np.abs(x) < tol] = tol 

74 y[np.abs(y) < tol] = tol 

75 z[np.abs(z) < tol] = tol 

76 

77 bpTpT = bpT * self.base_op.func(x, y, z) 

78 bx = np.nansum(self.xf[None, :] * bpTpT / x[:, None], axis=-1) 

79 by = np.nansum(self.yf[None, :] * bpTpT / y[:, None], axis=-1) 

80 bz = np.nansum(self.zf[None, :] * bpTpT / z[:, None], axis=-1) 

81 outputs[0][0] = np.reshape(bx, np.shape(inputs[0])) 

82 outputs[1][0] = np.reshape(by, np.shape(inputs[1])) 

83 outputs[2][0] = np.reshape(bz, np.shape(inputs[2]))