1 import sys
, os
, re
, inspect
9 from distutils
.core
import Distribution
, Extension
10 from distutils
.command
.build_ext
import build_ext
13 from Cython
.Compiler
.Main
import Context
, CompilationOptions
, default_options
15 from Cython
.Compiler
.ParseTreeTransforms
import CythonTransform
, SkipDeclarations
, AnalyseDeclarationsTransform
16 from Cython
.Compiler
.TreeFragment
import parse_from_strings
17 from Cython
.Build
.Dependencies
import strip_string_literals
, cythonize
, cached_function
18 from Cython
.Compiler
import Pipeline
19 from Cython
.Utils
import get_cython_cache_dir
20 import cython
as cython_module
22 # A utility function to convert user-supplied ASCII strings to unicode.
23 if sys
.version_info
[0] < 3:
25 if not isinstance(s
, unicode):
26 return s
.decode('ascii')
30 to_unicode
= lambda x
: x
33 class AllSymbols(CythonTransform
, SkipDeclarations
):
35 CythonTransform
.__init
__(self
, None)
37 def visit_NameNode(self
, node
):
38 self
.names
.add(node
.name
)
41 def unbound_symbols(code
, context
=None):
42 code
= to_unicode(code
)
44 context
= Context([], default_options
)
45 from Cython
.Compiler
.ParseTreeTransforms
import AnalyseDeclarationsTransform
46 tree
= parse_from_strings('(tree fragment)', code
)
47 for phase
in Pipeline
.create_pipeline(context
, 'pyx'):
51 if isinstance(phase
, AnalyseDeclarationsTransform
):
53 symbol_collector
= AllSymbols()
54 symbol_collector(tree
)
59 import __builtin__
as builtins
60 for name
in symbol_collector
.names
:
61 if not tree
.scope
.lookup(name
) and not hasattr(builtins
, name
):
65 def unsafe_type(arg
, context
=None):
70 return safe_type(arg
, context
)
72 def safe_type(arg
, context
=None):
74 if py_type
in [list, tuple, dict, str]:
75 return py_type
.__name
__
76 elif py_type
is complex:
77 return 'double complex'
78 elif py_type
is float:
82 elif 'numpy' in sys
.modules
and isinstance(arg
, sys
.modules
['numpy'].ndarray
):
83 return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg
.dtype
.name
, arg
.ndim
)
85 for base_type
in py_type
.mro():
86 if base_type
.__module
__ in ('__builtin__', 'builtins'):
88 module
= context
.find_module(base_type
.__module
__, need_pxd
=False)
90 entry
= module
.lookup(base_type
.__name
__)
92 return '%s.%s' % (base_type
.__module
__, base_type
.__name
__)
95 def _get_build_extension():
97 # Ensure the build respects distutils configuration by parsing
98 # the configuration files
99 config_files
= dist
.find_config_files()
100 dist
.parse_config_files(config_files
)
101 build_extension
= build_ext(dist
)
102 build_extension
.finalize_options()
103 return build_extension
106 def _create_context(cython_include_dirs
):
107 return Context(list(cython_include_dirs
), default_options
)
109 def cython_inline(code
,
110 get_type
=unsafe_type
,
111 lib_dir
=os
.path
.join(get_cython_cache_dir(), 'inline'),
112 cython_include_dirs
=['.'],
119 get_type
= lambda x
: 'object'
120 code
= to_unicode(code
)
122 code
, literals
= strip_string_literals(code
)
123 code
= strip_common_indent(code
)
124 ctx
= _create_context(tuple(cython_include_dirs
))
126 locals = inspect
.currentframe().f_back
.f_back
.f_locals
128 globals = inspect
.currentframe().f_back
.f_back
.f_globals
130 for symbol
in unbound_symbols(code
):
133 elif symbol
in locals:
134 kwds
[symbol
] = locals[symbol
]
135 elif symbol
in globals:
136 kwds
[symbol
] = globals[symbol
]
138 print("Couldn't find ", symbol
)
139 except AssertionError:
141 # Parsing from strings not fully supported (e.g. cimports).
142 print("Could not parse code as a string (to extract unbound symbols).")
144 for name
, arg
in kwds
.items():
145 if arg
is cython_module
:
146 cimports
.append('\ncimport cython as %s' % name
)
148 arg_names
= kwds
.keys()
150 arg_sigs
= tuple([(get_type(kwds
[arg
], ctx
), arg
) for arg
in arg_names
])
151 key
= orig_code
, arg_sigs
, sys
.version_info
, sys
.executable
, Cython
.__version
__
152 module_name
= "_cython_inline_" + hashlib
.md5(str(key
).encode('utf-8')).hexdigest()
154 if module_name
in sys
.modules
:
155 module
= sys
.modules
[module_name
]
158 build_extension
= None
159 if cython_inline
.so_ext
is None:
160 # Figure out and cache current extension suffix
161 build_extension
= _get_build_extension()
162 cython_inline
.so_ext
= build_extension
.get_ext_filename('')
164 module_path
= os
.path
.join(lib_dir
, module_name
+ cython_inline
.so_ext
)
166 if not os
.path
.exists(lib_dir
):
168 if force
or not os
.path
.isfile(module_path
):
171 qualified
= re
.compile(r
'([.\w]+)[.]')
172 for type, _
in arg_sigs
:
173 m
= qualified
.match(type)
175 cimports
.append('\ncimport %s' % m
.groups()[0])
177 if m
.groups()[0] == 'numpy':
179 c_include_dirs
.append(numpy
.get_include())
180 # cflags.append('-Wno-unused')
181 module_body
, func_body
= extract_func_code(code
)
182 params
= ', '.join(['%s %s' % a
for a
in arg_sigs
])
186 def __invoke(%(params)s):
188 """ % {'cimports': '\n'.join(cimports
), 'module_body': module_body
, 'params': params
, 'func_body': func_body
}
189 for key
, value
in literals
.items():
190 module_code
= module_code
.replace(key
, value
)
191 pyx_file
= os
.path
.join(lib_dir
, module_name
+ '.pyx')
192 fh
= open(pyx_file
, 'w')
194 fh
.write(module_code
)
197 extension
= Extension(
199 sources
= [pyx_file
],
200 include_dirs
= c_include_dirs
,
201 extra_compile_args
= cflags
)
202 if build_extension
is None:
203 build_extension
= _get_build_extension()
204 build_extension
.extensions
= cythonize([extension
], include_path
=cython_include_dirs
, quiet
=quiet
)
205 build_extension
.build_temp
= os
.path
.dirname(pyx_file
)
206 build_extension
.build_lib
= lib_dir
207 build_extension
.run()
209 module
= imp
.load_dynamic(module_name
, module_path
)
211 arg_list
= [kwds
[arg
] for arg
in arg_names
]
212 return module
.__invoke
(*arg_list
)
214 # Cached suffix used by cython_inline above. None should get
215 # overridden with actual value upon the first cython_inline invocation
216 cython_inline
.so_ext
= None
218 non_space
= re
.compile('[^ ]')
219 def strip_common_indent(code
):
221 lines
= code
.split('\n')
223 match
= non_space
.search(line
)
226 indent
= match
.start()
227 if line
[indent
] == '#':
229 elif min_indent
is None or min_indent
> indent
:
231 for ix
, line
in enumerate(lines
):
232 match
= non_space
.search(line
)
233 if not match
or line
[indent
] == '#':
236 lines
[ix
] = line
[min_indent
:]
237 return '\n'.join(lines
)
239 module_statement
= re
.compile(r
'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
240 def extract_func_code(code
):
244 code
= code
.replace('\t', ' ')
245 lines
= code
.split('\n')
247 if not line
.startswith(' '):
248 if module_statement
.match(line
):
253 return '\n'.join(module
), ' ' + '\n '.join(function
)
258 from inspect
import getcallargs
260 def getcallargs(func
, *arg_values
, **kwd_values
):
262 args
, varargs
, kwds
, defaults
= inspect
.getargspec(func
)
263 if varargs
is not None:
264 all
[varargs
] = arg_values
[len(args
):]
265 for name
, value
in zip(args
, arg_values
):
267 for name
, value
in kwd_values
.items():
270 raise TypeError("Duplicate argument %s" % name
)
271 all
[name
] = kwd_values
.pop(name
)
273 all
[kwds
] = kwd_values
275 raise TypeError("Unexpected keyword arguments: %s" % kwd_values
.keys())
278 first_default
= len(args
) - len(defaults
)
279 for ix
, name
in enumerate(args
):
281 if ix
>= first_default
:
282 all
[name
] = defaults
[ix
- first_default
]
284 raise TypeError("Missing argument: %s" % name
)
287 def get_body(source
):
288 ix
= source
.index(':')
289 if source
[:5] == 'lambda':
290 return "return %s" % source
[ix
+1:]
294 # Lots to be done here... It would be especially cool if compiled functions
295 # could invoke each other quickly.
296 class RuntimeCompiledFunction(object):
298 def __init__(self
, f
):
300 self
._body
= get_body(inspect
.getsource(f
))
302 def __call__(self
, *args
, **kwds
):
303 all
= getcallargs(self
._f
, *args
, **kwds
)
304 return cython_inline(self
._body
, locals=self
._f
.func_globals
, globals=self
._f
.func_globals
, **all
)