Coverage for starry/_core/ops/diffrot.py : 92%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2from __future__ import division, print_function
3import numpy as np
4from theano import gof
5import theano.tensor as tt
8__all__ = ["tensordotDOp"]
11class tensordotDOp(tt.Op):
12 def __init__(self, func):
13 self.func = func
14 self._grad_op = tensordotDGradientOp(self)
16 def make_node(self, *inputs):
17 inputs = [tt.as_tensor_variable(i) for i in inputs]
18 outputs = [tt.TensorType(inputs[0].dtype, (False, False))()]
19 return gof.Apply(self, inputs, outputs)
21 def infer_shape(self, node, shapes):
22 return [[shapes[1][0], shapes[0][-1]]]
24 def R_op(self, inputs, eval_points):
25 if eval_points[0] is None:
26 return eval_points
27 return self.grad(inputs, eval_points)
29 def perform(self, node, inputs, outputs):
30 outputs[0][0] = self.func(*inputs)
32 def grad(self, inputs, gradients):
33 return self._grad_op(*(inputs + gradients))
36class tensordotDGradientOp(tt.Op):
37 def __init__(self, base_op):
38 self.base_op = base_op
40 def make_node(self, *inputs):
41 inputs = [tt.as_tensor_variable(i) for i in inputs]
42 outputs = [i.type() for i in inputs[:-1]]
43 return gof.Apply(self, inputs, outputs)
45 def infer_shape(self, node, shapes):
46 return shapes[:-1]
48 def perform(self, node, inputs, outputs):
49 bM, bwta = self.base_op.func(*inputs)
50 outputs[0][0] = np.reshape(bM, np.shape(inputs[0]))
51 outputs[1][0] = np.reshape(bwta, np.shape(inputs[1]))