Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / third_party / cython / src / Cython / Compiler / Optimize.py
blobcc5a8c94885273a38ea6e1e8ce9c1335abc8434e
1 from Cython.Compiler import TypeSlots
2 from Cython.Compiler.ExprNodes import not_a_constant
3 import cython
4 cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
5 Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
6 UtilNodes=object, Naming=object)
8 import Nodes
9 import ExprNodes
10 import PyrexTypes
11 import Visitor
12 import Builtin
13 import UtilNodes
14 import Options
15 import Naming
17 from Code import UtilityCode
18 from StringEncoding import EncodedString, BytesLiteral
19 from Errors import error
20 from ParseTreeTransforms import SkipDeclarations
22 import copy
23 import codecs
25 try:
26 from __builtin__ import reduce
27 except ImportError:
28 from functools import reduce
30 try:
31 from __builtin__ import basestring
32 except ImportError:
33 basestring = str # Python 3
35 def load_c_utility(name):
36 return UtilityCode.load_cached(name, "Optimize.c")
38 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
39 if isinstance(node, coercion_nodes):
40 return node.arg
41 return node
43 def unwrap_node(node):
44 while isinstance(node, UtilNodes.ResultRefNode):
45 node = node.expression
46 return node
48 def is_common_value(a, b):
49 a = unwrap_node(a)
50 b = unwrap_node(b)
51 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
52 return a.name == b.name
53 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
54 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
55 return False
57 def filter_none_node(node):
58 if node is not None and node.constant_result is None:
59 return None
60 return node
62 class IterationTransform(Visitor.EnvTransform):
63 """Transform some common for-in loop patterns into efficient C loops:
65 - for-in-dict loop becomes a while loop calling PyDict_Next()
66 - for-in-enumerate is replaced by an external counter variable
67 - for-in-range loop becomes a plain C for loop
68 """
69 def visit_PrimaryCmpNode(self, node):
70 if node.is_ptr_contains():
72 # for t in operand2:
73 # if operand1 == t:
74 # res = True
75 # break
76 # else:
77 # res = False
79 pos = node.pos
80 result_ref = UtilNodes.ResultRefNode(node)
81 if isinstance(node.operand2, ExprNodes.IndexNode):
82 base_type = node.operand2.base.type.base_type
83 else:
84 base_type = node.operand2.type.base_type
85 target_handle = UtilNodes.TempHandle(base_type)
86 target = target_handle.ref(pos)
87 cmp_node = ExprNodes.PrimaryCmpNode(
88 pos, operator=u'==', operand1=node.operand1, operand2=target)
89 if_body = Nodes.StatListNode(
90 pos,
91 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
92 Nodes.BreakStatNode(pos)])
93 if_node = Nodes.IfStatNode(
94 pos,
95 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
96 else_clause=None)
97 for_loop = UtilNodes.TempsBlockNode(
98 pos,
99 temps = [target_handle],
100 body = Nodes.ForInStatNode(
101 pos,
102 target=target,
103 iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
104 body=if_node,
105 else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
106 for_loop = for_loop.analyse_expressions(self.current_env())
107 for_loop = self.visit(for_loop)
108 new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
110 if node.operator == 'not_in':
111 new_node = ExprNodes.NotNode(pos, operand=new_node)
112 return new_node
114 else:
115 self.visitchildren(node)
116 return node
118 def visit_ForInStatNode(self, node):
119 self.visitchildren(node)
120 return self._optimise_for_loop(node, node.iterator.sequence)
122 def _optimise_for_loop(self, node, iterator, reversed=False):
123 if iterator.type is Builtin.dict_type:
124 # like iterating over dict.keys()
125 if reversed:
126 # CPython raises an error here: not a sequence
127 return node
128 return self._transform_dict_iteration(
129 node, dict_obj=iterator, method=None, keys=True, values=False)
131 # C array (slice) iteration?
132 if iterator.type.is_ptr or iterator.type.is_array:
133 return self._transform_carray_iteration(node, iterator, reversed=reversed)
134 if iterator.type is Builtin.bytes_type:
135 return self._transform_bytes_iteration(node, iterator, reversed=reversed)
136 if iterator.type is Builtin.unicode_type:
137 return self._transform_unicode_iteration(node, iterator, reversed=reversed)
139 # the rest is based on function calls
140 if not isinstance(iterator, ExprNodes.SimpleCallNode):
141 return node
143 if iterator.args is None:
144 arg_count = iterator.arg_tuple and len(iterator.arg_tuple.args) or 0
145 else:
146 arg_count = len(iterator.args)
147 if arg_count and iterator.self is not None:
148 arg_count -= 1
150 function = iterator.function
151 # dict iteration?
152 if function.is_attribute and not reversed and not arg_count:
153 base_obj = iterator.self or function.obj
154 method = function.attribute
155 # in Py3, items() is equivalent to Py2's iteritems()
156 is_safe_iter = self.global_scope().context.language_level >= 3
158 if not is_safe_iter and method in ('keys', 'values', 'items'):
159 # try to reduce this to the corresponding .iter*() methods
160 if isinstance(base_obj, ExprNodes.SimpleCallNode):
161 inner_function = base_obj.function
162 if (inner_function.is_name and inner_function.name == 'dict'
163 and inner_function.entry
164 and inner_function.entry.is_builtin):
165 # e.g. dict(something).items() => safe to use .iter*()
166 is_safe_iter = True
168 keys = values = False
169 if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
170 keys = True
171 elif method == 'itervalues' or (is_safe_iter and method == 'values'):
172 values = True
173 elif method == 'iteritems' or (is_safe_iter and method == 'items'):
174 keys = values = True
176 if keys or values:
177 return self._transform_dict_iteration(
178 node, base_obj, method, keys, values)
180 # enumerate/reversed ?
181 if iterator.self is None and function.is_name and \
182 function.entry and function.entry.is_builtin:
183 if function.name == 'enumerate':
184 if reversed:
185 # CPython raises an error here: not a sequence
186 return node
187 return self._transform_enumerate_iteration(node, iterator)
188 elif function.name == 'reversed':
189 if reversed:
190 # CPython raises an error here: not a sequence
191 return node
192 return self._transform_reversed_iteration(node, iterator)
194 # range() iteration?
195 if Options.convert_range and node.target.type.is_int:
196 if iterator.self is None and function.is_name and \
197 function.entry and function.entry.is_builtin and \
198 function.name in ('range', 'xrange'):
199 return self._transform_range_iteration(node, iterator, reversed=reversed)
201 return node
203 def _transform_reversed_iteration(self, node, reversed_function):
204 args = reversed_function.arg_tuple.args
205 if len(args) == 0:
206 error(reversed_function.pos,
207 "reversed() requires an iterable argument")
208 return node
209 elif len(args) > 1:
210 error(reversed_function.pos,
211 "reversed() takes exactly 1 argument")
212 return node
213 arg = args[0]
215 # reversed(list/tuple) ?
216 if arg.type in (Builtin.tuple_type, Builtin.list_type):
217 node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
218 node.iterator.reversed = True
219 return node
221 return self._optimise_for_loop(node, arg, reversed=True)
223 PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
224 PyrexTypes.c_char_ptr_type, [
225 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
228 PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
229 PyrexTypes.c_py_ssize_t_type, [
230 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
233 def _transform_bytes_iteration(self, node, slice_node, reversed=False):
234 target_type = node.target.type
235 if not target_type.is_int and target_type is not Builtin.bytes_type:
236 # bytes iteration returns bytes objects in Py2, but
237 # integers in Py3
238 return node
240 unpack_temp_node = UtilNodes.LetRefNode(
241 slice_node.as_none_safe_node("'NoneType' is not iterable"))
243 slice_base_node = ExprNodes.PythonCapiCallNode(
244 slice_node.pos, "PyBytes_AS_STRING",
245 self.PyBytes_AS_STRING_func_type,
246 args = [unpack_temp_node],
247 is_temp = 0,
249 len_node = ExprNodes.PythonCapiCallNode(
250 slice_node.pos, "PyBytes_GET_SIZE",
251 self.PyBytes_GET_SIZE_func_type,
252 args = [unpack_temp_node],
253 is_temp = 0,
256 return UtilNodes.LetNode(
257 unpack_temp_node,
258 self._transform_carray_iteration(
259 node,
260 ExprNodes.SliceIndexNode(
261 slice_node.pos,
262 base = slice_base_node,
263 start = None,
264 step = None,
265 stop = len_node,
266 type = slice_base_node.type,
267 is_temp = 1,
269 reversed = reversed))
271 PyUnicode_READ_func_type = PyrexTypes.CFuncType(
272 PyrexTypes.c_py_ucs4_type, [
273 PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
274 PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
275 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
278 init_unicode_iteration_func_type = PyrexTypes.CFuncType(
279 PyrexTypes.c_int_type, [
280 PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
281 PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
282 PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
283 PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
285 exception_value = '-1')
287 def _transform_unicode_iteration(self, node, slice_node, reversed=False):
288 if slice_node.is_literal:
289 # try to reduce to byte iteration for plain Latin-1 strings
290 try:
291 bytes_value = BytesLiteral(slice_node.value.encode('latin1'))
292 except UnicodeEncodeError:
293 pass
294 else:
295 bytes_slice = ExprNodes.SliceIndexNode(
296 slice_node.pos,
297 base=ExprNodes.BytesNode(
298 slice_node.pos, value=bytes_value,
299 constant_result=bytes_value,
300 type=PyrexTypes.c_char_ptr_type).coerce_to(
301 PyrexTypes.c_uchar_ptr_type, self.current_env()),
302 start=None,
303 stop=ExprNodes.IntNode(
304 slice_node.pos, value=str(len(bytes_value)),
305 constant_result=len(bytes_value),
306 type=PyrexTypes.c_py_ssize_t_type),
307 type=Builtin.unicode_type, # hint for Python conversion
309 return self._transform_carray_iteration(node, bytes_slice, reversed)
311 unpack_temp_node = UtilNodes.LetRefNode(
312 slice_node.as_none_safe_node("'NoneType' is not iterable"))
314 start_node = ExprNodes.IntNode(
315 node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
316 length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
317 end_node = length_temp.ref(node.pos)
318 if reversed:
319 relation1, relation2 = '>', '>='
320 start_node, end_node = end_node, start_node
321 else:
322 relation1, relation2 = '<=', '<'
324 kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
325 data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
326 counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
328 target_value = ExprNodes.PythonCapiCallNode(
329 slice_node.pos, "__Pyx_PyUnicode_READ",
330 self.PyUnicode_READ_func_type,
331 args = [kind_temp.ref(slice_node.pos),
332 data_temp.ref(slice_node.pos),
333 counter_temp.ref(node.target.pos)],
334 is_temp = False,
336 if target_value.type != node.target.type:
337 target_value = target_value.coerce_to(node.target.type,
338 self.current_env())
339 target_assign = Nodes.SingleAssignmentNode(
340 pos = node.target.pos,
341 lhs = node.target,
342 rhs = target_value)
343 body = Nodes.StatListNode(
344 node.pos,
345 stats = [target_assign, node.body])
347 loop_node = Nodes.ForFromStatNode(
348 node.pos,
349 bound1=start_node, relation1=relation1,
350 target=counter_temp.ref(node.target.pos),
351 relation2=relation2, bound2=end_node,
352 step=None, body=body,
353 else_clause=node.else_clause,
354 from_range=True)
356 setup_node = Nodes.ExprStatNode(
357 node.pos,
358 expr = ExprNodes.PythonCapiCallNode(
359 slice_node.pos, "__Pyx_init_unicode_iteration",
360 self.init_unicode_iteration_func_type,
361 args = [unpack_temp_node,
362 ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
363 type=PyrexTypes.c_py_ssize_t_ptr_type),
364 ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
365 type=PyrexTypes.c_void_ptr_ptr_type),
366 ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
367 type=PyrexTypes.c_int_ptr_type),
369 is_temp = True,
370 result_is_used = False,
371 utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
373 return UtilNodes.LetNode(
374 unpack_temp_node,
375 UtilNodes.TempsBlockNode(
376 node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
377 body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
379 def _transform_carray_iteration(self, node, slice_node, reversed=False):
380 neg_step = False
381 if isinstance(slice_node, ExprNodes.SliceIndexNode):
382 slice_base = slice_node.base
383 start = filter_none_node(slice_node.start)
384 stop = filter_none_node(slice_node.stop)
385 step = None
386 if not stop:
387 if not slice_base.type.is_pyobject:
388 error(slice_node.pos, "C array iteration requires known end index")
389 return node
391 elif isinstance(slice_node, ExprNodes.IndexNode):
392 assert isinstance(slice_node.index, ExprNodes.SliceNode)
393 slice_base = slice_node.base
394 index = slice_node.index
395 start = filter_none_node(index.start)
396 stop = filter_none_node(index.stop)
397 step = filter_none_node(index.step)
398 if step:
399 if not isinstance(step.constant_result, (int,long)) \
400 or step.constant_result == 0 \
401 or step.constant_result > 0 and not stop \
402 or step.constant_result < 0 and not start:
403 if not slice_base.type.is_pyobject:
404 error(step.pos, "C array iteration requires known step size and end index")
405 return node
406 else:
407 # step sign is handled internally by ForFromStatNode
408 step_value = step.constant_result
409 if reversed:
410 step_value = -step_value
411 neg_step = step_value < 0
412 step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
413 value=str(abs(step_value)),
414 constant_result=abs(step_value))
416 elif slice_node.type.is_array:
417 if slice_node.type.size is None:
418 error(slice_node.pos, "C array iteration requires known end index")
419 return node
420 slice_base = slice_node
421 start = None
422 stop = ExprNodes.IntNode(
423 slice_node.pos, value=str(slice_node.type.size),
424 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
425 step = None
427 else:
428 if not slice_node.type.is_pyobject:
429 error(slice_node.pos, "C array iteration requires known end index")
430 return node
432 if start:
433 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
434 if stop:
435 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
436 if stop is None:
437 if neg_step:
438 stop = ExprNodes.IntNode(
439 slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
440 else:
441 error(slice_node.pos, "C array iteration requires known step size and end index")
442 return node
444 if reversed:
445 if not start:
446 start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0,
447 type=PyrexTypes.c_py_ssize_t_type)
448 # if step was provided, it was already negated above
449 start, stop = stop, start
451 ptr_type = slice_base.type
452 if ptr_type.is_array:
453 ptr_type = ptr_type.element_ptr_type()
454 carray_ptr = slice_base.coerce_to_simple(self.current_env())
456 if start and start.constant_result != 0:
457 start_ptr_node = ExprNodes.AddNode(
458 start.pos,
459 operand1=carray_ptr,
460 operator='+',
461 operand2=start,
462 type=ptr_type)
463 else:
464 start_ptr_node = carray_ptr
466 if stop and stop.constant_result != 0:
467 stop_ptr_node = ExprNodes.AddNode(
468 stop.pos,
469 operand1=ExprNodes.CloneNode(carray_ptr),
470 operator='+',
471 operand2=stop,
472 type=ptr_type
473 ).coerce_to_simple(self.current_env())
474 else:
475 stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
477 counter = UtilNodes.TempHandle(ptr_type)
478 counter_temp = counter.ref(node.target.pos)
480 if slice_base.type.is_string and node.target.type.is_pyobject:
481 # special case: char* -> bytes/unicode
482 if slice_node.type is Builtin.unicode_type:
483 target_value = ExprNodes.CastNode(
484 ExprNodes.DereferenceNode(
485 node.target.pos, operand=counter_temp,
486 type=ptr_type.base_type),
487 PyrexTypes.c_py_ucs4_type).coerce_to(
488 node.target.type, self.current_env())
489 else:
490 # char* -> bytes coercion requires slicing, not indexing
491 target_value = ExprNodes.SliceIndexNode(
492 node.target.pos,
493 start=ExprNodes.IntNode(node.target.pos, value='0',
494 constant_result=0,
495 type=PyrexTypes.c_int_type),
496 stop=ExprNodes.IntNode(node.target.pos, value='1',
497 constant_result=1,
498 type=PyrexTypes.c_int_type),
499 base=counter_temp,
500 type=Builtin.bytes_type,
501 is_temp=1)
502 elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
503 # Allow iteration with pointer target to avoid copy.
504 target_value = counter_temp
505 else:
506 # TODO: can this safely be replaced with DereferenceNode() as above?
507 target_value = ExprNodes.IndexNode(
508 node.target.pos,
509 index=ExprNodes.IntNode(node.target.pos, value='0',
510 constant_result=0,
511 type=PyrexTypes.c_int_type),
512 base=counter_temp,
513 is_buffer_access=False,
514 type=ptr_type.base_type)
516 if target_value.type != node.target.type:
517 target_value = target_value.coerce_to(node.target.type,
518 self.current_env())
520 target_assign = Nodes.SingleAssignmentNode(
521 pos = node.target.pos,
522 lhs = node.target,
523 rhs = target_value)
525 body = Nodes.StatListNode(
526 node.pos,
527 stats = [target_assign, node.body])
529 relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
531 for_node = Nodes.ForFromStatNode(
532 node.pos,
533 bound1=start_ptr_node, relation1=relation1,
534 target=counter_temp,
535 relation2=relation2, bound2=stop_ptr_node,
536 step=step, body=body,
537 else_clause=node.else_clause,
538 from_range=True)
540 return UtilNodes.TempsBlockNode(
541 node.pos, temps=[counter],
542 body=for_node)
544 def _transform_enumerate_iteration(self, node, enumerate_function):
545 args = enumerate_function.arg_tuple.args
546 if len(args) == 0:
547 error(enumerate_function.pos,
548 "enumerate() requires an iterable argument")
549 return node
550 elif len(args) > 2:
551 error(enumerate_function.pos,
552 "enumerate() takes at most 2 arguments")
553 return node
555 if not node.target.is_sequence_constructor:
556 # leave this untouched for now
557 return node
558 targets = node.target.args
559 if len(targets) != 2:
560 # leave this untouched for now
561 return node
563 enumerate_target, iterable_target = targets
564 counter_type = enumerate_target.type
566 if not counter_type.is_pyobject and not counter_type.is_int:
567 # nothing we can do here, I guess
568 return node
570 if len(args) == 2:
571 start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
572 else:
573 start = ExprNodes.IntNode(enumerate_function.pos,
574 value='0',
575 type=counter_type,
576 constant_result=0)
577 temp = UtilNodes.LetRefNode(start)
579 inc_expression = ExprNodes.AddNode(
580 enumerate_function.pos,
581 operand1 = temp,
582 operand2 = ExprNodes.IntNode(node.pos, value='1',
583 type=counter_type,
584 constant_result=1),
585 operator = '+',
586 type = counter_type,
587 #inplace = True, # not worth using in-place operation for Py ints
588 is_temp = counter_type.is_pyobject
591 loop_body = [
592 Nodes.SingleAssignmentNode(
593 pos = enumerate_target.pos,
594 lhs = enumerate_target,
595 rhs = temp),
596 Nodes.SingleAssignmentNode(
597 pos = enumerate_target.pos,
598 lhs = temp,
599 rhs = inc_expression)
602 if isinstance(node.body, Nodes.StatListNode):
603 node.body.stats = loop_body + node.body.stats
604 else:
605 loop_body.append(node.body)
606 node.body = Nodes.StatListNode(
607 node.body.pos,
608 stats = loop_body)
610 node.target = iterable_target
611 node.item = node.item.coerce_to(iterable_target.type, self.current_env())
612 node.iterator.sequence = args[0]
614 # recurse into loop to check for further optimisations
615 return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
617 def _find_for_from_node_relations(self, neg_step_value, reversed):
618 if reversed:
619 if neg_step_value:
620 return '<', '<='
621 else:
622 return '>', '>='
623 else:
624 if neg_step_value:
625 return '>=', '>'
626 else:
627 return '<=', '<'
629 def _transform_range_iteration(self, node, range_function, reversed=False):
630 args = range_function.arg_tuple.args
631 if len(args) < 3:
632 step_pos = range_function.pos
633 step_value = 1
634 step = ExprNodes.IntNode(step_pos, value='1',
635 constant_result=1)
636 else:
637 step = args[2]
638 step_pos = step.pos
639 if not isinstance(step.constant_result, (int, long)):
640 # cannot determine step direction
641 return node
642 step_value = step.constant_result
643 if step_value == 0:
644 # will lead to an error elsewhere
645 return node
646 if reversed and step_value not in (1, -1):
647 # FIXME: currently broken - requires calculation of the correct bounds
648 return node
649 if not isinstance(step, ExprNodes.IntNode):
650 step = ExprNodes.IntNode(step_pos, value=str(step_value),
651 constant_result=step_value)
653 if len(args) == 1:
654 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
655 constant_result=0)
656 bound2 = args[0].coerce_to_integer(self.current_env())
657 else:
658 bound1 = args[0].coerce_to_integer(self.current_env())
659 bound2 = args[1].coerce_to_integer(self.current_env())
661 relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
663 if reversed:
664 bound1, bound2 = bound2, bound1
665 if step_value < 0:
666 step_value = -step_value
667 else:
668 if step_value < 0:
669 step_value = -step_value
671 step.value = str(step_value)
672 step.constant_result = step_value
673 step = step.coerce_to_integer(self.current_env())
675 if not bound2.is_literal:
676 # stop bound must be immutable => keep it in a temp var
677 bound2_is_temp = True
678 bound2 = UtilNodes.LetRefNode(bound2)
679 else:
680 bound2_is_temp = False
682 for_node = Nodes.ForFromStatNode(
683 node.pos,
684 target=node.target,
685 bound1=bound1, relation1=relation1,
686 relation2=relation2, bound2=bound2,
687 step=step, body=node.body,
688 else_clause=node.else_clause,
689 from_range=True)
691 if bound2_is_temp:
692 for_node = UtilNodes.LetNode(bound2, for_node)
694 return for_node
696 def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
697 temps = []
698 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
699 temps.append(temp)
700 dict_temp = temp.ref(dict_obj.pos)
701 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
702 temps.append(temp)
703 pos_temp = temp.ref(node.pos)
705 key_target = value_target = tuple_target = None
706 if keys and values:
707 if node.target.is_sequence_constructor:
708 if len(node.target.args) == 2:
709 key_target, value_target = node.target.args
710 else:
711 # unusual case that may or may not lead to an error
712 return node
713 else:
714 tuple_target = node.target
715 elif keys:
716 key_target = node.target
717 else:
718 value_target = node.target
720 if isinstance(node.body, Nodes.StatListNode):
721 body = node.body
722 else:
723 body = Nodes.StatListNode(pos = node.body.pos,
724 stats = [node.body])
726 # keep original length to guard against dict modification
727 dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
728 temps.append(dict_len_temp)
729 dict_len_temp_addr = ExprNodes.AmpersandNode(
730 node.pos, operand=dict_len_temp.ref(dict_obj.pos),
731 type=PyrexTypes.c_ptr_type(dict_len_temp.type))
732 temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
733 temps.append(temp)
734 is_dict_temp = temp.ref(node.pos)
735 is_dict_temp_addr = ExprNodes.AmpersandNode(
736 node.pos, operand=is_dict_temp,
737 type=PyrexTypes.c_ptr_type(temp.type))
739 iter_next_node = Nodes.DictIterationNextNode(
740 dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
741 key_target, value_target, tuple_target,
742 is_dict_temp)
743 iter_next_node = iter_next_node.analyse_expressions(self.current_env())
744 body.stats[0:0] = [iter_next_node]
746 if method:
747 method_node = ExprNodes.StringNode(
748 dict_obj.pos, is_identifier=True, value=method)
749 dict_obj = dict_obj.as_none_safe_node(
750 "'NoneType' object has no attribute '%s'",
751 error = "PyExc_AttributeError",
752 format_args = [method])
753 else:
754 method_node = ExprNodes.NullNode(dict_obj.pos)
755 dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
757 def flag_node(value):
758 value = value and 1 or 0
759 return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
761 result_code = [
762 Nodes.SingleAssignmentNode(
763 node.pos,
764 lhs = pos_temp,
765 rhs = ExprNodes.IntNode(node.pos, value='0',
766 constant_result=0)),
767 Nodes.SingleAssignmentNode(
768 dict_obj.pos,
769 lhs = dict_temp,
770 rhs = ExprNodes.PythonCapiCallNode(
771 dict_obj.pos,
772 "__Pyx_dict_iterator",
773 self.PyDict_Iterator_func_type,
774 utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
775 args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
776 method_node, dict_len_temp_addr, is_dict_temp_addr,
778 is_temp=True,
780 Nodes.WhileStatNode(
781 node.pos,
782 condition = None,
783 body = body,
784 else_clause = node.else_clause
788 return UtilNodes.TempsBlockNode(
789 node.pos, temps=temps,
790 body=Nodes.StatListNode(
791 node.pos,
792 stats = result_code
795 PyDict_Iterator_func_type = PyrexTypes.CFuncType(
796 PyrexTypes.py_object_type, [
797 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
798 PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None),
799 PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, None),
800 PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None),
801 PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None),
805 class SwitchTransform(Visitor.VisitorTransform):
807 This transformation tries to turn long if statements into C switch statements.
808 The requirement is that every clause be an (or of) var == value, where the var
809 is common among all clauses and both var and value are ints.
811 NO_MATCH = (None, None, None)
813 def extract_conditions(self, cond, allow_not_in):
814 while True:
815 if isinstance(cond, (ExprNodes.CoerceToTempNode,
816 ExprNodes.CoerceToBooleanNode)):
817 cond = cond.arg
818 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
819 # this is what we get from the FlattenInListTransform
820 cond = cond.subexpression
821 elif isinstance(cond, ExprNodes.TypecastNode):
822 cond = cond.operand
823 else:
824 break
826 if isinstance(cond, ExprNodes.PrimaryCmpNode):
827 if cond.cascade is not None:
828 return self.NO_MATCH
829 elif cond.is_c_string_contains() and \
830 isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
831 not_in = cond.operator == 'not_in'
832 if not_in and not allow_not_in:
833 return self.NO_MATCH
834 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
835 cond.operand2.contains_surrogates():
836 # dealing with surrogates leads to different
837 # behaviour on wide and narrow Unicode
838 # platforms => refuse to optimise this case
839 return self.NO_MATCH
840 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
841 elif not cond.is_python_comparison():
842 if cond.operator == '==':
843 not_in = False
844 elif allow_not_in and cond.operator == '!=':
845 not_in = True
846 else:
847 return self.NO_MATCH
848 # this looks somewhat silly, but it does the right
849 # checks for NameNode and AttributeNode
850 if is_common_value(cond.operand1, cond.operand1):
851 if cond.operand2.is_literal:
852 return not_in, cond.operand1, [cond.operand2]
853 elif getattr(cond.operand2, 'entry', None) \
854 and cond.operand2.entry.is_const:
855 return not_in, cond.operand1, [cond.operand2]
856 if is_common_value(cond.operand2, cond.operand2):
857 if cond.operand1.is_literal:
858 return not_in, cond.operand2, [cond.operand1]
859 elif getattr(cond.operand1, 'entry', None) \
860 and cond.operand1.entry.is_const:
861 return not_in, cond.operand2, [cond.operand1]
862 elif isinstance(cond, ExprNodes.BoolBinopNode):
863 if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
864 allow_not_in = (cond.operator == 'and')
865 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
866 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
867 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
868 if (not not_in_1) or allow_not_in:
869 return not_in_1, t1, c1+c2
870 return self.NO_MATCH
872 def extract_in_string_conditions(self, string_literal):
873 if isinstance(string_literal, ExprNodes.UnicodeNode):
874 charvals = list(map(ord, set(string_literal.value)))
875 charvals.sort()
876 return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
877 constant_result=charval)
878 for charval in charvals ]
879 else:
880 # this is a bit tricky as Py3's bytes type returns
881 # integers on iteration, whereas Py2 returns 1-char byte
882 # strings
883 characters = string_literal.value
884 characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
885 characters.sort()
886 return [ ExprNodes.CharNode(string_literal.pos, value=charval,
887 constant_result=charval)
888 for charval in characters ]
890 def extract_common_conditions(self, common_var, condition, allow_not_in):
891 not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
892 if var is None:
893 return self.NO_MATCH
894 elif common_var is not None and not is_common_value(var, common_var):
895 return self.NO_MATCH
896 elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
897 return self.NO_MATCH
898 return not_in, var, conditions
900 def has_duplicate_values(self, condition_values):
901 # duplicated values don't work in a switch statement
902 seen = set()
903 for value in condition_values:
904 if value.has_constant_result():
905 if value.constant_result in seen:
906 return True
907 seen.add(value.constant_result)
908 else:
909 # this isn't completely safe as we don't know the
910 # final C value, but this is about the best we can do
911 try:
912 if value.entry.cname in seen:
913 return True
914 except AttributeError:
915 return True # play safe
916 seen.add(value.entry.cname)
917 return False
919 def visit_IfStatNode(self, node):
920 common_var = None
921 cases = []
922 for if_clause in node.if_clauses:
923 _, common_var, conditions = self.extract_common_conditions(
924 common_var, if_clause.condition, False)
925 if common_var is None:
926 self.visitchildren(node)
927 return node
928 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
929 conditions = conditions,
930 body = if_clause.body))
932 condition_values = [
933 cond for case in cases for cond in case.conditions]
934 if len(condition_values) < 2:
935 self.visitchildren(node)
936 return node
937 if self.has_duplicate_values(condition_values):
938 self.visitchildren(node)
939 return node
941 common_var = unwrap_node(common_var)
942 switch_node = Nodes.SwitchStatNode(pos = node.pos,
943 test = common_var,
944 cases = cases,
945 else_clause = node.else_clause)
946 return switch_node
948 def visit_CondExprNode(self, node):
949 not_in, common_var, conditions = self.extract_common_conditions(
950 None, node.test, True)
951 if common_var is None \
952 or len(conditions) < 2 \
953 or self.has_duplicate_values(conditions):
954 self.visitchildren(node)
955 return node
956 return self.build_simple_switch_statement(
957 node, common_var, conditions, not_in,
958 node.true_val, node.false_val)
960 def visit_BoolBinopNode(self, node):
961 not_in, common_var, conditions = self.extract_common_conditions(
962 None, node, True)
963 if common_var is None \
964 or len(conditions) < 2 \
965 or self.has_duplicate_values(conditions):
966 self.visitchildren(node)
967 return node
969 return self.build_simple_switch_statement(
970 node, common_var, conditions, not_in,
971 ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
972 ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
974 def visit_PrimaryCmpNode(self, node):
975 not_in, common_var, conditions = self.extract_common_conditions(
976 None, node, True)
977 if common_var is None \
978 or len(conditions) < 2 \
979 or self.has_duplicate_values(conditions):
980 self.visitchildren(node)
981 return node
983 return self.build_simple_switch_statement(
984 node, common_var, conditions, not_in,
985 ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
986 ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
988 def build_simple_switch_statement(self, node, common_var, conditions,
989 not_in, true_val, false_val):
990 result_ref = UtilNodes.ResultRefNode(node)
991 true_body = Nodes.SingleAssignmentNode(
992 node.pos,
993 lhs = result_ref,
994 rhs = true_val,
995 first = True)
996 false_body = Nodes.SingleAssignmentNode(
997 node.pos,
998 lhs = result_ref,
999 rhs = false_val,
1000 first = True)
1002 if not_in:
1003 true_body, false_body = false_body, true_body
1005 cases = [Nodes.SwitchCaseNode(pos = node.pos,
1006 conditions = conditions,
1007 body = true_body)]
1009 common_var = unwrap_node(common_var)
1010 switch_node = Nodes.SwitchStatNode(pos = node.pos,
1011 test = common_var,
1012 cases = cases,
1013 else_clause = false_body)
1014 replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
1015 return replacement
1017 def visit_EvalWithTempExprNode(self, node):
1018 # drop unused expression temp from FlattenInListTransform
1019 orig_expr = node.subexpression
1020 temp_ref = node.lazy_temp
1021 self.visitchildren(node)
1022 if node.subexpression is not orig_expr:
1023 # node was restructured => check if temp is still used
1024 if not Visitor.tree_contains(node.subexpression, temp_ref):
1025 return node.subexpression
1026 return node
1028 visit_Node = Visitor.VisitorTransform.recurse_to_children
1031 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
1033 This transformation flattens "x in [val1, ..., valn]" into a sequential list
1034 of comparisons.
1037 def visit_PrimaryCmpNode(self, node):
1038 self.visitchildren(node)
1039 if node.cascade is not None:
1040 return node
1041 elif node.operator == 'in':
1042 conjunction = 'or'
1043 eq_or_neq = '=='
1044 elif node.operator == 'not_in':
1045 conjunction = 'and'
1046 eq_or_neq = '!='
1047 else:
1048 return node
1050 if not isinstance(node.operand2, (ExprNodes.TupleNode,
1051 ExprNodes.ListNode,
1052 ExprNodes.SetNode)):
1053 return node
1055 args = node.operand2.args
1056 if len(args) == 0:
1057 # note: lhs may have side effects
1058 return node
1060 lhs = UtilNodes.ResultRefNode(node.operand1)
1062 conds = []
1063 temps = []
1064 for arg in args:
1065 try:
1066 # Trial optimisation to avoid redundant temp
1067 # assignments. However, since is_simple() is meant to
1068 # be called after type analysis, we ignore any errors
1069 # and just play safe in that case.
1070 is_simple_arg = arg.is_simple()
1071 except Exception:
1072 is_simple_arg = False
1073 if not is_simple_arg:
1074 # must evaluate all non-simple RHS before doing the comparisons
1075 arg = UtilNodes.LetRefNode(arg)
1076 temps.append(arg)
1077 cond = ExprNodes.PrimaryCmpNode(
1078 pos = node.pos,
1079 operand1 = lhs,
1080 operator = eq_or_neq,
1081 operand2 = arg,
1082 cascade = None)
1083 conds.append(ExprNodes.TypecastNode(
1084 pos = node.pos,
1085 operand = cond,
1086 type = PyrexTypes.c_bint_type))
1087 def concat(left, right):
1088 return ExprNodes.BoolBinopNode(
1089 pos = node.pos,
1090 operator = conjunction,
1091 operand1 = left,
1092 operand2 = right)
1094 condition = reduce(concat, conds)
1095 new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
1096 for temp in temps[::-1]:
1097 new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
1098 return new_node
1100 visit_Node = Visitor.VisitorTransform.recurse_to_children
1103 class DropRefcountingTransform(Visitor.VisitorTransform):
1104 """Drop ref-counting in safe places.
1106 visit_Node = Visitor.VisitorTransform.recurse_to_children
1108 def visit_ParallelAssignmentNode(self, node):
1110 Parallel swap assignments like 'a,b = b,a' are safe.
1112 left_names, right_names = [], []
1113 left_indices, right_indices = [], []
1114 temps = []
1116 for stat in node.stats:
1117 if isinstance(stat, Nodes.SingleAssignmentNode):
1118 if not self._extract_operand(stat.lhs, left_names,
1119 left_indices, temps):
1120 return node
1121 if not self._extract_operand(stat.rhs, right_names,
1122 right_indices, temps):
1123 return node
1124 elif isinstance(stat, Nodes.CascadedAssignmentNode):
1125 # FIXME
1126 return node
1127 else:
1128 return node
1130 if left_names or right_names:
1131 # lhs/rhs names must be a non-redundant permutation
1132 lnames = [ path for path, n in left_names ]
1133 rnames = [ path for path, n in right_names ]
1134 if set(lnames) != set(rnames):
1135 return node
1136 if len(set(lnames)) != len(right_names):
1137 return node
1139 if left_indices or right_indices:
1140 # base name and index of index nodes must be a
1141 # non-redundant permutation
1142 lindices = []
1143 for lhs_node in left_indices:
1144 index_id = self._extract_index_id(lhs_node)
1145 if not index_id:
1146 return node
1147 lindices.append(index_id)
1148 rindices = []
1149 for rhs_node in right_indices:
1150 index_id = self._extract_index_id(rhs_node)
1151 if not index_id:
1152 return node
1153 rindices.append(index_id)
1155 if set(lindices) != set(rindices):
1156 return node
1157 if len(set(lindices)) != len(right_indices):
1158 return node
1160 # really supporting IndexNode requires support in
1161 # __Pyx_GetItemInt(), so let's stop short for now
1162 return node
1164 temp_args = [t.arg for t in temps]
1165 for temp in temps:
1166 temp.use_managed_ref = False
1168 for _, name_node in left_names + right_names:
1169 if name_node not in temp_args:
1170 name_node.use_managed_ref = False
1172 for index_node in left_indices + right_indices:
1173 index_node.use_managed_ref = False
1175 return node
1177 def _extract_operand(self, node, names, indices, temps):
1178 node = unwrap_node(node)
1179 if not node.type.is_pyobject:
1180 return False
1181 if isinstance(node, ExprNodes.CoerceToTempNode):
1182 temps.append(node)
1183 node = node.arg
1184 name_path = []
1185 obj_node = node
1186 while isinstance(obj_node, ExprNodes.AttributeNode):
1187 if obj_node.is_py_attr:
1188 return False
1189 name_path.append(obj_node.member)
1190 obj_node = obj_node.obj
1191 if isinstance(obj_node, ExprNodes.NameNode):
1192 name_path.append(obj_node.name)
1193 names.append( ('.'.join(name_path[::-1]), node) )
1194 elif isinstance(node, ExprNodes.IndexNode):
1195 if node.base.type != Builtin.list_type:
1196 return False
1197 if not node.index.type.is_int:
1198 return False
1199 if not isinstance(node.base, ExprNodes.NameNode):
1200 return False
1201 indices.append(node)
1202 else:
1203 return False
1204 return True
1206 def _extract_index_id(self, index_node):
1207 base = index_node.base
1208 index = index_node.index
1209 if isinstance(index, ExprNodes.NameNode):
1210 index_val = index.name
1211 elif isinstance(index, ExprNodes.ConstNode):
1212 # FIXME:
1213 return None
1214 else:
1215 return None
1216 return (base.name, index_val)
1219 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1220 """Optimize some common calls to builtin types *before* the type
1221 analysis phase and *after* the declarations analysis phase.
1223 This transform cannot make use of any argument types, but it can
1224 restructure the tree in a way that the type analysis phase can
1225 respond to.
1227 Introducing C function calls here may not be a good idea. Move
1228 them to the OptimizeBuiltinCalls transform instead, which runs
1229 after type analysis.
1231 # only intercept on call nodes
1232 visit_Node = Visitor.VisitorTransform.recurse_to_children
1234 def visit_SimpleCallNode(self, node):
1235 self.visitchildren(node)
1236 function = node.function
1237 if not self._function_is_builtin_name(function):
1238 return node
1239 return self._dispatch_to_handler(node, function, node.args)
1241 def visit_GeneralCallNode(self, node):
1242 self.visitchildren(node)
1243 function = node.function
1244 if not self._function_is_builtin_name(function):
1245 return node
1246 arg_tuple = node.positional_args
1247 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1248 return node
1249 args = arg_tuple.args
1250 return self._dispatch_to_handler(
1251 node, function, args, node.keyword_args)
1253 def _function_is_builtin_name(self, function):
1254 if not function.is_name:
1255 return False
1256 env = self.current_env()
1257 entry = env.lookup(function.name)
1258 if entry is not env.builtin_scope().lookup_here(function.name):
1259 return False
1260 # if entry is None, it's at least an undeclared name, so likely builtin
1261 return True
1263 def _dispatch_to_handler(self, node, function, args, kwargs=None):
1264 if kwargs is None:
1265 handler_name = '_handle_simple_function_%s' % function.name
1266 else:
1267 handler_name = '_handle_general_function_%s' % function.name
1268 handle_call = getattr(self, handler_name, None)
1269 if handle_call is not None:
1270 if kwargs is None:
1271 return handle_call(node, args)
1272 else:
1273 return handle_call(node, args, kwargs)
1274 return node
1276 def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1277 node.function = ExprNodes.PythonCapiFunctionNode(
1278 node.function.pos, node.function.name, cname, func_type,
1279 utility_code = utility_code)
1281 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1282 if not expected: # None or 0
1283 arg_str = ''
1284 elif isinstance(expected, basestring) or expected > 1:
1285 arg_str = '...'
1286 elif expected == 1:
1287 arg_str = 'x'
1288 else:
1289 arg_str = ''
1290 if expected is not None:
1291 expected_str = 'expected %s, ' % expected
1292 else:
1293 expected_str = ''
1294 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1295 function_name, arg_str, expected_str, len(args)))
1297 # specific handlers for simple call nodes
1299 def _handle_simple_function_float(self, node, pos_args):
1300 if not pos_args:
1301 return ExprNodes.FloatNode(node.pos, value='0.0')
1302 if len(pos_args) > 1:
1303 self._error_wrong_arg_count('float', node, pos_args, 1)
1304 arg_type = getattr(pos_args[0], 'type', None)
1305 if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
1306 return pos_args[0]
1307 return node
1309 class YieldNodeCollector(Visitor.TreeVisitor):
1310 def __init__(self):
1311 Visitor.TreeVisitor.__init__(self)
1312 self.yield_stat_nodes = {}
1313 self.yield_nodes = []
1315 visit_Node = Visitor.TreeVisitor.visitchildren
1316 # XXX: disable inlining while it's not back supported
1317 def __visit_YieldExprNode(self, node):
1318 self.yield_nodes.append(node)
1319 self.visitchildren(node)
1321 def __visit_ExprStatNode(self, node):
1322 self.visitchildren(node)
1323 if node.expr in self.yield_nodes:
1324 self.yield_stat_nodes[node.expr] = node
1326 def __visit_GeneratorExpressionNode(self, node):
1327 # enable when we support generic generator expressions
1329 # everything below this node is out of scope
1330 pass
1332 def _find_single_yield_expression(self, node):
1333 collector = self.YieldNodeCollector()
1334 collector.visitchildren(node)
1335 if len(collector.yield_nodes) != 1:
1336 return None, None
1337 yield_node = collector.yield_nodes[0]
1338 try:
1339 return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1340 except KeyError:
1341 return None, None
1343 def _handle_simple_function_all(self, node, pos_args):
1344 """Transform
1346 _result = all(x for L in LL for x in L)
1348 into
1350 for L in LL:
1351 for x in L:
1352 if not x:
1353 _result = False
1354 break
1355 else:
1356 continue
1357 break
1358 else:
1359 _result = True
1361 return self._transform_any_all(node, pos_args, False)
1363 def _handle_simple_function_any(self, node, pos_args):
1364 """Transform
1366 _result = any(x for L in LL for x in L)
1368 into
1370 for L in LL:
1371 for x in L:
1372 if x:
1373 _result = True
1374 break
1375 else:
1376 continue
1377 break
1378 else:
1379 _result = False
1381 return self._transform_any_all(node, pos_args, True)
1383 def _transform_any_all(self, node, pos_args, is_any):
1384 if len(pos_args) != 1:
1385 return node
1386 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1387 return node
1388 gen_expr_node = pos_args[0]
1389 loop_node = gen_expr_node.loop
1390 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1391 if yield_expression is None:
1392 return node
1394 if is_any:
1395 condition = yield_expression
1396 else:
1397 condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1399 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1400 test_node = Nodes.IfStatNode(
1401 yield_expression.pos,
1402 else_clause = None,
1403 if_clauses = [ Nodes.IfClauseNode(
1404 yield_expression.pos,
1405 condition = condition,
1406 body = Nodes.StatListNode(
1407 node.pos,
1408 stats = [
1409 Nodes.SingleAssignmentNode(
1410 node.pos,
1411 lhs = result_ref,
1412 rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1413 constant_result = is_any)),
1414 Nodes.BreakStatNode(node.pos)
1415 ])) ]
1417 loop = loop_node
1418 while isinstance(loop.body, Nodes.LoopNode):
1419 next_loop = loop.body
1420 loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1421 loop.body,
1422 Nodes.BreakStatNode(yield_expression.pos)
1424 next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1425 loop = next_loop
1426 loop_node.else_clause = Nodes.SingleAssignmentNode(
1427 node.pos,
1428 lhs = result_ref,
1429 rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1430 constant_result = not is_any))
1432 Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1434 return ExprNodes.InlinedGeneratorExpressionNode(
1435 gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1436 expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1438 def _handle_simple_function_sorted(self, node, pos_args):
1439 """Transform sorted(genexpr) and sorted([listcomp]) into
1440 [listcomp].sort(). CPython just reads the iterable into a
1441 list and calls .sort() on it. Expanding the iterable in a
1442 listcomp is still faster and the result can be sorted in
1443 place.
1445 if len(pos_args) != 1:
1446 return node
1447 if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
1448 and pos_args[0].type is Builtin.list_type:
1449 listcomp_node = pos_args[0]
1450 loop_node = listcomp_node.loop
1451 elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1452 gen_expr_node = pos_args[0]
1453 loop_node = gen_expr_node.loop
1454 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1455 if yield_expression is None:
1456 return node
1458 append_node = ExprNodes.ComprehensionAppendNode(
1459 yield_expression.pos, expr = yield_expression)
1461 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1463 listcomp_node = ExprNodes.ComprehensionNode(
1464 gen_expr_node.pos, loop = loop_node,
1465 append = append_node, type = Builtin.list_type,
1466 expr_scope = gen_expr_node.expr_scope,
1467 has_local_scope = True)
1468 append_node.target = listcomp_node
1469 else:
1470 return node
1472 result_node = UtilNodes.ResultRefNode(
1473 pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1474 listcomp_assign_node = Nodes.SingleAssignmentNode(
1475 node.pos, lhs = result_node, rhs = listcomp_node, first = True)
1477 sort_method = ExprNodes.AttributeNode(
1478 node.pos, obj = result_node, attribute = EncodedString('sort'),
1479 # entry ? type ?
1480 needs_none_check = False)
1481 sort_node = Nodes.ExprStatNode(
1482 node.pos, expr = ExprNodes.SimpleCallNode(
1483 node.pos, function = sort_method, args = []))
1485 sort_node.analyse_declarations(self.current_env())
1487 return UtilNodes.TempResultFromStatNode(
1488 result_node,
1489 Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
1491 def _handle_simple_function_sum(self, node, pos_args):
1492 """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1494 if len(pos_args) not in (1,2):
1495 return node
1496 if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
1497 ExprNodes.ComprehensionNode)):
1498 return node
1499 gen_expr_node = pos_args[0]
1500 loop_node = gen_expr_node.loop
1502 if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1503 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1504 if yield_expression is None:
1505 return node
1506 else: # ComprehensionNode
1507 yield_stat_node = gen_expr_node.append
1508 yield_expression = yield_stat_node.expr
1509 try:
1510 if not yield_expression.is_literal or not yield_expression.type.is_int:
1511 return node
1512 except AttributeError:
1513 return node # in case we don't have a type yet
1514 # special case: old Py2 backwards compatible "sum([int_const for ...])"
1515 # can safely be unpacked into a genexpr
1517 if len(pos_args) == 1:
1518 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1519 else:
1520 start = pos_args[1]
1522 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1523 add_node = Nodes.SingleAssignmentNode(
1524 yield_expression.pos,
1525 lhs = result_ref,
1526 rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1529 Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1531 exec_code = Nodes.StatListNode(
1532 node.pos,
1533 stats = [
1534 Nodes.SingleAssignmentNode(
1535 start.pos,
1536 lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1537 rhs = start,
1538 first = True),
1539 loop_node
1542 return ExprNodes.InlinedGeneratorExpressionNode(
1543 gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1544 expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
1545 has_local_scope = gen_expr_node.has_local_scope)
1547 def _handle_simple_function_min(self, node, pos_args):
1548 return self._optimise_min_max(node, pos_args, '<')
1550 def _handle_simple_function_max(self, node, pos_args):
1551 return self._optimise_min_max(node, pos_args, '>')
1553 def _optimise_min_max(self, node, args, operator):
1554 """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1556 if len(args) <= 1:
1557 if len(args) == 1 and args[0].is_sequence_constructor:
1558 args = args[0].args
1559 else:
1560 # leave this to Python
1561 return node
1563 cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1565 last_result = args[0]
1566 for arg_node in cascaded_nodes:
1567 result_ref = UtilNodes.ResultRefNode(last_result)
1568 last_result = ExprNodes.CondExprNode(
1569 arg_node.pos,
1570 true_val = arg_node,
1571 false_val = result_ref,
1572 test = ExprNodes.PrimaryCmpNode(
1573 arg_node.pos,
1574 operand1 = arg_node,
1575 operator = operator,
1576 operand2 = result_ref,
1579 last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1581 for ref_node in cascaded_nodes[::-1]:
1582 last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1584 return last_result
1586 def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1587 if not pos_args:
1588 return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1589 # This is a bit special - for iterables (including genexps),
1590 # Python actually overallocates and resizes a newly created
1591 # tuple incrementally while reading items, which we can't
1592 # easily do without explicit node support. Instead, we read
1593 # the items into a list and then copy them into a tuple of the
1594 # final size. This takes up to twice as much memory, but will
1595 # have to do until we have real support for genexps.
1596 result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1597 if result is not node:
1598 return ExprNodes.AsTupleNode(node.pos, arg=result)
1599 return node
1601 def _handle_simple_function_frozenset(self, node, pos_args):
1602 """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
1604 if len(pos_args) != 1:
1605 return node
1606 if pos_args[0].is_sequence_constructor and not pos_args[0].args:
1607 del pos_args[0]
1608 elif isinstance(pos_args[0], ExprNodes.ListNode):
1609 pos_args[0] = pos_args[0].as_tuple()
1610 return node
1612 def _handle_simple_function_list(self, node, pos_args):
1613 if not pos_args:
1614 return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1615 return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1617 def _handle_simple_function_set(self, node, pos_args):
1618 if not pos_args:
1619 return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1620 return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
1622 def _transform_list_set_genexpr(self, node, pos_args, target_type):
1623 """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1625 if len(pos_args) > 1:
1626 return node
1627 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1628 return node
1629 gen_expr_node = pos_args[0]
1630 loop_node = gen_expr_node.loop
1632 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1633 if yield_expression is None:
1634 return node
1636 append_node = ExprNodes.ComprehensionAppendNode(
1637 yield_expression.pos,
1638 expr = yield_expression)
1640 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1642 comp = ExprNodes.ComprehensionNode(
1643 node.pos,
1644 has_local_scope = True,
1645 expr_scope = gen_expr_node.expr_scope,
1646 loop = loop_node,
1647 append = append_node,
1648 type = target_type)
1649 append_node.target = comp
1650 return comp
1652 def _handle_simple_function_dict(self, node, pos_args):
1653 """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1655 if len(pos_args) == 0:
1656 return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1657 if len(pos_args) > 1:
1658 return node
1659 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1660 return node
1661 gen_expr_node = pos_args[0]
1662 loop_node = gen_expr_node.loop
1664 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1665 if yield_expression is None:
1666 return node
1668 if not isinstance(yield_expression, ExprNodes.TupleNode):
1669 return node
1670 if len(yield_expression.args) != 2:
1671 return node
1673 append_node = ExprNodes.DictComprehensionAppendNode(
1674 yield_expression.pos,
1675 key_expr = yield_expression.args[0],
1676 value_expr = yield_expression.args[1])
1678 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1680 dictcomp = ExprNodes.ComprehensionNode(
1681 node.pos,
1682 has_local_scope = True,
1683 expr_scope = gen_expr_node.expr_scope,
1684 loop = loop_node,
1685 append = append_node,
1686 type = Builtin.dict_type)
1687 append_node.target = dictcomp
1688 return dictcomp
1690 # specific handlers for general call nodes
1692 def _handle_general_function_dict(self, node, pos_args, kwargs):
1693 """Replace dict(a=b,c=d,...) by the underlying keyword dict
1694 construction which is done anyway.
1696 if len(pos_args) > 0:
1697 return node
1698 if not isinstance(kwargs, ExprNodes.DictNode):
1699 return node
1700 return kwargs
1703 class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
1704 visit_Node = Visitor.VisitorTransform.recurse_to_children
1706 def get_constant_value_node(self, name_node):
1707 if name_node.cf_state is None:
1708 return None
1709 if name_node.cf_state.cf_is_null:
1710 return None
1711 entry = self.current_env().lookup(name_node.name)
1712 if not entry or (not entry.cf_assignments
1713 or len(entry.cf_assignments) != 1):
1714 # not just a single assignment in all closures
1715 return None
1716 return entry.cf_assignments[0].rhs
1718 def visit_SimpleCallNode(self, node):
1719 self.visitchildren(node)
1720 if not self.current_directives.get('optimize.inline_defnode_calls'):
1721 return node
1722 function_name = node.function
1723 if not function_name.is_name:
1724 return node
1725 function = self.get_constant_value_node(function_name)
1726 if not isinstance(function, ExprNodes.PyCFunctionNode):
1727 return node
1728 inlined = ExprNodes.InlinedDefNodeCallNode(
1729 node.pos, function_name=function_name,
1730 function=function, args=node.args)
1731 if inlined.can_be_inlined():
1732 return self.replace(node, inlined)
1733 return node
1736 class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
1737 """Optimize some common methods calls and instantiation patterns
1738 for builtin types *after* the type analysis phase.
1740 Running after type analysis, this transform can only perform
1741 function replacements that do not alter the function return type
1742 in a way that was not anticipated by the type analysis.
1744 ### cleanup to avoid redundant coercions to/from Python types
1746 def _visit_PyTypeTestNode(self, node):
1747 # disabled - appears to break assignments in some cases, and
1748 # also drops a None check, which might still be required
1749 """Flatten redundant type checks after tree changes.
1751 old_arg = node.arg
1752 self.visitchildren(node)
1753 if old_arg is node.arg or node.arg.type != node.type:
1754 return node
1755 return node.arg
1757 def _visit_TypecastNode(self, node):
1758 # disabled - the user may have had a reason to put a type
1759 # cast, even if it looks redundant to Cython
1761 Drop redundant type casts.
1763 self.visitchildren(node)
1764 if node.type == node.operand.type:
1765 return node.operand
1766 return node
1768 def visit_ExprStatNode(self, node):
1770 Drop useless coercions.
1772 self.visitchildren(node)
1773 if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
1774 node.expr = node.expr.arg
1775 return node
1777 def visit_CoerceToBooleanNode(self, node):
1778 """Drop redundant conversion nodes after tree changes.
1780 self.visitchildren(node)
1781 arg = node.arg
1782 if isinstance(arg, ExprNodes.PyTypeTestNode):
1783 arg = arg.arg
1784 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1785 if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1786 return arg.arg.coerce_to_boolean(self.current_env())
1787 return node
1789 def visit_CoerceFromPyTypeNode(self, node):
1790 """Drop redundant conversion nodes after tree changes.
1792 Also, optimise away calls to Python's builtin int() and
1793 float() if the result is going to be coerced back into a C
1794 type anyway.
1796 self.visitchildren(node)
1797 arg = node.arg
1798 if not arg.type.is_pyobject:
1799 # no Python conversion left at all, just do a C coercion instead
1800 if node.type == arg.type:
1801 return arg
1802 else:
1803 return arg.coerce_to(node.type, self.current_env())
1804 if isinstance(arg, ExprNodes.PyTypeTestNode):
1805 arg = arg.arg
1806 if arg.is_literal:
1807 if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
1808 node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
1809 node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
1810 return arg.coerce_to(node.type, self.current_env())
1811 elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1812 if arg.type is PyrexTypes.py_object_type:
1813 if node.type.assignable_from(arg.arg.type):
1814 # completely redundant C->Py->C coercion
1815 return arg.arg.coerce_to(node.type, self.current_env())
1816 elif isinstance(arg, ExprNodes.SimpleCallNode):
1817 if node.type.is_int or node.type.is_float:
1818 return self._optimise_numeric_cast_call(node, arg)
1819 elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1820 index_node = arg.index
1821 if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1822 index_node = index_node.arg
1823 if index_node.type.is_int:
1824 return self._optimise_int_indexing(node, arg, index_node)
1825 return node
1827 PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1828 PyrexTypes.c_char_type, [
1829 PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1830 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1831 PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1833 exception_value = "((char)-1)",
1834 exception_check = True)
1836 def _optimise_int_indexing(self, coerce_node, arg, index_node):
1837 env = self.current_env()
1838 bound_check_bool = env.directives['boundscheck'] and 1 or 0
1839 if arg.base.type is Builtin.bytes_type:
1840 if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1841 # bytes[index] -> char
1842 bound_check_node = ExprNodes.IntNode(
1843 coerce_node.pos, value=str(bound_check_bool),
1844 constant_result=bound_check_bool)
1845 node = ExprNodes.PythonCapiCallNode(
1846 coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1847 self.PyBytes_GetItemInt_func_type,
1848 args=[
1849 arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1850 index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1851 bound_check_node,
1853 is_temp=True,
1854 utility_code=UtilityCode.load_cached(
1855 'bytes_index', 'StringTools.c'))
1856 if coerce_node.type is not PyrexTypes.c_char_type:
1857 node = node.coerce_to(coerce_node.type, env)
1858 return node
1859 return coerce_node
1861 def _optimise_numeric_cast_call(self, node, arg):
1862 function = arg.function
1863 if not isinstance(function, ExprNodes.NameNode) \
1864 or not function.type.is_builtin_type \
1865 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1866 return node
1867 args = arg.arg_tuple.args
1868 if len(args) != 1:
1869 return node
1870 func_arg = args[0]
1871 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1872 func_arg = func_arg.arg
1873 elif func_arg.type.is_pyobject:
1874 # play safe: Python conversion might work on all sorts of things
1875 return node
1876 if function.name == 'int':
1877 if func_arg.type.is_int or node.type.is_int:
1878 if func_arg.type == node.type:
1879 return func_arg
1880 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1881 return ExprNodes.TypecastNode(
1882 node.pos, operand=func_arg, type=node.type)
1883 elif function.name == 'float':
1884 if func_arg.type.is_float or node.type.is_float:
1885 if func_arg.type == node.type:
1886 return func_arg
1887 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1888 return ExprNodes.TypecastNode(
1889 node.pos, operand=func_arg, type=node.type)
1890 return node
1892 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1893 if not expected: # None or 0
1894 arg_str = ''
1895 elif isinstance(expected, basestring) or expected > 1:
1896 arg_str = '...'
1897 elif expected == 1:
1898 arg_str = 'x'
1899 else:
1900 arg_str = ''
1901 if expected is not None:
1902 expected_str = 'expected %s, ' % expected
1903 else:
1904 expected_str = ''
1905 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1906 function_name, arg_str, expected_str, len(args)))
1908 ### generic fallbacks
1910 def _handle_function(self, node, function_name, function, arg_list, kwargs):
1911 return node
1913 def _handle_method(self, node, type_name, attr_name, function,
1914 arg_list, is_unbound_method, kwargs):
1916 Try to inject C-API calls for unbound method calls to builtin types.
1917 While the method declarations in Builtin.py already handle this, we
1918 can additionally resolve bound and unbound methods here that were
1919 assigned to variables ahead of time.
1921 if kwargs:
1922 return node
1923 if not function or not function.is_attribute or not function.obj.is_name:
1924 # cannot track unbound method calls over more than one indirection as
1925 # the names might have been reassigned in the meantime
1926 return node
1927 type_entry = self.current_env().lookup(type_name)
1928 if not type_entry:
1929 return node
1930 method = ExprNodes.AttributeNode(
1931 node.function.pos,
1932 obj=ExprNodes.NameNode(
1933 function.pos,
1934 name=type_name,
1935 entry=type_entry,
1936 type=type_entry.type),
1937 attribute=attr_name,
1938 is_called=True).analyse_as_unbound_cmethod_node(self.current_env())
1939 if method is None:
1940 return node
1941 args = node.args
1942 if args is None and node.arg_tuple:
1943 args = node.arg_tuple.args
1944 call_node = ExprNodes.SimpleCallNode(
1945 node.pos,
1946 function=method,
1947 args=args)
1948 if not is_unbound_method:
1949 call_node.self = function.obj
1950 call_node.analyse_c_function_call(self.current_env())
1951 call_node.analysed = True
1952 return call_node.coerce_to(node.type, self.current_env())
1954 ### builtin types
1956 PyDict_Copy_func_type = PyrexTypes.CFuncType(
1957 Builtin.dict_type, [
1958 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1961 def _handle_simple_function_dict(self, node, function, pos_args):
1962 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1964 if len(pos_args) != 1:
1965 return node
1966 arg = pos_args[0]
1967 if arg.type is Builtin.dict_type:
1968 arg = arg.as_none_safe_node("'NoneType' is not iterable")
1969 return ExprNodes.PythonCapiCallNode(
1970 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1971 args = [arg],
1972 is_temp = node.is_temp
1974 return node
1976 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1977 Builtin.tuple_type, [
1978 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1981 def _handle_simple_function_tuple(self, node, function, pos_args):
1982 """Replace tuple([...]) by a call to PyList_AsTuple.
1984 if len(pos_args) != 1:
1985 return node
1986 arg = pos_args[0]
1987 if arg.type is Builtin.tuple_type and not arg.may_be_none():
1988 return arg
1989 if arg.type is not Builtin.list_type:
1990 return node
1991 pos_args[0] = arg.as_none_safe_node(
1992 "'NoneType' object is not iterable")
1994 return ExprNodes.PythonCapiCallNode(
1995 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1996 args = pos_args,
1997 is_temp = node.is_temp
2000 PySet_New_func_type = PyrexTypes.CFuncType(
2001 Builtin.set_type, [
2002 PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
2005 def _handle_simple_function_set(self, node, function, pos_args):
2006 if len(pos_args) != 1:
2007 return node
2008 if pos_args[0].is_sequence_constructor:
2009 # We can optimise set([x,y,z]) safely into a set literal,
2010 # but only if we create all items before adding them -
2011 # adding an item may raise an exception if it is not
2012 # hashable, but creating the later items may have
2013 # side-effects.
2014 args = []
2015 temps = []
2016 for arg in pos_args[0].args:
2017 if not arg.is_simple():
2018 arg = UtilNodes.LetRefNode(arg)
2019 temps.append(arg)
2020 args.append(arg)
2021 result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
2022 for temp in temps[::-1]:
2023 result = UtilNodes.EvalWithTempExprNode(temp, result)
2024 return result
2025 else:
2026 # PySet_New(it) is better than a generic Python call to set(it)
2027 return ExprNodes.PythonCapiCallNode(
2028 node.pos, "PySet_New",
2029 self.PySet_New_func_type,
2030 args=pos_args,
2031 is_temp=node.is_temp,
2032 utility_code=UtilityCode.load_cached('pyset_compat', 'Builtins.c'),
2033 py_name="set")
2035 PyFrozenSet_New_func_type = PyrexTypes.CFuncType(
2036 Builtin.frozenset_type, [
2037 PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
2040 def _handle_simple_function_frozenset(self, node, function, pos_args):
2041 if not pos_args:
2042 pos_args = [ExprNodes.NullNode(node.pos)]
2043 elif len(pos_args) > 1:
2044 return node
2045 elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none():
2046 return pos_args[0]
2047 # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
2048 return ExprNodes.PythonCapiCallNode(
2049 node.pos, "__Pyx_PyFrozenSet_New",
2050 self.PyFrozenSet_New_func_type,
2051 args=pos_args,
2052 is_temp=node.is_temp,
2053 utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'),
2054 py_name="frozenset")
2056 PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
2057 PyrexTypes.c_double_type, [
2058 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2060 exception_value = "((double)-1)",
2061 exception_check = True)
2063 def _handle_simple_function_float(self, node, function, pos_args):
2064 """Transform float() into either a C type cast or a faster C
2065 function call.
2067 # Note: this requires the float() function to be typed as
2068 # returning a C 'double'
2069 if len(pos_args) == 0:
2070 return ExprNodes.FloatNode(
2071 node, value="0.0", constant_result=0.0
2072 ).coerce_to(Builtin.float_type, self.current_env())
2073 elif len(pos_args) != 1:
2074 self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
2075 return node
2076 func_arg = pos_args[0]
2077 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2078 func_arg = func_arg.arg
2079 if func_arg.type is PyrexTypes.c_double_type:
2080 return func_arg
2081 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
2082 return ExprNodes.TypecastNode(
2083 node.pos, operand=func_arg, type=node.type)
2084 return ExprNodes.PythonCapiCallNode(
2085 node.pos, "__Pyx_PyObject_AsDouble",
2086 self.PyObject_AsDouble_func_type,
2087 args = pos_args,
2088 is_temp = node.is_temp,
2089 utility_code = load_c_utility('pyobject_as_double'),
2090 py_name = "float")
2092 PyNumber_Int_func_type = PyrexTypes.CFuncType(
2093 PyrexTypes.py_object_type, [
2094 PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
2097 def _handle_simple_function_int(self, node, function, pos_args):
2098 """Transform int() into a faster C function call.
2100 if len(pos_args) == 0:
2101 return ExprNodes.IntNode(node, value="0", constant_result=0,
2102 type=PyrexTypes.py_object_type)
2103 elif len(pos_args) != 1:
2104 return node # int(x, base)
2105 func_arg = pos_args[0]
2106 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2107 return node # handled in visit_CoerceFromPyTypeNode()
2108 if func_arg.type.is_pyobject and node.type.is_pyobject:
2109 return ExprNodes.PythonCapiCallNode(
2110 node.pos, "PyNumber_Int", self.PyNumber_Int_func_type,
2111 args=pos_args, is_temp=True)
2112 return node
2114 def _handle_simple_function_bool(self, node, function, pos_args):
2115 """Transform bool(x) into a type coercion to a boolean.
2117 if len(pos_args) == 0:
2118 return ExprNodes.BoolNode(
2119 node.pos, value=False, constant_result=False
2120 ).coerce_to(Builtin.bool_type, self.current_env())
2121 elif len(pos_args) != 1:
2122 self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
2123 return node
2124 else:
2125 # => !!<bint>(x) to make sure it's exactly 0 or 1
2126 operand = pos_args[0].coerce_to_boolean(self.current_env())
2127 operand = ExprNodes.NotNode(node.pos, operand = operand)
2128 operand = ExprNodes.NotNode(node.pos, operand = operand)
2129 # coerce back to Python object as that's the result we are expecting
2130 return operand.coerce_to_pyobject(self.current_env())
2132 ### builtin functions
2134 Pyx_strlen_func_type = PyrexTypes.CFuncType(
2135 PyrexTypes.c_size_t_type, [
2136 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
2139 Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
2140 PyrexTypes.c_size_t_type, [
2141 PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_py_unicode_ptr_type, None)
2144 PyObject_Size_func_type = PyrexTypes.CFuncType(
2145 PyrexTypes.c_py_ssize_t_type, [
2146 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
2148 exception_value="-1")
2150 _map_to_capi_len_function = {
2151 Builtin.unicode_type : "__Pyx_PyUnicode_GET_LENGTH",
2152 Builtin.bytes_type : "PyBytes_GET_SIZE",
2153 Builtin.list_type : "PyList_GET_SIZE",
2154 Builtin.tuple_type : "PyTuple_GET_SIZE",
2155 Builtin.dict_type : "PyDict_Size",
2156 Builtin.set_type : "PySet_Size",
2157 Builtin.frozenset_type : "PySet_Size",
2158 }.get
2160 _ext_types_with_pysize = set(["cpython.array.array"])
2162 def _handle_simple_function_len(self, node, function, pos_args):
2163 """Replace len(char*) by the equivalent call to strlen(),
2164 len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
2165 len(known_builtin_type) by an equivalent C-API call.
2167 if len(pos_args) != 1:
2168 self._error_wrong_arg_count('len', node, pos_args, 1)
2169 return node
2170 arg = pos_args[0]
2171 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2172 arg = arg.arg
2173 if arg.type.is_string:
2174 new_node = ExprNodes.PythonCapiCallNode(
2175 node.pos, "strlen", self.Pyx_strlen_func_type,
2176 args = [arg],
2177 is_temp = node.is_temp,
2178 utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
2179 elif arg.type.is_pyunicode_ptr:
2180 new_node = ExprNodes.PythonCapiCallNode(
2181 node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
2182 args = [arg],
2183 is_temp = node.is_temp)
2184 elif arg.type.is_pyobject:
2185 cfunc_name = self._map_to_capi_len_function(arg.type)
2186 if cfunc_name is None:
2187 arg_type = arg.type
2188 if ((arg_type.is_extension_type or arg_type.is_builtin_type)
2189 and arg_type.entry.qualified_name in self._ext_types_with_pysize):
2190 cfunc_name = 'Py_SIZE'
2191 else:
2192 return node
2193 arg = arg.as_none_safe_node(
2194 "object of type 'NoneType' has no len()")
2195 new_node = ExprNodes.PythonCapiCallNode(
2196 node.pos, cfunc_name, self.PyObject_Size_func_type,
2197 args = [arg],
2198 is_temp = node.is_temp)
2199 elif arg.type.is_unicode_char:
2200 return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
2201 type=node.type)
2202 else:
2203 return node
2204 if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
2205 new_node = new_node.coerce_to(node.type, self.current_env())
2206 return new_node
2208 Pyx_Type_func_type = PyrexTypes.CFuncType(
2209 Builtin.type_type, [
2210 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
2213 def _handle_simple_function_type(self, node, function, pos_args):
2214 """Replace type(o) by a macro call to Py_TYPE(o).
2216 if len(pos_args) != 1:
2217 return node
2218 node = ExprNodes.PythonCapiCallNode(
2219 node.pos, "Py_TYPE", self.Pyx_Type_func_type,
2220 args = pos_args,
2221 is_temp = False)
2222 return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
2224 Py_type_check_func_type = PyrexTypes.CFuncType(
2225 PyrexTypes.c_bint_type, [
2226 PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
2229 def _handle_simple_function_isinstance(self, node, function, pos_args):
2230 """Replace isinstance() checks against builtin types by the
2231 corresponding C-API call.
2233 if len(pos_args) != 2:
2234 return node
2235 arg, types = pos_args
2236 temp = None
2237 if isinstance(types, ExprNodes.TupleNode):
2238 types = types.args
2239 if arg.is_attribute or not arg.is_simple():
2240 arg = temp = UtilNodes.ResultRefNode(arg)
2241 elif types.type is Builtin.type_type:
2242 types = [types]
2243 else:
2244 return node
2246 tests = []
2247 test_nodes = []
2248 env = self.current_env()
2249 for test_type_node in types:
2250 builtin_type = None
2251 if test_type_node.is_name:
2252 if test_type_node.entry:
2253 entry = env.lookup(test_type_node.entry.name)
2254 if entry and entry.type and entry.type.is_builtin_type:
2255 builtin_type = entry.type
2256 if builtin_type is Builtin.type_type:
2257 # all types have type "type", but there's only one 'type'
2258 if entry.name != 'type' or not (
2259 entry.scope and entry.scope.is_builtin_scope):
2260 builtin_type = None
2261 if builtin_type is not None:
2262 type_check_function = entry.type.type_check_function(exact=False)
2263 if type_check_function in tests:
2264 continue
2265 tests.append(type_check_function)
2266 type_check_args = [arg]
2267 elif test_type_node.type is Builtin.type_type:
2268 type_check_function = '__Pyx_TypeCheck'
2269 type_check_args = [arg, test_type_node]
2270 else:
2271 return node
2272 test_nodes.append(
2273 ExprNodes.PythonCapiCallNode(
2274 test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2275 args = type_check_args,
2276 is_temp = True,
2279 def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
2280 or_node = make_binop_node(node.pos, 'or', a, b)
2281 or_node.type = PyrexTypes.c_bint_type
2282 or_node.is_temp = True
2283 return or_node
2285 test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2286 if temp is not None:
2287 test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
2288 return test_node
2290 def _handle_simple_function_ord(self, node, function, pos_args):
2291 """Unpack ord(Py_UNICODE) and ord('X').
2293 if len(pos_args) != 1:
2294 return node
2295 arg = pos_args[0]
2296 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2297 if arg.arg.type.is_unicode_char:
2298 return ExprNodes.TypecastNode(
2299 arg.pos, operand=arg.arg, type=PyrexTypes.c_int_type
2300 ).coerce_to(node.type, self.current_env())
2301 elif isinstance(arg, ExprNodes.UnicodeNode):
2302 if len(arg.value) == 1:
2303 return ExprNodes.IntNode(
2304 arg.pos, type=PyrexTypes.c_int_type,
2305 value=str(ord(arg.value)),
2306 constant_result=ord(arg.value)
2307 ).coerce_to(node.type, self.current_env())
2308 elif isinstance(arg, ExprNodes.StringNode):
2309 if arg.unicode_value and len(arg.unicode_value) == 1 \
2310 and ord(arg.unicode_value) <= 255: # Py2/3 portability
2311 return ExprNodes.IntNode(
2312 arg.pos, type=PyrexTypes.c_int_type,
2313 value=str(ord(arg.unicode_value)),
2314 constant_result=ord(arg.unicode_value)
2315 ).coerce_to(node.type, self.current_env())
2316 return node
2318 ### special methods
2320 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2321 PyrexTypes.py_object_type, [
2322 PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
2323 PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None),
2326 Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
2327 PyrexTypes.py_object_type, [
2328 PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
2329 PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None),
2330 PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
2333 def _handle_any_slot__new__(self, node, function, args,
2334 is_unbound_method, kwargs=None):
2335 """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
2337 obj = function.obj
2338 if not is_unbound_method or len(args) < 1:
2339 return node
2340 type_arg = args[0]
2341 if not obj.is_name or not type_arg.is_name:
2342 # play safe
2343 return node
2344 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2345 # not a known type, play safe
2346 return node
2347 if not type_arg.type_entry or not obj.type_entry:
2348 if obj.name != type_arg.name:
2349 return node
2350 # otherwise, we know it's a type and we know it's the same
2351 # type for both - that should do
2352 elif type_arg.type_entry != obj.type_entry:
2353 # different types - may or may not lead to an error at runtime
2354 return node
2356 args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
2357 args_tuple = args_tuple.analyse_types(
2358 self.current_env(), skip_children=True)
2360 if type_arg.type_entry:
2361 ext_type = type_arg.type_entry.type
2362 if (ext_type.is_extension_type and ext_type.typeobj_cname and
2363 ext_type.scope.global_scope() == self.current_env().global_scope()):
2364 # known type in current module
2365 tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
2366 slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
2367 if slot_func_cname:
2368 cython_scope = self.context.cython_scope
2369 PyTypeObjectPtr = PyrexTypes.CPtrType(
2370 cython_scope.lookup('PyTypeObject').type)
2371 pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
2372 PyrexTypes.py_object_type, [
2373 PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, None),
2374 PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_object_type, None),
2375 PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
2378 type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
2379 if not kwargs:
2380 kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type) # hack?
2381 return ExprNodes.PythonCapiCallNode(
2382 node.pos, slot_func_cname,
2383 pyx_tp_new_kwargs_func_type,
2384 args=[type_arg, args_tuple, kwargs],
2385 is_temp=True)
2386 else:
2387 # arbitrary variable, needs a None check for safety
2388 type_arg = type_arg.as_none_safe_node(
2389 "object.__new__(X): X is not a type object (NoneType)")
2391 utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
2392 if kwargs:
2393 return ExprNodes.PythonCapiCallNode(
2394 node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
2395 args=[type_arg, args_tuple, kwargs],
2396 utility_code=utility_code,
2397 is_temp=node.is_temp
2399 else:
2400 return ExprNodes.PythonCapiCallNode(
2401 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2402 args=[type_arg, args_tuple],
2403 utility_code=utility_code,
2404 is_temp=node.is_temp
2407 ### methods of builtin types
2409 PyObject_Append_func_type = PyrexTypes.CFuncType(
2410 PyrexTypes.c_returncode_type, [
2411 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2412 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2414 exception_value="-1")
2416 def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
2417 """Optimistic optimisation as X.append() is almost always
2418 referring to a list.
2420 if len(args) != 2 or node.result_is_used:
2421 return node
2423 return ExprNodes.PythonCapiCallNode(
2424 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2425 args=args,
2426 may_return_none=False,
2427 is_temp=node.is_temp,
2428 result_is_used=False,
2429 utility_code=load_c_utility('append')
2432 PyByteArray_Append_func_type = PyrexTypes.CFuncType(
2433 PyrexTypes.c_returncode_type, [
2434 PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
2435 PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
2437 exception_value="-1")
2439 PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
2440 PyrexTypes.c_returncode_type, [
2441 PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
2442 PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
2444 exception_value="-1")
2446 def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
2447 if len(args) != 2:
2448 return node
2449 func_name = "__Pyx_PyByteArray_Append"
2450 func_type = self.PyByteArray_Append_func_type
2452 value = unwrap_coerced_node(args[1])
2453 if value.type.is_int or isinstance(value, ExprNodes.IntNode):
2454 value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
2455 utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
2456 elif value.is_string_literal:
2457 if not value.can_coerce_to_char_literal():
2458 return node
2459 value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
2460 utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
2461 elif value.type.is_pyobject:
2462 func_name = "__Pyx_PyByteArray_AppendObject"
2463 func_type = self.PyByteArray_AppendObject_func_type
2464 utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
2465 else:
2466 return node
2468 new_node = ExprNodes.PythonCapiCallNode(
2469 node.pos, func_name, func_type,
2470 args=[args[0], value],
2471 may_return_none=False,
2472 is_temp=node.is_temp,
2473 utility_code=utility_code,
2475 if node.result_is_used:
2476 new_node = new_node.coerce_to(node.type, self.current_env())
2477 return new_node
2479 PyObject_Pop_func_type = PyrexTypes.CFuncType(
2480 PyrexTypes.py_object_type, [
2481 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2484 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2485 PyrexTypes.py_object_type, [
2486 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2487 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2490 def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
2491 return self._handle_simple_method_object_pop(
2492 node, function, args, is_unbound_method, is_list=True)
2494 def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
2495 """Optimistic optimisation as X.pop([n]) is almost always
2496 referring to a list.
2498 if not args:
2499 return node
2500 args = args[:]
2501 if is_list:
2502 type_name = 'List'
2503 args[0] = args[0].as_none_safe_node(
2504 "'NoneType' object has no attribute '%s'",
2505 error="PyExc_AttributeError",
2506 format_args=['pop'])
2507 else:
2508 type_name = 'Object'
2509 if len(args) == 1:
2510 return ExprNodes.PythonCapiCallNode(
2511 node.pos, "__Pyx_Py%s_Pop" % type_name,
2512 self.PyObject_Pop_func_type,
2513 args=args,
2514 may_return_none=True,
2515 is_temp=node.is_temp,
2516 utility_code=load_c_utility('pop'),
2518 elif len(args) == 2:
2519 index = unwrap_coerced_node(args[1])
2520 if is_list or isinstance(index, ExprNodes.IntNode):
2521 index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2522 if index.type.is_int:
2523 widest = PyrexTypes.widest_numeric_type(
2524 index.type, PyrexTypes.c_py_ssize_t_type)
2525 if widest == PyrexTypes.c_py_ssize_t_type:
2526 args[1] = index
2527 return ExprNodes.PythonCapiCallNode(
2528 node.pos, "__Pyx_Py%s_PopIndex" % type_name,
2529 self.PyObject_PopIndex_func_type,
2530 args=args,
2531 may_return_none=True,
2532 is_temp=node.is_temp,
2533 utility_code=load_c_utility("pop_index"),
2536 return node
2538 single_param_func_type = PyrexTypes.CFuncType(
2539 PyrexTypes.c_returncode_type, [
2540 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2542 exception_value = "-1")
2544 def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
2545 """Call PyList_Sort() instead of the 0-argument l.sort().
2547 if len(args) != 1:
2548 return node
2549 return self._substitute_method_call(
2550 node, function, "PyList_Sort", self.single_param_func_type,
2551 'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
2553 Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2554 PyrexTypes.py_object_type, [
2555 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2556 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2557 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2560 def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
2561 """Replace dict.get() by a call to PyDict_GetItem().
2563 if len(args) == 2:
2564 args.append(ExprNodes.NoneNode(node.pos))
2565 elif len(args) != 3:
2566 self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2567 return node
2569 return self._substitute_method_call(
2570 node, function,
2571 "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2572 'get', is_unbound_method, args,
2573 may_return_none = True,
2574 utility_code = load_c_utility("dict_getitem_default"))
2576 Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
2577 PyrexTypes.py_object_type, [
2578 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2579 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2580 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2581 PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
2584 def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
2585 """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
2587 if len(args) == 2:
2588 args.append(ExprNodes.NoneNode(node.pos))
2589 elif len(args) != 3:
2590 self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
2591 return node
2592 key_type = args[1].type
2593 if key_type.is_builtin_type:
2594 is_safe_type = int(key_type.name in
2595 'str bytes unicode float int long bool')
2596 elif key_type is PyrexTypes.py_object_type:
2597 is_safe_type = -1 # don't know
2598 else:
2599 is_safe_type = 0 # definitely not
2600 args.append(ExprNodes.IntNode(
2601 node.pos, value=str(is_safe_type), constant_result=is_safe_type))
2603 return self._substitute_method_call(
2604 node, function,
2605 "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
2606 'setdefault', is_unbound_method, args,
2607 may_return_none=True,
2608 utility_code=load_c_utility('dict_setdefault'))
2611 ### unicode type methods
2613 PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2614 PyrexTypes.c_bint_type, [
2615 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2618 def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
2619 if is_unbound_method or len(args) != 1:
2620 return node
2621 ustring = args[0]
2622 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2623 not ustring.arg.type.is_unicode_char:
2624 return node
2625 uchar = ustring.arg
2626 method_name = function.attribute
2627 if method_name == 'istitle':
2628 # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2629 utility_code = UtilityCode.load_cached(
2630 "py_unicode_istitle", "StringTools.c")
2631 function_name = '__Pyx_Py_UNICODE_ISTITLE'
2632 else:
2633 utility_code = None
2634 function_name = 'Py_UNICODE_%s' % method_name.upper()
2635 func_call = self._substitute_method_call(
2636 node, function,
2637 function_name, self.PyUnicode_uchar_predicate_func_type,
2638 method_name, is_unbound_method, [uchar],
2639 utility_code = utility_code)
2640 if node.type.is_pyobject:
2641 func_call = func_call.coerce_to_pyobject(self.current_env)
2642 return func_call
2644 _handle_simple_method_unicode_isalnum = _inject_unicode_predicate
2645 _handle_simple_method_unicode_isalpha = _inject_unicode_predicate
2646 _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2647 _handle_simple_method_unicode_isdigit = _inject_unicode_predicate
2648 _handle_simple_method_unicode_islower = _inject_unicode_predicate
2649 _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2650 _handle_simple_method_unicode_isspace = _inject_unicode_predicate
2651 _handle_simple_method_unicode_istitle = _inject_unicode_predicate
2652 _handle_simple_method_unicode_isupper = _inject_unicode_predicate
2654 PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2655 PyrexTypes.c_py_ucs4_type, [
2656 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2659 def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
2660 if is_unbound_method or len(args) != 1:
2661 return node
2662 ustring = args[0]
2663 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2664 not ustring.arg.type.is_unicode_char:
2665 return node
2666 uchar = ustring.arg
2667 method_name = function.attribute
2668 function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2669 func_call = self._substitute_method_call(
2670 node, function,
2671 function_name, self.PyUnicode_uchar_conversion_func_type,
2672 method_name, is_unbound_method, [uchar])
2673 if node.type.is_pyobject:
2674 func_call = func_call.coerce_to_pyobject(self.current_env)
2675 return func_call
2677 _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2678 _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2679 _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2681 PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2682 Builtin.list_type, [
2683 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2684 PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2687 def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
2688 """Replace unicode.splitlines(...) by a direct call to the
2689 corresponding C-API function.
2691 if len(args) not in (1,2):
2692 self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2693 return node
2694 self._inject_bint_default_argument(node, args, 1, False)
2696 return self._substitute_method_call(
2697 node, function,
2698 "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2699 'splitlines', is_unbound_method, args)
2701 PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2702 Builtin.list_type, [
2703 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2704 PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2705 PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2709 def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
2710 """Replace unicode.split(...) by a direct call to the
2711 corresponding C-API function.
2713 if len(args) not in (1,2,3):
2714 self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2715 return node
2716 if len(args) < 2:
2717 args.append(ExprNodes.NullNode(node.pos))
2718 self._inject_int_default_argument(
2719 node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2721 return self._substitute_method_call(
2722 node, function,
2723 "PyUnicode_Split", self.PyUnicode_Split_func_type,
2724 'split', is_unbound_method, args)
2726 PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
2727 PyrexTypes.c_bint_type, [
2728 PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode
2729 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2730 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2731 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2732 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2734 exception_value = '-1')
2736 def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
2737 return self._inject_tailmatch(
2738 node, function, args, is_unbound_method, 'unicode', 'endswith',
2739 unicode_tailmatch_utility_code, +1)
2741 def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
2742 return self._inject_tailmatch(
2743 node, function, args, is_unbound_method, 'unicode', 'startswith',
2744 unicode_tailmatch_utility_code, -1)
2746 def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
2747 method_name, utility_code, direction):
2748 """Replace unicode.startswith(...) and unicode.endswith(...)
2749 by a direct call to the corresponding C-API function.
2751 if len(args) not in (2,3,4):
2752 self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
2753 return node
2754 self._inject_int_default_argument(
2755 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2756 self._inject_int_default_argument(
2757 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2758 args.append(ExprNodes.IntNode(
2759 node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2761 method_call = self._substitute_method_call(
2762 node, function,
2763 "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
2764 self.PyString_Tailmatch_func_type,
2765 method_name, is_unbound_method, args,
2766 utility_code = utility_code)
2767 return method_call.coerce_to(Builtin.bool_type, self.current_env())
2769 PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2770 PyrexTypes.c_py_ssize_t_type, [
2771 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2772 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2773 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2774 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2775 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2777 exception_value = '-2')
2779 def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
2780 return self._inject_unicode_find(
2781 node, function, args, is_unbound_method, 'find', +1)
2783 def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
2784 return self._inject_unicode_find(
2785 node, function, args, is_unbound_method, 'rfind', -1)
2787 def _inject_unicode_find(self, node, function, args, is_unbound_method,
2788 method_name, direction):
2789 """Replace unicode.find(...) and unicode.rfind(...) by a
2790 direct call to the corresponding C-API function.
2792 if len(args) not in (2,3,4):
2793 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2794 return node
2795 self._inject_int_default_argument(
2796 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2797 self._inject_int_default_argument(
2798 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2799 args.append(ExprNodes.IntNode(
2800 node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2802 method_call = self._substitute_method_call(
2803 node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2804 method_name, is_unbound_method, args)
2805 return method_call.coerce_to_pyobject(self.current_env())
2807 PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2808 PyrexTypes.c_py_ssize_t_type, [
2809 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2810 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2811 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2812 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2814 exception_value = '-1')
2816 def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
2817 """Replace unicode.count(...) by a direct call to the
2818 corresponding C-API function.
2820 if len(args) not in (2,3,4):
2821 self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2822 return node
2823 self._inject_int_default_argument(
2824 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2825 self._inject_int_default_argument(
2826 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2828 method_call = self._substitute_method_call(
2829 node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2830 'count', is_unbound_method, args)
2831 return method_call.coerce_to_pyobject(self.current_env())
2833 PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2834 Builtin.unicode_type, [
2835 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2836 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2837 PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2838 PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2841 def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
2842 """Replace unicode.replace(...) by a direct call to the
2843 corresponding C-API function.
2845 if len(args) not in (3,4):
2846 self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2847 return node
2848 self._inject_int_default_argument(
2849 node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2851 return self._substitute_method_call(
2852 node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2853 'replace', is_unbound_method, args)
2855 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2856 Builtin.bytes_type, [
2857 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2858 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2859 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2862 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2863 Builtin.bytes_type, [
2864 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2867 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2868 'unicode_escape', 'raw_unicode_escape']
2870 _special_codecs = [ (name, codecs.getencoder(name))
2871 for name in _special_encodings ]
2873 def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
2874 """Replace unicode.encode(...) by a direct C-API call to the
2875 corresponding codec.
2877 if len(args) < 1 or len(args) > 3:
2878 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2879 return node
2881 string_node = args[0]
2883 if len(args) == 1:
2884 null_node = ExprNodes.NullNode(node.pos)
2885 return self._substitute_method_call(
2886 node, function, "PyUnicode_AsEncodedString",
2887 self.PyUnicode_AsEncodedString_func_type,
2888 'encode', is_unbound_method, [string_node, null_node, null_node])
2890 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2891 if parameters is None:
2892 return node
2893 encoding, encoding_node, error_handling, error_handling_node = parameters
2895 if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
2896 # constant, so try to do the encoding at compile time
2897 try:
2898 value = string_node.value.encode(encoding, error_handling)
2899 except:
2900 # well, looks like we can't
2901 pass
2902 else:
2903 value = BytesLiteral(value)
2904 value.encoding = encoding
2905 return ExprNodes.BytesNode(
2906 string_node.pos, value=value, type=Builtin.bytes_type)
2908 if encoding and error_handling == 'strict':
2909 # try to find a specific encoder function
2910 codec_name = self._find_special_codec_name(encoding)
2911 if codec_name is not None:
2912 encode_function = "PyUnicode_As%sString" % codec_name
2913 return self._substitute_method_call(
2914 node, function, encode_function,
2915 self.PyUnicode_AsXyzString_func_type,
2916 'encode', is_unbound_method, [string_node])
2918 return self._substitute_method_call(
2919 node, function, "PyUnicode_AsEncodedString",
2920 self.PyUnicode_AsEncodedString_func_type,
2921 'encode', is_unbound_method,
2922 [string_node, encoding_node, error_handling_node])
2924 PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
2925 Builtin.unicode_type, [
2926 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2927 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2928 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2931 _decode_c_string_func_type = PyrexTypes.CFuncType(
2932 Builtin.unicode_type, [
2933 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2934 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2935 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
2936 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2937 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2938 PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
2941 _decode_bytes_func_type = PyrexTypes.CFuncType(
2942 Builtin.unicode_type, [
2943 PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
2944 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2945 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
2946 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2947 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2948 PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
2951 _decode_cpp_string_func_type = None # lazy init
2953 def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
2954 """Replace char*.decode() by a direct C-API call to the
2955 corresponding codec, possibly resolving a slice on the char*.
2957 if not (1 <= len(args) <= 3):
2958 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2959 return node
2961 # normalise input nodes
2962 string_node = args[0]
2963 start = stop = None
2964 if isinstance(string_node, ExprNodes.SliceIndexNode):
2965 index_node = string_node
2966 string_node = index_node.base
2967 start, stop = index_node.start, index_node.stop
2968 if not start or start.constant_result == 0:
2969 start = None
2970 if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
2971 string_node = string_node.arg
2973 string_type = string_node.type
2974 if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
2975 if is_unbound_method:
2976 string_node = string_node.as_none_safe_node(
2977 "descriptor '%s' requires a '%s' object but received a 'NoneType'",
2978 format_args=['decode', string_type.name])
2979 else:
2980 string_node = string_node.as_none_safe_node(
2981 "'NoneType' object has no attribute '%s'",
2982 error="PyExc_AttributeError",
2983 format_args=['decode'])
2984 elif not string_type.is_string and not string_type.is_cpp_string:
2985 # nothing to optimise here
2986 return node
2988 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2989 if parameters is None:
2990 return node
2991 encoding, encoding_node, error_handling, error_handling_node = parameters
2993 if not start:
2994 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
2995 elif not start.type.is_int:
2996 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2997 if stop and not stop.type.is_int:
2998 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3000 # try to find a specific encoder function
3001 codec_name = None
3002 if encoding is not None:
3003 codec_name = self._find_special_codec_name(encoding)
3004 if codec_name is not None:
3005 decode_function = ExprNodes.RawCNameExprNode(
3006 node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type,
3007 cname="PyUnicode_Decode%s" % codec_name)
3008 encoding_node = ExprNodes.NullNode(node.pos)
3009 else:
3010 decode_function = ExprNodes.NullNode(node.pos)
3012 # build the helper function call
3013 temps = []
3014 if string_type.is_string:
3015 # C string
3016 if not stop:
3017 # use strlen() to find the string length, just as CPython would
3018 if not string_node.is_name:
3019 string_node = UtilNodes.LetRefNode(string_node) # used twice
3020 temps.append(string_node)
3021 stop = ExprNodes.PythonCapiCallNode(
3022 string_node.pos, "strlen", self.Pyx_strlen_func_type,
3023 args=[string_node],
3024 is_temp=False,
3025 utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
3026 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3027 helper_func_type = self._decode_c_string_func_type
3028 utility_code_name = 'decode_c_string'
3029 elif string_type.is_cpp_string:
3030 # C++ std::string
3031 if not stop:
3032 stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
3033 constant_result=ExprNodes.not_a_constant)
3034 if self._decode_cpp_string_func_type is None:
3035 # lazy init to reuse the C++ string type
3036 self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
3037 Builtin.unicode_type, [
3038 PyrexTypes.CFuncTypeArg("string", string_type, None),
3039 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3040 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3041 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
3042 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
3043 PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
3045 helper_func_type = self._decode_cpp_string_func_type
3046 utility_code_name = 'decode_cpp_string'
3047 else:
3048 # Python bytes/bytearray object
3049 if not stop:
3050 stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
3051 constant_result=ExprNodes.not_a_constant)
3052 helper_func_type = self._decode_bytes_func_type
3053 if string_type is Builtin.bytes_type:
3054 utility_code_name = 'decode_bytes'
3055 else:
3056 utility_code_name = 'decode_bytearray'
3058 node = ExprNodes.PythonCapiCallNode(
3059 node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
3060 args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
3061 is_temp=node.is_temp,
3062 utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
3065 for temp in temps[::-1]:
3066 node = UtilNodes.EvalWithTempExprNode(temp, node)
3067 return node
3069 _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode
3071 def _find_special_codec_name(self, encoding):
3072 try:
3073 requested_codec = codecs.getencoder(encoding)
3074 except LookupError:
3075 return None
3076 for name, codec in self._special_codecs:
3077 if codec == requested_codec:
3078 if '_' in name:
3079 name = ''.join([s.capitalize()
3080 for s in name.split('_')])
3081 return name
3082 return None
3084 def _unpack_encoding_and_error_mode(self, pos, args):
3085 null_node = ExprNodes.NullNode(pos)
3087 if len(args) >= 2:
3088 encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
3089 if encoding_node is None:
3090 return None
3091 else:
3092 encoding = None
3093 encoding_node = null_node
3095 if len(args) == 3:
3096 error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
3097 if error_handling_node is None:
3098 return None
3099 if error_handling == 'strict':
3100 error_handling_node = null_node
3101 else:
3102 error_handling = 'strict'
3103 error_handling_node = null_node
3105 return (encoding, encoding_node, error_handling, error_handling_node)
3107 def _unpack_string_and_cstring_node(self, node):
3108 if isinstance(node, ExprNodes.CoerceToPyTypeNode):
3109 node = node.arg
3110 if isinstance(node, ExprNodes.UnicodeNode):
3111 encoding = node.value
3112 node = ExprNodes.BytesNode(
3113 node.pos, value=BytesLiteral(encoding.utf8encode()),
3114 type=PyrexTypes.c_char_ptr_type)
3115 elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
3116 encoding = node.value.decode('ISO-8859-1')
3117 node = ExprNodes.BytesNode(
3118 node.pos, value=node.value, type=PyrexTypes.c_char_ptr_type)
3119 elif node.type is Builtin.bytes_type:
3120 encoding = None
3121 node = node.coerce_to(PyrexTypes.c_char_ptr_type, self.current_env())
3122 elif node.type.is_string:
3123 encoding = None
3124 else:
3125 encoding = node = None
3126 return encoding, node
3128 def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method):
3129 return self._inject_tailmatch(
3130 node, function, args, is_unbound_method, 'str', 'endswith',
3131 str_tailmatch_utility_code, +1)
3133 def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method):
3134 return self._inject_tailmatch(
3135 node, function, args, is_unbound_method, 'str', 'startswith',
3136 str_tailmatch_utility_code, -1)
3138 def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
3139 return self._inject_tailmatch(
3140 node, function, args, is_unbound_method, 'bytes', 'endswith',
3141 bytes_tailmatch_utility_code, +1)
3143 def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
3144 return self._inject_tailmatch(
3145 node, function, args, is_unbound_method, 'bytes', 'startswith',
3146 bytes_tailmatch_utility_code, -1)
3148 ''' # disabled for now, enable when we consider it worth it (see StringTools.c)
3149 def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
3150 return self._inject_tailmatch(
3151 node, function, args, is_unbound_method, 'bytearray', 'endswith',
3152 bytes_tailmatch_utility_code, +1)
3154 def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
3155 return self._inject_tailmatch(
3156 node, function, args, is_unbound_method, 'bytearray', 'startswith',
3157 bytes_tailmatch_utility_code, -1)
3160 ### helpers
3162 def _substitute_method_call(self, node, function, name, func_type,
3163 attr_name, is_unbound_method, args=(),
3164 utility_code=None, is_temp=None,
3165 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
3166 args = list(args)
3167 if args and not args[0].is_literal:
3168 self_arg = args[0]
3169 if is_unbound_method:
3170 self_arg = self_arg.as_none_safe_node(
3171 "descriptor '%s' requires a '%s' object but received a 'NoneType'",
3172 format_args=[attr_name, function.obj.name])
3173 else:
3174 self_arg = self_arg.as_none_safe_node(
3175 "'NoneType' object has no attribute '%s'",
3176 error = "PyExc_AttributeError",
3177 format_args = [attr_name])
3178 args[0] = self_arg
3179 if is_temp is None:
3180 is_temp = node.is_temp
3181 return ExprNodes.PythonCapiCallNode(
3182 node.pos, name, func_type,
3183 args = args,
3184 is_temp = is_temp,
3185 utility_code = utility_code,
3186 may_return_none = may_return_none,
3187 result_is_used = node.result_is_used,
3190 def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
3191 assert len(args) >= arg_index
3192 if len(args) == arg_index:
3193 args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
3194 type=type, constant_result=default_value))
3195 else:
3196 args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
3198 def _inject_bint_default_argument(self, node, args, arg_index, default_value):
3199 assert len(args) >= arg_index
3200 if len(args) == arg_index:
3201 default_value = bool(default_value)
3202 args.append(ExprNodes.BoolNode(node.pos, value=default_value,
3203 constant_result=default_value))
3204 else:
3205 args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
3208 unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
3209 bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
3210 str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
3213 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
3214 """Calculate the result of constant expressions to store it in
3215 ``expr_node.constant_result``, and replace trivial cases by their
3216 constant result.
3218 General rules:
3220 - We calculate float constants to make them available to the
3221 compiler, but we do not aggregate them into a single literal
3222 node to prevent any loss of precision.
3224 - We recursively calculate constants from non-literal nodes to
3225 make them available to the compiler, but we only aggregate
3226 literal nodes at each step. Non-literal nodes are never merged
3227 into a single node.
3230 def __init__(self, reevaluate=False):
3232 The reevaluate argument specifies whether constant values that were
3233 previously computed should be recomputed.
3235 super(ConstantFolding, self).__init__()
3236 self.reevaluate = reevaluate
3238 def _calculate_const(self, node):
3239 if (not self.reevaluate and
3240 node.constant_result is not ExprNodes.constant_value_not_set):
3241 return
3243 # make sure we always set the value
3244 not_a_constant = ExprNodes.not_a_constant
3245 node.constant_result = not_a_constant
3247 # check if all children are constant
3248 children = self.visitchildren(node)
3249 for child_result in children.values():
3250 if type(child_result) is list:
3251 for child in child_result:
3252 if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
3253 return
3254 elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
3255 return
3257 # now try to calculate the real constant value
3258 try:
3259 node.calculate_constant_result()
3260 # if node.constant_result is not ExprNodes.not_a_constant:
3261 # print node.__class__.__name__, node.constant_result
3262 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3263 # ignore all 'normal' errors here => no constant result
3264 pass
3265 except Exception:
3266 # this looks like a real error
3267 import traceback, sys
3268 traceback.print_exc(file=sys.stdout)
3270 NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
3271 ExprNodes.IntNode, ExprNodes.FloatNode]
3273 def _widest_node_class(self, *nodes):
3274 try:
3275 return self.NODE_TYPE_ORDER[
3276 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
3277 except ValueError:
3278 return None
3280 def _bool_node(self, node, value):
3281 value = bool(value)
3282 return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)
3284 def visit_ExprNode(self, node):
3285 self._calculate_const(node)
3286 return node
3288 def visit_UnopNode(self, node):
3289 self._calculate_const(node)
3290 if not node.has_constant_result():
3291 if node.operator == '!':
3292 return self._handle_NotNode(node)
3293 return node
3294 if not node.operand.is_literal:
3295 return node
3296 if node.operator == '!':
3297 return self._bool_node(node, node.constant_result)
3298 elif isinstance(node.operand, ExprNodes.BoolNode):
3299 return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
3300 type=PyrexTypes.c_int_type,
3301 constant_result=int(node.constant_result))
3302 elif node.operator == '+':
3303 return self._handle_UnaryPlusNode(node)
3304 elif node.operator == '-':
3305 return self._handle_UnaryMinusNode(node)
3306 return node
3308 _negate_operator = {
3309 'in': 'not_in',
3310 'not_in': 'in',
3311 'is': 'is_not',
3312 'is_not': 'is'
3313 }.get
3315 def _handle_NotNode(self, node):
3316 operand = node.operand
3317 if isinstance(operand, ExprNodes.PrimaryCmpNode):
3318 operator = self._negate_operator(operand.operator)
3319 if operator:
3320 node = copy.copy(operand)
3321 node.operator = operator
3322 node = self.visit_PrimaryCmpNode(node)
3323 return node
3325 def _handle_UnaryMinusNode(self, node):
3326 def _negate(value):
3327 if value.startswith('-'):
3328 value = value[1:]
3329 else:
3330 value = '-' + value
3331 return value
3333 node_type = node.operand.type
3334 if isinstance(node.operand, ExprNodes.FloatNode):
3335 # this is a safe operation
3336 return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
3337 type=node_type,
3338 constant_result=node.constant_result)
3339 if node_type.is_int and node_type.signed or \
3340 isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
3341 return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
3342 type=node_type,
3343 longness=node.operand.longness,
3344 constant_result=node.constant_result)
3345 return node
3347 def _handle_UnaryPlusNode(self, node):
3348 if (node.operand.has_constant_result() and
3349 node.constant_result == node.operand.constant_result):
3350 return node.operand
3351 return node
3353 def visit_BoolBinopNode(self, node):
3354 self._calculate_const(node)
3355 if not node.operand1.has_constant_result():
3356 return node
3357 if node.operand1.constant_result:
3358 if node.operator == 'and':
3359 return node.operand2
3360 else:
3361 return node.operand1
3362 else:
3363 if node.operator == 'and':
3364 return node.operand1
3365 else:
3366 return node.operand2
3368 def visit_BinopNode(self, node):
3369 self._calculate_const(node)
3370 if node.constant_result is ExprNodes.not_a_constant:
3371 return node
3372 if isinstance(node.constant_result, float):
3373 return node
3374 operand1, operand2 = node.operand1, node.operand2
3375 if not operand1.is_literal or not operand2.is_literal:
3376 return node
3378 # now inject a new constant node with the calculated value
3379 try:
3380 type1, type2 = operand1.type, operand2.type
3381 if type1 is None or type2 is None:
3382 return node
3383 except AttributeError:
3384 return node
3386 if type1.is_numeric and type2.is_numeric:
3387 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3388 else:
3389 widest_type = PyrexTypes.py_object_type
3391 target_class = self._widest_node_class(operand1, operand2)
3392 if target_class is None:
3393 return node
3394 elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
3395 # C arithmetic results in at least an int type
3396 target_class = ExprNodes.IntNode
3397 elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
3398 # C arithmetic results in at least an int type
3399 target_class = ExprNodes.IntNode
3401 if target_class is ExprNodes.IntNode:
3402 unsigned = getattr(operand1, 'unsigned', '') and \
3403 getattr(operand2, 'unsigned', '')
3404 longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
3405 len(getattr(operand2, 'longness', '')))]
3406 new_node = ExprNodes.IntNode(pos=node.pos,
3407 unsigned=unsigned, longness=longness,
3408 value=str(int(node.constant_result)),
3409 constant_result=int(node.constant_result))
3410 # IntNode is smart about the type it chooses, so we just
3411 # make sure we were not smarter this time
3412 if widest_type.is_pyobject or new_node.type.is_pyobject:
3413 new_node.type = PyrexTypes.py_object_type
3414 else:
3415 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3416 else:
3417 if target_class is ExprNodes.BoolNode:
3418 node_value = node.constant_result
3419 else:
3420 node_value = str(node.constant_result)
3421 new_node = target_class(pos=node.pos, type = widest_type,
3422 value = node_value,
3423 constant_result = node.constant_result)
3424 return new_node
3426 def visit_MulNode(self, node):
3427 self._calculate_const(node)
3428 if node.operand1.is_sequence_constructor:
3429 return self._calculate_constant_seq(node, node.operand1, node.operand2)
3430 if isinstance(node.operand1, ExprNodes.IntNode) and \
3431 node.operand2.is_sequence_constructor:
3432 return self._calculate_constant_seq(node, node.operand2, node.operand1)
3433 return self.visit_BinopNode(node)
3435 def _calculate_constant_seq(self, node, sequence_node, factor):
3436 if factor.constant_result != 1 and sequence_node.args:
3437 if isinstance(factor.constant_result, (int, long)) and factor.constant_result <= 0:
3438 del sequence_node.args[:]
3439 sequence_node.mult_factor = None
3440 elif sequence_node.mult_factor is not None:
3441 if (isinstance(factor.constant_result, (int, long)) and
3442 isinstance(sequence_node.mult_factor.constant_result, (int, long))):
3443 value = sequence_node.mult_factor.constant_result * factor.constant_result
3444 sequence_node.mult_factor = ExprNodes.IntNode(
3445 sequence_node.mult_factor.pos,
3446 value=str(value), constant_result=value)
3447 else:
3448 # don't know if we can combine the factors, so don't
3449 return self.visit_BinopNode(node)
3450 else:
3451 sequence_node.mult_factor = factor
3452 return sequence_node
3454 def visit_PrimaryCmpNode(self, node):
3455 # calculate constant partial results in the comparison cascade
3456 self.visitchildren(node, ['operand1'])
3457 left_node = node.operand1
3458 cmp_node = node
3459 while cmp_node is not None:
3460 self.visitchildren(cmp_node, ['operand2'])
3461 right_node = cmp_node.operand2
3462 cmp_node.constant_result = not_a_constant
3463 if left_node.has_constant_result() and right_node.has_constant_result():
3464 try:
3465 cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
3466 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3467 pass # ignore all 'normal' errors here => no constant result
3468 left_node = right_node
3469 cmp_node = cmp_node.cascade
3471 if not node.cascade:
3472 if node.has_constant_result():
3473 return self._bool_node(node, node.constant_result)
3474 return node
3476 # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
3477 cascades = [[node.operand1]]
3478 final_false_result = []
3480 def split_cascades(cmp_node):
3481 if cmp_node.has_constant_result():
3482 if not cmp_node.constant_result:
3483 # False => short-circuit
3484 final_false_result.append(self._bool_node(cmp_node, False))
3485 return
3486 else:
3487 # True => discard and start new cascade
3488 cascades.append([cmp_node.operand2])
3489 else:
3490 # not constant => append to current cascade
3491 cascades[-1].append(cmp_node)
3492 if cmp_node.cascade:
3493 split_cascades(cmp_node.cascade)
3495 split_cascades(node)
3497 cmp_nodes = []
3498 for cascade in cascades:
3499 if len(cascade) < 2:
3500 continue
3501 cmp_node = cascade[1]
3502 pcmp_node = ExprNodes.PrimaryCmpNode(
3503 cmp_node.pos,
3504 operand1=cascade[0],
3505 operator=cmp_node.operator,
3506 operand2=cmp_node.operand2,
3507 constant_result=not_a_constant)
3508 cmp_nodes.append(pcmp_node)
3510 last_cmp_node = pcmp_node
3511 for cmp_node in cascade[2:]:
3512 last_cmp_node.cascade = cmp_node
3513 last_cmp_node = cmp_node
3514 last_cmp_node.cascade = None
3516 if final_false_result:
3517 # last cascade was constant False
3518 cmp_nodes.append(final_false_result[0])
3519 elif not cmp_nodes:
3520 # only constants, but no False result
3521 return self._bool_node(node, True)
3522 node = cmp_nodes[0]
3523 if len(cmp_nodes) == 1:
3524 if node.has_constant_result():
3525 return self._bool_node(node, node.constant_result)
3526 else:
3527 for cmp_node in cmp_nodes[1:]:
3528 node = ExprNodes.BoolBinopNode(
3529 node.pos,
3530 operand1=node,
3531 operator='and',
3532 operand2=cmp_node,
3533 constant_result=not_a_constant)
3534 return node
3536 def visit_CondExprNode(self, node):
3537 self._calculate_const(node)
3538 if not node.test.has_constant_result():
3539 return node
3540 if node.test.constant_result:
3541 return node.true_val
3542 else:
3543 return node.false_val
3545 def visit_IfStatNode(self, node):
3546 self.visitchildren(node)
3547 # eliminate dead code based on constant condition results
3548 if_clauses = []
3549 for if_clause in node.if_clauses:
3550 condition = if_clause.condition
3551 if condition.has_constant_result():
3552 if condition.constant_result:
3553 # always true => subsequent clauses can safely be dropped
3554 node.else_clause = if_clause.body
3555 break
3556 # else: false => drop clause
3557 else:
3558 # unknown result => normal runtime evaluation
3559 if_clauses.append(if_clause)
3560 if if_clauses:
3561 node.if_clauses = if_clauses
3562 return node
3563 elif node.else_clause:
3564 return node.else_clause
3565 else:
3566 return Nodes.StatListNode(node.pos, stats=[])
3568 def visit_SliceIndexNode(self, node):
3569 self._calculate_const(node)
3570 # normalise start/stop values
3571 if node.start is None or node.start.constant_result is None:
3572 start = node.start = None
3573 else:
3574 start = node.start.constant_result
3575 if node.stop is None or node.stop.constant_result is None:
3576 stop = node.stop = None
3577 else:
3578 stop = node.stop.constant_result
3579 # cut down sliced constant sequences
3580 if node.constant_result is not not_a_constant:
3581 base = node.base
3582 if base.is_sequence_constructor and base.mult_factor is None:
3583 base.args = base.args[start:stop]
3584 return base
3585 elif base.is_string_literal:
3586 base = base.as_sliced_node(start, stop)
3587 if base is not None:
3588 return base
3589 return node
3591 def visit_ComprehensionNode(self, node):
3592 self.visitchildren(node)
3593 if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
3594 # loop was pruned already => transform into literal
3595 if node.type is Builtin.list_type:
3596 return ExprNodes.ListNode(
3597 node.pos, args=[], constant_result=[])
3598 elif node.type is Builtin.set_type:
3599 return ExprNodes.SetNode(
3600 node.pos, args=[], constant_result=set())
3601 elif node.type is Builtin.dict_type:
3602 return ExprNodes.DictNode(
3603 node.pos, key_value_pairs=[], constant_result={})
3604 return node
3606 def visit_ForInStatNode(self, node):
3607 self.visitchildren(node)
3608 sequence = node.iterator.sequence
3609 if isinstance(sequence, ExprNodes.SequenceNode):
3610 if not sequence.args:
3611 if node.else_clause:
3612 return node.else_clause
3613 else:
3614 # don't break list comprehensions
3615 return Nodes.StatListNode(node.pos, stats=[])
3616 # iterating over a list literal? => tuples are more efficient
3617 if isinstance(sequence, ExprNodes.ListNode):
3618 node.iterator.sequence = sequence.as_tuple()
3619 return node
3621 def visit_WhileStatNode(self, node):
3622 self.visitchildren(node)
3623 if node.condition and node.condition.has_constant_result():
3624 if node.condition.constant_result:
3625 node.condition = None
3626 node.else_clause = None
3627 else:
3628 return node.else_clause
3629 return node
3631 def visit_ExprStatNode(self, node):
3632 self.visitchildren(node)
3633 if not isinstance(node.expr, ExprNodes.ExprNode):
3634 # ParallelRangeTransform does this ...
3635 return node
3636 # drop unused constant expressions
3637 if node.expr.has_constant_result():
3638 return None
3639 return node
3641 # in the future, other nodes can have their own handler method here
3642 # that can replace them with a constant result node
3644 visit_Node = Visitor.VisitorTransform.recurse_to_children
3647 class FinalOptimizePhase(Visitor.CythonTransform):
3649 This visitor handles several commuting optimizations, and is run
3650 just before the C code generation phase.
3652 The optimizations currently implemented in this class are:
3653 - eliminate None assignment and refcounting for first assignment.
3654 - isinstance -> typecheck for cdef types
3655 - eliminate checks for None and/or types that became redundant after tree changes
3657 def visit_SingleAssignmentNode(self, node):
3658 """Avoid redundant initialisation of local variables before their
3659 first assignment.
3661 self.visitchildren(node)
3662 if node.first:
3663 lhs = node.lhs
3664 lhs.lhs_of_first_assignment = True
3665 return node
3667 def visit_SimpleCallNode(self, node):
3668 """Replace generic calls to isinstance(x, type) by a more efficient
3669 type check.
3671 self.visitchildren(node)
3672 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3673 if node.function.name == 'isinstance' and len(node.args) == 2:
3674 type_arg = node.args[1]
3675 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3676 cython_scope = self.context.cython_scope
3677 node.function.entry = cython_scope.lookup('PyObject_TypeCheck')
3678 node.function.type = node.function.entry.type
3679 PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
3680 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3681 return node
3683 def visit_PyTypeTestNode(self, node):
3684 """Remove tests for alternatively allowed None values from
3685 type tests when we know that the argument cannot be None
3686 anyway.
3688 self.visitchildren(node)
3689 if not node.notnone:
3690 if not node.arg.may_be_none():
3691 node.notnone = True
3692 return node
3694 def visit_NoneCheckNode(self, node):
3695 """Remove None checks from expressions that definitely do not
3696 carry a None value.
3698 self.visitchildren(node)
3699 if not node.arg.may_be_none():
3700 return node.arg
3701 return node
3703 class ConsolidateOverflowCheck(Visitor.CythonTransform):
3705 This class facilitates the sharing of overflow checking among all nodes
3706 of a nested arithmetic expression. For example, given the expression
3707 a*b + c, where a, b, and x are all possibly overflowing ints, the entire
3708 sequence will be evaluated and the overflow bit checked only at the end.
3710 overflow_bit_node = None
3712 def visit_Node(self, node):
3713 if self.overflow_bit_node is not None:
3714 saved = self.overflow_bit_node
3715 self.overflow_bit_node = None
3716 self.visitchildren(node)
3717 self.overflow_bit_node = saved
3718 else:
3719 self.visitchildren(node)
3720 return node
3722 def visit_NumBinopNode(self, node):
3723 if node.overflow_check and node.overflow_fold:
3724 top_level_overflow = self.overflow_bit_node is None
3725 if top_level_overflow:
3726 self.overflow_bit_node = node
3727 else:
3728 node.overflow_bit_node = self.overflow_bit_node
3729 node.overflow_check = False
3730 self.visitchildren(node)
3731 if top_level_overflow:
3732 self.overflow_bit_node = None
3733 else:
3734 self.visitchildren(node)
3735 return node