Coverage for starry/_core/ops/spot.py : 100%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2import numpy as np
3import theano
4from theano import gof
5import theano.tensor as tt
7__all__ = ["spotYlmOp"]
10class spotYlmOp(tt.Op):
11 def __init__(self, func, ydeg, nw):
12 self.func = func
13 self._grad_op = spotYlmGradientOp(self)
14 self.Ny = (ydeg + 1) ** 2
15 self.nw = nw
17 def make_node(self, *inputs):
18 inputs = [tt.as_tensor_variable(i) for i in inputs]
19 if self.nw is None:
20 outputs = [tt.TensorType(inputs[0].dtype, (False,))()]
21 else:
22 outputs = [tt.TensorType(inputs[0].dtype, (False, False))()]
23 return gof.Apply(self, inputs, outputs)
25 def infer_shape(self, node, shapes):
26 if self.nw is None:
27 return [(self.Ny,)]
28 else:
29 return [(self.Ny, self.nw)]
31 def perform(self, node, inputs, outputs):
32 outputs[0][0] = self.func(*inputs)
33 if self.nw is None:
34 outputs[0][0] = np.reshape(outputs[0][0], -1)
36 def grad(self, inputs, gradients):
37 return self._grad_op(*(inputs + gradients))
40class spotYlmGradientOp(tt.Op):
41 def __init__(self, base_op):
42 self.base_op = base_op
44 def make_node(self, *inputs):
45 inputs = [tt.as_tensor_variable(i) for i in inputs]
46 outputs = [i.type() for i in inputs[:-1]]
47 return gof.Apply(self, inputs, outputs)
49 def infer_shape(self, node, shapes):
50 return shapes[:-1]
52 def perform(self, node, inputs, outputs):
53 bamp, bsigma, blat, blon = self.base_op.func(*inputs)
54 outputs[0][0] = np.reshape(bamp, np.shape(inputs[0]))
55 outputs[1][0] = np.reshape(bsigma, np.shape(inputs[1]))
56 outputs[2][0] = np.reshape(blat, np.shape(inputs[2]))
57 outputs[3][0] = np.reshape(blon, np.shape(inputs[3]))