Coverage for starry/_core/ops/limbdark/limbdark.py : 84%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
3__all__ = ["LimbDarkOp"]
5import theano
6import theano.tensor as tt
7from theano import gof
8from .base_op import LimbDarkBaseOp
11class LimbDarkOp(LimbDarkBaseOp):
13 __props__ = ()
14 func_file = "./limbdark.cc"
15 func_name = "APPLY_SPECIFIC(limbdark)"
17 def make_node(self, c, b, r, los):
18 in_args = []
19 dtype = theano.config.floatX
20 for a in [c, b, r, los]:
21 try:
22 a = tt.as_tensor_variable(a)
23 except tt.AsTensorError:
24 pass
25 else:
26 dtype = theano.scalar.upcast(dtype, a.dtype)
27 in_args.append(a)
29 out_args = [
30 in_args[1].type(),
31 tt.TensorType(
32 dtype=dtype, broadcastable=[False] * (in_args[1].ndim + 1)
33 )(),
34 in_args[1].type(),
35 in_args[2].type(),
36 ]
37 return gof.Apply(self, in_args, out_args)
39 def infer_shape(self, node, shapes):
40 return (
41 shapes[1],
42 list(shapes[0]) + list(shapes[1]),
43 shapes[1],
44 shapes[2],
45 )
47 def grad(self, inputs, gradients):
48 c, b, r, los = inputs
49 f, dfdcl, dfdb, dfdr = self(*inputs)
50 bf = gradients[0]
51 for i, g in enumerate(gradients[1:]):
52 if not isinstance(g.type, theano.gradient.DisconnectedType):
53 raise ValueError(
54 "can't propagate gradients wrt parameter {0}".format(i + 1)
55 )
56 bc = tt.sum(
57 tt.reshape(bf, (1, bf.size))
58 * tt.reshape(dfdcl, (c.size, bf.size)),
59 axis=-1,
60 )
61 bb = bf * dfdb
62 br = bf * dfdr
63 return bc, bb, br, tt.zeros_like(los)
65 def R_op(self, inputs, eval_points):
66 if eval_points[0] is None:
67 return eval_points
68 return self.grad(inputs, eval_points)