Coverage for starry/_core/utils.py : 92%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2from .. import config
3import theano
4import theano.tensor as tt
5import numpy as np
6from theano.configparser import change_flags
7from inspect import getmro
8from functools import wraps
9import logging
11logger = logging.getLogger("starry.ops")
13__all__ = ["logger", "autocompile"]
16integers = (int, np.int, np.int16, np.int32, np.int64)
19def is_theano(*objs):
20 """Return ``True`` if any of ``objs`` is a ``Theano`` object."""
21 for obj in objs:
22 for c in getmro(type(obj)):
23 if c is theano.gof.graph.Node:
24 return True
25 return False
28class CompileLogMessage:
29 """
30 Log a brief message saying what method is currently
31 being compiled and print `Done` when finished.
33 """
35 def __init__(self, name):
36 self.name = name
38 def __enter__(self):
39 config.rootHandler.terminator = ""
40 logger.info("Compiling `{0}`... ".format(self.name))
42 def __exit__(self, type, value, traceback):
43 config.rootHandler.terminator = "\n"
44 logger.info("Done.")
47def _get_type(arg):
48 """
49 Get the theano tensor type corresponding to `arg`.
51 Note that arg must be one of the following:
52 - a theano tensor
53 - an integer (`int`, `np.int`, `np.int16`, `np.int32`, `np.int64`)
54 - a numpy boolean (`np.array(True)`, `np.array(False)`)
55 - a numpy float array with ndim equal to 0, 1, 2, or 3
57 # TODO: Cast lists to arrays and floats to np.array(float)
59 """
60 ttype = type(arg)
61 if is_theano(arg):
62 return ttype
63 else:
64 if ttype in integers:
65 return tt.iscalar
66 elif hasattr(arg, "ndim"):
67 if arg.ndim == 0:
68 if arg.dtype is np.array(True).dtype:
69 return tt.bscalar
70 else:
71 return tt.dscalar
72 elif arg.ndim == 1:
73 return tt.dvector
74 elif arg.ndim == 2:
75 return tt.dmatrix
76 elif arg.ndim == 3:
77 return tt.dtensor3
78 else:
79 raise NotImplementedError(
80 "Invalid array dimension passed to @autocompile: {}.".format(
81 arg.ndim
82 )
83 )
84 else:
85 raise NotImplementedError(
86 "Invalid argument type passed to @autocompile: {}.".format(
87 ttype
88 )
89 )
92def autocompile(func):
93 """
94 Wrap the method `func` and return a compiled version
95 if none of the arguments are tensors.
97 """
99 @wraps(func) # inherit docstring
100 def wrapper(instance, *args):
102 if is_theano(*args):
104 # Just return the function as is
105 return func(instance, *args)
107 else:
109 # Determine the argument types
110 arg_types = tuple([_get_type(arg) for arg in args])
112 # Get a unique name for the compiled function
113 cname = "{}_{}".format(func.__name__, hex(hash(arg_types)))
115 # Compile the function if needed & cache it
116 if not hasattr(instance, cname):
118 dummy_args = [arg_type() for arg_type in arg_types]
120 # Compile the function
121 with CompileLogMessage(func.__name__):
122 with change_flags(compute_test_value="off"):
123 compiled_func = theano.function(
124 [*dummy_args],
125 func(instance, *dummy_args),
126 on_unused_input="ignore",
127 profile=config.profile,
128 )
129 setattr(instance, cname, compiled_func)
131 # Return the compiled version
132 return getattr(instance, cname)(*args)
134 return wrapper