1 from Cython
.Compiler
import TypeSlots
2 from Cython
.Compiler
.ExprNodes
import not_a_constant
4 cython
.declare(UtilityCode
=object, EncodedString
=object, BytesLiteral
=object,
5 Nodes
=object, ExprNodes
=object, PyrexTypes
=object, Builtin
=object,
6 UtilNodes
=object, Naming
=object)
17 from Code
import UtilityCode
18 from StringEncoding
import EncodedString
, BytesLiteral
19 from Errors
import error
20 from ParseTreeTransforms
import SkipDeclarations
26 from __builtin__
import reduce
28 from functools
import reduce
31 from __builtin__
import basestring
33 basestring
= str # Python 3
35 def load_c_utility(name
):
36 return UtilityCode
.load_cached(name
, "Optimize.c")
38 def unwrap_coerced_node(node
, coercion_nodes
=(ExprNodes
.CoerceToPyTypeNode
, ExprNodes
.CoerceFromPyTypeNode
)):
39 if isinstance(node
, coercion_nodes
):
43 def unwrap_node(node
):
44 while isinstance(node
, UtilNodes
.ResultRefNode
):
45 node
= node
.expression
48 def is_common_value(a
, b
):
51 if isinstance(a
, ExprNodes
.NameNode
) and isinstance(b
, ExprNodes
.NameNode
):
52 return a
.name
== b
.name
53 if isinstance(a
, ExprNodes
.AttributeNode
) and isinstance(b
, ExprNodes
.AttributeNode
):
54 return not a
.is_py_attr
and is_common_value(a
.obj
, b
.obj
) and a
.attribute
== b
.attribute
57 def filter_none_node(node
):
58 if node
is not None and node
.constant_result
is None:
62 class IterationTransform(Visitor
.EnvTransform
):
63 """Transform some common for-in loop patterns into efficient C loops:
65 - for-in-dict loop becomes a while loop calling PyDict_Next()
66 - for-in-enumerate is replaced by an external counter variable
67 - for-in-range loop becomes a plain C for loop
69 def visit_PrimaryCmpNode(self
, node
):
70 if node
.is_ptr_contains():
80 result_ref
= UtilNodes
.ResultRefNode(node
)
81 if isinstance(node
.operand2
, ExprNodes
.IndexNode
):
82 base_type
= node
.operand2
.base
.type.base_type
84 base_type
= node
.operand2
.type.base_type
85 target_handle
= UtilNodes
.TempHandle(base_type
)
86 target
= target_handle
.ref(pos
)
87 cmp_node
= ExprNodes
.PrimaryCmpNode(
88 pos
, operator
=u
'==', operand1
=node
.operand1
, operand2
=target
)
89 if_body
= Nodes
.StatListNode(
91 stats
= [Nodes
.SingleAssignmentNode(pos
, lhs
=result_ref
, rhs
=ExprNodes
.BoolNode(pos
, value
=1)),
92 Nodes
.BreakStatNode(pos
)])
93 if_node
= Nodes
.IfStatNode(
95 if_clauses
=[Nodes
.IfClauseNode(pos
, condition
=cmp_node
, body
=if_body
)],
97 for_loop
= UtilNodes
.TempsBlockNode(
99 temps
= [target_handle
],
100 body
= Nodes
.ForInStatNode(
103 iterator
=ExprNodes
.IteratorNode(node
.operand2
.pos
, sequence
=node
.operand2
),
105 else_clause
=Nodes
.SingleAssignmentNode(pos
, lhs
=result_ref
, rhs
=ExprNodes
.BoolNode(pos
, value
=0))))
106 for_loop
= for_loop
.analyse_expressions(self
.current_env())
107 for_loop
= self
.visit(for_loop
)
108 new_node
= UtilNodes
.TempResultFromStatNode(result_ref
, for_loop
)
110 if node
.operator
== 'not_in':
111 new_node
= ExprNodes
.NotNode(pos
, operand
=new_node
)
115 self
.visitchildren(node
)
118 def visit_ForInStatNode(self
, node
):
119 self
.visitchildren(node
)
120 return self
._optimise
_for
_loop
(node
, node
.iterator
.sequence
)
122 def _optimise_for_loop(self
, node
, iterator
, reversed=False):
123 if iterator
.type is Builtin
.dict_type
:
124 # like iterating over dict.keys()
126 # CPython raises an error here: not a sequence
128 return self
._transform
_dict
_iteration
(
129 node
, dict_obj
=iterator
, method
=None, keys
=True, values
=False)
131 # C array (slice) iteration?
132 if iterator
.type.is_ptr
or iterator
.type.is_array
:
133 return self
._transform
_carray
_iteration
(node
, iterator
, reversed=reversed)
134 if iterator
.type is Builtin
.bytes_type
:
135 return self
._transform
_bytes
_iteration
(node
, iterator
, reversed=reversed)
136 if iterator
.type is Builtin
.unicode_type
:
137 return self
._transform
_unicode
_iteration
(node
, iterator
, reversed=reversed)
139 # the rest is based on function calls
140 if not isinstance(iterator
, ExprNodes
.SimpleCallNode
):
143 if iterator
.args
is None:
144 arg_count
= iterator
.arg_tuple
and len(iterator
.arg_tuple
.args
) or 0
146 arg_count
= len(iterator
.args
)
147 if arg_count
and iterator
.self
is not None:
150 function
= iterator
.function
152 if function
.is_attribute
and not reversed and not arg_count
:
153 base_obj
= iterator
.self
or function
.obj
154 method
= function
.attribute
155 # in Py3, items() is equivalent to Py2's iteritems()
156 is_safe_iter
= self
.global_scope().context
.language_level
>= 3
158 if not is_safe_iter
and method
in ('keys', 'values', 'items'):
159 # try to reduce this to the corresponding .iter*() methods
160 if isinstance(base_obj
, ExprNodes
.SimpleCallNode
):
161 inner_function
= base_obj
.function
162 if (inner_function
.is_name
and inner_function
.name
== 'dict'
163 and inner_function
.entry
164 and inner_function
.entry
.is_builtin
):
165 # e.g. dict(something).items() => safe to use .iter*()
168 keys
= values
= False
169 if method
== 'iterkeys' or (is_safe_iter
and method
== 'keys'):
171 elif method
== 'itervalues' or (is_safe_iter
and method
== 'values'):
173 elif method
== 'iteritems' or (is_safe_iter
and method
== 'items'):
177 return self
._transform
_dict
_iteration
(
178 node
, base_obj
, method
, keys
, values
)
180 # enumerate/reversed ?
181 if iterator
.self
is None and function
.is_name
and \
182 function
.entry
and function
.entry
.is_builtin
:
183 if function
.name
== 'enumerate':
185 # CPython raises an error here: not a sequence
187 return self
._transform
_enumerate
_iteration
(node
, iterator
)
188 elif function
.name
== 'reversed':
190 # CPython raises an error here: not a sequence
192 return self
._transform
_reversed
_iteration
(node
, iterator
)
195 if Options
.convert_range
and node
.target
.type.is_int
:
196 if iterator
.self
is None and function
.is_name
and \
197 function
.entry
and function
.entry
.is_builtin
and \
198 function
.name
in ('range', 'xrange'):
199 return self
._transform
_range
_iteration
(node
, iterator
, reversed=reversed)
203 def _transform_reversed_iteration(self
, node
, reversed_function
):
204 args
= reversed_function
.arg_tuple
.args
206 error(reversed_function
.pos
,
207 "reversed() requires an iterable argument")
210 error(reversed_function
.pos
,
211 "reversed() takes exactly 1 argument")
215 # reversed(list/tuple) ?
216 if arg
.type in (Builtin
.tuple_type
, Builtin
.list_type
):
217 node
.iterator
.sequence
= arg
.as_none_safe_node("'NoneType' object is not iterable")
218 node
.iterator
.reversed = True
221 return self
._optimise
_for
_loop
(node
, arg
, reversed=True)
223 PyBytes_AS_STRING_func_type
= PyrexTypes
.CFuncType(
224 PyrexTypes
.c_char_ptr_type
, [
225 PyrexTypes
.CFuncTypeArg("s", Builtin
.bytes_type
, None)
228 PyBytes_GET_SIZE_func_type
= PyrexTypes
.CFuncType(
229 PyrexTypes
.c_py_ssize_t_type
, [
230 PyrexTypes
.CFuncTypeArg("s", Builtin
.bytes_type
, None)
233 def _transform_bytes_iteration(self
, node
, slice_node
, reversed=False):
234 target_type
= node
.target
.type
235 if not target_type
.is_int
and target_type
is not Builtin
.bytes_type
:
236 # bytes iteration returns bytes objects in Py2, but
240 unpack_temp_node
= UtilNodes
.LetRefNode(
241 slice_node
.as_none_safe_node("'NoneType' is not iterable"))
243 slice_base_node
= ExprNodes
.PythonCapiCallNode(
244 slice_node
.pos
, "PyBytes_AS_STRING",
245 self
.PyBytes_AS_STRING_func_type
,
246 args
= [unpack_temp_node
],
249 len_node
= ExprNodes
.PythonCapiCallNode(
250 slice_node
.pos
, "PyBytes_GET_SIZE",
251 self
.PyBytes_GET_SIZE_func_type
,
252 args
= [unpack_temp_node
],
256 return UtilNodes
.LetNode(
258 self
._transform
_carray
_iteration
(
260 ExprNodes
.SliceIndexNode(
262 base
= slice_base_node
,
266 type = slice_base_node
.type,
269 reversed = reversed))
271 PyUnicode_READ_func_type
= PyrexTypes
.CFuncType(
272 PyrexTypes
.c_py_ucs4_type
, [
273 PyrexTypes
.CFuncTypeArg("kind", PyrexTypes
.c_int_type
, None),
274 PyrexTypes
.CFuncTypeArg("data", PyrexTypes
.c_void_ptr_type
, None),
275 PyrexTypes
.CFuncTypeArg("index", PyrexTypes
.c_py_ssize_t_type
, None)
278 init_unicode_iteration_func_type
= PyrexTypes
.CFuncType(
279 PyrexTypes
.c_int_type
, [
280 PyrexTypes
.CFuncTypeArg("s", PyrexTypes
.py_object_type
, None),
281 PyrexTypes
.CFuncTypeArg("length", PyrexTypes
.c_py_ssize_t_ptr_type
, None),
282 PyrexTypes
.CFuncTypeArg("data", PyrexTypes
.c_void_ptr_ptr_type
, None),
283 PyrexTypes
.CFuncTypeArg("kind", PyrexTypes
.c_int_ptr_type
, None)
285 exception_value
= '-1')
287 def _transform_unicode_iteration(self
, node
, slice_node
, reversed=False):
288 if slice_node
.is_literal
:
289 # try to reduce to byte iteration for plain Latin-1 strings
291 bytes_value
= BytesLiteral(slice_node
.value
.encode('latin1'))
292 except UnicodeEncodeError:
295 bytes_slice
= ExprNodes
.SliceIndexNode(
297 base
=ExprNodes
.BytesNode(
298 slice_node
.pos
, value
=bytes_value
,
299 constant_result
=bytes_value
,
300 type=PyrexTypes
.c_char_ptr_type
).coerce_to(
301 PyrexTypes
.c_uchar_ptr_type
, self
.current_env()),
303 stop
=ExprNodes
.IntNode(
304 slice_node
.pos
, value
=str(len(bytes_value
)),
305 constant_result
=len(bytes_value
),
306 type=PyrexTypes
.c_py_ssize_t_type
),
307 type=Builtin
.unicode_type
, # hint for Python conversion
309 return self
._transform
_carray
_iteration
(node
, bytes_slice
, reversed)
311 unpack_temp_node
= UtilNodes
.LetRefNode(
312 slice_node
.as_none_safe_node("'NoneType' is not iterable"))
314 start_node
= ExprNodes
.IntNode(
315 node
.pos
, value
='0', constant_result
=0, type=PyrexTypes
.c_py_ssize_t_type
)
316 length_temp
= UtilNodes
.TempHandle(PyrexTypes
.c_py_ssize_t_type
)
317 end_node
= length_temp
.ref(node
.pos
)
319 relation1
, relation2
= '>', '>='
320 start_node
, end_node
= end_node
, start_node
322 relation1
, relation2
= '<=', '<'
324 kind_temp
= UtilNodes
.TempHandle(PyrexTypes
.c_int_type
)
325 data_temp
= UtilNodes
.TempHandle(PyrexTypes
.c_void_ptr_type
)
326 counter_temp
= UtilNodes
.TempHandle(PyrexTypes
.c_py_ssize_t_type
)
328 target_value
= ExprNodes
.PythonCapiCallNode(
329 slice_node
.pos
, "__Pyx_PyUnicode_READ",
330 self
.PyUnicode_READ_func_type
,
331 args
= [kind_temp
.ref(slice_node
.pos
),
332 data_temp
.ref(slice_node
.pos
),
333 counter_temp
.ref(node
.target
.pos
)],
336 if target_value
.type != node
.target
.type:
337 target_value
= target_value
.coerce_to(node
.target
.type,
339 target_assign
= Nodes
.SingleAssignmentNode(
340 pos
= node
.target
.pos
,
343 body
= Nodes
.StatListNode(
345 stats
= [target_assign
, node
.body
])
347 loop_node
= Nodes
.ForFromStatNode(
349 bound1
=start_node
, relation1
=relation1
,
350 target
=counter_temp
.ref(node
.target
.pos
),
351 relation2
=relation2
, bound2
=end_node
,
352 step
=None, body
=body
,
353 else_clause
=node
.else_clause
,
356 setup_node
= Nodes
.ExprStatNode(
358 expr
= ExprNodes
.PythonCapiCallNode(
359 slice_node
.pos
, "__Pyx_init_unicode_iteration",
360 self
.init_unicode_iteration_func_type
,
361 args
= [unpack_temp_node
,
362 ExprNodes
.AmpersandNode(slice_node
.pos
, operand
=length_temp
.ref(slice_node
.pos
),
363 type=PyrexTypes
.c_py_ssize_t_ptr_type
),
364 ExprNodes
.AmpersandNode(slice_node
.pos
, operand
=data_temp
.ref(slice_node
.pos
),
365 type=PyrexTypes
.c_void_ptr_ptr_type
),
366 ExprNodes
.AmpersandNode(slice_node
.pos
, operand
=kind_temp
.ref(slice_node
.pos
),
367 type=PyrexTypes
.c_int_ptr_type
),
370 result_is_used
= False,
371 utility_code
=UtilityCode
.load_cached("unicode_iter", "Optimize.c"),
373 return UtilNodes
.LetNode(
375 UtilNodes
.TempsBlockNode(
376 node
.pos
, temps
=[counter_temp
, length_temp
, data_temp
, kind_temp
],
377 body
=Nodes
.StatListNode(node
.pos
, stats
=[setup_node
, loop_node
])))
379 def _transform_carray_iteration(self
, node
, slice_node
, reversed=False):
381 if isinstance(slice_node
, ExprNodes
.SliceIndexNode
):
382 slice_base
= slice_node
.base
383 start
= filter_none_node(slice_node
.start
)
384 stop
= filter_none_node(slice_node
.stop
)
387 if not slice_base
.type.is_pyobject
:
388 error(slice_node
.pos
, "C array iteration requires known end index")
391 elif isinstance(slice_node
, ExprNodes
.IndexNode
):
392 assert isinstance(slice_node
.index
, ExprNodes
.SliceNode
)
393 slice_base
= slice_node
.base
394 index
= slice_node
.index
395 start
= filter_none_node(index
.start
)
396 stop
= filter_none_node(index
.stop
)
397 step
= filter_none_node(index
.step
)
399 if not isinstance(step
.constant_result
, (int,long)) \
400 or step
.constant_result
== 0 \
401 or step
.constant_result
> 0 and not stop \
402 or step
.constant_result
< 0 and not start
:
403 if not slice_base
.type.is_pyobject
:
404 error(step
.pos
, "C array iteration requires known step size and end index")
407 # step sign is handled internally by ForFromStatNode
408 step_value
= step
.constant_result
410 step_value
= -step_value
411 neg_step
= step_value
< 0
412 step
= ExprNodes
.IntNode(step
.pos
, type=PyrexTypes
.c_py_ssize_t_type
,
413 value
=str(abs(step_value
)),
414 constant_result
=abs(step_value
))
416 elif slice_node
.type.is_array
:
417 if slice_node
.type.size
is None:
418 error(slice_node
.pos
, "C array iteration requires known end index")
420 slice_base
= slice_node
422 stop
= ExprNodes
.IntNode(
423 slice_node
.pos
, value
=str(slice_node
.type.size
),
424 type=PyrexTypes
.c_py_ssize_t_type
, constant_result
=slice_node
.type.size
)
428 if not slice_node
.type.is_pyobject
:
429 error(slice_node
.pos
, "C array iteration requires known end index")
433 start
= start
.coerce_to(PyrexTypes
.c_py_ssize_t_type
, self
.current_env())
435 stop
= stop
.coerce_to(PyrexTypes
.c_py_ssize_t_type
, self
.current_env())
438 stop
= ExprNodes
.IntNode(
439 slice_node
.pos
, value
='-1', type=PyrexTypes
.c_py_ssize_t_type
, constant_result
=-1)
441 error(slice_node
.pos
, "C array iteration requires known step size and end index")
446 start
= ExprNodes
.IntNode(slice_node
.pos
, value
="0", constant_result
=0,
447 type=PyrexTypes
.c_py_ssize_t_type
)
448 # if step was provided, it was already negated above
449 start
, stop
= stop
, start
451 ptr_type
= slice_base
.type
452 if ptr_type
.is_array
:
453 ptr_type
= ptr_type
.element_ptr_type()
454 carray_ptr
= slice_base
.coerce_to_simple(self
.current_env())
456 if start
and start
.constant_result
!= 0:
457 start_ptr_node
= ExprNodes
.AddNode(
464 start_ptr_node
= carray_ptr
466 if stop
and stop
.constant_result
!= 0:
467 stop_ptr_node
= ExprNodes
.AddNode(
469 operand1
=ExprNodes
.CloneNode(carray_ptr
),
473 ).coerce_to_simple(self
.current_env())
475 stop_ptr_node
= ExprNodes
.CloneNode(carray_ptr
)
477 counter
= UtilNodes
.TempHandle(ptr_type
)
478 counter_temp
= counter
.ref(node
.target
.pos
)
480 if slice_base
.type.is_string
and node
.target
.type.is_pyobject
:
481 # special case: char* -> bytes/unicode
482 if slice_node
.type is Builtin
.unicode_type
:
483 target_value
= ExprNodes
.CastNode(
484 ExprNodes
.DereferenceNode(
485 node
.target
.pos
, operand
=counter_temp
,
486 type=ptr_type
.base_type
),
487 PyrexTypes
.c_py_ucs4_type
).coerce_to(
488 node
.target
.type, self
.current_env())
490 # char* -> bytes coercion requires slicing, not indexing
491 target_value
= ExprNodes
.SliceIndexNode(
493 start
=ExprNodes
.IntNode(node
.target
.pos
, value
='0',
495 type=PyrexTypes
.c_int_type
),
496 stop
=ExprNodes
.IntNode(node
.target
.pos
, value
='1',
498 type=PyrexTypes
.c_int_type
),
500 type=Builtin
.bytes_type
,
502 elif node
.target
.type.is_ptr
and not node
.target
.type.assignable_from(ptr_type
.base_type
):
503 # Allow iteration with pointer target to avoid copy.
504 target_value
= counter_temp
506 # TODO: can this safely be replaced with DereferenceNode() as above?
507 target_value
= ExprNodes
.IndexNode(
509 index
=ExprNodes
.IntNode(node
.target
.pos
, value
='0',
511 type=PyrexTypes
.c_int_type
),
513 is_buffer_access
=False,
514 type=ptr_type
.base_type
)
516 if target_value
.type != node
.target
.type:
517 target_value
= target_value
.coerce_to(node
.target
.type,
520 target_assign
= Nodes
.SingleAssignmentNode(
521 pos
= node
.target
.pos
,
525 body
= Nodes
.StatListNode(
527 stats
= [target_assign
, node
.body
])
529 relation1
, relation2
= self
._find
_for
_from
_node
_relations
(neg_step
, reversed)
531 for_node
= Nodes
.ForFromStatNode(
533 bound1
=start_ptr_node
, relation1
=relation1
,
535 relation2
=relation2
, bound2
=stop_ptr_node
,
536 step
=step
, body
=body
,
537 else_clause
=node
.else_clause
,
540 return UtilNodes
.TempsBlockNode(
541 node
.pos
, temps
=[counter
],
544 def _transform_enumerate_iteration(self
, node
, enumerate_function
):
545 args
= enumerate_function
.arg_tuple
.args
547 error(enumerate_function
.pos
,
548 "enumerate() requires an iterable argument")
551 error(enumerate_function
.pos
,
552 "enumerate() takes at most 2 arguments")
555 if not node
.target
.is_sequence_constructor
:
556 # leave this untouched for now
558 targets
= node
.target
.args
559 if len(targets
) != 2:
560 # leave this untouched for now
563 enumerate_target
, iterable_target
= targets
564 counter_type
= enumerate_target
.type
566 if not counter_type
.is_pyobject
and not counter_type
.is_int
:
567 # nothing we can do here, I guess
571 start
= unwrap_coerced_node(args
[1]).coerce_to(counter_type
, self
.current_env())
573 start
= ExprNodes
.IntNode(enumerate_function
.pos
,
577 temp
= UtilNodes
.LetRefNode(start
)
579 inc_expression
= ExprNodes
.AddNode(
580 enumerate_function
.pos
,
582 operand2
= ExprNodes
.IntNode(node
.pos
, value
='1',
587 #inplace = True, # not worth using in-place operation for Py ints
588 is_temp
= counter_type
.is_pyobject
592 Nodes
.SingleAssignmentNode(
593 pos
= enumerate_target
.pos
,
594 lhs
= enumerate_target
,
596 Nodes
.SingleAssignmentNode(
597 pos
= enumerate_target
.pos
,
599 rhs
= inc_expression
)
602 if isinstance(node
.body
, Nodes
.StatListNode
):
603 node
.body
.stats
= loop_body
+ node
.body
.stats
605 loop_body
.append(node
.body
)
606 node
.body
= Nodes
.StatListNode(
610 node
.target
= iterable_target
611 node
.item
= node
.item
.coerce_to(iterable_target
.type, self
.current_env())
612 node
.iterator
.sequence
= args
[0]
614 # recurse into loop to check for further optimisations
615 return UtilNodes
.LetNode(temp
, self
._optimise
_for
_loop
(node
, node
.iterator
.sequence
))
617 def _find_for_from_node_relations(self
, neg_step_value
, reversed):
629 def _transform_range_iteration(self
, node
, range_function
, reversed=False):
630 args
= range_function
.arg_tuple
.args
632 step_pos
= range_function
.pos
634 step
= ExprNodes
.IntNode(step_pos
, value
='1',
639 if not isinstance(step
.constant_result
, (int, long)):
640 # cannot determine step direction
642 step_value
= step
.constant_result
644 # will lead to an error elsewhere
646 if reversed and step_value
not in (1, -1):
647 # FIXME: currently broken - requires calculation of the correct bounds
649 if not isinstance(step
, ExprNodes
.IntNode
):
650 step
= ExprNodes
.IntNode(step_pos
, value
=str(step_value
),
651 constant_result
=step_value
)
654 bound1
= ExprNodes
.IntNode(range_function
.pos
, value
='0',
656 bound2
= args
[0].coerce_to_integer(self
.current_env())
658 bound1
= args
[0].coerce_to_integer(self
.current_env())
659 bound2
= args
[1].coerce_to_integer(self
.current_env())
661 relation1
, relation2
= self
._find
_for
_from
_node
_relations
(step_value
< 0, reversed)
664 bound1
, bound2
= bound2
, bound1
666 step_value
= -step_value
669 step_value
= -step_value
671 step
.value
= str(step_value
)
672 step
.constant_result
= step_value
673 step
= step
.coerce_to_integer(self
.current_env())
675 if not bound2
.is_literal
:
676 # stop bound must be immutable => keep it in a temp var
677 bound2_is_temp
= True
678 bound2
= UtilNodes
.LetRefNode(bound2
)
680 bound2_is_temp
= False
682 for_node
= Nodes
.ForFromStatNode(
685 bound1
=bound1
, relation1
=relation1
,
686 relation2
=relation2
, bound2
=bound2
,
687 step
=step
, body
=node
.body
,
688 else_clause
=node
.else_clause
,
692 for_node
= UtilNodes
.LetNode(bound2
, for_node
)
696 def _transform_dict_iteration(self
, node
, dict_obj
, method
, keys
, values
):
698 temp
= UtilNodes
.TempHandle(PyrexTypes
.py_object_type
)
700 dict_temp
= temp
.ref(dict_obj
.pos
)
701 temp
= UtilNodes
.TempHandle(PyrexTypes
.c_py_ssize_t_type
)
703 pos_temp
= temp
.ref(node
.pos
)
705 key_target
= value_target
= tuple_target
= None
707 if node
.target
.is_sequence_constructor
:
708 if len(node
.target
.args
) == 2:
709 key_target
, value_target
= node
.target
.args
711 # unusual case that may or may not lead to an error
714 tuple_target
= node
.target
716 key_target
= node
.target
718 value_target
= node
.target
720 if isinstance(node
.body
, Nodes
.StatListNode
):
723 body
= Nodes
.StatListNode(pos
= node
.body
.pos
,
726 # keep original length to guard against dict modification
727 dict_len_temp
= UtilNodes
.TempHandle(PyrexTypes
.c_py_ssize_t_type
)
728 temps
.append(dict_len_temp
)
729 dict_len_temp_addr
= ExprNodes
.AmpersandNode(
730 node
.pos
, operand
=dict_len_temp
.ref(dict_obj
.pos
),
731 type=PyrexTypes
.c_ptr_type(dict_len_temp
.type))
732 temp
= UtilNodes
.TempHandle(PyrexTypes
.c_int_type
)
734 is_dict_temp
= temp
.ref(node
.pos
)
735 is_dict_temp_addr
= ExprNodes
.AmpersandNode(
736 node
.pos
, operand
=is_dict_temp
,
737 type=PyrexTypes
.c_ptr_type(temp
.type))
739 iter_next_node
= Nodes
.DictIterationNextNode(
740 dict_temp
, dict_len_temp
.ref(dict_obj
.pos
), pos_temp
,
741 key_target
, value_target
, tuple_target
,
743 iter_next_node
= iter_next_node
.analyse_expressions(self
.current_env())
744 body
.stats
[0:0] = [iter_next_node
]
747 method_node
= ExprNodes
.StringNode(
748 dict_obj
.pos
, is_identifier
=True, value
=method
)
749 dict_obj
= dict_obj
.as_none_safe_node(
750 "'NoneType' object has no attribute '%s'",
751 error
= "PyExc_AttributeError",
752 format_args
= [method
])
754 method_node
= ExprNodes
.NullNode(dict_obj
.pos
)
755 dict_obj
= dict_obj
.as_none_safe_node("'NoneType' object is not iterable")
757 def flag_node(value
):
758 value
= value
and 1 or 0
759 return ExprNodes
.IntNode(node
.pos
, value
=str(value
), constant_result
=value
)
762 Nodes
.SingleAssignmentNode(
765 rhs
= ExprNodes
.IntNode(node
.pos
, value
='0',
767 Nodes
.SingleAssignmentNode(
770 rhs
= ExprNodes
.PythonCapiCallNode(
772 "__Pyx_dict_iterator",
773 self
.PyDict_Iterator_func_type
,
774 utility_code
= UtilityCode
.load_cached("dict_iter", "Optimize.c"),
775 args
= [dict_obj
, flag_node(dict_obj
.type is Builtin
.dict_type
),
776 method_node
, dict_len_temp_addr
, is_dict_temp_addr
,
784 else_clause
= node
.else_clause
788 return UtilNodes
.TempsBlockNode(
789 node
.pos
, temps
=temps
,
790 body
=Nodes
.StatListNode(
795 PyDict_Iterator_func_type
= PyrexTypes
.CFuncType(
796 PyrexTypes
.py_object_type
, [
797 PyrexTypes
.CFuncTypeArg("dict", PyrexTypes
.py_object_type
, None),
798 PyrexTypes
.CFuncTypeArg("is_dict", PyrexTypes
.c_int_type
, None),
799 PyrexTypes
.CFuncTypeArg("method_name", PyrexTypes
.py_object_type
, None),
800 PyrexTypes
.CFuncTypeArg("p_orig_length", PyrexTypes
.c_py_ssize_t_ptr_type
, None),
801 PyrexTypes
.CFuncTypeArg("p_is_dict", PyrexTypes
.c_int_ptr_type
, None),
805 class SwitchTransform(Visitor
.VisitorTransform
):
807 This transformation tries to turn long if statements into C switch statements.
808 The requirement is that every clause be an (or of) var == value, where the var
809 is common among all clauses and both var and value are ints.
811 NO_MATCH
= (None, None, None)
813 def extract_conditions(self
, cond
, allow_not_in
):
815 if isinstance(cond
, (ExprNodes
.CoerceToTempNode
,
816 ExprNodes
.CoerceToBooleanNode
)):
818 elif isinstance(cond
, UtilNodes
.EvalWithTempExprNode
):
819 # this is what we get from the FlattenInListTransform
820 cond
= cond
.subexpression
821 elif isinstance(cond
, ExprNodes
.TypecastNode
):
826 if isinstance(cond
, ExprNodes
.PrimaryCmpNode
):
827 if cond
.cascade
is not None:
829 elif cond
.is_c_string_contains() and \
830 isinstance(cond
.operand2
, (ExprNodes
.UnicodeNode
, ExprNodes
.BytesNode
)):
831 not_in
= cond
.operator
== 'not_in'
832 if not_in
and not allow_not_in
:
834 if isinstance(cond
.operand2
, ExprNodes
.UnicodeNode
) and \
835 cond
.operand2
.contains_surrogates():
836 # dealing with surrogates leads to different
837 # behaviour on wide and narrow Unicode
838 # platforms => refuse to optimise this case
840 return not_in
, cond
.operand1
, self
.extract_in_string_conditions(cond
.operand2
)
841 elif not cond
.is_python_comparison():
842 if cond
.operator
== '==':
844 elif allow_not_in
and cond
.operator
== '!=':
848 # this looks somewhat silly, but it does the right
849 # checks for NameNode and AttributeNode
850 if is_common_value(cond
.operand1
, cond
.operand1
):
851 if cond
.operand2
.is_literal
:
852 return not_in
, cond
.operand1
, [cond
.operand2
]
853 elif getattr(cond
.operand2
, 'entry', None) \
854 and cond
.operand2
.entry
.is_const
:
855 return not_in
, cond
.operand1
, [cond
.operand2
]
856 if is_common_value(cond
.operand2
, cond
.operand2
):
857 if cond
.operand1
.is_literal
:
858 return not_in
, cond
.operand2
, [cond
.operand1
]
859 elif getattr(cond
.operand1
, 'entry', None) \
860 and cond
.operand1
.entry
.is_const
:
861 return not_in
, cond
.operand2
, [cond
.operand1
]
862 elif isinstance(cond
, ExprNodes
.BoolBinopNode
):
863 if cond
.operator
== 'or' or (allow_not_in
and cond
.operator
== 'and'):
864 allow_not_in
= (cond
.operator
== 'and')
865 not_in_1
, t1
, c1
= self
.extract_conditions(cond
.operand1
, allow_not_in
)
866 not_in_2
, t2
, c2
= self
.extract_conditions(cond
.operand2
, allow_not_in
)
867 if t1
is not None and not_in_1
== not_in_2
and is_common_value(t1
, t2
):
868 if (not not_in_1
) or allow_not_in
:
869 return not_in_1
, t1
, c1
+c2
872 def extract_in_string_conditions(self
, string_literal
):
873 if isinstance(string_literal
, ExprNodes
.UnicodeNode
):
874 charvals
= list(map(ord, set(string_literal
.value
)))
876 return [ ExprNodes
.IntNode(string_literal
.pos
, value
=str(charval
),
877 constant_result
=charval
)
878 for charval
in charvals
]
880 # this is a bit tricky as Py3's bytes type returns
881 # integers on iteration, whereas Py2 returns 1-char byte
883 characters
= string_literal
.value
884 characters
= list(set([ characters
[i
:i
+1] for i
in range(len(characters
)) ]))
886 return [ ExprNodes
.CharNode(string_literal
.pos
, value
=charval
,
887 constant_result
=charval
)
888 for charval
in characters
]
890 def extract_common_conditions(self
, common_var
, condition
, allow_not_in
):
891 not_in
, var
, conditions
= self
.extract_conditions(condition
, allow_not_in
)
894 elif common_var
is not None and not is_common_value(var
, common_var
):
896 elif not (var
.type.is_int
or var
.type.is_enum
) or sum([not (cond
.type.is_int
or cond
.type.is_enum
) for cond
in conditions
]):
898 return not_in
, var
, conditions
900 def has_duplicate_values(self
, condition_values
):
901 # duplicated values don't work in a switch statement
903 for value
in condition_values
:
904 if value
.has_constant_result():
905 if value
.constant_result
in seen
:
907 seen
.add(value
.constant_result
)
909 # this isn't completely safe as we don't know the
910 # final C value, but this is about the best we can do
912 if value
.entry
.cname
in seen
:
914 except AttributeError:
915 return True # play safe
916 seen
.add(value
.entry
.cname
)
919 def visit_IfStatNode(self
, node
):
922 for if_clause
in node
.if_clauses
:
923 _
, common_var
, conditions
= self
.extract_common_conditions(
924 common_var
, if_clause
.condition
, False)
925 if common_var
is None:
926 self
.visitchildren(node
)
928 cases
.append(Nodes
.SwitchCaseNode(pos
= if_clause
.pos
,
929 conditions
= conditions
,
930 body
= if_clause
.body
))
933 cond
for case
in cases
for cond
in case
.conditions
]
934 if len(condition_values
) < 2:
935 self
.visitchildren(node
)
937 if self
.has_duplicate_values(condition_values
):
938 self
.visitchildren(node
)
941 common_var
= unwrap_node(common_var
)
942 switch_node
= Nodes
.SwitchStatNode(pos
= node
.pos
,
945 else_clause
= node
.else_clause
)
948 def visit_CondExprNode(self
, node
):
949 not_in
, common_var
, conditions
= self
.extract_common_conditions(
950 None, node
.test
, True)
951 if common_var
is None \
952 or len(conditions
) < 2 \
953 or self
.has_duplicate_values(conditions
):
954 self
.visitchildren(node
)
956 return self
.build_simple_switch_statement(
957 node
, common_var
, conditions
, not_in
,
958 node
.true_val
, node
.false_val
)
960 def visit_BoolBinopNode(self
, node
):
961 not_in
, common_var
, conditions
= self
.extract_common_conditions(
963 if common_var
is None \
964 or len(conditions
) < 2 \
965 or self
.has_duplicate_values(conditions
):
966 self
.visitchildren(node
)
969 return self
.build_simple_switch_statement(
970 node
, common_var
, conditions
, not_in
,
971 ExprNodes
.BoolNode(node
.pos
, value
=True, constant_result
=True),
972 ExprNodes
.BoolNode(node
.pos
, value
=False, constant_result
=False))
974 def visit_PrimaryCmpNode(self
, node
):
975 not_in
, common_var
, conditions
= self
.extract_common_conditions(
977 if common_var
is None \
978 or len(conditions
) < 2 \
979 or self
.has_duplicate_values(conditions
):
980 self
.visitchildren(node
)
983 return self
.build_simple_switch_statement(
984 node
, common_var
, conditions
, not_in
,
985 ExprNodes
.BoolNode(node
.pos
, value
=True, constant_result
=True),
986 ExprNodes
.BoolNode(node
.pos
, value
=False, constant_result
=False))
988 def build_simple_switch_statement(self
, node
, common_var
, conditions
,
989 not_in
, true_val
, false_val
):
990 result_ref
= UtilNodes
.ResultRefNode(node
)
991 true_body
= Nodes
.SingleAssignmentNode(
996 false_body
= Nodes
.SingleAssignmentNode(
1003 true_body
, false_body
= false_body
, true_body
1005 cases
= [Nodes
.SwitchCaseNode(pos
= node
.pos
,
1006 conditions
= conditions
,
1009 common_var
= unwrap_node(common_var
)
1010 switch_node
= Nodes
.SwitchStatNode(pos
= node
.pos
,
1013 else_clause
= false_body
)
1014 replacement
= UtilNodes
.TempResultFromStatNode(result_ref
, switch_node
)
1017 def visit_EvalWithTempExprNode(self
, node
):
1018 # drop unused expression temp from FlattenInListTransform
1019 orig_expr
= node
.subexpression
1020 temp_ref
= node
.lazy_temp
1021 self
.visitchildren(node
)
1022 if node
.subexpression
is not orig_expr
:
1023 # node was restructured => check if temp is still used
1024 if not Visitor
.tree_contains(node
.subexpression
, temp_ref
):
1025 return node
.subexpression
1028 visit_Node
= Visitor
.VisitorTransform
.recurse_to_children
1031 class FlattenInListTransform(Visitor
.VisitorTransform
, SkipDeclarations
):
1033 This transformation flattens "x in [val1, ..., valn]" into a sequential list
1037 def visit_PrimaryCmpNode(self
, node
):
1038 self
.visitchildren(node
)
1039 if node
.cascade
is not None:
1041 elif node
.operator
== 'in':
1044 elif node
.operator
== 'not_in':
1050 if not isinstance(node
.operand2
, (ExprNodes
.TupleNode
,
1052 ExprNodes
.SetNode
)):
1055 args
= node
.operand2
.args
1057 # note: lhs may have side effects
1060 lhs
= UtilNodes
.ResultRefNode(node
.operand1
)
1066 # Trial optimisation to avoid redundant temp
1067 # assignments. However, since is_simple() is meant to
1068 # be called after type analysis, we ignore any errors
1069 # and just play safe in that case.
1070 is_simple_arg
= arg
.is_simple()
1072 is_simple_arg
= False
1073 if not is_simple_arg
:
1074 # must evaluate all non-simple RHS before doing the comparisons
1075 arg
= UtilNodes
.LetRefNode(arg
)
1077 cond
= ExprNodes
.PrimaryCmpNode(
1080 operator
= eq_or_neq
,
1083 conds
.append(ExprNodes
.TypecastNode(
1086 type = PyrexTypes
.c_bint_type
))
1087 def concat(left
, right
):
1088 return ExprNodes
.BoolBinopNode(
1090 operator
= conjunction
,
1094 condition
= reduce(concat
, conds
)
1095 new_node
= UtilNodes
.EvalWithTempExprNode(lhs
, condition
)
1096 for temp
in temps
[::-1]:
1097 new_node
= UtilNodes
.EvalWithTempExprNode(temp
, new_node
)
1100 visit_Node
= Visitor
.VisitorTransform
.recurse_to_children
1103 class DropRefcountingTransform(Visitor
.VisitorTransform
):
1104 """Drop ref-counting in safe places.
1106 visit_Node
= Visitor
.VisitorTransform
.recurse_to_children
1108 def visit_ParallelAssignmentNode(self
, node
):
1110 Parallel swap assignments like 'a,b = b,a' are safe.
1112 left_names
, right_names
= [], []
1113 left_indices
, right_indices
= [], []
1116 for stat
in node
.stats
:
1117 if isinstance(stat
, Nodes
.SingleAssignmentNode
):
1118 if not self
._extract
_operand
(stat
.lhs
, left_names
,
1119 left_indices
, temps
):
1121 if not self
._extract
_operand
(stat
.rhs
, right_names
,
1122 right_indices
, temps
):
1124 elif isinstance(stat
, Nodes
.CascadedAssignmentNode
):
1130 if left_names
or right_names
:
1131 # lhs/rhs names must be a non-redundant permutation
1132 lnames
= [ path
for path
, n
in left_names
]
1133 rnames
= [ path
for path
, n
in right_names
]
1134 if set(lnames
) != set(rnames
):
1136 if len(set(lnames
)) != len(right_names
):
1139 if left_indices
or right_indices
:
1140 # base name and index of index nodes must be a
1141 # non-redundant permutation
1143 for lhs_node
in left_indices
:
1144 index_id
= self
._extract
_index
_id
(lhs_node
)
1147 lindices
.append(index_id
)
1149 for rhs_node
in right_indices
:
1150 index_id
= self
._extract
_index
_id
(rhs_node
)
1153 rindices
.append(index_id
)
1155 if set(lindices
) != set(rindices
):
1157 if len(set(lindices
)) != len(right_indices
):
1160 # really supporting IndexNode requires support in
1161 # __Pyx_GetItemInt(), so let's stop short for now
1164 temp_args
= [t
.arg
for t
in temps
]
1166 temp
.use_managed_ref
= False
1168 for _
, name_node
in left_names
+ right_names
:
1169 if name_node
not in temp_args
:
1170 name_node
.use_managed_ref
= False
1172 for index_node
in left_indices
+ right_indices
:
1173 index_node
.use_managed_ref
= False
1177 def _extract_operand(self
, node
, names
, indices
, temps
):
1178 node
= unwrap_node(node
)
1179 if not node
.type.is_pyobject
:
1181 if isinstance(node
, ExprNodes
.CoerceToTempNode
):
1186 while isinstance(obj_node
, ExprNodes
.AttributeNode
):
1187 if obj_node
.is_py_attr
:
1189 name_path
.append(obj_node
.member
)
1190 obj_node
= obj_node
.obj
1191 if isinstance(obj_node
, ExprNodes
.NameNode
):
1192 name_path
.append(obj_node
.name
)
1193 names
.append( ('.'.join(name_path
[::-1]), node
) )
1194 elif isinstance(node
, ExprNodes
.IndexNode
):
1195 if node
.base
.type != Builtin
.list_type
:
1197 if not node
.index
.type.is_int
:
1199 if not isinstance(node
.base
, ExprNodes
.NameNode
):
1201 indices
.append(node
)
1206 def _extract_index_id(self
, index_node
):
1207 base
= index_node
.base
1208 index
= index_node
.index
1209 if isinstance(index
, ExprNodes
.NameNode
):
1210 index_val
= index
.name
1211 elif isinstance(index
, ExprNodes
.ConstNode
):
1216 return (base
.name
, index_val
)
1219 class EarlyReplaceBuiltinCalls(Visitor
.EnvTransform
):
1220 """Optimize some common calls to builtin types *before* the type
1221 analysis phase and *after* the declarations analysis phase.
1223 This transform cannot make use of any argument types, but it can
1224 restructure the tree in a way that the type analysis phase can
1227 Introducing C function calls here may not be a good idea. Move
1228 them to the OptimizeBuiltinCalls transform instead, which runs
1229 after type analysis.
1231 # only intercept on call nodes
1232 visit_Node
= Visitor
.VisitorTransform
.recurse_to_children
1234 def visit_SimpleCallNode(self
, node
):
1235 self
.visitchildren(node
)
1236 function
= node
.function
1237 if not self
._function
_is
_builtin
_name
(function
):
1239 return self
._dispatch
_to
_handler
(node
, function
, node
.args
)
1241 def visit_GeneralCallNode(self
, node
):
1242 self
.visitchildren(node
)
1243 function
= node
.function
1244 if not self
._function
_is
_builtin
_name
(function
):
1246 arg_tuple
= node
.positional_args
1247 if not isinstance(arg_tuple
, ExprNodes
.TupleNode
):
1249 args
= arg_tuple
.args
1250 return self
._dispatch
_to
_handler
(
1251 node
, function
, args
, node
.keyword_args
)
1253 def _function_is_builtin_name(self
, function
):
1254 if not function
.is_name
:
1256 env
= self
.current_env()
1257 entry
= env
.lookup(function
.name
)
1258 if entry
is not env
.builtin_scope().lookup_here(function
.name
):
1260 # if entry is None, it's at least an undeclared name, so likely builtin
1263 def _dispatch_to_handler(self
, node
, function
, args
, kwargs
=None):
1265 handler_name
= '_handle_simple_function_%s' % function
.name
1267 handler_name
= '_handle_general_function_%s' % function
.name
1268 handle_call
= getattr(self
, handler_name
, None)
1269 if handle_call
is not None:
1271 return handle_call(node
, args
)
1273 return handle_call(node
, args
, kwargs
)
1276 def _inject_capi_function(self
, node
, cname
, func_type
, utility_code
=None):
1277 node
.function
= ExprNodes
.PythonCapiFunctionNode(
1278 node
.function
.pos
, node
.function
.name
, cname
, func_type
,
1279 utility_code
= utility_code
)
1281 def _error_wrong_arg_count(self
, function_name
, node
, args
, expected
=None):
1282 if not expected
: # None or 0
1284 elif isinstance(expected
, basestring
) or expected
> 1:
1290 if expected
is not None:
1291 expected_str
= 'expected %s, ' % expected
1294 error(node
.pos
, "%s(%s) called with wrong number of args, %sfound %d" % (
1295 function_name
, arg_str
, expected_str
, len(args
)))
1297 # specific handlers for simple call nodes
1299 def _handle_simple_function_float(self
, node
, pos_args
):
1301 return ExprNodes
.FloatNode(node
.pos
, value
='0.0')
1302 if len(pos_args
) > 1:
1303 self
._error
_wrong
_arg
_count
('float', node
, pos_args
, 1)
1304 arg_type
= getattr(pos_args
[0], 'type', None)
1305 if arg_type
in (PyrexTypes
.c_double_type
, Builtin
.float_type
):
1309 class YieldNodeCollector(Visitor
.TreeVisitor
):
1311 Visitor
.TreeVisitor
.__init
__(self
)
1312 self
.yield_stat_nodes
= {}
1313 self
.yield_nodes
= []
1315 visit_Node
= Visitor
.TreeVisitor
.visitchildren
1316 # XXX: disable inlining while it's not back supported
1317 def __visit_YieldExprNode(self
, node
):
1318 self
.yield_nodes
.append(node
)
1319 self
.visitchildren(node
)
1321 def __visit_ExprStatNode(self
, node
):
1322 self
.visitchildren(node
)
1323 if node
.expr
in self
.yield_nodes
:
1324 self
.yield_stat_nodes
[node
.expr
] = node
1326 def __visit_GeneratorExpressionNode(self
, node
):
1327 # enable when we support generic generator expressions
1329 # everything below this node is out of scope
1332 def _find_single_yield_expression(self
, node
):
1333 collector
= self
.YieldNodeCollector()
1334 collector
.visitchildren(node
)
1335 if len(collector
.yield_nodes
) != 1:
1337 yield_node
= collector
.yield_nodes
[0]
1339 return (yield_node
.arg
, collector
.yield_stat_nodes
[yield_node
])
1343 def _handle_simple_function_all(self
, node
, pos_args
):
1346 _result = all(x for L in LL for x in L)
1361 return self
._transform
_any
_all
(node
, pos_args
, False)
1363 def _handle_simple_function_any(self
, node
, pos_args
):
1366 _result = any(x for L in LL for x in L)
1381 return self
._transform
_any
_all
(node
, pos_args
, True)
1383 def _transform_any_all(self
, node
, pos_args
, is_any
):
1384 if len(pos_args
) != 1:
1386 if not isinstance(pos_args
[0], ExprNodes
.GeneratorExpressionNode
):
1388 gen_expr_node
= pos_args
[0]
1389 loop_node
= gen_expr_node
.loop
1390 yield_expression
, yield_stat_node
= self
._find
_single
_yield
_expression
(loop_node
)
1391 if yield_expression
is None:
1395 condition
= yield_expression
1397 condition
= ExprNodes
.NotNode(yield_expression
.pos
, operand
= yield_expression
)
1399 result_ref
= UtilNodes
.ResultRefNode(pos
=node
.pos
, type=PyrexTypes
.c_bint_type
)
1400 test_node
= Nodes
.IfStatNode(
1401 yield_expression
.pos
,
1403 if_clauses
= [ Nodes
.IfClauseNode(
1404 yield_expression
.pos
,
1405 condition
= condition
,
1406 body
= Nodes
.StatListNode(
1409 Nodes
.SingleAssignmentNode(
1412 rhs
= ExprNodes
.BoolNode(yield_expression
.pos
, value
= is_any
,
1413 constant_result
= is_any
)),
1414 Nodes
.BreakStatNode(node
.pos
)
1418 while isinstance(loop
.body
, Nodes
.LoopNode
):
1419 next_loop
= loop
.body
1420 loop
.body
= Nodes
.StatListNode(loop
.body
.pos
, stats
= [
1422 Nodes
.BreakStatNode(yield_expression
.pos
)
1424 next_loop
.else_clause
= Nodes
.ContinueStatNode(yield_expression
.pos
)
1426 loop_node
.else_clause
= Nodes
.SingleAssignmentNode(
1429 rhs
= ExprNodes
.BoolNode(yield_expression
.pos
, value
= not is_any
,
1430 constant_result
= not is_any
))
1432 Visitor
.recursively_replace_node(loop_node
, yield_stat_node
, test_node
)
1434 return ExprNodes
.InlinedGeneratorExpressionNode(
1435 gen_expr_node
.pos
, loop
= loop_node
, result_node
= result_ref
,
1436 expr_scope
= gen_expr_node
.expr_scope
, orig_func
= is_any
and 'any' or 'all')
1438 def _handle_simple_function_sorted(self
, node
, pos_args
):
1439 """Transform sorted(genexpr) and sorted([listcomp]) into
1440 [listcomp].sort(). CPython just reads the iterable into a
1441 list and calls .sort() on it. Expanding the iterable in a
1442 listcomp is still faster and the result can be sorted in
1445 if len(pos_args
) != 1:
1447 if isinstance(pos_args
[0], ExprNodes
.ComprehensionNode
) \
1448 and pos_args
[0].type is Builtin
.list_type
:
1449 listcomp_node
= pos_args
[0]
1450 loop_node
= listcomp_node
.loop
1451 elif isinstance(pos_args
[0], ExprNodes
.GeneratorExpressionNode
):
1452 gen_expr_node
= pos_args
[0]
1453 loop_node
= gen_expr_node
.loop
1454 yield_expression
, yield_stat_node
= self
._find
_single
_yield
_expression
(loop_node
)
1455 if yield_expression
is None:
1458 append_node
= ExprNodes
.ComprehensionAppendNode(
1459 yield_expression
.pos
, expr
= yield_expression
)
1461 Visitor
.recursively_replace_node(loop_node
, yield_stat_node
, append_node
)
1463 listcomp_node
= ExprNodes
.ComprehensionNode(
1464 gen_expr_node
.pos
, loop
= loop_node
,
1465 append
= append_node
, type = Builtin
.list_type
,
1466 expr_scope
= gen_expr_node
.expr_scope
,
1467 has_local_scope
= True)
1468 append_node
.target
= listcomp_node
1472 result_node
= UtilNodes
.ResultRefNode(
1473 pos
= loop_node
.pos
, type = Builtin
.list_type
, may_hold_none
=False)
1474 listcomp_assign_node
= Nodes
.SingleAssignmentNode(
1475 node
.pos
, lhs
= result_node
, rhs
= listcomp_node
, first
= True)
1477 sort_method
= ExprNodes
.AttributeNode(
1478 node
.pos
, obj
= result_node
, attribute
= EncodedString('sort'),
1480 needs_none_check
= False)
1481 sort_node
= Nodes
.ExprStatNode(
1482 node
.pos
, expr
= ExprNodes
.SimpleCallNode(
1483 node
.pos
, function
= sort_method
, args
= []))
1485 sort_node
.analyse_declarations(self
.current_env())
1487 return UtilNodes
.TempResultFromStatNode(
1489 Nodes
.StatListNode(node
.pos
, stats
= [ listcomp_assign_node
, sort_node
]))
1491 def _handle_simple_function_sum(self
, node
, pos_args
):
1492 """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1494 if len(pos_args
) not in (1,2):
1496 if not isinstance(pos_args
[0], (ExprNodes
.GeneratorExpressionNode
,
1497 ExprNodes
.ComprehensionNode
)):
1499 gen_expr_node
= pos_args
[0]
1500 loop_node
= gen_expr_node
.loop
1502 if isinstance(gen_expr_node
, ExprNodes
.GeneratorExpressionNode
):
1503 yield_expression
, yield_stat_node
= self
._find
_single
_yield
_expression
(loop_node
)
1504 if yield_expression
is None:
1506 else: # ComprehensionNode
1507 yield_stat_node
= gen_expr_node
.append
1508 yield_expression
= yield_stat_node
.expr
1510 if not yield_expression
.is_literal
or not yield_expression
.type.is_int
:
1512 except AttributeError:
1513 return node
# in case we don't have a type yet
1514 # special case: old Py2 backwards compatible "sum([int_const for ...])"
1515 # can safely be unpacked into a genexpr
1517 if len(pos_args
) == 1:
1518 start
= ExprNodes
.IntNode(node
.pos
, value
='0', constant_result
=0)
1522 result_ref
= UtilNodes
.ResultRefNode(pos
=node
.pos
, type=PyrexTypes
.py_object_type
)
1523 add_node
= Nodes
.SingleAssignmentNode(
1524 yield_expression
.pos
,
1526 rhs
= ExprNodes
.binop_node(node
.pos
, '+', result_ref
, yield_expression
)
1529 Visitor
.recursively_replace_node(loop_node
, yield_stat_node
, add_node
)
1531 exec_code
= Nodes
.StatListNode(
1534 Nodes
.SingleAssignmentNode(
1536 lhs
= UtilNodes
.ResultRefNode(pos
=node
.pos
, expression
=result_ref
),
1542 return ExprNodes
.InlinedGeneratorExpressionNode(
1543 gen_expr_node
.pos
, loop
= exec_code
, result_node
= result_ref
,
1544 expr_scope
= gen_expr_node
.expr_scope
, orig_func
= 'sum',
1545 has_local_scope
= gen_expr_node
.has_local_scope
)
1547 def _handle_simple_function_min(self
, node
, pos_args
):
1548 return self
._optimise
_min
_max
(node
, pos_args
, '<')
1550 def _handle_simple_function_max(self
, node
, pos_args
):
1551 return self
._optimise
_min
_max
(node
, pos_args
, '>')
1553 def _optimise_min_max(self
, node
, args
, operator
):
1554 """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1557 if len(args
) == 1 and args
[0].is_sequence_constructor
:
1560 # leave this to Python
1563 cascaded_nodes
= list(map(UtilNodes
.ResultRefNode
, args
[1:]))
1565 last_result
= args
[0]
1566 for arg_node
in cascaded_nodes
:
1567 result_ref
= UtilNodes
.ResultRefNode(last_result
)
1568 last_result
= ExprNodes
.CondExprNode(
1570 true_val
= arg_node
,
1571 false_val
= result_ref
,
1572 test
= ExprNodes
.PrimaryCmpNode(
1574 operand1
= arg_node
,
1575 operator
= operator
,
1576 operand2
= result_ref
,
1579 last_result
= UtilNodes
.EvalWithTempExprNode(result_ref
, last_result
)
1581 for ref_node
in cascaded_nodes
[::-1]:
1582 last_result
= UtilNodes
.EvalWithTempExprNode(ref_node
, last_result
)
1586 def _DISABLED_handle_simple_function_tuple(self
, node
, pos_args
):
1588 return ExprNodes
.TupleNode(node
.pos
, args
=[], constant_result
=())
1589 # This is a bit special - for iterables (including genexps),
1590 # Python actually overallocates and resizes a newly created
1591 # tuple incrementally while reading items, which we can't
1592 # easily do without explicit node support. Instead, we read
1593 # the items into a list and then copy them into a tuple of the
1594 # final size. This takes up to twice as much memory, but will
1595 # have to do until we have real support for genexps.
1596 result
= self
._transform
_list
_set
_genexpr
(node
, pos_args
, Builtin
.list_type
)
1597 if result
is not node
:
1598 return ExprNodes
.AsTupleNode(node
.pos
, arg
=result
)
1601 def _handle_simple_function_frozenset(self
, node
, pos_args
):
1602 """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
1604 if len(pos_args
) != 1:
1606 if pos_args
[0].is_sequence_constructor
and not pos_args
[0].args
:
1608 elif isinstance(pos_args
[0], ExprNodes
.ListNode
):
1609 pos_args
[0] = pos_args
[0].as_tuple()
1612 def _handle_simple_function_list(self
, node
, pos_args
):
1614 return ExprNodes
.ListNode(node
.pos
, args
=[], constant_result
=[])
1615 return self
._transform
_list
_set
_genexpr
(node
, pos_args
, Builtin
.list_type
)
1617 def _handle_simple_function_set(self
, node
, pos_args
):
1619 return ExprNodes
.SetNode(node
.pos
, args
=[], constant_result
=set())
1620 return self
._transform
_list
_set
_genexpr
(node
, pos_args
, Builtin
.set_type
)
1622 def _transform_list_set_genexpr(self
, node
, pos_args
, target_type
):
1623 """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1625 if len(pos_args
) > 1:
1627 if not isinstance(pos_args
[0], ExprNodes
.GeneratorExpressionNode
):
1629 gen_expr_node
= pos_args
[0]
1630 loop_node
= gen_expr_node
.loop
1632 yield_expression
, yield_stat_node
= self
._find
_single
_yield
_expression
(loop_node
)
1633 if yield_expression
is None:
1636 append_node
= ExprNodes
.ComprehensionAppendNode(
1637 yield_expression
.pos
,
1638 expr
= yield_expression
)
1640 Visitor
.recursively_replace_node(loop_node
, yield_stat_node
, append_node
)
1642 comp
= ExprNodes
.ComprehensionNode(
1644 has_local_scope
= True,
1645 expr_scope
= gen_expr_node
.expr_scope
,
1647 append
= append_node
,
1649 append_node
.target
= comp
1652 def _handle_simple_function_dict(self
, node
, pos_args
):
1653 """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1655 if len(pos_args
) == 0:
1656 return ExprNodes
.DictNode(node
.pos
, key_value_pairs
=[], constant_result
={})
1657 if len(pos_args
) > 1:
1659 if not isinstance(pos_args
[0], ExprNodes
.GeneratorExpressionNode
):
1661 gen_expr_node
= pos_args
[0]
1662 loop_node
= gen_expr_node
.loop
1664 yield_expression
, yield_stat_node
= self
._find
_single
_yield
_expression
(loop_node
)
1665 if yield_expression
is None:
1668 if not isinstance(yield_expression
, ExprNodes
.TupleNode
):
1670 if len(yield_expression
.args
) != 2:
1673 append_node
= ExprNodes
.DictComprehensionAppendNode(
1674 yield_expression
.pos
,
1675 key_expr
= yield_expression
.args
[0],
1676 value_expr
= yield_expression
.args
[1])
1678 Visitor
.recursively_replace_node(loop_node
, yield_stat_node
, append_node
)
1680 dictcomp
= ExprNodes
.ComprehensionNode(
1682 has_local_scope
= True,
1683 expr_scope
= gen_expr_node
.expr_scope
,
1685 append
= append_node
,
1686 type = Builtin
.dict_type
)
1687 append_node
.target
= dictcomp
1690 # specific handlers for general call nodes
1692 def _handle_general_function_dict(self
, node
, pos_args
, kwargs
):
1693 """Replace dict(a=b,c=d,...) by the underlying keyword dict
1694 construction which is done anyway.
1696 if len(pos_args
) > 0:
1698 if not isinstance(kwargs
, ExprNodes
.DictNode
):
1703 class InlineDefNodeCalls(Visitor
.NodeRefCleanupMixin
, Visitor
.EnvTransform
):
1704 visit_Node
= Visitor
.VisitorTransform
.recurse_to_children
1706 def get_constant_value_node(self
, name_node
):
1707 if name_node
.cf_state
is None:
1709 if name_node
.cf_state
.cf_is_null
:
1711 entry
= self
.current_env().lookup(name_node
.name
)
1712 if not entry
or (not entry
.cf_assignments
1713 or len(entry
.cf_assignments
) != 1):
1714 # not just a single assignment in all closures
1716 return entry
.cf_assignments
[0].rhs
1718 def visit_SimpleCallNode(self
, node
):
1719 self
.visitchildren(node
)
1720 if not self
.current_directives
.get('optimize.inline_defnode_calls'):
1722 function_name
= node
.function
1723 if not function_name
.is_name
:
1725 function
= self
.get_constant_value_node(function_name
)
1726 if not isinstance(function
, ExprNodes
.PyCFunctionNode
):
1728 inlined
= ExprNodes
.InlinedDefNodeCallNode(
1729 node
.pos
, function_name
=function_name
,
1730 function
=function
, args
=node
.args
)
1731 if inlined
.can_be_inlined():
1732 return self
.replace(node
, inlined
)
1736 class OptimizeBuiltinCalls(Visitor
.MethodDispatcherTransform
):
1737 """Optimize some common methods calls and instantiation patterns
1738 for builtin types *after* the type analysis phase.
1740 Running after type analysis, this transform can only perform
1741 function replacements that do not alter the function return type
1742 in a way that was not anticipated by the type analysis.
1744 ### cleanup to avoid redundant coercions to/from Python types
1746 def _visit_PyTypeTestNode(self
, node
):
1747 # disabled - appears to break assignments in some cases, and
1748 # also drops a None check, which might still be required
1749 """Flatten redundant type checks after tree changes.
1752 self
.visitchildren(node
)
1753 if old_arg
is node
.arg
or node
.arg
.type != node
.type:
1757 def _visit_TypecastNode(self
, node
):
1758 # disabled - the user may have had a reason to put a type
1759 # cast, even if it looks redundant to Cython
1761 Drop redundant type casts.
1763 self
.visitchildren(node
)
1764 if node
.type == node
.operand
.type:
1768 def visit_ExprStatNode(self
, node
):
1770 Drop useless coercions.
1772 self
.visitchildren(node
)
1773 if isinstance(node
.expr
, ExprNodes
.CoerceToPyTypeNode
):
1774 node
.expr
= node
.expr
.arg
1777 def visit_CoerceToBooleanNode(self
, node
):
1778 """Drop redundant conversion nodes after tree changes.
1780 self
.visitchildren(node
)
1782 if isinstance(arg
, ExprNodes
.PyTypeTestNode
):
1784 if isinstance(arg
, ExprNodes
.CoerceToPyTypeNode
):
1785 if arg
.type in (PyrexTypes
.py_object_type
, Builtin
.bool_type
):
1786 return arg
.arg
.coerce_to_boolean(self
.current_env())
1789 def visit_CoerceFromPyTypeNode(self
, node
):
1790 """Drop redundant conversion nodes after tree changes.
1792 Also, optimise away calls to Python's builtin int() and
1793 float() if the result is going to be coerced back into a C
1796 self
.visitchildren(node
)
1798 if not arg
.type.is_pyobject
:
1799 # no Python conversion left at all, just do a C coercion instead
1800 if node
.type == arg
.type:
1803 return arg
.coerce_to(node
.type, self
.current_env())
1804 if isinstance(arg
, ExprNodes
.PyTypeTestNode
):
1807 if (node
.type.is_int
and isinstance(arg
, ExprNodes
.IntNode
) or
1808 node
.type.is_float
and isinstance(arg
, ExprNodes
.FloatNode
) or
1809 node
.type.is_int
and isinstance(arg
, ExprNodes
.BoolNode
)):
1810 return arg
.coerce_to(node
.type, self
.current_env())
1811 elif isinstance(arg
, ExprNodes
.CoerceToPyTypeNode
):
1812 if arg
.type is PyrexTypes
.py_object_type
:
1813 if node
.type.assignable_from(arg
.arg
.type):
1814 # completely redundant C->Py->C coercion
1815 return arg
.arg
.coerce_to(node
.type, self
.current_env())
1816 elif isinstance(arg
, ExprNodes
.SimpleCallNode
):
1817 if node
.type.is_int
or node
.type.is_float
:
1818 return self
._optimise
_numeric
_cast
_call
(node
, arg
)
1819 elif isinstance(arg
, ExprNodes
.IndexNode
) and not arg
.is_buffer_access
:
1820 index_node
= arg
.index
1821 if isinstance(index_node
, ExprNodes
.CoerceToPyTypeNode
):
1822 index_node
= index_node
.arg
1823 if index_node
.type.is_int
:
1824 return self
._optimise
_int
_indexing
(node
, arg
, index_node
)
1827 PyBytes_GetItemInt_func_type
= PyrexTypes
.CFuncType(
1828 PyrexTypes
.c_char_type
, [
1829 PyrexTypes
.CFuncTypeArg("bytes", Builtin
.bytes_type
, None),
1830 PyrexTypes
.CFuncTypeArg("index", PyrexTypes
.c_py_ssize_t_type
, None),
1831 PyrexTypes
.CFuncTypeArg("check_bounds", PyrexTypes
.c_int_type
, None),
1833 exception_value
= "((char)-1)",
1834 exception_check
= True)
1836 def _optimise_int_indexing(self
, coerce_node
, arg
, index_node
):
1837 env
= self
.current_env()
1838 bound_check_bool
= env
.directives
['boundscheck'] and 1 or 0
1839 if arg
.base
.type is Builtin
.bytes_type
:
1840 if coerce_node
.type in (PyrexTypes
.c_char_type
, PyrexTypes
.c_uchar_type
):
1841 # bytes[index] -> char
1842 bound_check_node
= ExprNodes
.IntNode(
1843 coerce_node
.pos
, value
=str(bound_check_bool
),
1844 constant_result
=bound_check_bool
)
1845 node
= ExprNodes
.PythonCapiCallNode(
1846 coerce_node
.pos
, "__Pyx_PyBytes_GetItemInt",
1847 self
.PyBytes_GetItemInt_func_type
,
1849 arg
.base
.as_none_safe_node("'NoneType' object is not subscriptable"),
1850 index_node
.coerce_to(PyrexTypes
.c_py_ssize_t_type
, env
),
1854 utility_code
=UtilityCode
.load_cached(
1855 'bytes_index', 'StringTools.c'))
1856 if coerce_node
.type is not PyrexTypes
.c_char_type
:
1857 node
= node
.coerce_to(coerce_node
.type, env
)
1861 def _optimise_numeric_cast_call(self
, node
, arg
):
1862 function
= arg
.function
1863 if not isinstance(function
, ExprNodes
.NameNode
) \
1864 or not function
.type.is_builtin_type \
1865 or not isinstance(arg
.arg_tuple
, ExprNodes
.TupleNode
):
1867 args
= arg
.arg_tuple
.args
1871 if isinstance(func_arg
, ExprNodes
.CoerceToPyTypeNode
):
1872 func_arg
= func_arg
.arg
1873 elif func_arg
.type.is_pyobject
:
1874 # play safe: Python conversion might work on all sorts of things
1876 if function
.name
== 'int':
1877 if func_arg
.type.is_int
or node
.type.is_int
:
1878 if func_arg
.type == node
.type:
1880 elif node
.type.assignable_from(func_arg
.type) or func_arg
.type.is_float
:
1881 return ExprNodes
.TypecastNode(
1882 node
.pos
, operand
=func_arg
, type=node
.type)
1883 elif function
.name
== 'float':
1884 if func_arg
.type.is_float
or node
.type.is_float
:
1885 if func_arg
.type == node
.type:
1887 elif node
.type.assignable_from(func_arg
.type) or func_arg
.type.is_float
:
1888 return ExprNodes
.TypecastNode(
1889 node
.pos
, operand
=func_arg
, type=node
.type)
1892 def _error_wrong_arg_count(self
, function_name
, node
, args
, expected
=None):
1893 if not expected
: # None or 0
1895 elif isinstance(expected
, basestring
) or expected
> 1:
1901 if expected
is not None:
1902 expected_str
= 'expected %s, ' % expected
1905 error(node
.pos
, "%s(%s) called with wrong number of args, %sfound %d" % (
1906 function_name
, arg_str
, expected_str
, len(args
)))
1908 ### generic fallbacks
1910 def _handle_function(self
, node
, function_name
, function
, arg_list
, kwargs
):
1913 def _handle_method(self
, node
, type_name
, attr_name
, function
,
1914 arg_list
, is_unbound_method
, kwargs
):
1916 Try to inject C-API calls for unbound method calls to builtin types.
1917 While the method declarations in Builtin.py already handle this, we
1918 can additionally resolve bound and unbound methods here that were
1919 assigned to variables ahead of time.
1923 if not function
or not function
.is_attribute
or not function
.obj
.is_name
:
1924 # cannot track unbound method calls over more than one indirection as
1925 # the names might have been reassigned in the meantime
1927 type_entry
= self
.current_env().lookup(type_name
)
1930 method
= ExprNodes
.AttributeNode(
1932 obj
=ExprNodes
.NameNode(
1936 type=type_entry
.type),
1937 attribute
=attr_name
,
1938 is_called
=True).analyse_as_unbound_cmethod_node(self
.current_env())
1942 if args
is None and node
.arg_tuple
:
1943 args
= node
.arg_tuple
.args
1944 call_node
= ExprNodes
.SimpleCallNode(
1948 if not is_unbound_method
:
1949 call_node
.self
= function
.obj
1950 call_node
.analyse_c_function_call(self
.current_env())
1951 call_node
.analysed
= True
1952 return call_node
.coerce_to(node
.type, self
.current_env())
1956 PyDict_Copy_func_type
= PyrexTypes
.CFuncType(
1957 Builtin
.dict_type
, [
1958 PyrexTypes
.CFuncTypeArg("dict", Builtin
.dict_type
, None)
1961 def _handle_simple_function_dict(self
, node
, function
, pos_args
):
1962 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1964 if len(pos_args
) != 1:
1967 if arg
.type is Builtin
.dict_type
:
1968 arg
= arg
.as_none_safe_node("'NoneType' is not iterable")
1969 return ExprNodes
.PythonCapiCallNode(
1970 node
.pos
, "PyDict_Copy", self
.PyDict_Copy_func_type
,
1972 is_temp
= node
.is_temp
1976 PyList_AsTuple_func_type
= PyrexTypes
.CFuncType(
1977 Builtin
.tuple_type
, [
1978 PyrexTypes
.CFuncTypeArg("list", Builtin
.list_type
, None)
1981 def _handle_simple_function_tuple(self
, node
, function
, pos_args
):
1982 """Replace tuple([...]) by a call to PyList_AsTuple.
1984 if len(pos_args
) != 1:
1987 if arg
.type is Builtin
.tuple_type
and not arg
.may_be_none():
1989 if arg
.type is not Builtin
.list_type
:
1991 pos_args
[0] = arg
.as_none_safe_node(
1992 "'NoneType' object is not iterable")
1994 return ExprNodes
.PythonCapiCallNode(
1995 node
.pos
, "PyList_AsTuple", self
.PyList_AsTuple_func_type
,
1997 is_temp
= node
.is_temp
2000 PySet_New_func_type
= PyrexTypes
.CFuncType(
2002 PyrexTypes
.CFuncTypeArg("it", PyrexTypes
.py_object_type
, None)
2005 def _handle_simple_function_set(self
, node
, function
, pos_args
):
2006 if len(pos_args
) != 1:
2008 if pos_args
[0].is_sequence_constructor
:
2009 # We can optimise set([x,y,z]) safely into a set literal,
2010 # but only if we create all items before adding them -
2011 # adding an item may raise an exception if it is not
2012 # hashable, but creating the later items may have
2016 for arg
in pos_args
[0].args
:
2017 if not arg
.is_simple():
2018 arg
= UtilNodes
.LetRefNode(arg
)
2021 result
= ExprNodes
.SetNode(node
.pos
, is_temp
=1, args
=args
)
2022 for temp
in temps
[::-1]:
2023 result
= UtilNodes
.EvalWithTempExprNode(temp
, result
)
2026 # PySet_New(it) is better than a generic Python call to set(it)
2027 return ExprNodes
.PythonCapiCallNode(
2028 node
.pos
, "PySet_New",
2029 self
.PySet_New_func_type
,
2031 is_temp
=node
.is_temp
,
2032 utility_code
=UtilityCode
.load_cached('pyset_compat', 'Builtins.c'),
2035 PyFrozenSet_New_func_type
= PyrexTypes
.CFuncType(
2036 Builtin
.frozenset_type
, [
2037 PyrexTypes
.CFuncTypeArg("it", PyrexTypes
.py_object_type
, None)
2040 def _handle_simple_function_frozenset(self
, node
, function
, pos_args
):
2042 pos_args
= [ExprNodes
.NullNode(node
.pos
)]
2043 elif len(pos_args
) > 1:
2045 elif pos_args
[0].type is Builtin
.frozenset_type
and not pos_args
[0].may_be_none():
2047 # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
2048 return ExprNodes
.PythonCapiCallNode(
2049 node
.pos
, "__Pyx_PyFrozenSet_New",
2050 self
.PyFrozenSet_New_func_type
,
2052 is_temp
=node
.is_temp
,
2053 utility_code
=UtilityCode
.load_cached('pyfrozenset_new', 'Builtins.c'),
2054 py_name
="frozenset")
2056 PyObject_AsDouble_func_type
= PyrexTypes
.CFuncType(
2057 PyrexTypes
.c_double_type
, [
2058 PyrexTypes
.CFuncTypeArg("obj", PyrexTypes
.py_object_type
, None),
2060 exception_value
= "((double)-1)",
2061 exception_check
= True)
2063 def _handle_simple_function_float(self
, node
, function
, pos_args
):
2064 """Transform float() into either a C type cast or a faster C
2067 # Note: this requires the float() function to be typed as
2068 # returning a C 'double'
2069 if len(pos_args
) == 0:
2070 return ExprNodes
.FloatNode(
2071 node
, value
="0.0", constant_result
=0.0
2072 ).coerce_to(Builtin
.float_type
, self
.current_env())
2073 elif len(pos_args
) != 1:
2074 self
._error
_wrong
_arg
_count
('float', node
, pos_args
, '0 or 1')
2076 func_arg
= pos_args
[0]
2077 if isinstance(func_arg
, ExprNodes
.CoerceToPyTypeNode
):
2078 func_arg
= func_arg
.arg
2079 if func_arg
.type is PyrexTypes
.c_double_type
:
2081 elif node
.type.assignable_from(func_arg
.type) or func_arg
.type.is_numeric
:
2082 return ExprNodes
.TypecastNode(
2083 node
.pos
, operand
=func_arg
, type=node
.type)
2084 return ExprNodes
.PythonCapiCallNode(
2085 node
.pos
, "__Pyx_PyObject_AsDouble",
2086 self
.PyObject_AsDouble_func_type
,
2088 is_temp
= node
.is_temp
,
2089 utility_code
= load_c_utility('pyobject_as_double'),
2092 PyNumber_Int_func_type
= PyrexTypes
.CFuncType(
2093 PyrexTypes
.py_object_type
, [
2094 PyrexTypes
.CFuncTypeArg("o", PyrexTypes
.py_object_type
, None)
2097 def _handle_simple_function_int(self
, node
, function
, pos_args
):
2098 """Transform int() into a faster C function call.
2100 if len(pos_args
) == 0:
2101 return ExprNodes
.IntNode(node
, value
="0", constant_result
=0,
2102 type=PyrexTypes
.py_object_type
)
2103 elif len(pos_args
) != 1:
2104 return node
# int(x, base)
2105 func_arg
= pos_args
[0]
2106 if isinstance(func_arg
, ExprNodes
.CoerceToPyTypeNode
):
2107 return node
# handled in visit_CoerceFromPyTypeNode()
2108 if func_arg
.type.is_pyobject
and node
.type.is_pyobject
:
2109 return ExprNodes
.PythonCapiCallNode(
2110 node
.pos
, "PyNumber_Int", self
.PyNumber_Int_func_type
,
2111 args
=pos_args
, is_temp
=True)
2114 def _handle_simple_function_bool(self
, node
, function
, pos_args
):
2115 """Transform bool(x) into a type coercion to a boolean.
2117 if len(pos_args
) == 0:
2118 return ExprNodes
.BoolNode(
2119 node
.pos
, value
=False, constant_result
=False
2120 ).coerce_to(Builtin
.bool_type
, self
.current_env())
2121 elif len(pos_args
) != 1:
2122 self
._error
_wrong
_arg
_count
('bool', node
, pos_args
, '0 or 1')
2125 # => !!<bint>(x) to make sure it's exactly 0 or 1
2126 operand
= pos_args
[0].coerce_to_boolean(self
.current_env())
2127 operand
= ExprNodes
.NotNode(node
.pos
, operand
= operand
)
2128 operand
= ExprNodes
.NotNode(node
.pos
, operand
= operand
)
2129 # coerce back to Python object as that's the result we are expecting
2130 return operand
.coerce_to_pyobject(self
.current_env())
2132 ### builtin functions
2134 Pyx_strlen_func_type
= PyrexTypes
.CFuncType(
2135 PyrexTypes
.c_size_t_type
, [
2136 PyrexTypes
.CFuncTypeArg("bytes", PyrexTypes
.c_char_ptr_type
, None)
2139 Pyx_Py_UNICODE_strlen_func_type
= PyrexTypes
.CFuncType(
2140 PyrexTypes
.c_size_t_type
, [
2141 PyrexTypes
.CFuncTypeArg("unicode", PyrexTypes
.c_py_unicode_ptr_type
, None)
2144 PyObject_Size_func_type
= PyrexTypes
.CFuncType(
2145 PyrexTypes
.c_py_ssize_t_type
, [
2146 PyrexTypes
.CFuncTypeArg("obj", PyrexTypes
.py_object_type
, None)
2148 exception_value
="-1")
2150 _map_to_capi_len_function
= {
2151 Builtin
.unicode_type
: "__Pyx_PyUnicode_GET_LENGTH",
2152 Builtin
.bytes_type
: "PyBytes_GET_SIZE",
2153 Builtin
.list_type
: "PyList_GET_SIZE",
2154 Builtin
.tuple_type
: "PyTuple_GET_SIZE",
2155 Builtin
.dict_type
: "PyDict_Size",
2156 Builtin
.set_type
: "PySet_Size",
2157 Builtin
.frozenset_type
: "PySet_Size",
2160 _ext_types_with_pysize
= set(["cpython.array.array"])
2162 def _handle_simple_function_len(self
, node
, function
, pos_args
):
2163 """Replace len(char*) by the equivalent call to strlen(),
2164 len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
2165 len(known_builtin_type) by an equivalent C-API call.
2167 if len(pos_args
) != 1:
2168 self
._error
_wrong
_arg
_count
('len', node
, pos_args
, 1)
2171 if isinstance(arg
, ExprNodes
.CoerceToPyTypeNode
):
2173 if arg
.type.is_string
:
2174 new_node
= ExprNodes
.PythonCapiCallNode(
2175 node
.pos
, "strlen", self
.Pyx_strlen_func_type
,
2177 is_temp
= node
.is_temp
,
2178 utility_code
= UtilityCode
.load_cached("IncludeStringH", "StringTools.c"))
2179 elif arg
.type.is_pyunicode_ptr
:
2180 new_node
= ExprNodes
.PythonCapiCallNode(
2181 node
.pos
, "__Pyx_Py_UNICODE_strlen", self
.Pyx_Py_UNICODE_strlen_func_type
,
2183 is_temp
= node
.is_temp
)
2184 elif arg
.type.is_pyobject
:
2185 cfunc_name
= self
._map
_to
_capi
_len
_function
(arg
.type)
2186 if cfunc_name
is None:
2188 if ((arg_type
.is_extension_type
or arg_type
.is_builtin_type
)
2189 and arg_type
.entry
.qualified_name
in self
._ext
_types
_with
_pysize
):
2190 cfunc_name
= 'Py_SIZE'
2193 arg
= arg
.as_none_safe_node(
2194 "object of type 'NoneType' has no len()")
2195 new_node
= ExprNodes
.PythonCapiCallNode(
2196 node
.pos
, cfunc_name
, self
.PyObject_Size_func_type
,
2198 is_temp
= node
.is_temp
)
2199 elif arg
.type.is_unicode_char
:
2200 return ExprNodes
.IntNode(node
.pos
, value
='1', constant_result
=1,
2204 if node
.type not in (PyrexTypes
.c_size_t_type
, PyrexTypes
.c_py_ssize_t_type
):
2205 new_node
= new_node
.coerce_to(node
.type, self
.current_env())
2208 Pyx_Type_func_type
= PyrexTypes
.CFuncType(
2209 Builtin
.type_type
, [
2210 PyrexTypes
.CFuncTypeArg("object", PyrexTypes
.py_object_type
, None)
2213 def _handle_simple_function_type(self
, node
, function
, pos_args
):
2214 """Replace type(o) by a macro call to Py_TYPE(o).
2216 if len(pos_args
) != 1:
2218 node
= ExprNodes
.PythonCapiCallNode(
2219 node
.pos
, "Py_TYPE", self
.Pyx_Type_func_type
,
2222 return ExprNodes
.CastNode(node
, PyrexTypes
.py_object_type
)
2224 Py_type_check_func_type
= PyrexTypes
.CFuncType(
2225 PyrexTypes
.c_bint_type
, [
2226 PyrexTypes
.CFuncTypeArg("arg", PyrexTypes
.py_object_type
, None)
2229 def _handle_simple_function_isinstance(self
, node
, function
, pos_args
):
2230 """Replace isinstance() checks against builtin types by the
2231 corresponding C-API call.
2233 if len(pos_args
) != 2:
2235 arg
, types
= pos_args
2237 if isinstance(types
, ExprNodes
.TupleNode
):
2239 if arg
.is_attribute
or not arg
.is_simple():
2240 arg
= temp
= UtilNodes
.ResultRefNode(arg
)
2241 elif types
.type is Builtin
.type_type
:
2248 env
= self
.current_env()
2249 for test_type_node
in types
:
2251 if test_type_node
.is_name
:
2252 if test_type_node
.entry
:
2253 entry
= env
.lookup(test_type_node
.entry
.name
)
2254 if entry
and entry
.type and entry
.type.is_builtin_type
:
2255 builtin_type
= entry
.type
2256 if builtin_type
is Builtin
.type_type
:
2257 # all types have type "type", but there's only one 'type'
2258 if entry
.name
!= 'type' or not (
2259 entry
.scope
and entry
.scope
.is_builtin_scope
):
2261 if builtin_type
is not None:
2262 type_check_function
= entry
.type.type_check_function(exact
=False)
2263 if type_check_function
in tests
:
2265 tests
.append(type_check_function
)
2266 type_check_args
= [arg
]
2267 elif test_type_node
.type is Builtin
.type_type
:
2268 type_check_function
= '__Pyx_TypeCheck'
2269 type_check_args
= [arg
, test_type_node
]
2273 ExprNodes
.PythonCapiCallNode(
2274 test_type_node
.pos
, type_check_function
, self
.Py_type_check_func_type
,
2275 args
= type_check_args
,
2279 def join_with_or(a
,b
, make_binop_node
=ExprNodes
.binop_node
):
2280 or_node
= make_binop_node(node
.pos
, 'or', a
, b
)
2281 or_node
.type = PyrexTypes
.c_bint_type
2282 or_node
.is_temp
= True
2285 test_node
= reduce(join_with_or
, test_nodes
).coerce_to(node
.type, env
)
2286 if temp
is not None:
2287 test_node
= UtilNodes
.EvalWithTempExprNode(temp
, test_node
)
2290 def _handle_simple_function_ord(self
, node
, function
, pos_args
):
2291 """Unpack ord(Py_UNICODE) and ord('X').
2293 if len(pos_args
) != 1:
2296 if isinstance(arg
, ExprNodes
.CoerceToPyTypeNode
):
2297 if arg
.arg
.type.is_unicode_char
:
2298 return ExprNodes
.TypecastNode(
2299 arg
.pos
, operand
=arg
.arg
, type=PyrexTypes
.c_int_type
2300 ).coerce_to(node
.type, self
.current_env())
2301 elif isinstance(arg
, ExprNodes
.UnicodeNode
):
2302 if len(arg
.value
) == 1:
2303 return ExprNodes
.IntNode(
2304 arg
.pos
, type=PyrexTypes
.c_int_type
,
2305 value
=str(ord(arg
.value
)),
2306 constant_result
=ord(arg
.value
)
2307 ).coerce_to(node
.type, self
.current_env())
2308 elif isinstance(arg
, ExprNodes
.StringNode
):
2309 if arg
.unicode_value
and len(arg
.unicode_value
) == 1 \
2310 and ord(arg
.unicode_value
) <= 255: # Py2/3 portability
2311 return ExprNodes
.IntNode(
2312 arg
.pos
, type=PyrexTypes
.c_int_type
,
2313 value
=str(ord(arg
.unicode_value
)),
2314 constant_result
=ord(arg
.unicode_value
)
2315 ).coerce_to(node
.type, self
.current_env())
2320 Pyx_tp_new_func_type
= PyrexTypes
.CFuncType(
2321 PyrexTypes
.py_object_type
, [
2322 PyrexTypes
.CFuncTypeArg("type", PyrexTypes
.py_object_type
, None),
2323 PyrexTypes
.CFuncTypeArg("args", Builtin
.tuple_type
, None),
2326 Pyx_tp_new_kwargs_func_type
= PyrexTypes
.CFuncType(
2327 PyrexTypes
.py_object_type
, [
2328 PyrexTypes
.CFuncTypeArg("type", PyrexTypes
.py_object_type
, None),
2329 PyrexTypes
.CFuncTypeArg("args", Builtin
.tuple_type
, None),
2330 PyrexTypes
.CFuncTypeArg("kwargs", Builtin
.dict_type
, None),
2333 def _handle_any_slot__new__(self
, node
, function
, args
,
2334 is_unbound_method
, kwargs
=None):
2335 """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
2338 if not is_unbound_method
or len(args
) < 1:
2341 if not obj
.is_name
or not type_arg
.is_name
:
2344 if obj
.type != Builtin
.type_type
or type_arg
.type != Builtin
.type_type
:
2345 # not a known type, play safe
2347 if not type_arg
.type_entry
or not obj
.type_entry
:
2348 if obj
.name
!= type_arg
.name
:
2350 # otherwise, we know it's a type and we know it's the same
2351 # type for both - that should do
2352 elif type_arg
.type_entry
!= obj
.type_entry
:
2353 # different types - may or may not lead to an error at runtime
2356 args_tuple
= ExprNodes
.TupleNode(node
.pos
, args
=args
[1:])
2357 args_tuple
= args_tuple
.analyse_types(
2358 self
.current_env(), skip_children
=True)
2360 if type_arg
.type_entry
:
2361 ext_type
= type_arg
.type_entry
.type
2362 if (ext_type
.is_extension_type
and ext_type
.typeobj_cname
and
2363 ext_type
.scope
.global_scope() == self
.current_env().global_scope()):
2364 # known type in current module
2365 tp_slot
= TypeSlots
.ConstructorSlot("tp_new", '__new__')
2366 slot_func_cname
= TypeSlots
.get_slot_function(ext_type
.scope
, tp_slot
)
2368 cython_scope
= self
.context
.cython_scope
2369 PyTypeObjectPtr
= PyrexTypes
.CPtrType(
2370 cython_scope
.lookup('PyTypeObject').type)
2371 pyx_tp_new_kwargs_func_type
= PyrexTypes
.CFuncType(
2372 PyrexTypes
.py_object_type
, [
2373 PyrexTypes
.CFuncTypeArg("type", PyTypeObjectPtr
, None),
2374 PyrexTypes
.CFuncTypeArg("args", PyrexTypes
.py_object_type
, None),
2375 PyrexTypes
.CFuncTypeArg("kwargs", PyrexTypes
.py_object_type
, None),
2378 type_arg
= ExprNodes
.CastNode(type_arg
, PyTypeObjectPtr
)
2380 kwargs
= ExprNodes
.NullNode(node
.pos
, type=PyrexTypes
.py_object_type
) # hack?
2381 return ExprNodes
.PythonCapiCallNode(
2382 node
.pos
, slot_func_cname
,
2383 pyx_tp_new_kwargs_func_type
,
2384 args
=[type_arg
, args_tuple
, kwargs
],
2387 # arbitrary variable, needs a None check for safety
2388 type_arg
= type_arg
.as_none_safe_node(
2389 "object.__new__(X): X is not a type object (NoneType)")
2391 utility_code
= UtilityCode
.load_cached('tp_new', 'ObjectHandling.c')
2393 return ExprNodes
.PythonCapiCallNode(
2394 node
.pos
, "__Pyx_tp_new_kwargs", self
.Pyx_tp_new_kwargs_func_type
,
2395 args
=[type_arg
, args_tuple
, kwargs
],
2396 utility_code
=utility_code
,
2397 is_temp
=node
.is_temp
2400 return ExprNodes
.PythonCapiCallNode(
2401 node
.pos
, "__Pyx_tp_new", self
.Pyx_tp_new_func_type
,
2402 args
=[type_arg
, args_tuple
],
2403 utility_code
=utility_code
,
2404 is_temp
=node
.is_temp
2407 ### methods of builtin types
2409 PyObject_Append_func_type
= PyrexTypes
.CFuncType(
2410 PyrexTypes
.c_returncode_type
, [
2411 PyrexTypes
.CFuncTypeArg("list", PyrexTypes
.py_object_type
, None),
2412 PyrexTypes
.CFuncTypeArg("item", PyrexTypes
.py_object_type
, None),
2414 exception_value
="-1")
2416 def _handle_simple_method_object_append(self
, node
, function
, args
, is_unbound_method
):
2417 """Optimistic optimisation as X.append() is almost always
2418 referring to a list.
2420 if len(args
) != 2 or node
.result_is_used
:
2423 return ExprNodes
.PythonCapiCallNode(
2424 node
.pos
, "__Pyx_PyObject_Append", self
.PyObject_Append_func_type
,
2426 may_return_none
=False,
2427 is_temp
=node
.is_temp
,
2428 result_is_used
=False,
2429 utility_code
=load_c_utility('append')
2432 PyByteArray_Append_func_type
= PyrexTypes
.CFuncType(
2433 PyrexTypes
.c_returncode_type
, [
2434 PyrexTypes
.CFuncTypeArg("bytearray", PyrexTypes
.py_object_type
, None),
2435 PyrexTypes
.CFuncTypeArg("value", PyrexTypes
.c_int_type
, None),
2437 exception_value
="-1")
2439 PyByteArray_AppendObject_func_type
= PyrexTypes
.CFuncType(
2440 PyrexTypes
.c_returncode_type
, [
2441 PyrexTypes
.CFuncTypeArg("bytearray", PyrexTypes
.py_object_type
, None),
2442 PyrexTypes
.CFuncTypeArg("value", PyrexTypes
.py_object_type
, None),
2444 exception_value
="-1")
2446 def _handle_simple_method_bytearray_append(self
, node
, function
, args
, is_unbound_method
):
2449 func_name
= "__Pyx_PyByteArray_Append"
2450 func_type
= self
.PyByteArray_Append_func_type
2452 value
= unwrap_coerced_node(args
[1])
2453 if value
.type.is_int
or isinstance(value
, ExprNodes
.IntNode
):
2454 value
= value
.coerce_to(PyrexTypes
.c_int_type
, self
.current_env())
2455 utility_code
= UtilityCode
.load_cached("ByteArrayAppend", "StringTools.c")
2456 elif value
.is_string_literal
:
2457 if not value
.can_coerce_to_char_literal():
2459 value
= value
.coerce_to(PyrexTypes
.c_char_type
, self
.current_env())
2460 utility_code
= UtilityCode
.load_cached("ByteArrayAppend", "StringTools.c")
2461 elif value
.type.is_pyobject
:
2462 func_name
= "__Pyx_PyByteArray_AppendObject"
2463 func_type
= self
.PyByteArray_AppendObject_func_type
2464 utility_code
= UtilityCode
.load_cached("ByteArrayAppendObject", "StringTools.c")
2468 new_node
= ExprNodes
.PythonCapiCallNode(
2469 node
.pos
, func_name
, func_type
,
2470 args
=[args
[0], value
],
2471 may_return_none
=False,
2472 is_temp
=node
.is_temp
,
2473 utility_code
=utility_code
,
2475 if node
.result_is_used
:
2476 new_node
= new_node
.coerce_to(node
.type, self
.current_env())
2479 PyObject_Pop_func_type
= PyrexTypes
.CFuncType(
2480 PyrexTypes
.py_object_type
, [
2481 PyrexTypes
.CFuncTypeArg("list", PyrexTypes
.py_object_type
, None),
2484 PyObject_PopIndex_func_type
= PyrexTypes
.CFuncType(
2485 PyrexTypes
.py_object_type
, [
2486 PyrexTypes
.CFuncTypeArg("list", PyrexTypes
.py_object_type
, None),
2487 PyrexTypes
.CFuncTypeArg("index", PyrexTypes
.c_long_type
, None),
2490 def _handle_simple_method_list_pop(self
, node
, function
, args
, is_unbound_method
):
2491 return self
._handle
_simple
_method
_object
_pop
(
2492 node
, function
, args
, is_unbound_method
, is_list
=True)
2494 def _handle_simple_method_object_pop(self
, node
, function
, args
, is_unbound_method
, is_list
=False):
2495 """Optimistic optimisation as X.pop([n]) is almost always
2496 referring to a list.
2503 args
[0] = args
[0].as_none_safe_node(
2504 "'NoneType' object has no attribute '%s'",
2505 error
="PyExc_AttributeError",
2506 format_args
=['pop'])
2508 type_name
= 'Object'
2510 return ExprNodes
.PythonCapiCallNode(
2511 node
.pos
, "__Pyx_Py%s_Pop" % type_name
,
2512 self
.PyObject_Pop_func_type
,
2514 may_return_none
=True,
2515 is_temp
=node
.is_temp
,
2516 utility_code
=load_c_utility('pop'),
2518 elif len(args
) == 2:
2519 index
= unwrap_coerced_node(args
[1])
2520 if is_list
or isinstance(index
, ExprNodes
.IntNode
):
2521 index
= index
.coerce_to(PyrexTypes
.c_py_ssize_t_type
, self
.current_env())
2522 if index
.type.is_int
:
2523 widest
= PyrexTypes
.widest_numeric_type(
2524 index
.type, PyrexTypes
.c_py_ssize_t_type
)
2525 if widest
== PyrexTypes
.c_py_ssize_t_type
:
2527 return ExprNodes
.PythonCapiCallNode(
2528 node
.pos
, "__Pyx_Py%s_PopIndex" % type_name
,
2529 self
.PyObject_PopIndex_func_type
,
2531 may_return_none
=True,
2532 is_temp
=node
.is_temp
,
2533 utility_code
=load_c_utility("pop_index"),
2538 single_param_func_type
= PyrexTypes
.CFuncType(
2539 PyrexTypes
.c_returncode_type
, [
2540 PyrexTypes
.CFuncTypeArg("obj", PyrexTypes
.py_object_type
, None),
2542 exception_value
= "-1")
2544 def _handle_simple_method_list_sort(self
, node
, function
, args
, is_unbound_method
):
2545 """Call PyList_Sort() instead of the 0-argument l.sort().
2549 return self
._substitute
_method
_call
(
2550 node
, function
, "PyList_Sort", self
.single_param_func_type
,
2551 'sort', is_unbound_method
, args
).coerce_to(node
.type, self
.current_env
)
2553 Pyx_PyDict_GetItem_func_type
= PyrexTypes
.CFuncType(
2554 PyrexTypes
.py_object_type
, [
2555 PyrexTypes
.CFuncTypeArg("dict", PyrexTypes
.py_object_type
, None),
2556 PyrexTypes
.CFuncTypeArg("key", PyrexTypes
.py_object_type
, None),
2557 PyrexTypes
.CFuncTypeArg("default", PyrexTypes
.py_object_type
, None),
2560 def _handle_simple_method_dict_get(self
, node
, function
, args
, is_unbound_method
):
2561 """Replace dict.get() by a call to PyDict_GetItem().
2564 args
.append(ExprNodes
.NoneNode(node
.pos
))
2565 elif len(args
) != 3:
2566 self
._error
_wrong
_arg
_count
('dict.get', node
, args
, "2 or 3")
2569 return self
._substitute
_method
_call
(
2571 "__Pyx_PyDict_GetItemDefault", self
.Pyx_PyDict_GetItem_func_type
,
2572 'get', is_unbound_method
, args
,
2573 may_return_none
= True,
2574 utility_code
= load_c_utility("dict_getitem_default"))
2576 Pyx_PyDict_SetDefault_func_type
= PyrexTypes
.CFuncType(
2577 PyrexTypes
.py_object_type
, [
2578 PyrexTypes
.CFuncTypeArg("dict", PyrexTypes
.py_object_type
, None),
2579 PyrexTypes
.CFuncTypeArg("key", PyrexTypes
.py_object_type
, None),
2580 PyrexTypes
.CFuncTypeArg("default", PyrexTypes
.py_object_type
, None),
2581 PyrexTypes
.CFuncTypeArg("is_safe_type", PyrexTypes
.c_int_type
, None),
2584 def _handle_simple_method_dict_setdefault(self
, node
, function
, args
, is_unbound_method
):
2585 """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
2588 args
.append(ExprNodes
.NoneNode(node
.pos
))
2589 elif len(args
) != 3:
2590 self
._error
_wrong
_arg
_count
('dict.setdefault', node
, args
, "2 or 3")
2592 key_type
= args
[1].type
2593 if key_type
.is_builtin_type
:
2594 is_safe_type
= int(key_type
.name
in
2595 'str bytes unicode float int long bool')
2596 elif key_type
is PyrexTypes
.py_object_type
:
2597 is_safe_type
= -1 # don't know
2599 is_safe_type
= 0 # definitely not
2600 args
.append(ExprNodes
.IntNode(
2601 node
.pos
, value
=str(is_safe_type
), constant_result
=is_safe_type
))
2603 return self
._substitute
_method
_call
(
2605 "__Pyx_PyDict_SetDefault", self
.Pyx_PyDict_SetDefault_func_type
,
2606 'setdefault', is_unbound_method
, args
,
2607 may_return_none
=True,
2608 utility_code
=load_c_utility('dict_setdefault'))
2611 ### unicode type methods
2613 PyUnicode_uchar_predicate_func_type
= PyrexTypes
.CFuncType(
2614 PyrexTypes
.c_bint_type
, [
2615 PyrexTypes
.CFuncTypeArg("uchar", PyrexTypes
.c_py_ucs4_type
, None),
2618 def _inject_unicode_predicate(self
, node
, function
, args
, is_unbound_method
):
2619 if is_unbound_method
or len(args
) != 1:
2622 if not isinstance(ustring
, ExprNodes
.CoerceToPyTypeNode
) or \
2623 not ustring
.arg
.type.is_unicode_char
:
2626 method_name
= function
.attribute
2627 if method_name
== 'istitle':
2628 # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2629 utility_code
= UtilityCode
.load_cached(
2630 "py_unicode_istitle", "StringTools.c")
2631 function_name
= '__Pyx_Py_UNICODE_ISTITLE'
2634 function_name
= 'Py_UNICODE_%s' % method_name
.upper()
2635 func_call
= self
._substitute
_method
_call
(
2637 function_name
, self
.PyUnicode_uchar_predicate_func_type
,
2638 method_name
, is_unbound_method
, [uchar
],
2639 utility_code
= utility_code
)
2640 if node
.type.is_pyobject
:
2641 func_call
= func_call
.coerce_to_pyobject(self
.current_env
)
2644 _handle_simple_method_unicode_isalnum
= _inject_unicode_predicate
2645 _handle_simple_method_unicode_isalpha
= _inject_unicode_predicate
2646 _handle_simple_method_unicode_isdecimal
= _inject_unicode_predicate
2647 _handle_simple_method_unicode_isdigit
= _inject_unicode_predicate
2648 _handle_simple_method_unicode_islower
= _inject_unicode_predicate
2649 _handle_simple_method_unicode_isnumeric
= _inject_unicode_predicate
2650 _handle_simple_method_unicode_isspace
= _inject_unicode_predicate
2651 _handle_simple_method_unicode_istitle
= _inject_unicode_predicate
2652 _handle_simple_method_unicode_isupper
= _inject_unicode_predicate
2654 PyUnicode_uchar_conversion_func_type
= PyrexTypes
.CFuncType(
2655 PyrexTypes
.c_py_ucs4_type
, [
2656 PyrexTypes
.CFuncTypeArg("uchar", PyrexTypes
.c_py_ucs4_type
, None),
2659 def _inject_unicode_character_conversion(self
, node
, function
, args
, is_unbound_method
):
2660 if is_unbound_method
or len(args
) != 1:
2663 if not isinstance(ustring
, ExprNodes
.CoerceToPyTypeNode
) or \
2664 not ustring
.arg
.type.is_unicode_char
:
2667 method_name
= function
.attribute
2668 function_name
= 'Py_UNICODE_TO%s' % method_name
.upper()
2669 func_call
= self
._substitute
_method
_call
(
2671 function_name
, self
.PyUnicode_uchar_conversion_func_type
,
2672 method_name
, is_unbound_method
, [uchar
])
2673 if node
.type.is_pyobject
:
2674 func_call
= func_call
.coerce_to_pyobject(self
.current_env
)
2677 _handle_simple_method_unicode_lower
= _inject_unicode_character_conversion
2678 _handle_simple_method_unicode_upper
= _inject_unicode_character_conversion
2679 _handle_simple_method_unicode_title
= _inject_unicode_character_conversion
2681 PyUnicode_Splitlines_func_type
= PyrexTypes
.CFuncType(
2682 Builtin
.list_type
, [
2683 PyrexTypes
.CFuncTypeArg("str", Builtin
.unicode_type
, None),
2684 PyrexTypes
.CFuncTypeArg("keepends", PyrexTypes
.c_bint_type
, None),
2687 def _handle_simple_method_unicode_splitlines(self
, node
, function
, args
, is_unbound_method
):
2688 """Replace unicode.splitlines(...) by a direct call to the
2689 corresponding C-API function.
2691 if len(args
) not in (1,2):
2692 self
._error
_wrong
_arg
_count
('unicode.splitlines', node
, args
, "1 or 2")
2694 self
._inject
_bint
_default
_argument
(node
, args
, 1, False)
2696 return self
._substitute
_method
_call
(
2698 "PyUnicode_Splitlines", self
.PyUnicode_Splitlines_func_type
,
2699 'splitlines', is_unbound_method
, args
)
2701 PyUnicode_Split_func_type
= PyrexTypes
.CFuncType(
2702 Builtin
.list_type
, [
2703 PyrexTypes
.CFuncTypeArg("str", Builtin
.unicode_type
, None),
2704 PyrexTypes
.CFuncTypeArg("sep", PyrexTypes
.py_object_type
, None),
2705 PyrexTypes
.CFuncTypeArg("maxsplit", PyrexTypes
.c_py_ssize_t_type
, None),
2709 def _handle_simple_method_unicode_split(self
, node
, function
, args
, is_unbound_method
):
2710 """Replace unicode.split(...) by a direct call to the
2711 corresponding C-API function.
2713 if len(args
) not in (1,2,3):
2714 self
._error
_wrong
_arg
_count
('unicode.split', node
, args
, "1-3")
2717 args
.append(ExprNodes
.NullNode(node
.pos
))
2718 self
._inject
_int
_default
_argument
(
2719 node
, args
, 2, PyrexTypes
.c_py_ssize_t_type
, "-1")
2721 return self
._substitute
_method
_call
(
2723 "PyUnicode_Split", self
.PyUnicode_Split_func_type
,
2724 'split', is_unbound_method
, args
)
2726 PyString_Tailmatch_func_type
= PyrexTypes
.CFuncType(
2727 PyrexTypes
.c_bint_type
, [
2728 PyrexTypes
.CFuncTypeArg("str", PyrexTypes
.py_object_type
, None), # bytes/str/unicode
2729 PyrexTypes
.CFuncTypeArg("substring", PyrexTypes
.py_object_type
, None),
2730 PyrexTypes
.CFuncTypeArg("start", PyrexTypes
.c_py_ssize_t_type
, None),
2731 PyrexTypes
.CFuncTypeArg("end", PyrexTypes
.c_py_ssize_t_type
, None),
2732 PyrexTypes
.CFuncTypeArg("direction", PyrexTypes
.c_int_type
, None),
2734 exception_value
= '-1')
2736 def _handle_simple_method_unicode_endswith(self
, node
, function
, args
, is_unbound_method
):
2737 return self
._inject
_tailmatch
(
2738 node
, function
, args
, is_unbound_method
, 'unicode', 'endswith',
2739 unicode_tailmatch_utility_code
, +1)
2741 def _handle_simple_method_unicode_startswith(self
, node
, function
, args
, is_unbound_method
):
2742 return self
._inject
_tailmatch
(
2743 node
, function
, args
, is_unbound_method
, 'unicode', 'startswith',
2744 unicode_tailmatch_utility_code
, -1)
2746 def _inject_tailmatch(self
, node
, function
, args
, is_unbound_method
, type_name
,
2747 method_name
, utility_code
, direction
):
2748 """Replace unicode.startswith(...) and unicode.endswith(...)
2749 by a direct call to the corresponding C-API function.
2751 if len(args
) not in (2,3,4):
2752 self
._error
_wrong
_arg
_count
('%s.%s' % (type_name
, method_name
), node
, args
, "2-4")
2754 self
._inject
_int
_default
_argument
(
2755 node
, args
, 2, PyrexTypes
.c_py_ssize_t_type
, "0")
2756 self
._inject
_int
_default
_argument
(
2757 node
, args
, 3, PyrexTypes
.c_py_ssize_t_type
, "PY_SSIZE_T_MAX")
2758 args
.append(ExprNodes
.IntNode(
2759 node
.pos
, value
=str(direction
), type=PyrexTypes
.c_int_type
))
2761 method_call
= self
._substitute
_method
_call
(
2763 "__Pyx_Py%s_Tailmatch" % type_name
.capitalize(),
2764 self
.PyString_Tailmatch_func_type
,
2765 method_name
, is_unbound_method
, args
,
2766 utility_code
= utility_code
)
2767 return method_call
.coerce_to(Builtin
.bool_type
, self
.current_env())
2769 PyUnicode_Find_func_type
= PyrexTypes
.CFuncType(
2770 PyrexTypes
.c_py_ssize_t_type
, [
2771 PyrexTypes
.CFuncTypeArg("str", Builtin
.unicode_type
, None),
2772 PyrexTypes
.CFuncTypeArg("substring", PyrexTypes
.py_object_type
, None),
2773 PyrexTypes
.CFuncTypeArg("start", PyrexTypes
.c_py_ssize_t_type
, None),
2774 PyrexTypes
.CFuncTypeArg("end", PyrexTypes
.c_py_ssize_t_type
, None),
2775 PyrexTypes
.CFuncTypeArg("direction", PyrexTypes
.c_int_type
, None),
2777 exception_value
= '-2')
2779 def _handle_simple_method_unicode_find(self
, node
, function
, args
, is_unbound_method
):
2780 return self
._inject
_unicode
_find
(
2781 node
, function
, args
, is_unbound_method
, 'find', +1)
2783 def _handle_simple_method_unicode_rfind(self
, node
, function
, args
, is_unbound_method
):
2784 return self
._inject
_unicode
_find
(
2785 node
, function
, args
, is_unbound_method
, 'rfind', -1)
2787 def _inject_unicode_find(self
, node
, function
, args
, is_unbound_method
,
2788 method_name
, direction
):
2789 """Replace unicode.find(...) and unicode.rfind(...) by a
2790 direct call to the corresponding C-API function.
2792 if len(args
) not in (2,3,4):
2793 self
._error
_wrong
_arg
_count
('unicode.%s' % method_name
, node
, args
, "2-4")
2795 self
._inject
_int
_default
_argument
(
2796 node
, args
, 2, PyrexTypes
.c_py_ssize_t_type
, "0")
2797 self
._inject
_int
_default
_argument
(
2798 node
, args
, 3, PyrexTypes
.c_py_ssize_t_type
, "PY_SSIZE_T_MAX")
2799 args
.append(ExprNodes
.IntNode(
2800 node
.pos
, value
=str(direction
), type=PyrexTypes
.c_int_type
))
2802 method_call
= self
._substitute
_method
_call
(
2803 node
, function
, "PyUnicode_Find", self
.PyUnicode_Find_func_type
,
2804 method_name
, is_unbound_method
, args
)
2805 return method_call
.coerce_to_pyobject(self
.current_env())
2807 PyUnicode_Count_func_type
= PyrexTypes
.CFuncType(
2808 PyrexTypes
.c_py_ssize_t_type
, [
2809 PyrexTypes
.CFuncTypeArg("str", Builtin
.unicode_type
, None),
2810 PyrexTypes
.CFuncTypeArg("substring", PyrexTypes
.py_object_type
, None),
2811 PyrexTypes
.CFuncTypeArg("start", PyrexTypes
.c_py_ssize_t_type
, None),
2812 PyrexTypes
.CFuncTypeArg("end", PyrexTypes
.c_py_ssize_t_type
, None),
2814 exception_value
= '-1')
2816 def _handle_simple_method_unicode_count(self
, node
, function
, args
, is_unbound_method
):
2817 """Replace unicode.count(...) by a direct call to the
2818 corresponding C-API function.
2820 if len(args
) not in (2,3,4):
2821 self
._error
_wrong
_arg
_count
('unicode.count', node
, args
, "2-4")
2823 self
._inject
_int
_default
_argument
(
2824 node
, args
, 2, PyrexTypes
.c_py_ssize_t_type
, "0")
2825 self
._inject
_int
_default
_argument
(
2826 node
, args
, 3, PyrexTypes
.c_py_ssize_t_type
, "PY_SSIZE_T_MAX")
2828 method_call
= self
._substitute
_method
_call
(
2829 node
, function
, "PyUnicode_Count", self
.PyUnicode_Count_func_type
,
2830 'count', is_unbound_method
, args
)
2831 return method_call
.coerce_to_pyobject(self
.current_env())
2833 PyUnicode_Replace_func_type
= PyrexTypes
.CFuncType(
2834 Builtin
.unicode_type
, [
2835 PyrexTypes
.CFuncTypeArg("str", Builtin
.unicode_type
, None),
2836 PyrexTypes
.CFuncTypeArg("substring", PyrexTypes
.py_object_type
, None),
2837 PyrexTypes
.CFuncTypeArg("replstr", PyrexTypes
.py_object_type
, None),
2838 PyrexTypes
.CFuncTypeArg("maxcount", PyrexTypes
.c_py_ssize_t_type
, None),
2841 def _handle_simple_method_unicode_replace(self
, node
, function
, args
, is_unbound_method
):
2842 """Replace unicode.replace(...) by a direct call to the
2843 corresponding C-API function.
2845 if len(args
) not in (3,4):
2846 self
._error
_wrong
_arg
_count
('unicode.replace', node
, args
, "3-4")
2848 self
._inject
_int
_default
_argument
(
2849 node
, args
, 3, PyrexTypes
.c_py_ssize_t_type
, "-1")
2851 return self
._substitute
_method
_call
(
2852 node
, function
, "PyUnicode_Replace", self
.PyUnicode_Replace_func_type
,
2853 'replace', is_unbound_method
, args
)
2855 PyUnicode_AsEncodedString_func_type
= PyrexTypes
.CFuncType(
2856 Builtin
.bytes_type
, [
2857 PyrexTypes
.CFuncTypeArg("obj", Builtin
.unicode_type
, None),
2858 PyrexTypes
.CFuncTypeArg("encoding", PyrexTypes
.c_char_ptr_type
, None),
2859 PyrexTypes
.CFuncTypeArg("errors", PyrexTypes
.c_char_ptr_type
, None),
2862 PyUnicode_AsXyzString_func_type
= PyrexTypes
.CFuncType(
2863 Builtin
.bytes_type
, [
2864 PyrexTypes
.CFuncTypeArg("obj", Builtin
.unicode_type
, None),
2867 _special_encodings
= ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2868 'unicode_escape', 'raw_unicode_escape']
2870 _special_codecs
= [ (name
, codecs
.getencoder(name
))
2871 for name
in _special_encodings
]
2873 def _handle_simple_method_unicode_encode(self
, node
, function
, args
, is_unbound_method
):
2874 """Replace unicode.encode(...) by a direct C-API call to the
2875 corresponding codec.
2877 if len(args
) < 1 or len(args
) > 3:
2878 self
._error
_wrong
_arg
_count
('unicode.encode', node
, args
, '1-3')
2881 string_node
= args
[0]
2884 null_node
= ExprNodes
.NullNode(node
.pos
)
2885 return self
._substitute
_method
_call
(
2886 node
, function
, "PyUnicode_AsEncodedString",
2887 self
.PyUnicode_AsEncodedString_func_type
,
2888 'encode', is_unbound_method
, [string_node
, null_node
, null_node
])
2890 parameters
= self
._unpack
_encoding
_and
_error
_mode
(node
.pos
, args
)
2891 if parameters
is None:
2893 encoding
, encoding_node
, error_handling
, error_handling_node
= parameters
2895 if encoding
and isinstance(string_node
, ExprNodes
.UnicodeNode
):
2896 # constant, so try to do the encoding at compile time
2898 value
= string_node
.value
.encode(encoding
, error_handling
)
2900 # well, looks like we can't
2903 value
= BytesLiteral(value
)
2904 value
.encoding
= encoding
2905 return ExprNodes
.BytesNode(
2906 string_node
.pos
, value
=value
, type=Builtin
.bytes_type
)
2908 if encoding
and error_handling
== 'strict':
2909 # try to find a specific encoder function
2910 codec_name
= self
._find
_special
_codec
_name
(encoding
)
2911 if codec_name
is not None:
2912 encode_function
= "PyUnicode_As%sString" % codec_name
2913 return self
._substitute
_method
_call
(
2914 node
, function
, encode_function
,
2915 self
.PyUnicode_AsXyzString_func_type
,
2916 'encode', is_unbound_method
, [string_node
])
2918 return self
._substitute
_method
_call
(
2919 node
, function
, "PyUnicode_AsEncodedString",
2920 self
.PyUnicode_AsEncodedString_func_type
,
2921 'encode', is_unbound_method
,
2922 [string_node
, encoding_node
, error_handling_node
])
2924 PyUnicode_DecodeXyz_func_ptr_type
= PyrexTypes
.CPtrType(PyrexTypes
.CFuncType(
2925 Builtin
.unicode_type
, [
2926 PyrexTypes
.CFuncTypeArg("string", PyrexTypes
.c_char_ptr_type
, None),
2927 PyrexTypes
.CFuncTypeArg("size", PyrexTypes
.c_py_ssize_t_type
, None),
2928 PyrexTypes
.CFuncTypeArg("errors", PyrexTypes
.c_char_ptr_type
, None),
2931 _decode_c_string_func_type
= PyrexTypes
.CFuncType(
2932 Builtin
.unicode_type
, [
2933 PyrexTypes
.CFuncTypeArg("string", PyrexTypes
.c_char_ptr_type
, None),
2934 PyrexTypes
.CFuncTypeArg("start", PyrexTypes
.c_py_ssize_t_type
, None),
2935 PyrexTypes
.CFuncTypeArg("stop", PyrexTypes
.c_py_ssize_t_type
, None),
2936 PyrexTypes
.CFuncTypeArg("encoding", PyrexTypes
.c_char_ptr_type
, None),
2937 PyrexTypes
.CFuncTypeArg("errors", PyrexTypes
.c_char_ptr_type
, None),
2938 PyrexTypes
.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type
, None),
2941 _decode_bytes_func_type
= PyrexTypes
.CFuncType(
2942 Builtin
.unicode_type
, [
2943 PyrexTypes
.CFuncTypeArg("string", PyrexTypes
.py_object_type
, None),
2944 PyrexTypes
.CFuncTypeArg("start", PyrexTypes
.c_py_ssize_t_type
, None),
2945 PyrexTypes
.CFuncTypeArg("stop", PyrexTypes
.c_py_ssize_t_type
, None),
2946 PyrexTypes
.CFuncTypeArg("encoding", PyrexTypes
.c_char_ptr_type
, None),
2947 PyrexTypes
.CFuncTypeArg("errors", PyrexTypes
.c_char_ptr_type
, None),
2948 PyrexTypes
.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type
, None),
2951 _decode_cpp_string_func_type
= None # lazy init
2953 def _handle_simple_method_bytes_decode(self
, node
, function
, args
, is_unbound_method
):
2954 """Replace char*.decode() by a direct C-API call to the
2955 corresponding codec, possibly resolving a slice on the char*.
2957 if not (1 <= len(args
) <= 3):
2958 self
._error
_wrong
_arg
_count
('bytes.decode', node
, args
, '1-3')
2961 # normalise input nodes
2962 string_node
= args
[0]
2964 if isinstance(string_node
, ExprNodes
.SliceIndexNode
):
2965 index_node
= string_node
2966 string_node
= index_node
.base
2967 start
, stop
= index_node
.start
, index_node
.stop
2968 if not start
or start
.constant_result
== 0:
2970 if isinstance(string_node
, ExprNodes
.CoerceToPyTypeNode
):
2971 string_node
= string_node
.arg
2973 string_type
= string_node
.type
2974 if string_type
in (Builtin
.bytes_type
, Builtin
.bytearray_type
):
2975 if is_unbound_method
:
2976 string_node
= string_node
.as_none_safe_node(
2977 "descriptor '%s' requires a '%s' object but received a 'NoneType'",
2978 format_args
=['decode', string_type
.name
])
2980 string_node
= string_node
.as_none_safe_node(
2981 "'NoneType' object has no attribute '%s'",
2982 error
="PyExc_AttributeError",
2983 format_args
=['decode'])
2984 elif not string_type
.is_string
and not string_type
.is_cpp_string
:
2985 # nothing to optimise here
2988 parameters
= self
._unpack
_encoding
_and
_error
_mode
(node
.pos
, args
)
2989 if parameters
is None:
2991 encoding
, encoding_node
, error_handling
, error_handling_node
= parameters
2994 start
= ExprNodes
.IntNode(node
.pos
, value
='0', constant_result
=0)
2995 elif not start
.type.is_int
:
2996 start
= start
.coerce_to(PyrexTypes
.c_py_ssize_t_type
, self
.current_env())
2997 if stop
and not stop
.type.is_int
:
2998 stop
= stop
.coerce_to(PyrexTypes
.c_py_ssize_t_type
, self
.current_env())
3000 # try to find a specific encoder function
3002 if encoding
is not None:
3003 codec_name
= self
._find
_special
_codec
_name
(encoding
)
3004 if codec_name
is not None:
3005 decode_function
= ExprNodes
.RawCNameExprNode(
3006 node
.pos
, type=self
.PyUnicode_DecodeXyz_func_ptr_type
,
3007 cname
="PyUnicode_Decode%s" % codec_name
)
3008 encoding_node
= ExprNodes
.NullNode(node
.pos
)
3010 decode_function
= ExprNodes
.NullNode(node
.pos
)
3012 # build the helper function call
3014 if string_type
.is_string
:
3017 # use strlen() to find the string length, just as CPython would
3018 if not string_node
.is_name
:
3019 string_node
= UtilNodes
.LetRefNode(string_node
) # used twice
3020 temps
.append(string_node
)
3021 stop
= ExprNodes
.PythonCapiCallNode(
3022 string_node
.pos
, "strlen", self
.Pyx_strlen_func_type
,
3025 utility_code
=UtilityCode
.load_cached("IncludeStringH", "StringTools.c"),
3026 ).coerce_to(PyrexTypes
.c_py_ssize_t_type
, self
.current_env())
3027 helper_func_type
= self
._decode
_c
_string
_func
_type
3028 utility_code_name
= 'decode_c_string'
3029 elif string_type
.is_cpp_string
:
3032 stop
= ExprNodes
.IntNode(node
.pos
, value
='PY_SSIZE_T_MAX',
3033 constant_result
=ExprNodes
.not_a_constant
)
3034 if self
._decode
_cpp
_string
_func
_type
is None:
3035 # lazy init to reuse the C++ string type
3036 self
._decode
_cpp
_string
_func
_type
= PyrexTypes
.CFuncType(
3037 Builtin
.unicode_type
, [
3038 PyrexTypes
.CFuncTypeArg("string", string_type
, None),
3039 PyrexTypes
.CFuncTypeArg("start", PyrexTypes
.c_py_ssize_t_type
, None),
3040 PyrexTypes
.CFuncTypeArg("stop", PyrexTypes
.c_py_ssize_t_type
, None),
3041 PyrexTypes
.CFuncTypeArg("encoding", PyrexTypes
.c_char_ptr_type
, None),
3042 PyrexTypes
.CFuncTypeArg("errors", PyrexTypes
.c_char_ptr_type
, None),
3043 PyrexTypes
.CFuncTypeArg("decode_func", self
.PyUnicode_DecodeXyz_func_ptr_type
, None),
3045 helper_func_type
= self
._decode
_cpp
_string
_func
_type
3046 utility_code_name
= 'decode_cpp_string'
3048 # Python bytes/bytearray object
3050 stop
= ExprNodes
.IntNode(node
.pos
, value
='PY_SSIZE_T_MAX',
3051 constant_result
=ExprNodes
.not_a_constant
)
3052 helper_func_type
= self
._decode
_bytes
_func
_type
3053 if string_type
is Builtin
.bytes_type
:
3054 utility_code_name
= 'decode_bytes'
3056 utility_code_name
= 'decode_bytearray'
3058 node
= ExprNodes
.PythonCapiCallNode(
3059 node
.pos
, '__Pyx_%s' % utility_code_name
, helper_func_type
,
3060 args
=[string_node
, start
, stop
, encoding_node
, error_handling_node
, decode_function
],
3061 is_temp
=node
.is_temp
,
3062 utility_code
=UtilityCode
.load_cached(utility_code_name
, 'StringTools.c'),
3065 for temp
in temps
[::-1]:
3066 node
= UtilNodes
.EvalWithTempExprNode(temp
, node
)
3069 _handle_simple_method_bytearray_decode
= _handle_simple_method_bytes_decode
3071 def _find_special_codec_name(self
, encoding
):
3073 requested_codec
= codecs
.getencoder(encoding
)
3076 for name
, codec
in self
._special
_codecs
:
3077 if codec
== requested_codec
:
3079 name
= ''.join([s
.capitalize()
3080 for s
in name
.split('_')])
3084 def _unpack_encoding_and_error_mode(self
, pos
, args
):
3085 null_node
= ExprNodes
.NullNode(pos
)
3088 encoding
, encoding_node
= self
._unpack
_string
_and
_cstring
_node
(args
[1])
3089 if encoding_node
is None:
3093 encoding_node
= null_node
3096 error_handling
, error_handling_node
= self
._unpack
_string
_and
_cstring
_node
(args
[2])
3097 if error_handling_node
is None:
3099 if error_handling
== 'strict':
3100 error_handling_node
= null_node
3102 error_handling
= 'strict'
3103 error_handling_node
= null_node
3105 return (encoding
, encoding_node
, error_handling
, error_handling_node
)
3107 def _unpack_string_and_cstring_node(self
, node
):
3108 if isinstance(node
, ExprNodes
.CoerceToPyTypeNode
):
3110 if isinstance(node
, ExprNodes
.UnicodeNode
):
3111 encoding
= node
.value
3112 node
= ExprNodes
.BytesNode(
3113 node
.pos
, value
=BytesLiteral(encoding
.utf8encode()),
3114 type=PyrexTypes
.c_char_ptr_type
)
3115 elif isinstance(node
, (ExprNodes
.StringNode
, ExprNodes
.BytesNode
)):
3116 encoding
= node
.value
.decode('ISO-8859-1')
3117 node
= ExprNodes
.BytesNode(
3118 node
.pos
, value
=node
.value
, type=PyrexTypes
.c_char_ptr_type
)
3119 elif node
.type is Builtin
.bytes_type
:
3121 node
= node
.coerce_to(PyrexTypes
.c_char_ptr_type
, self
.current_env())
3122 elif node
.type.is_string
:
3125 encoding
= node
= None
3126 return encoding
, node
3128 def _handle_simple_method_str_endswith(self
, node
, function
, args
, is_unbound_method
):
3129 return self
._inject
_tailmatch
(
3130 node
, function
, args
, is_unbound_method
, 'str', 'endswith',
3131 str_tailmatch_utility_code
, +1)
3133 def _handle_simple_method_str_startswith(self
, node
, function
, args
, is_unbound_method
):
3134 return self
._inject
_tailmatch
(
3135 node
, function
, args
, is_unbound_method
, 'str', 'startswith',
3136 str_tailmatch_utility_code
, -1)
3138 def _handle_simple_method_bytes_endswith(self
, node
, function
, args
, is_unbound_method
):
3139 return self
._inject
_tailmatch
(
3140 node
, function
, args
, is_unbound_method
, 'bytes', 'endswith',
3141 bytes_tailmatch_utility_code
, +1)
3143 def _handle_simple_method_bytes_startswith(self
, node
, function
, args
, is_unbound_method
):
3144 return self
._inject
_tailmatch
(
3145 node
, function
, args
, is_unbound_method
, 'bytes', 'startswith',
3146 bytes_tailmatch_utility_code
, -1)
3148 ''' # disabled for now, enable when we consider it worth it (see StringTools.c)
3149 def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
3150 return self._inject_tailmatch(
3151 node, function, args, is_unbound_method, 'bytearray', 'endswith',
3152 bytes_tailmatch_utility_code, +1)
3154 def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
3155 return self._inject_tailmatch(
3156 node, function, args, is_unbound_method, 'bytearray', 'startswith',
3157 bytes_tailmatch_utility_code, -1)
3162 def _substitute_method_call(self
, node
, function
, name
, func_type
,
3163 attr_name
, is_unbound_method
, args
=(),
3164 utility_code
=None, is_temp
=None,
3165 may_return_none
=ExprNodes
.PythonCapiCallNode
.may_return_none
):
3167 if args
and not args
[0].is_literal
:
3169 if is_unbound_method
:
3170 self_arg
= self_arg
.as_none_safe_node(
3171 "descriptor '%s' requires a '%s' object but received a 'NoneType'",
3172 format_args
=[attr_name
, function
.obj
.name
])
3174 self_arg
= self_arg
.as_none_safe_node(
3175 "'NoneType' object has no attribute '%s'",
3176 error
= "PyExc_AttributeError",
3177 format_args
= [attr_name
])
3180 is_temp
= node
.is_temp
3181 return ExprNodes
.PythonCapiCallNode(
3182 node
.pos
, name
, func_type
,
3185 utility_code
= utility_code
,
3186 may_return_none
= may_return_none
,
3187 result_is_used
= node
.result_is_used
,
3190 def _inject_int_default_argument(self
, node
, args
, arg_index
, type, default_value
):
3191 assert len(args
) >= arg_index
3192 if len(args
) == arg_index
:
3193 args
.append(ExprNodes
.IntNode(node
.pos
, value
=str(default_value
),
3194 type=type, constant_result
=default_value
))
3196 args
[arg_index
] = args
[arg_index
].coerce_to(type, self
.current_env())
3198 def _inject_bint_default_argument(self
, node
, args
, arg_index
, default_value
):
3199 assert len(args
) >= arg_index
3200 if len(args
) == arg_index
:
3201 default_value
= bool(default_value
)
3202 args
.append(ExprNodes
.BoolNode(node
.pos
, value
=default_value
,
3203 constant_result
=default_value
))
3205 args
[arg_index
] = args
[arg_index
].coerce_to_boolean(self
.current_env())
3208 unicode_tailmatch_utility_code
= UtilityCode
.load_cached('unicode_tailmatch', 'StringTools.c')
3209 bytes_tailmatch_utility_code
= UtilityCode
.load_cached('bytes_tailmatch', 'StringTools.c')
3210 str_tailmatch_utility_code
= UtilityCode
.load_cached('str_tailmatch', 'StringTools.c')
3213 class ConstantFolding(Visitor
.VisitorTransform
, SkipDeclarations
):
3214 """Calculate the result of constant expressions to store it in
3215 ``expr_node.constant_result``, and replace trivial cases by their
3220 - We calculate float constants to make them available to the
3221 compiler, but we do not aggregate them into a single literal
3222 node to prevent any loss of precision.
3224 - We recursively calculate constants from non-literal nodes to
3225 make them available to the compiler, but we only aggregate
3226 literal nodes at each step. Non-literal nodes are never merged
3230 def __init__(self
, reevaluate
=False):
3232 The reevaluate argument specifies whether constant values that were
3233 previously computed should be recomputed.
3235 super(ConstantFolding
, self
).__init
__()
3236 self
.reevaluate
= reevaluate
3238 def _calculate_const(self
, node
):
3239 if (not self
.reevaluate
and
3240 node
.constant_result
is not ExprNodes
.constant_value_not_set
):
3243 # make sure we always set the value
3244 not_a_constant
= ExprNodes
.not_a_constant
3245 node
.constant_result
= not_a_constant
3247 # check if all children are constant
3248 children
= self
.visitchildren(node
)
3249 for child_result
in children
.values():
3250 if type(child_result
) is list:
3251 for child
in child_result
:
3252 if getattr(child
, 'constant_result', not_a_constant
) is not_a_constant
:
3254 elif getattr(child_result
, 'constant_result', not_a_constant
) is not_a_constant
:
3257 # now try to calculate the real constant value
3259 node
.calculate_constant_result()
3260 # if node.constant_result is not ExprNodes.not_a_constant:
3261 # print node.__class__.__name__, node.constant_result
3262 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3263 # ignore all 'normal' errors here => no constant result
3266 # this looks like a real error
3267 import traceback
, sys
3268 traceback
.print_exc(file=sys
.stdout
)
3270 NODE_TYPE_ORDER
= [ExprNodes
.BoolNode
, ExprNodes
.CharNode
,
3271 ExprNodes
.IntNode
, ExprNodes
.FloatNode
]
3273 def _widest_node_class(self
, *nodes
):
3275 return self
.NODE_TYPE_ORDER
[
3276 max(map(self
.NODE_TYPE_ORDER
.index
, map(type, nodes
)))]
3280 def _bool_node(self
, node
, value
):
3282 return ExprNodes
.BoolNode(node
.pos
, value
=value
, constant_result
=value
)
3284 def visit_ExprNode(self
, node
):
3285 self
._calculate
_const
(node
)
3288 def visit_UnopNode(self
, node
):
3289 self
._calculate
_const
(node
)
3290 if not node
.has_constant_result():
3291 if node
.operator
== '!':
3292 return self
._handle
_NotNode
(node
)
3294 if not node
.operand
.is_literal
:
3296 if node
.operator
== '!':
3297 return self
._bool
_node
(node
, node
.constant_result
)
3298 elif isinstance(node
.operand
, ExprNodes
.BoolNode
):
3299 return ExprNodes
.IntNode(node
.pos
, value
=str(int(node
.constant_result
)),
3300 type=PyrexTypes
.c_int_type
,
3301 constant_result
=int(node
.constant_result
))
3302 elif node
.operator
== '+':
3303 return self
._handle
_UnaryPlusNode
(node
)
3304 elif node
.operator
== '-':
3305 return self
._handle
_UnaryMinusNode
(node
)
3308 _negate_operator
= {
3315 def _handle_NotNode(self
, node
):
3316 operand
= node
.operand
3317 if isinstance(operand
, ExprNodes
.PrimaryCmpNode
):
3318 operator
= self
._negate
_operator
(operand
.operator
)
3320 node
= copy
.copy(operand
)
3321 node
.operator
= operator
3322 node
= self
.visit_PrimaryCmpNode(node
)
3325 def _handle_UnaryMinusNode(self
, node
):
3327 if value
.startswith('-'):
3333 node_type
= node
.operand
.type
3334 if isinstance(node
.operand
, ExprNodes
.FloatNode
):
3335 # this is a safe operation
3336 return ExprNodes
.FloatNode(node
.pos
, value
=_negate(node
.operand
.value
),
3338 constant_result
=node
.constant_result
)
3339 if node_type
.is_int
and node_type
.signed
or \
3340 isinstance(node
.operand
, ExprNodes
.IntNode
) and node_type
.is_pyobject
:
3341 return ExprNodes
.IntNode(node
.pos
, value
=_negate(node
.operand
.value
),
3343 longness
=node
.operand
.longness
,
3344 constant_result
=node
.constant_result
)
3347 def _handle_UnaryPlusNode(self
, node
):
3348 if (node
.operand
.has_constant_result() and
3349 node
.constant_result
== node
.operand
.constant_result
):
3353 def visit_BoolBinopNode(self
, node
):
3354 self
._calculate
_const
(node
)
3355 if not node
.operand1
.has_constant_result():
3357 if node
.operand1
.constant_result
:
3358 if node
.operator
== 'and':
3359 return node
.operand2
3361 return node
.operand1
3363 if node
.operator
== 'and':
3364 return node
.operand1
3366 return node
.operand2
3368 def visit_BinopNode(self
, node
):
3369 self
._calculate
_const
(node
)
3370 if node
.constant_result
is ExprNodes
.not_a_constant
:
3372 if isinstance(node
.constant_result
, float):
3374 operand1
, operand2
= node
.operand1
, node
.operand2
3375 if not operand1
.is_literal
or not operand2
.is_literal
:
3378 # now inject a new constant node with the calculated value
3380 type1
, type2
= operand1
.type, operand2
.type
3381 if type1
is None or type2
is None:
3383 except AttributeError:
3386 if type1
.is_numeric
and type2
.is_numeric
:
3387 widest_type
= PyrexTypes
.widest_numeric_type(type1
, type2
)
3389 widest_type
= PyrexTypes
.py_object_type
3391 target_class
= self
._widest
_node
_class
(operand1
, operand2
)
3392 if target_class
is None:
3394 elif target_class
is ExprNodes
.BoolNode
and node
.operator
in '+-//<<%**>>':
3395 # C arithmetic results in at least an int type
3396 target_class
= ExprNodes
.IntNode
3397 elif target_class
is ExprNodes
.CharNode
and node
.operator
in '+-//<<%**>>&|^':
3398 # C arithmetic results in at least an int type
3399 target_class
= ExprNodes
.IntNode
3401 if target_class
is ExprNodes
.IntNode
:
3402 unsigned
= getattr(operand1
, 'unsigned', '') and \
3403 getattr(operand2
, 'unsigned', '')
3404 longness
= "LL"[:max(len(getattr(operand1
, 'longness', '')),
3405 len(getattr(operand2
, 'longness', '')))]
3406 new_node
= ExprNodes
.IntNode(pos
=node
.pos
,
3407 unsigned
=unsigned
, longness
=longness
,
3408 value
=str(int(node
.constant_result
)),
3409 constant_result
=int(node
.constant_result
))
3410 # IntNode is smart about the type it chooses, so we just
3411 # make sure we were not smarter this time
3412 if widest_type
.is_pyobject
or new_node
.type.is_pyobject
:
3413 new_node
.type = PyrexTypes
.py_object_type
3415 new_node
.type = PyrexTypes
.widest_numeric_type(widest_type
, new_node
.type)
3417 if target_class
is ExprNodes
.BoolNode
:
3418 node_value
= node
.constant_result
3420 node_value
= str(node
.constant_result
)
3421 new_node
= target_class(pos
=node
.pos
, type = widest_type
,
3423 constant_result
= node
.constant_result
)
3426 def visit_MulNode(self
, node
):
3427 self
._calculate
_const
(node
)
3428 if node
.operand1
.is_sequence_constructor
:
3429 return self
._calculate
_constant
_seq
(node
, node
.operand1
, node
.operand2
)
3430 if isinstance(node
.operand1
, ExprNodes
.IntNode
) and \
3431 node
.operand2
.is_sequence_constructor
:
3432 return self
._calculate
_constant
_seq
(node
, node
.operand2
, node
.operand1
)
3433 return self
.visit_BinopNode(node
)
3435 def _calculate_constant_seq(self
, node
, sequence_node
, factor
):
3436 if factor
.constant_result
!= 1 and sequence_node
.args
:
3437 if isinstance(factor
.constant_result
, (int, long)) and factor
.constant_result
<= 0:
3438 del sequence_node
.args
[:]
3439 sequence_node
.mult_factor
= None
3440 elif sequence_node
.mult_factor
is not None:
3441 if (isinstance(factor
.constant_result
, (int, long)) and
3442 isinstance(sequence_node
.mult_factor
.constant_result
, (int, long))):
3443 value
= sequence_node
.mult_factor
.constant_result
* factor
.constant_result
3444 sequence_node
.mult_factor
= ExprNodes
.IntNode(
3445 sequence_node
.mult_factor
.pos
,
3446 value
=str(value
), constant_result
=value
)
3448 # don't know if we can combine the factors, so don't
3449 return self
.visit_BinopNode(node
)
3451 sequence_node
.mult_factor
= factor
3452 return sequence_node
3454 def visit_PrimaryCmpNode(self
, node
):
3455 # calculate constant partial results in the comparison cascade
3456 self
.visitchildren(node
, ['operand1'])
3457 left_node
= node
.operand1
3459 while cmp_node
is not None:
3460 self
.visitchildren(cmp_node
, ['operand2'])
3461 right_node
= cmp_node
.operand2
3462 cmp_node
.constant_result
= not_a_constant
3463 if left_node
.has_constant_result() and right_node
.has_constant_result():
3465 cmp_node
.calculate_cascaded_constant_result(left_node
.constant_result
)
3466 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3467 pass # ignore all 'normal' errors here => no constant result
3468 left_node
= right_node
3469 cmp_node
= cmp_node
.cascade
3471 if not node
.cascade
:
3472 if node
.has_constant_result():
3473 return self
._bool
_node
(node
, node
.constant_result
)
3476 # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
3477 cascades
= [[node
.operand1
]]
3478 final_false_result
= []
3480 def split_cascades(cmp_node
):
3481 if cmp_node
.has_constant_result():
3482 if not cmp_node
.constant_result
:
3483 # False => short-circuit
3484 final_false_result
.append(self
._bool
_node
(cmp_node
, False))
3487 # True => discard and start new cascade
3488 cascades
.append([cmp_node
.operand2
])
3490 # not constant => append to current cascade
3491 cascades
[-1].append(cmp_node
)
3492 if cmp_node
.cascade
:
3493 split_cascades(cmp_node
.cascade
)
3495 split_cascades(node
)
3498 for cascade
in cascades
:
3499 if len(cascade
) < 2:
3501 cmp_node
= cascade
[1]
3502 pcmp_node
= ExprNodes
.PrimaryCmpNode(
3504 operand1
=cascade
[0],
3505 operator
=cmp_node
.operator
,
3506 operand2
=cmp_node
.operand2
,
3507 constant_result
=not_a_constant
)
3508 cmp_nodes
.append(pcmp_node
)
3510 last_cmp_node
= pcmp_node
3511 for cmp_node
in cascade
[2:]:
3512 last_cmp_node
.cascade
= cmp_node
3513 last_cmp_node
= cmp_node
3514 last_cmp_node
.cascade
= None
3516 if final_false_result
:
3517 # last cascade was constant False
3518 cmp_nodes
.append(final_false_result
[0])
3520 # only constants, but no False result
3521 return self
._bool
_node
(node
, True)
3523 if len(cmp_nodes
) == 1:
3524 if node
.has_constant_result():
3525 return self
._bool
_node
(node
, node
.constant_result
)
3527 for cmp_node
in cmp_nodes
[1:]:
3528 node
= ExprNodes
.BoolBinopNode(
3533 constant_result
=not_a_constant
)
3536 def visit_CondExprNode(self
, node
):
3537 self
._calculate
_const
(node
)
3538 if not node
.test
.has_constant_result():
3540 if node
.test
.constant_result
:
3541 return node
.true_val
3543 return node
.false_val
3545 def visit_IfStatNode(self
, node
):
3546 self
.visitchildren(node
)
3547 # eliminate dead code based on constant condition results
3549 for if_clause
in node
.if_clauses
:
3550 condition
= if_clause
.condition
3551 if condition
.has_constant_result():
3552 if condition
.constant_result
:
3553 # always true => subsequent clauses can safely be dropped
3554 node
.else_clause
= if_clause
.body
3556 # else: false => drop clause
3558 # unknown result => normal runtime evaluation
3559 if_clauses
.append(if_clause
)
3561 node
.if_clauses
= if_clauses
3563 elif node
.else_clause
:
3564 return node
.else_clause
3566 return Nodes
.StatListNode(node
.pos
, stats
=[])
3568 def visit_SliceIndexNode(self
, node
):
3569 self
._calculate
_const
(node
)
3570 # normalise start/stop values
3571 if node
.start
is None or node
.start
.constant_result
is None:
3572 start
= node
.start
= None
3574 start
= node
.start
.constant_result
3575 if node
.stop
is None or node
.stop
.constant_result
is None:
3576 stop
= node
.stop
= None
3578 stop
= node
.stop
.constant_result
3579 # cut down sliced constant sequences
3580 if node
.constant_result
is not not_a_constant
:
3582 if base
.is_sequence_constructor
and base
.mult_factor
is None:
3583 base
.args
= base
.args
[start
:stop
]
3585 elif base
.is_string_literal
:
3586 base
= base
.as_sliced_node(start
, stop
)
3587 if base
is not None:
3591 def visit_ComprehensionNode(self
, node
):
3592 self
.visitchildren(node
)
3593 if isinstance(node
.loop
, Nodes
.StatListNode
) and not node
.loop
.stats
:
3594 # loop was pruned already => transform into literal
3595 if node
.type is Builtin
.list_type
:
3596 return ExprNodes
.ListNode(
3597 node
.pos
, args
=[], constant_result
=[])
3598 elif node
.type is Builtin
.set_type
:
3599 return ExprNodes
.SetNode(
3600 node
.pos
, args
=[], constant_result
=set())
3601 elif node
.type is Builtin
.dict_type
:
3602 return ExprNodes
.DictNode(
3603 node
.pos
, key_value_pairs
=[], constant_result
={})
3606 def visit_ForInStatNode(self
, node
):
3607 self
.visitchildren(node
)
3608 sequence
= node
.iterator
.sequence
3609 if isinstance(sequence
, ExprNodes
.SequenceNode
):
3610 if not sequence
.args
:
3611 if node
.else_clause
:
3612 return node
.else_clause
3614 # don't break list comprehensions
3615 return Nodes
.StatListNode(node
.pos
, stats
=[])
3616 # iterating over a list literal? => tuples are more efficient
3617 if isinstance(sequence
, ExprNodes
.ListNode
):
3618 node
.iterator
.sequence
= sequence
.as_tuple()
3621 def visit_WhileStatNode(self
, node
):
3622 self
.visitchildren(node
)
3623 if node
.condition
and node
.condition
.has_constant_result():
3624 if node
.condition
.constant_result
:
3625 node
.condition
= None
3626 node
.else_clause
= None
3628 return node
.else_clause
3631 def visit_ExprStatNode(self
, node
):
3632 self
.visitchildren(node
)
3633 if not isinstance(node
.expr
, ExprNodes
.ExprNode
):
3634 # ParallelRangeTransform does this ...
3636 # drop unused constant expressions
3637 if node
.expr
.has_constant_result():
3641 # in the future, other nodes can have their own handler method here
3642 # that can replace them with a constant result node
3644 visit_Node
= Visitor
.VisitorTransform
.recurse_to_children
3647 class FinalOptimizePhase(Visitor
.CythonTransform
):
3649 This visitor handles several commuting optimizations, and is run
3650 just before the C code generation phase.
3652 The optimizations currently implemented in this class are:
3653 - eliminate None assignment and refcounting for first assignment.
3654 - isinstance -> typecheck for cdef types
3655 - eliminate checks for None and/or types that became redundant after tree changes
3657 def visit_SingleAssignmentNode(self
, node
):
3658 """Avoid redundant initialisation of local variables before their
3661 self
.visitchildren(node
)
3664 lhs
.lhs_of_first_assignment
= True
3667 def visit_SimpleCallNode(self
, node
):
3668 """Replace generic calls to isinstance(x, type) by a more efficient
3671 self
.visitchildren(node
)
3672 if node
.function
.type.is_cfunction
and isinstance(node
.function
, ExprNodes
.NameNode
):
3673 if node
.function
.name
== 'isinstance' and len(node
.args
) == 2:
3674 type_arg
= node
.args
[1]
3675 if type_arg
.type.is_builtin_type
and type_arg
.type.name
== 'type':
3676 cython_scope
= self
.context
.cython_scope
3677 node
.function
.entry
= cython_scope
.lookup('PyObject_TypeCheck')
3678 node
.function
.type = node
.function
.entry
.type
3679 PyTypeObjectPtr
= PyrexTypes
.CPtrType(cython_scope
.lookup('PyTypeObject').type)
3680 node
.args
[1] = ExprNodes
.CastNode(node
.args
[1], PyTypeObjectPtr
)
3683 def visit_PyTypeTestNode(self
, node
):
3684 """Remove tests for alternatively allowed None values from
3685 type tests when we know that the argument cannot be None
3688 self
.visitchildren(node
)
3689 if not node
.notnone
:
3690 if not node
.arg
.may_be_none():
3694 def visit_NoneCheckNode(self
, node
):
3695 """Remove None checks from expressions that definitely do not
3698 self
.visitchildren(node
)
3699 if not node
.arg
.may_be_none():
3703 class ConsolidateOverflowCheck(Visitor
.CythonTransform
):
3705 This class facilitates the sharing of overflow checking among all nodes
3706 of a nested arithmetic expression. For example, given the expression
3707 a*b + c, where a, b, and x are all possibly overflowing ints, the entire
3708 sequence will be evaluated and the overflow bit checked only at the end.
3710 overflow_bit_node
= None
3712 def visit_Node(self
, node
):
3713 if self
.overflow_bit_node
is not None:
3714 saved
= self
.overflow_bit_node
3715 self
.overflow_bit_node
= None
3716 self
.visitchildren(node
)
3717 self
.overflow_bit_node
= saved
3719 self
.visitchildren(node
)
3722 def visit_NumBinopNode(self
, node
):
3723 if node
.overflow_check
and node
.overflow_fold
:
3724 top_level_overflow
= self
.overflow_bit_node
is None
3725 if top_level_overflow
:
3726 self
.overflow_bit_node
= node
3728 node
.overflow_bit_node
= self
.overflow_bit_node
3729 node
.overflow_check
= False
3730 self
.visitchildren(node
)
3731 if top_level_overflow
:
3732 self
.overflow_bit_node
= None
3734 self
.visitchildren(node
)