Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / third_party / cython / src / Cython / Build / Inline.py
blobfcbb6c1282a3e9768ac21e25c692d51213498300
1 import sys, os, re, inspect
2 import imp
4 try:
5 import hashlib
6 except ImportError:
7 import md5 as hashlib
9 from distutils.core import Distribution, Extension
10 from distutils.command.build_ext import build_ext
12 import Cython
13 from Cython.Compiler.Main import Context, CompilationOptions, default_options
15 from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
16 from Cython.Compiler.TreeFragment import parse_from_strings
17 from Cython.Build.Dependencies import strip_string_literals, cythonize, cached_function
18 from Cython.Compiler import Pipeline
19 from Cython.Utils import get_cython_cache_dir
20 import cython as cython_module
22 # A utility function to convert user-supplied ASCII strings to unicode.
23 if sys.version_info[0] < 3:
24 def to_unicode(s):
25 if not isinstance(s, unicode):
26 return s.decode('ascii')
27 else:
28 return s
29 else:
30 to_unicode = lambda x: x
33 class AllSymbols(CythonTransform, SkipDeclarations):
34 def __init__(self):
35 CythonTransform.__init__(self, None)
36 self.names = set()
37 def visit_NameNode(self, node):
38 self.names.add(node.name)
40 @cached_function
41 def unbound_symbols(code, context=None):
42 code = to_unicode(code)
43 if context is None:
44 context = Context([], default_options)
45 from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
46 tree = parse_from_strings('(tree fragment)', code)
47 for phase in Pipeline.create_pipeline(context, 'pyx'):
48 if phase is None:
49 continue
50 tree = phase(tree)
51 if isinstance(phase, AnalyseDeclarationsTransform):
52 break
53 symbol_collector = AllSymbols()
54 symbol_collector(tree)
55 unbound = []
56 try:
57 import builtins
58 except ImportError:
59 import __builtin__ as builtins
60 for name in symbol_collector.names:
61 if not tree.scope.lookup(name) and not hasattr(builtins, name):
62 unbound.append(name)
63 return unbound
65 def unsafe_type(arg, context=None):
66 py_type = type(arg)
67 if py_type is int:
68 return 'long'
69 else:
70 return safe_type(arg, context)
72 def safe_type(arg, context=None):
73 py_type = type(arg)
74 if py_type in [list, tuple, dict, str]:
75 return py_type.__name__
76 elif py_type is complex:
77 return 'double complex'
78 elif py_type is float:
79 return 'double'
80 elif py_type is bool:
81 return 'bint'
82 elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
83 return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
84 else:
85 for base_type in py_type.mro():
86 if base_type.__module__ in ('__builtin__', 'builtins'):
87 return 'object'
88 module = context.find_module(base_type.__module__, need_pxd=False)
89 if module:
90 entry = module.lookup(base_type.__name__)
91 if entry.is_type:
92 return '%s.%s' % (base_type.__module__, base_type.__name__)
93 return 'object'
95 def _get_build_extension():
96 dist = Distribution()
97 # Ensure the build respects distutils configuration by parsing
98 # the configuration files
99 config_files = dist.find_config_files()
100 dist.parse_config_files(config_files)
101 build_extension = build_ext(dist)
102 build_extension.finalize_options()
103 return build_extension
105 @cached_function
106 def _create_context(cython_include_dirs):
107 return Context(list(cython_include_dirs), default_options)
109 def cython_inline(code,
110 get_type=unsafe_type,
111 lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
112 cython_include_dirs=['.'],
113 force=False,
114 quiet=False,
115 locals=None,
116 globals=None,
117 **kwds):
118 if get_type is None:
119 get_type = lambda x: 'object'
120 code = to_unicode(code)
121 orig_code = code
122 code, literals = strip_string_literals(code)
123 code = strip_common_indent(code)
124 ctx = _create_context(tuple(cython_include_dirs))
125 if locals is None:
126 locals = inspect.currentframe().f_back.f_back.f_locals
127 if globals is None:
128 globals = inspect.currentframe().f_back.f_back.f_globals
129 try:
130 for symbol in unbound_symbols(code):
131 if symbol in kwds:
132 continue
133 elif symbol in locals:
134 kwds[symbol] = locals[symbol]
135 elif symbol in globals:
136 kwds[symbol] = globals[symbol]
137 else:
138 print("Couldn't find ", symbol)
139 except AssertionError:
140 if not quiet:
141 # Parsing from strings not fully supported (e.g. cimports).
142 print("Could not parse code as a string (to extract unbound symbols).")
143 cimports = []
144 for name, arg in kwds.items():
145 if arg is cython_module:
146 cimports.append('\ncimport cython as %s' % name)
147 del kwds[name]
148 arg_names = kwds.keys()
149 arg_names.sort()
150 arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
151 key = orig_code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
152 module_name = "_cython_inline_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()
154 if module_name in sys.modules:
155 module = sys.modules[module_name]
157 else:
158 build_extension = None
159 if cython_inline.so_ext is None:
160 # Figure out and cache current extension suffix
161 build_extension = _get_build_extension()
162 cython_inline.so_ext = build_extension.get_ext_filename('')
164 module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)
166 if not os.path.exists(lib_dir):
167 os.makedirs(lib_dir)
168 if force or not os.path.isfile(module_path):
169 cflags = []
170 c_include_dirs = []
171 qualified = re.compile(r'([.\w]+)[.]')
172 for type, _ in arg_sigs:
173 m = qualified.match(type)
174 if m:
175 cimports.append('\ncimport %s' % m.groups()[0])
176 # one special case
177 if m.groups()[0] == 'numpy':
178 import numpy
179 c_include_dirs.append(numpy.get_include())
180 # cflags.append('-Wno-unused')
181 module_body, func_body = extract_func_code(code)
182 params = ', '.join(['%s %s' % a for a in arg_sigs])
183 module_code = """
184 %(module_body)s
185 %(cimports)s
186 def __invoke(%(params)s):
187 %(func_body)s
188 """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
189 for key, value in literals.items():
190 module_code = module_code.replace(key, value)
191 pyx_file = os.path.join(lib_dir, module_name + '.pyx')
192 fh = open(pyx_file, 'w')
193 try:
194 fh.write(module_code)
195 finally:
196 fh.close()
197 extension = Extension(
198 name = module_name,
199 sources = [pyx_file],
200 include_dirs = c_include_dirs,
201 extra_compile_args = cflags)
202 if build_extension is None:
203 build_extension = _get_build_extension()
204 build_extension.extensions = cythonize([extension], include_path=cython_include_dirs, quiet=quiet)
205 build_extension.build_temp = os.path.dirname(pyx_file)
206 build_extension.build_lib = lib_dir
207 build_extension.run()
209 module = imp.load_dynamic(module_name, module_path)
211 arg_list = [kwds[arg] for arg in arg_names]
212 return module.__invoke(*arg_list)
214 # Cached suffix used by cython_inline above. None should get
215 # overridden with actual value upon the first cython_inline invocation
216 cython_inline.so_ext = None
218 non_space = re.compile('[^ ]')
219 def strip_common_indent(code):
220 min_indent = None
221 lines = code.split('\n')
222 for line in lines:
223 match = non_space.search(line)
224 if not match:
225 continue # blank
226 indent = match.start()
227 if line[indent] == '#':
228 continue # comment
229 elif min_indent is None or min_indent > indent:
230 min_indent = indent
231 for ix, line in enumerate(lines):
232 match = non_space.search(line)
233 if not match or line[indent] == '#':
234 continue
235 else:
236 lines[ix] = line[min_indent:]
237 return '\n'.join(lines)
239 module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
240 def extract_func_code(code):
241 module = []
242 function = []
243 current = function
244 code = code.replace('\t', ' ')
245 lines = code.split('\n')
246 for line in lines:
247 if not line.startswith(' '):
248 if module_statement.match(line):
249 current = module
250 else:
251 current = function
252 current.append(line)
253 return '\n'.join(module), ' ' + '\n '.join(function)
257 try:
258 from inspect import getcallargs
259 except ImportError:
260 def getcallargs(func, *arg_values, **kwd_values):
261 all = {}
262 args, varargs, kwds, defaults = inspect.getargspec(func)
263 if varargs is not None:
264 all[varargs] = arg_values[len(args):]
265 for name, value in zip(args, arg_values):
266 all[name] = value
267 for name, value in kwd_values.items():
268 if name in args:
269 if name in all:
270 raise TypeError("Duplicate argument %s" % name)
271 all[name] = kwd_values.pop(name)
272 if kwds is not None:
273 all[kwds] = kwd_values
274 elif kwd_values:
275 raise TypeError("Unexpected keyword arguments: %s" % kwd_values.keys())
276 if defaults is None:
277 defaults = ()
278 first_default = len(args) - len(defaults)
279 for ix, name in enumerate(args):
280 if name not in all:
281 if ix >= first_default:
282 all[name] = defaults[ix - first_default]
283 else:
284 raise TypeError("Missing argument: %s" % name)
285 return all
287 def get_body(source):
288 ix = source.index(':')
289 if source[:5] == 'lambda':
290 return "return %s" % source[ix+1:]
291 else:
292 return source[ix+1:]
294 # Lots to be done here... It would be especially cool if compiled functions
295 # could invoke each other quickly.
296 class RuntimeCompiledFunction(object):
298 def __init__(self, f):
299 self._f = f
300 self._body = get_body(inspect.getsource(f))
302 def __call__(self, *args, **kwds):
303 all = getcallargs(self._f, *args, **kwds)
304 return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)