Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2from __future__ import division, print_function 

3import numpy as np 

4from theano import gof 

5import theano.tensor as tt 

6 

7 

8__all__ = ["tensordotDOp"] 

9 

10 

11class tensordotDOp(tt.Op): 

12 def __init__(self, func): 

13 self.func = func 

14 self._grad_op = tensordotDGradientOp(self) 

15 

16 def make_node(self, *inputs): 

17 inputs = [tt.as_tensor_variable(i) for i in inputs] 

18 outputs = [tt.TensorType(inputs[0].dtype, (False, False))()] 

19 return gof.Apply(self, inputs, outputs) 

20 

21 def infer_shape(self, node, shapes): 

22 return [[shapes[1][0], shapes[0][-1]]] 

23 

24 def R_op(self, inputs, eval_points): 

25 if eval_points[0] is None: 

26 return eval_points 

27 return self.grad(inputs, eval_points) 

28 

29 def perform(self, node, inputs, outputs): 

30 outputs[0][0] = self.func(*inputs) 

31 

32 def grad(self, inputs, gradients): 

33 return self._grad_op(*(inputs + gradients)) 

34 

35 

36class tensordotDGradientOp(tt.Op): 

37 def __init__(self, base_op): 

38 self.base_op = base_op 

39 

40 def make_node(self, *inputs): 

41 inputs = [tt.as_tensor_variable(i) for i in inputs] 

42 outputs = [i.type() for i in inputs[:-1]] 

43 return gof.Apply(self, inputs, outputs) 

44 

45 def infer_shape(self, node, shapes): 

46 return shapes[:-1] 

47 

48 def perform(self, node, inputs, outputs): 

49 bM, bwta = self.base_op.func(*inputs) 

50 outputs[0][0] = np.reshape(bM, np.shape(inputs[0])) 

51 outputs[1][0] = np.reshape(bwta, np.shape(inputs[1]))