Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2import numpy as np 

3import theano 

4from theano import gof 

5import theano.tensor as tt 

6 

7__all__ = ["spotYlmOp"] 

8 

9 

10class spotYlmOp(tt.Op): 

11 def __init__(self, func, ydeg, nw): 

12 self.func = func 

13 self._grad_op = spotYlmGradientOp(self) 

14 self.Ny = (ydeg + 1) ** 2 

15 self.nw = nw 

16 

17 def make_node(self, *inputs): 

18 inputs = [tt.as_tensor_variable(i) for i in inputs] 

19 if self.nw is None: 

20 outputs = [tt.TensorType(inputs[0].dtype, (False,))()] 

21 else: 

22 outputs = [tt.TensorType(inputs[0].dtype, (False, False))()] 

23 return gof.Apply(self, inputs, outputs) 

24 

25 def infer_shape(self, node, shapes): 

26 if self.nw is None: 

27 return [(self.Ny,)] 

28 else: 

29 return [(self.Ny, self.nw)] 

30 

31 def perform(self, node, inputs, outputs): 

32 outputs[0][0] = self.func(*inputs) 

33 if self.nw is None: 

34 outputs[0][0] = np.reshape(outputs[0][0], -1) 

35 

36 def grad(self, inputs, gradients): 

37 return self._grad_op(*(inputs + gradients)) 

38 

39 

40class spotYlmGradientOp(tt.Op): 

41 def __init__(self, base_op): 

42 self.base_op = base_op 

43 

44 def make_node(self, *inputs): 

45 inputs = [tt.as_tensor_variable(i) for i in inputs] 

46 outputs = [i.type() for i in inputs[:-1]] 

47 return gof.Apply(self, inputs, outputs) 

48 

49 def infer_shape(self, node, shapes): 

50 return shapes[:-1] 

51 

52 def perform(self, node, inputs, outputs): 

53 bamp, bsigma, blat, blon = self.base_op.func(*inputs) 

54 outputs[0][0] = np.reshape(bamp, np.shape(inputs[0])) 

55 outputs[1][0] = np.reshape(bsigma, np.shape(inputs[1])) 

56 outputs[2][0] = np.reshape(blat, np.shape(inputs[2])) 

57 outputs[3][0] = np.reshape(blon, np.shape(inputs[3]))