Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2from .. import config 

3import theano 

4import theano.tensor as tt 

5import numpy as np 

6from theano.configparser import change_flags 

7from inspect import getmro 

8from functools import wraps 

9import logging 

10 

11logger = logging.getLogger("starry.ops") 

12 

13__all__ = ["logger", "autocompile"] 

14 

15 

16integers = (int, np.int, np.int16, np.int32, np.int64) 

17 

18 

19def is_theano(*objs): 

20 """Return ``True`` if any of ``objs`` is a ``Theano`` object.""" 

21 for obj in objs: 

22 for c in getmro(type(obj)): 

23 if c is theano.gof.graph.Node: 

24 return True 

25 return False 

26 

27 

28class CompileLogMessage: 

29 """ 

30 Log a brief message saying what method is currently 

31 being compiled and print `Done` when finished. 

32 

33 """ 

34 

35 def __init__(self, name): 

36 self.name = name 

37 

38 def __enter__(self): 

39 config.rootHandler.terminator = "" 

40 logger.info("Compiling `{0}`... ".format(self.name)) 

41 

42 def __exit__(self, type, value, traceback): 

43 config.rootHandler.terminator = "\n" 

44 logger.info("Done.") 

45 

46 

47def _get_type(arg): 

48 """ 

49 Get the theano tensor type corresponding to `arg`. 

50 

51 Note that arg must be one of the following: 

52 - a theano tensor 

53 - an integer (`int`, `np.int`, `np.int16`, `np.int32`, `np.int64`) 

54 - a numpy boolean (`np.array(True)`, `np.array(False)`) 

55 - a numpy float array with ndim equal to 0, 1, 2, or 3 

56 

57 # TODO: Cast lists to arrays and floats to np.array(float) 

58 

59 """ 

60 ttype = type(arg) 

61 if is_theano(arg): 

62 return ttype 

63 else: 

64 if ttype in integers: 

65 return tt.iscalar 

66 elif hasattr(arg, "ndim"): 

67 if arg.ndim == 0: 

68 if arg.dtype is np.array(True).dtype: 

69 return tt.bscalar 

70 else: 

71 return tt.dscalar 

72 elif arg.ndim == 1: 

73 return tt.dvector 

74 elif arg.ndim == 2: 

75 return tt.dmatrix 

76 elif arg.ndim == 3: 

77 return tt.dtensor3 

78 else: 

79 raise NotImplementedError( 

80 "Invalid array dimension passed to @autocompile: {}.".format( 

81 arg.ndim 

82 ) 

83 ) 

84 else: 

85 raise NotImplementedError( 

86 "Invalid argument type passed to @autocompile: {}.".format( 

87 ttype 

88 ) 

89 ) 

90 

91 

92def autocompile(func): 

93 """ 

94 Wrap the method `func` and return a compiled version 

95 if none of the arguments are tensors. 

96 

97 """ 

98 

99 @wraps(func) # inherit docstring 

100 def wrapper(instance, *args): 

101 

102 if is_theano(*args): 

103 

104 # Just return the function as is 

105 return func(instance, *args) 

106 

107 else: 

108 

109 # Determine the argument types 

110 arg_types = tuple([_get_type(arg) for arg in args]) 

111 

112 # Get a unique name for the compiled function 

113 cname = "{}_{}".format(func.__name__, hex(hash(arg_types))) 

114 

115 # Compile the function if needed & cache it 

116 if not hasattr(instance, cname): 

117 

118 dummy_args = [arg_type() for arg_type in arg_types] 

119 

120 # Compile the function 

121 with CompileLogMessage(func.__name__): 

122 with change_flags(compute_test_value="off"): 

123 compiled_func = theano.function( 

124 [*dummy_args], 

125 func(instance, *dummy_args), 

126 on_unused_input="ignore", 

127 profile=config.profile, 

128 ) 

129 setattr(instance, cname, compiled_func) 

130 

131 # Return the compiled version 

132 return getattr(instance, cname)(*args) 

133 

134 return wrapper