BOO-999
[boo.git] / src / Boo.Lang.Compiler / Steps / OptimizeIterationStatements.cs
blobc9c591e72d46cf839442f5b39b3cbad944aad623
1 #region license
2 // Copyright (c) 2003, 2004, 2005 Rodrigo B. de Oliveira (rbo@acm.org)
3 // All rights reserved.
4 //
5 // Redistribution and use in source and binary forms, with or without modification,
6 // are permitted provided that the following conditions are met:
7 //
8 // * Redistributions of source code must retain the above copyright notice,
9 // this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 // this list of conditions and the following disclaimer in the documentation
12 // and/or other materials provided with the distribution.
13 // * Neither the name of Rodrigo B. de Oliveira nor the names of its
14 // contributors may be used to endorse or promote products derived from this
15 // software without specific prior written permission.
16 //
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
21 // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23 // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
26 // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 #endregion
29 //Authored by Cameron Kenneth Knight: http://jira.codehaus.org/browse/BOO-137
31 namespace Boo.Lang.Compiler.Steps
33 using System;
34 using Boo.Lang.Compiler.Ast;
35 using Boo.Lang.Compiler.TypeSystem;
37 /// <summary>
38 /// AST semantic evaluation.
39 /// </summary>
40 public class OptimizeIterationStatements : AbstractTransformerCompilerStep
42 static readonly System.Reflection.MethodInfo Array_get_Length = Types.Array.GetProperty("Length").GetGetMethod();
43 static readonly System.Reflection.MethodInfo System_Math_Ceiling = typeof(System.Math).GetMethod("Ceiling", new System.Type[] { typeof(double) });
44 static readonly System.Reflection.ConstructorInfo System_ArgumentOutOfRangeException_ctor = typeof(System.ArgumentOutOfRangeException).GetConstructor(new System.Type[] { typeof(string) });
46 IMethod _range_End;
47 IMethod _range_Begin_End;
48 IMethod _range_Begin_End_Step;
50 Method _currentMethod;
52 public OptimizeIterationStatements()
56 override public void Initialize(CompilerContext context)
58 base.Initialize(context);
60 Type builtins = typeof(Boo.Lang.Builtins);
61 _range_End = Map(builtins.GetMethod("range", new Type[] { Types.Int }));
62 _range_Begin_End = Map(builtins.GetMethod("range", new Type[] { Types.Int, Types.Int }));
63 _range_Begin_End_Step = Map(builtins.GetMethod("range", new Type[] { Types.Int, Types.Int, Types.Int }));
66 IMethod Map(System.Reflection.MethodInfo method)
68 return TypeSystemServices.Map(method);
71 override public void Run()
73 Visit(CompileUnit);
76 override public void OnMethod(Method node)
78 _currentMethod = node;
79 Visit(node.Body);
82 override public void OnConstructor(Constructor node)
84 OnMethod(node);
87 override public void OnDestructor(Destructor node)
89 OnMethod(node);
92 override public void OnBlockExpression(BlockExpression node)
94 // ignore closure's body since it will be visited
95 // through the closure's newly created method
98 override public void LeaveForStatement(ForStatement node)
100 CheckForItemInRangeLoop(node);
101 CheckForItemInArrayLoop(node);
104 bool IsRangeInvocation(MethodInvocationExpression mi)
106 IEntity entity = mi.Target.Entity;
107 return entity == _range_End
108 || entity == _range_Begin_End
109 || entity == _range_Begin_End_Step;
112 /// <summary>
113 /// Optimize the <c>for item in range()</c> construct
114 /// </summary>
115 /// <param name="node">the for statement to check</param>
116 private void CheckForItemInRangeLoop(ForStatement node)
118 MethodInvocationExpression mi = node.Iterator as MethodInvocationExpression;
119 if (null == mi) return;
120 if (!IsRangeInvocation(mi)) return;
122 DeclarationCollection declarations = node.Declarations;
123 if (declarations.Count != 1) return;
125 ExpressionCollection args = mi.Arguments;
126 Block body = new Block(node.LexicalInfo);
128 Expression min;
129 Expression max;
130 Expression step;
132 if (args.Count == 1)
134 min = CodeBuilder.CreateIntegerLiteral(0);
135 max = args[0];
136 step = CodeBuilder.CreateIntegerLiteral(1);
138 else if (args.Count == 2)
140 min = args[0];
141 max = args[1];
142 step = CodeBuilder.CreateIntegerLiteral(1);
144 else
146 min = args[0];
147 max = args[1];
148 step = args[2];
151 InternalLocal numVar = CodeBuilder.DeclareTempLocal(
152 _currentMethod,
153 TypeSystemServices.IntType);
154 Expression numRef = CodeBuilder.CreateReference(numVar);
156 // __num = <min>
157 body.Add(
158 CodeBuilder.CreateAssignment(
159 numRef,
160 min));
162 Expression endRef;
164 if (max.NodeType == NodeType.IntegerLiteralExpression)
166 endRef = max;
168 else
170 InternalLocal endVar = CodeBuilder.DeclareTempLocal(
171 _currentMethod,
172 TypeSystemServices.IntType);
173 endRef = CodeBuilder.CreateReference(endVar);
175 // __end = <end>
176 body.Add(
177 CodeBuilder.CreateAssignment(
178 endRef,
179 max));
182 if (args.Count == 1)
184 if (max.NodeType == NodeType.IntegerLiteralExpression)
186 if (((IntegerLiteralExpression)max).Value < 0)
188 // raise ArgumentOutOfRangeException("max") (if <max> < 0)
189 Statement statement = CodeBuilder.RaiseException(
190 body.LexicalInfo,
191 TypeSystemServices.Map(System_ArgumentOutOfRangeException_ctor),
192 CodeBuilder.CreateStringLiteral("max"));
194 body.Add(statement);
197 else
199 IfStatement ifStatement = new IfStatement(body.LexicalInfo);
200 ifStatement.TrueBlock = new Block();
202 // raise ArgumentOutOfRangeException("max") if __end < 0
203 Statement statement = CodeBuilder.RaiseException(
204 body.LexicalInfo,
205 TypeSystemServices.Map(System_ArgumentOutOfRangeException_ctor),
206 CodeBuilder.CreateStringLiteral("max"));
208 ifStatement.Condition = CodeBuilder.CreateBoundBinaryExpression(
209 TypeSystemServices.BoolType,
210 BinaryOperatorType.LessThan,
211 endRef,
212 CodeBuilder.CreateIntegerLiteral(0));
214 ifStatement.TrueBlock.Add(statement);
216 body.Add(ifStatement);
220 Expression stepRef;
222 switch (args.Count)
224 case 1:
225 stepRef = CodeBuilder.CreateIntegerLiteral(1);
226 break;
227 case 2:
228 if ((min.NodeType == NodeType.IntegerLiteralExpression) &&
229 (max.NodeType == NodeType.IntegerLiteralExpression) &&
230 (((IntegerLiteralExpression)max).Value < ((IntegerLiteralExpression)min).Value))
232 // __step = -1
233 stepRef = CodeBuilder.CreateIntegerLiteral(-1);
235 else if ((min.NodeType == NodeType.IntegerLiteralExpression) &&
236 (max.NodeType == NodeType.IntegerLiteralExpression))
238 // __step = 1
239 stepRef = CodeBuilder.CreateIntegerLiteral(1);
241 else
243 InternalLocal stepVar = CodeBuilder.DeclareTempLocal(
244 _currentMethod,
245 TypeSystemServices.IntType);
246 stepRef = CodeBuilder.CreateReference(stepVar);
248 // __step = 1
249 body.Add(
250 CodeBuilder.CreateAssignment(
251 stepRef,
252 CodeBuilder.CreateIntegerLiteral(1)));
254 // __step = -1 if __end < __num
255 IfStatement ifStatement = new IfStatement(node.LexicalInfo);
257 ifStatement.Condition = CodeBuilder.CreateBoundBinaryExpression(
258 TypeSystemServices.BoolType,
259 BinaryOperatorType.LessThan,
260 endRef,
261 numRef);
263 ifStatement.TrueBlock = new Block();
265 ifStatement.TrueBlock.Add(
266 CodeBuilder.CreateAssignment(
267 stepRef,
268 CodeBuilder.CreateIntegerLiteral(-1)));
270 body.Add(ifStatement);
272 break;
273 default:
274 if (step.NodeType == NodeType.IntegerLiteralExpression)
276 stepRef = step;
278 else
280 InternalLocal stepVar = CodeBuilder.DeclareTempLocal(
281 _currentMethod,
282 TypeSystemServices.IntType);
283 stepRef = CodeBuilder.CreateReference(stepVar);
285 // __step = <step>
286 body.Add(
287 CodeBuilder.CreateAssignment(
288 stepRef,
289 step));
291 break;
295 if (args.Count == 3)
297 Expression condition = null;
298 bool run = false;
300 if (step.NodeType == NodeType.IntegerLiteralExpression)
302 if (((IntegerLiteralExpression)step).Value < 0)
304 if ((max.NodeType == NodeType.IntegerLiteralExpression) &&
305 (min.NodeType == NodeType.IntegerLiteralExpression))
307 run = (((IntegerLiteralExpression)max).Value > ((IntegerLiteralExpression)min).Value);
309 else
311 condition = CodeBuilder.CreateBoundBinaryExpression(
312 TypeSystemServices.BoolType,
313 BinaryOperatorType.GreaterThan,
314 endRef,
315 numRef);
318 else
320 if ((max.NodeType == NodeType.IntegerLiteralExpression) &&
321 (min.NodeType == NodeType.IntegerLiteralExpression))
323 run = (((IntegerLiteralExpression)max).Value < ((IntegerLiteralExpression)min).Value);
325 else
327 condition = CodeBuilder.CreateBoundBinaryExpression(
328 TypeSystemServices.BoolType,
329 BinaryOperatorType.LessThan,
330 endRef,
331 numRef);
335 else
337 if ((max.NodeType == NodeType.IntegerLiteralExpression) &&
338 (min.NodeType == NodeType.IntegerLiteralExpression))
340 if (((IntegerLiteralExpression)max).Value < ((IntegerLiteralExpression)min).Value)
342 condition = CodeBuilder.CreateBoundBinaryExpression(
343 TypeSystemServices.BoolType,
344 BinaryOperatorType.GreaterThan,
345 stepRef,
346 CodeBuilder.CreateIntegerLiteral(0));
348 else
350 condition = CodeBuilder.CreateBoundBinaryExpression(
351 TypeSystemServices.BoolType,
352 BinaryOperatorType.LessThan,
353 stepRef,
354 CodeBuilder.CreateIntegerLiteral(0));
357 else
359 condition = CodeBuilder.CreateBoundBinaryExpression(
360 TypeSystemServices.BoolType,
361 BinaryOperatorType.Or,
362 CodeBuilder.CreateBoundBinaryExpression(
363 TypeSystemServices.BoolType,
364 BinaryOperatorType.And,
365 CodeBuilder.CreateBoundBinaryExpression(
366 TypeSystemServices.BoolType,
367 BinaryOperatorType.LessThan,
368 stepRef,
369 CodeBuilder.CreateIntegerLiteral(0)),
370 CodeBuilder.CreateBoundBinaryExpression(
371 TypeSystemServices.BoolType,
372 BinaryOperatorType.GreaterThan,
373 endRef,
374 numRef)),
375 CodeBuilder.CreateBoundBinaryExpression(
376 TypeSystemServices.BoolType,
377 BinaryOperatorType.And,
378 CodeBuilder.CreateBoundBinaryExpression(
379 TypeSystemServices.BoolType,
380 BinaryOperatorType.GreaterThan,
381 stepRef,
382 CodeBuilder.CreateIntegerLiteral(0)),
383 CodeBuilder.CreateBoundBinaryExpression(
384 TypeSystemServices.BoolType,
385 BinaryOperatorType.LessThan,
386 endRef,
387 numRef)));
391 // raise ArgumentOutOfRangeException("step") if (__step < 0 and __end > __begin) or (__step > 0 and __end < __begin)
392 Statement statement = CodeBuilder.RaiseException(
393 body.LexicalInfo,
394 TypeSystemServices.Map(System_ArgumentOutOfRangeException_ctor),
395 CodeBuilder.CreateStringLiteral("step"));
397 if (condition != null)
399 IfStatement ifStatement = new IfStatement(body.LexicalInfo);
400 ifStatement.TrueBlock = new Block();
402 ifStatement.Condition = condition;
404 ifStatement.TrueBlock.Add(statement);
406 body.Add(ifStatement);
408 else if (run)
410 body.Add(statement);
413 // __end = __num + __step * cast(int, Math.Ceiling((__end - __num)/cast(double, __step)))
414 if ((step.NodeType == NodeType.IntegerLiteralExpression) &&
415 (max.NodeType == NodeType.IntegerLiteralExpression) &&
416 (min.NodeType == NodeType.IntegerLiteralExpression))
418 int stepVal = (int)((IntegerLiteralExpression)step).Value;
419 int maxVal = (int)((IntegerLiteralExpression)max).Value;
420 int minVal = (int)((IntegerLiteralExpression)min).Value;
421 endRef = CodeBuilder.CreateIntegerLiteral(
422 minVal + stepVal * (int)System.Math.Ceiling((maxVal - minVal) / ((double)stepVal)));
424 else
426 Expression endBak = endRef;
427 if (max.NodeType == NodeType.IntegerLiteralExpression)
429 InternalLocal endVar = CodeBuilder.DeclareTempLocal(
430 _currentMethod,
431 TypeSystemServices.IntType);
432 endRef = CodeBuilder.CreateReference(endVar);
435 body.Add(
436 CodeBuilder.CreateAssignment(
437 endRef,
438 CodeBuilder.CreateBoundBinaryExpression(
439 TypeSystemServices.IntType,
440 BinaryOperatorType.Addition,
441 numRef,
442 CodeBuilder.CreateBoundBinaryExpression(
443 TypeSystemServices.IntType,
444 BinaryOperatorType.Multiply,
445 stepRef,
446 CodeBuilder.CreateCast(
447 TypeSystemServices.IntType,
448 CodeBuilder.CreateMethodInvocation(
449 TypeSystemServices.Map(System_Math_Ceiling),
450 CodeBuilder.CreateBoundBinaryExpression(
451 TypeSystemServices.DoubleType,
452 BinaryOperatorType.Division,
453 CodeBuilder.CreateBoundBinaryExpression(
454 TypeSystemServices.IntType,
455 BinaryOperatorType.Subtraction,
456 endBak,
457 numRef),
458 CodeBuilder.CreateCast(
459 TypeSystemServices.DoubleType,
460 stepRef))))))));
464 // while __num != __end:
465 WhileStatement ws = new WhileStatement(node.LexicalInfo);
467 BinaryOperatorType op = BinaryOperatorType.Inequality;
469 if (stepRef.NodeType == NodeType.IntegerLiteralExpression)
471 if (((IntegerLiteralExpression)stepRef).Value > 0)
473 op = BinaryOperatorType.LessThan;
475 else
477 op = BinaryOperatorType.GreaterThan;
481 ws.Condition = CodeBuilder.CreateBoundBinaryExpression(
482 TypeSystemServices.BoolType,
484 numRef,
485 endRef);
486 ws.Condition.LexicalInfo = node.LexicalInfo;
488 // item = __num
489 ws.Block.Add(
490 CodeBuilder.CreateAssignment(
491 CodeBuilder.CreateReference((InternalLocal)declarations[0].Entity),
492 numRef));
494 Block rawBlock = new Block();
495 rawBlock["checked"] = false;
497 // __num += __step
498 rawBlock.Add(
499 CodeBuilder.CreateAssignment(
500 numRef,
501 CodeBuilder.CreateBoundBinaryExpression(
502 TypeSystemServices.IntType,
503 BinaryOperatorType.Addition,
504 numRef,
505 stepRef)));
507 ws.Block.Add(rawBlock as Statement);
509 // <block>
510 ws.Block.Add(node.Block);
512 ws.OrBlock = node.OrBlock;
513 ws.ThenBlock = node.ThenBlock;
515 body.Add(ws);
517 ReplaceCurrentNode(body);
520 private class EntityPredicate
522 private IEntity _entity;
524 public EntityPredicate(IEntity entity)
526 _entity = entity;
529 public bool Matches(Node node)
531 return _entity == TypeSystemServices.GetOptionalEntity(node);
535 /// <summary>
536 /// Optimize the <c>for item in array</c> construct
537 /// </summary>
538 /// <param name="node">the for statement to check</param>
539 private void CheckForItemInArrayLoop(ForStatement node)
541 ArrayType enumeratorType = GetExpressionType(node.Iterator) as ArrayType;
542 if (enumeratorType == null || enumeratorType.GetArrayRank() > 1) return;
543 IType elementType = enumeratorType.GetElementType();
544 if (elementType is InternalCallableType) return;
546 Block body = new Block(node.LexicalInfo);
548 InternalLocal indexVariable = DeclareTempLocal(TypeSystemServices.IntType);
549 Expression indexReference = CodeBuilder.CreateReference(indexVariable);
551 // __num = 0
552 body.Add(
553 CodeBuilder.CreateAssignment(
554 indexReference,
555 CodeBuilder.CreateIntegerLiteral(0)));
558 InternalLocal arrayVar = DeclareTempLocal(node.Iterator.ExpressionType);
559 ReferenceExpression arrayRef = CodeBuilder.CreateReference(arrayVar);
561 // __arr = <arr>
562 body.Add(
563 CodeBuilder.CreateAssignment(
564 arrayRef,
565 node.Iterator));
567 InternalLocal endVar = CodeBuilder.DeclareTempLocal(
568 _currentMethod,
569 TypeSystemServices.IntType);
570 ReferenceExpression endRef = CodeBuilder.CreateReference(endVar);
572 // __end = __arr.Length
573 body.Add(
574 CodeBuilder.CreateAssignment(
575 node.Iterator.LexicalInfo,
576 endRef,
577 CodeBuilder.CreateMethodInvocation(
578 arrayRef,
579 Array_get_Length)));
581 // while __num < __end:
582 WhileStatement ws = new WhileStatement(node.LexicalInfo);
584 ws.Condition = CodeBuilder.CreateBoundBinaryExpression(
585 TypeSystemServices.BoolType,
586 BinaryOperatorType.LessThan,
587 indexReference,
588 endRef);
590 if (1 == node.Declarations.Count)
592 IEntity loopVariable = node.Declarations[0].Entity;
593 node.Block.ReplaceNodes(
594 new NodePredicate(new EntityPredicate(loopVariable).Matches),
595 CreateRawArraySlicing(arrayRef, indexReference, elementType));
597 else
599 // alpha, bravo, charlie = arr[__num]
600 UnpackExpression(
601 ws.Block,
602 CreateRawArraySlicing(arrayRef, indexReference, elementType),
603 node.Declarations);
606 // <block>
607 ws.Block.Add(node.Block);
609 FixContinueStatements(node, ws);
611 // __num += 1
612 BinaryExpression assignment = CodeBuilder.CreateAssignment(
613 indexReference,
614 CodeBuilder.CreateBoundBinaryExpression(
615 TypeSystemServices.IntType,
616 BinaryOperatorType.Addition,
617 indexReference,
618 CodeBuilder.CreateIntegerLiteral(1)));
619 AstAnnotations.MarkUnchecked(assignment);
621 ws.Block.Add(assignment);
622 ws.OrBlock = node.OrBlock;
623 ws.ThenBlock = node.ThenBlock;
624 body.Add(ws);
625 ReplaceCurrentNode(body);
628 private void FixContinueStatements(ForStatement node, WhileStatement ws)
630 // :update
631 LabelStatement label = CreateUpdateLabel(node);
632 GotoOnTopLevelContinue continueFixup = new GotoOnTopLevelContinue(label);
633 node.Block.Accept(continueFixup);
634 if (continueFixup.UsageCount > 0) ws.Block.Add(label);
637 private LabelStatement CreateUpdateLabel(ForStatement node)
639 return new LabelStatement(LexicalInfo.Empty, "$label$" + _context.AllocIndex());
642 private static SlicingExpression CreateRawArraySlicing(ReferenceExpression arrayRef, Expression numRef, IType elementType)
644 SlicingExpression expression = new SlicingExpression(arrayRef.CloneNode(), numRef.CloneNode());
645 expression.ExpressionType = elementType;
646 AstAnnotations.MarkRawArrayIndexing(expression);
647 return expression;
650 private InternalLocal DeclareTempLocal(IType type)
652 return CodeBuilder.DeclareTempLocal(
653 _currentMethod,
654 type);
657 /// <summary>
658 /// Unpacks an expression onto a list of declarations.
659 /// </summary>
660 /// <param name="block">Block this takes place in</param>
661 /// <param name="expression">expression to explode</param>
662 /// <param name="declarations">list of declarations to set</param>
663 void UnpackExpression(Block block, Expression expression, DeclarationCollection declarations)
665 NormalizeIterationStatements.UnpackExpression(CodeBuilder, _currentMethod, block, expression, declarations);