Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2 

3__all__ = ["LimbDarkOp"] 

4 

5import theano 

6import theano.tensor as tt 

7from theano import gof 

8from .base_op import LimbDarkBaseOp 

9 

10 

11class LimbDarkOp(LimbDarkBaseOp): 

12 

13 __props__ = () 

14 func_file = "./limbdark.cc" 

15 func_name = "APPLY_SPECIFIC(limbdark)" 

16 

17 def make_node(self, c, b, r, los): 

18 in_args = [] 

19 dtype = theano.config.floatX 

20 for a in [c, b, r, los]: 

21 try: 

22 a = tt.as_tensor_variable(a) 

23 except tt.AsTensorError: 

24 pass 

25 else: 

26 dtype = theano.scalar.upcast(dtype, a.dtype) 

27 in_args.append(a) 

28 

29 out_args = [ 

30 in_args[1].type(), 

31 tt.TensorType( 

32 dtype=dtype, broadcastable=[False] * (in_args[1].ndim + 1) 

33 )(), 

34 in_args[1].type(), 

35 in_args[2].type(), 

36 ] 

37 return gof.Apply(self, in_args, out_args) 

38 

39 def infer_shape(self, node, shapes): 

40 return ( 

41 shapes[1], 

42 list(shapes[0]) + list(shapes[1]), 

43 shapes[1], 

44 shapes[2], 

45 ) 

46 

47 def grad(self, inputs, gradients): 

48 c, b, r, los = inputs 

49 f, dfdcl, dfdb, dfdr = self(*inputs) 

50 bf = gradients[0] 

51 for i, g in enumerate(gradients[1:]): 

52 if not isinstance(g.type, theano.gradient.DisconnectedType): 

53 raise ValueError( 

54 "can't propagate gradients wrt parameter {0}".format(i + 1) 

55 ) 

56 bc = tt.sum( 

57 tt.reshape(bf, (1, bf.size)) 

58 * tt.reshape(dfdcl, (c.size, bf.size)), 

59 axis=-1, 

60 ) 

61 bb = bf * dfdb 

62 br = bf * dfdr 

63 return bc, bb, br, tt.zeros_like(los) 

64 

65 def R_op(self, inputs, eval_points): 

66 if eval_points[0] is None: 

67 return eval_points 

68 return self.grad(inputs, eval_points)