Coverage for starry/_core/ops/polybasis.py : 100%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2import numpy as np
3import theano
4from theano import gof
5import theano.tensor as tt
7__all__ = ["pTOp"]
10class pTOp(tt.Op):
11 def __init__(self, func, deg):
12 self.func = func
13 self.deg = deg
14 self.N = (deg + 1) ** 2
15 self._grad_op = pTGradientOp(self)
17 def make_node(self, *inputs):
18 inputs = [tt.as_tensor_variable(i) for i in inputs]
19 outputs = [tt.TensorType(inputs[0].dtype, (False, False))()]
20 return gof.Apply(self, inputs, outputs)
22 def infer_shape(self, node, shapes):
23 return [[shapes[0][0], self.N]]
25 def perform(self, node, inputs, outputs):
26 outputs[0][0] = self.func(*inputs)
28 def grad(self, inputs, gradients):
29 return self._grad_op(*(inputs + gradients))
32class pTGradientOp(tt.Op):
33 def __init__(self, base_op):
34 self.base_op = base_op
36 # Pre-compute the gradient factors for x, y, and z
37 n = 0
38 self.xf = np.zeros(self.base_op.N, dtype=int)
39 self.yf = np.zeros(self.base_op.N, dtype=int)
40 self.zf = np.zeros(self.base_op.N, dtype=int)
41 for l in range(self.base_op.deg + 1):
42 for m in range(-l, l + 1):
43 mu = l - m
44 nu = l + m
45 if nu % 2 == 0:
46 if mu > 0:
47 self.xf[n] = mu // 2
48 if nu > 0:
49 self.yf[n] = nu // 2
50 else:
51 if mu > 1:
52 self.xf[n] = (mu - 1) // 2
53 if nu > 1:
54 self.yf[n] = (nu - 1) // 2
55 self.zf[n] = 1
56 n += 1
58 def make_node(self, *inputs):
59 inputs = [tt.as_tensor_variable(i) for i in inputs]
60 outputs = [i.type() for i in inputs[:-1]]
61 return gof.Apply(self, inputs, outputs)
63 def infer_shape(self, node, shapes):
64 return shapes[:-1]
66 def perform(self, node, inputs, outputs):
67 x, y, z, bpT = inputs
69 # TODO: When any of the coords are zero, there's a div
70 # by zero below. This hack fixes the issue. We should
71 # think of a better way of doing this!
72 tol = 1e-8
73 x[np.abs(x) < tol] = tol
74 y[np.abs(y) < tol] = tol
75 z[np.abs(z) < tol] = tol
77 bpTpT = bpT * self.base_op.func(x, y, z)
78 bx = np.nansum(self.xf[None, :] * bpTpT / x[:, None], axis=-1)
79 by = np.nansum(self.yf[None, :] * bpTpT / y[:, None], axis=-1)
80 bz = np.nansum(self.zf[None, :] * bpTpT / z[:, None], axis=-1)
81 outputs[0][0] = np.reshape(bx, np.shape(inputs[0]))
82 outputs[1][0] = np.reshape(by, np.shape(inputs[1]))
83 outputs[2][0] = np.reshape(bz, np.shape(inputs[2]))