new blit instruction
[qbe.git] / fold.c
blobf992b3ae9249344b8e4087373a46188344a0208d
1 #include "all.h"
3 enum {
4 Bot = -2, /* lattice bottom */
5 Top = -1, /* lattice top */
6 };
8 typedef struct Edge Edge;
10 struct Edge {
11 int dest;
12 int dead;
13 Edge *work;
16 static int *val;
17 static Edge *flowrk, (*edge)[2];
18 static Use **usewrk;
19 static uint nuse;
21 static int
22 iscon(Con *c, int w, uint64_t k)
24 if (c->type != CBits)
25 return 0;
26 if (w)
27 return (uint64_t)c->bits.i == k;
28 else
29 return (uint32_t)c->bits.i == (uint32_t)k;
32 static int
33 latval(Ref r)
35 switch (rtype(r)) {
36 case RTmp:
37 return val[r.val];
38 case RCon:
39 return r.val;
40 default:
41 die("unreachable");
45 static int
46 latmerge(int v, int m)
48 return m == Top ? v : (v == Top || v == m) ? m : Bot;
51 static void
52 update(int t, int m, Fn *fn)
54 Tmp *tmp;
55 uint u;
57 m = latmerge(val[t], m);
58 if (m != val[t]) {
59 tmp = &fn->tmp[t];
60 for (u=0; u<tmp->nuse; u++) {
61 vgrow(&usewrk, ++nuse);
62 usewrk[nuse-1] = &tmp->use[u];
64 val[t] = m;
68 static int
69 deadedge(int s, int d)
71 Edge *e;
73 e = edge[s];
74 if (e[0].dest == d && !e[0].dead)
75 return 0;
76 if (e[1].dest == d && !e[1].dead)
77 return 0;
78 return 1;
81 static void
82 visitphi(Phi *p, int n, Fn *fn)
84 int v;
85 uint a;
87 v = Top;
88 for (a=0; a<p->narg; a++)
89 if (!deadedge(p->blk[a]->id, n))
90 v = latmerge(v, latval(p->arg[a]));
91 update(p->to.val, v, fn);
94 static int opfold(int, int, Con *, Con *, Fn *);
96 static void
97 visitins(Ins *i, Fn *fn)
99 int v, l, r;
101 if (rtype(i->to) != RTmp)
102 return;
103 if (optab[i->op].canfold) {
104 l = latval(i->arg[0]);
105 if (!req(i->arg[1], R))
106 r = latval(i->arg[1]);
107 else
108 r = CON_Z.val;
109 if (l == Bot || r == Bot)
110 v = Bot;
111 else if (l == Top || r == Top)
112 v = Top;
113 else
114 v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn);
115 } else
116 v = Bot;
117 /* fprintf(stderr, "\nvisiting %s (%p)", optab[i->op].name, (void *)i); */
118 update(i->to.val, v, fn);
121 static void
122 visitjmp(Blk *b, int n, Fn *fn)
124 int l;
126 switch (b->jmp.type) {
127 case Jjnz:
128 l = latval(b->jmp.arg);
129 assert(l != Top && "ssa invariant broken");
130 if (l == Bot) {
131 edge[n][1].work = flowrk;
132 edge[n][0].work = &edge[n][1];
133 flowrk = &edge[n][0];
135 else if (iscon(&fn->con[l], 0, 0)) {
136 assert(edge[n][0].dead);
137 edge[n][1].work = flowrk;
138 flowrk = &edge[n][1];
140 else {
141 assert(edge[n][1].dead);
142 edge[n][0].work = flowrk;
143 flowrk = &edge[n][0];
145 break;
146 case Jjmp:
147 edge[n][0].work = flowrk;
148 flowrk = &edge[n][0];
149 break;
150 case Jhlt:
151 break;
152 default:
153 if (isret(b->jmp.type))
154 break;
155 die("unreachable");
159 static void
160 initedge(Edge *e, Blk *s)
162 if (s)
163 e->dest = s->id;
164 else
165 e->dest = -1;
166 e->dead = 1;
167 e->work = 0;
170 static int
171 renref(Ref *r)
173 int l;
175 if (rtype(*r) == RTmp)
176 if ((l=val[r->val]) != Bot) {
177 assert(l != Top && "ssa invariant broken");
178 *r = CON(l);
179 return 1;
181 return 0;
184 /* require rpo, use, pred */
185 void
186 fold(Fn *fn)
188 Edge *e, start;
189 Use *u;
190 Blk *b, **pb;
191 Phi *p, **pp;
192 Ins *i;
193 int t, d;
194 uint n, a;
196 val = emalloc(fn->ntmp * sizeof val[0]);
197 edge = emalloc(fn->nblk * sizeof edge[0]);
198 usewrk = vnew(0, sizeof usewrk[0], PHeap);
200 for (t=0; t<fn->ntmp; t++)
201 val[t] = Top;
202 for (n=0; n<fn->nblk; n++) {
203 b = fn->rpo[n];
204 b->visit = 0;
205 initedge(&edge[n][0], b->s1);
206 initedge(&edge[n][1], b->s2);
208 initedge(&start, fn->start);
209 flowrk = &start;
210 nuse = 0;
212 /* 1. find out constants and dead cfg edges */
213 for (;;) {
214 e = flowrk;
215 if (e) {
216 flowrk = e->work;
217 e->work = 0;
218 if (e->dest == -1 || !e->dead)
219 continue;
220 e->dead = 0;
221 n = e->dest;
222 b = fn->rpo[n];
223 for (p=b->phi; p; p=p->link)
224 visitphi(p, n, fn);
225 if (b->visit == 0) {
226 for (i=b->ins; i<&b->ins[b->nins]; i++)
227 visitins(i, fn);
228 visitjmp(b, n, fn);
230 b->visit++;
231 assert(b->jmp.type != Jjmp
232 || !edge[n][0].dead
233 || flowrk == &edge[n][0]);
235 else if (nuse) {
236 u = usewrk[--nuse];
237 n = u->bid;
238 b = fn->rpo[n];
239 if (b->visit == 0)
240 continue;
241 switch (u->type) {
242 case UPhi:
243 visitphi(u->u.phi, u->bid, fn);
244 break;
245 case UIns:
246 visitins(u->u.ins, fn);
247 break;
248 case UJmp:
249 visitjmp(b, n, fn);
250 break;
251 default:
252 die("unreachable");
255 else
256 break;
259 if (debug['F']) {
260 fprintf(stderr, "\n> SCCP findings:");
261 for (t=Tmp0; t<fn->ntmp; t++) {
262 if (val[t] == Bot)
263 continue;
264 fprintf(stderr, "\n%10s: ", fn->tmp[t].name);
265 if (val[t] == Top)
266 fprintf(stderr, "Top");
267 else
268 printref(CON(val[t]), fn, stderr);
270 fprintf(stderr, "\n dead code: ");
273 /* 2. trim dead code, replace constants */
274 d = 0;
275 for (pb=&fn->start; (b=*pb);) {
276 if (b->visit == 0) {
277 d = 1;
278 if (debug['F'])
279 fprintf(stderr, "%s ", b->name);
280 edgedel(b, &b->s1);
281 edgedel(b, &b->s2);
282 *pb = b->link;
283 continue;
285 for (pp=&b->phi; (p=*pp);)
286 if (val[p->to.val] != Bot)
287 *pp = p->link;
288 else {
289 for (a=0; a<p->narg; a++)
290 if (!deadedge(p->blk[a]->id, b->id))
291 renref(&p->arg[a]);
292 pp = &p->link;
294 for (i=b->ins; i<&b->ins[b->nins]; i++)
295 if (renref(&i->to))
296 *i = (Ins){.op = Onop};
297 else
298 for (n=0; n<2; n++)
299 renref(&i->arg[n]);
300 renref(&b->jmp.arg);
301 if (b->jmp.type == Jjnz && rtype(b->jmp.arg) == RCon) {
302 if (iscon(&fn->con[b->jmp.arg.val], 0, 0)) {
303 edgedel(b, &b->s1);
304 b->s1 = b->s2;
305 b->s2 = 0;
306 } else
307 edgedel(b, &b->s2);
308 b->jmp.type = Jjmp;
309 b->jmp.arg = R;
311 pb = &b->link;
314 if (debug['F']) {
315 if (!d)
316 fprintf(stderr, "(none)");
317 fprintf(stderr, "\n\n> After constant folding:\n");
318 printfn(fn, stderr);
321 free(val);
322 free(edge);
323 vfree(usewrk);
326 /* boring folding code */
328 static int
329 foldint(Con *res, int op, int w, Con *cl, Con *cr)
331 union {
332 int64_t s;
333 uint64_t u;
334 float fs;
335 double fd;
336 } l, r;
337 uint64_t x;
338 Sym sym;
339 int typ;
341 memset(&sym, 0, sizeof sym);
342 typ = CBits;
343 l.s = cl->bits.i;
344 r.s = cr->bits.i;
345 if (op == Oadd) {
346 if (cl->type == CAddr) {
347 if (cr->type == CAddr)
348 return 1;
349 typ = CAddr;
350 sym = cl->sym;
352 else if (cr->type == CAddr) {
353 typ = CAddr;
354 sym = cr->sym;
357 else if (op == Osub) {
358 if (cl->type == CAddr) {
359 if (cr->type != CAddr) {
360 typ = CAddr;
361 sym = cl->sym;
362 } else if (!symeq(cl->sym, cr->sym))
363 return 1;
365 else if (cr->type == CAddr)
366 return 1;
368 else if (cl->type == CAddr || cr->type == CAddr)
369 return 1;
370 if (op == Odiv || op == Orem || op == Oudiv || op == Ourem) {
371 if (iscon(cr, w, 0))
372 return 1;
373 if (op == Odiv || op == Orem) {
374 x = w ? INT64_MIN : INT32_MIN;
375 if (iscon(cr, w, -1))
376 if (iscon(cl, w, x))
377 return 1;
380 switch (op) {
381 case Oadd: x = l.u + r.u; break;
382 case Osub: x = l.u - r.u; break;
383 case Oneg: x = -l.u; break;
384 case Odiv: x = w ? l.s / r.s : (int32_t)l.s / (int32_t)r.s; break;
385 case Orem: x = w ? l.s % r.s : (int32_t)l.s % (int32_t)r.s; break;
386 case Oudiv: x = w ? l.u / r.u : (uint32_t)l.u / (uint32_t)r.u; break;
387 case Ourem: x = w ? l.u % r.u : (uint32_t)l.u % (uint32_t)r.u; break;
388 case Omul: x = l.u * r.u; break;
389 case Oand: x = l.u & r.u; break;
390 case Oor: x = l.u | r.u; break;
391 case Oxor: x = l.u ^ r.u; break;
392 case Osar: x = (w ? l.s : (int32_t)l.s) >> (r.u & (31|w<<5)); break;
393 case Oshr: x = (w ? l.u : (uint32_t)l.u) >> (r.u & (31|w<<5)); break;
394 case Oshl: x = l.u << (r.u & (31|w<<5)); break;
395 case Oextsb: x = (int8_t)l.u; break;
396 case Oextub: x = (uint8_t)l.u; break;
397 case Oextsh: x = (int16_t)l.u; break;
398 case Oextuh: x = (uint16_t)l.u; break;
399 case Oextsw: x = (int32_t)l.u; break;
400 case Oextuw: x = (uint32_t)l.u; break;
401 case Ostosi: x = w ? (int64_t)cl->bits.s : (int32_t)cl->bits.s; break;
402 case Ostoui: x = w ? (uint64_t)cl->bits.s : (uint32_t)cl->bits.s; break;
403 case Odtosi: x = w ? (int64_t)cl->bits.d : (int32_t)cl->bits.d; break;
404 case Odtoui: x = w ? (uint64_t)cl->bits.d : (uint32_t)cl->bits.d; break;
405 case Ocast:
406 x = l.u;
407 if (cl->type == CAddr) {
408 typ = CAddr;
409 sym = cl->sym;
411 break;
412 default:
413 if (Ocmpw <= op && op <= Ocmpl1) {
414 if (op <= Ocmpw1) {
415 l.u = (int32_t)l.u;
416 r.u = (int32_t)r.u;
417 } else
418 op -= Ocmpl - Ocmpw;
419 switch (op - Ocmpw) {
420 case Ciule: x = l.u <= r.u; break;
421 case Ciult: x = l.u < r.u; break;
422 case Cisle: x = l.s <= r.s; break;
423 case Cislt: x = l.s < r.s; break;
424 case Cisgt: x = l.s > r.s; break;
425 case Cisge: x = l.s >= r.s; break;
426 case Ciugt: x = l.u > r.u; break;
427 case Ciuge: x = l.u >= r.u; break;
428 case Cieq: x = l.u == r.u; break;
429 case Cine: x = l.u != r.u; break;
430 default: die("unreachable");
433 else if (Ocmps <= op && op <= Ocmps1) {
434 switch (op - Ocmps) {
435 case Cfle: x = l.fs <= r.fs; break;
436 case Cflt: x = l.fs < r.fs; break;
437 case Cfgt: x = l.fs > r.fs; break;
438 case Cfge: x = l.fs >= r.fs; break;
439 case Cfne: x = l.fs != r.fs; break;
440 case Cfeq: x = l.fs == r.fs; break;
441 case Cfo: x = l.fs < r.fs || l.fs >= r.fs; break;
442 case Cfuo: x = !(l.fs < r.fs || l.fs >= r.fs); break;
443 default: die("unreachable");
446 else if (Ocmpd <= op && op <= Ocmpd1) {
447 switch (op - Ocmpd) {
448 case Cfle: x = l.fd <= r.fd; break;
449 case Cflt: x = l.fd < r.fd; break;
450 case Cfgt: x = l.fd > r.fd; break;
451 case Cfge: x = l.fd >= r.fd; break;
452 case Cfne: x = l.fd != r.fd; break;
453 case Cfeq: x = l.fd == r.fd; break;
454 case Cfo: x = l.fd < r.fd || l.fd >= r.fd; break;
455 case Cfuo: x = !(l.fd < r.fd || l.fd >= r.fd); break;
456 default: die("unreachable");
459 else
460 die("unreachable");
462 *res = (Con){.type=typ, .sym=sym, .bits={.i=x}};
463 return 0;
466 static void
467 foldflt(Con *res, int op, int w, Con *cl, Con *cr)
469 float xs, ls, rs;
470 double xd, ld, rd;
472 if (cl->type != CBits || cr->type != CBits)
473 err("invalid address operand for '%s'", optab[op].name);
474 *res = (Con){.type = CBits};
475 memset(&res->bits, 0, sizeof(res->bits));
476 if (w) {
477 ld = cl->bits.d;
478 rd = cr->bits.d;
479 switch (op) {
480 case Oadd: xd = ld + rd; break;
481 case Osub: xd = ld - rd; break;
482 case Oneg: xd = -ld; break;
483 case Odiv: xd = ld / rd; break;
484 case Omul: xd = ld * rd; break;
485 case Oswtof: xd = (int32_t)cl->bits.i; break;
486 case Ouwtof: xd = (uint32_t)cl->bits.i; break;
487 case Osltof: xd = (int64_t)cl->bits.i; break;
488 case Oultof: xd = (uint64_t)cl->bits.i; break;
489 case Oexts: xd = cl->bits.s; break;
490 case Ocast: xd = ld; break;
491 default: die("unreachable");
493 res->bits.d = xd;
494 res->flt = 2;
495 } else {
496 ls = cl->bits.s;
497 rs = cr->bits.s;
498 switch (op) {
499 case Oadd: xs = ls + rs; break;
500 case Osub: xs = ls - rs; break;
501 case Oneg: xs = -ls; break;
502 case Odiv: xs = ls / rs; break;
503 case Omul: xs = ls * rs; break;
504 case Oswtof: xs = (int32_t)cl->bits.i; break;
505 case Ouwtof: xs = (uint32_t)cl->bits.i; break;
506 case Osltof: xs = (int64_t)cl->bits.i; break;
507 case Oultof: xs = (uint64_t)cl->bits.i; break;
508 case Otruncd: xs = cl->bits.d; break;
509 case Ocast: xs = ls; break;
510 default: die("unreachable");
512 res->bits.s = xs;
513 res->flt = 1;
517 static int
518 opfold(int op, int cls, Con *cl, Con *cr, Fn *fn)
520 Ref r;
521 Con c;
523 if (cls == Kw || cls == Kl) {
524 if (foldint(&c, op, cls == Kl, cl, cr))
525 return Bot;
526 } else
527 foldflt(&c, op, cls == Kd, cl, cr);
528 if (!KWIDE(cls))
529 c.bits.i &= 0xffffffff;
530 r = newcon(&c, fn);
531 assert(!(cls == Ks || cls == Kd) || c.flt);
532 return r.val;