reenable and fix a bug in memopt
[qbe.git] / fold.c
blob6129421091f048fe3050bf432c5537bd2b2030d9
1 #include "all.h"
3 enum {
4 Bot = -2, /* lattice bottom */
5 Top = -1, /* lattice top */
6 };
8 typedef struct Edge Edge;
10 struct Edge {
11 int dest;
12 int dead;
13 Edge *work;
16 static int *val;
17 static Edge *flowrk, (*edge)[2];
18 static Use **usewrk;
19 static uint nuse;
21 static int
22 czero(Con *c, int w)
24 if (c->type != CBits)
25 return 0;
26 if (w)
27 return c->bits.i == 0;
28 else
29 return (uint32_t)c->bits.i == 0;
32 static int
33 latval(Ref r)
35 switch (rtype(r)) {
36 case RTmp:
37 return val[r.val];
38 case RCon:
39 return r.val;
40 default:
41 die("unreachable");
45 static int
46 latmerge(int v, int m)
48 return m == Top ? v : (v == Top || v == m) ? m : Bot;
51 static void
52 update(int t, int m, Fn *fn)
54 Tmp *tmp;
55 uint u;
57 m = latmerge(val[t], m);
58 if (m != val[t]) {
59 tmp = &fn->tmp[t];
60 for (u=0; u<tmp->nuse; u++) {
61 vgrow(&usewrk, ++nuse);
62 usewrk[nuse-1] = &tmp->use[u];
64 val[t] = m;
68 static int
69 deadedge(int s, int d)
71 Edge *e;
73 e = edge[s];
74 if (e[0].dest == d && !e[0].dead)
75 return 0;
76 if (e[1].dest == d && !e[1].dead)
77 return 0;
78 return 1;
81 static void
82 visitphi(Phi *p, int n, Fn *fn)
84 int v;
85 uint a;
87 v = Top;
88 for (a=0; a<p->narg; a++)
89 if (!deadedge(p->blk[a]->id, n))
90 v = latmerge(v, latval(p->arg[a]));
91 update(p->to.val, v, fn);
94 static int opfold(int, int, Con *, Con *, Fn *);
96 static void
97 visitins(Ins *i, Fn *fn)
99 int v, l, r;
101 if (rtype(i->to) != RTmp)
102 return;
103 if (opdesc[i->op].cfold) {
104 l = latval(i->arg[0]);
105 if (!req(i->arg[1], R))
106 r = latval(i->arg[1]);
107 else
108 r = CON_Z.val;
109 if (l == Bot || r == Bot)
110 v = Bot;
111 else if (l == Top || r == Top)
112 v = Top;
113 else
114 v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn);
115 } else
116 v = Bot;
117 /* fprintf(stderr, "\nvisiting %s (%p)", opdesc[i->op].name, (void *)i); */
118 update(i->to.val, v, fn);
121 static void
122 visitjmp(Blk *b, int n, Fn *fn)
124 int l;
126 switch (b->jmp.type) {
127 case Jjnz:
128 l = latval(b->jmp.arg);
129 assert(l != Top && "ssa invariant broken");
130 if (l == Bot) {
131 edge[n][1].work = flowrk;
132 edge[n][0].work = &edge[n][1];
133 flowrk = &edge[n][0];
135 else if (czero(&fn->con[l], 0)) {
136 assert(edge[n][0].dead);
137 edge[n][1].work = flowrk;
138 flowrk = &edge[n][1];
140 else {
141 assert(edge[n][1].dead);
142 edge[n][0].work = flowrk;
143 flowrk = &edge[n][0];
145 break;
146 case Jjmp:
147 edge[n][0].work = flowrk;
148 flowrk = &edge[n][0];
149 break;
150 default:
151 if (isret(b->jmp.type))
152 break;
153 die("unreachable");
157 static void
158 initedge(Edge *e, Blk *s)
160 if (s)
161 e->dest = s->id;
162 else
163 e->dest = -1;
164 e->dead = 1;
165 e->work = 0;
168 static int
169 renref(Ref *r)
171 int l;
173 if (rtype(*r) == RTmp)
174 if ((l=val[r->val]) != Bot) {
175 assert(l != Top && "ssa invariant broken");
176 *r = CON(l);
177 return 1;
179 return 0;
182 /* require rpo, use, pred */
183 void
184 fold(Fn *fn)
186 Edge *e, start;
187 Use *u;
188 Blk *b, **pb;
189 Phi *p, **pp;
190 Ins *i;
191 int t, d;
192 uint n, a;
194 val = emalloc(fn->ntmp * sizeof val[0]);
195 edge = emalloc(fn->nblk * sizeof edge[0]);
196 usewrk = vnew(0, sizeof usewrk[0], Pheap);
198 for (t=0; t<fn->ntmp; t++)
199 val[t] = Top;
200 for (n=0; n<fn->nblk; n++) {
201 b = fn->rpo[n];
202 b->visit = 0;
203 initedge(&edge[n][0], b->s1);
204 initedge(&edge[n][1], b->s2);
206 initedge(&start, fn->start);
207 flowrk = &start;
208 nuse = 0;
210 /* 1. find out constants and dead cfg edges */
211 for (;;) {
212 e = flowrk;
213 if (e) {
214 flowrk = e->work;
215 e->work = 0;
216 if (e->dest == -1 || !e->dead)
217 continue;
218 e->dead = 0;
219 n = e->dest;
220 b = fn->rpo[n];
221 for (p=b->phi; p; p=p->link)
222 visitphi(p, n, fn);
223 if (b->visit == 0) {
224 for (i=b->ins; i-b->ins < b->nins; i++)
225 visitins(i, fn);
226 visitjmp(b, n, fn);
228 b->visit++;
229 assert(b->jmp.type != Jjmp
230 || !edge[n][0].dead
231 || flowrk == &edge[n][0]);
233 else if (nuse) {
234 u = usewrk[--nuse];
235 n = u->bid;
236 b = fn->rpo[n];
237 if (b->visit == 0)
238 continue;
239 switch (u->type) {
240 case UPhi:
241 visitphi(u->u.phi, u->bid, fn);
242 break;
243 case UIns:
244 visitins(u->u.ins, fn);
245 break;
246 case UJmp:
247 visitjmp(b, n, fn);
248 break;
249 default:
250 die("unreachable");
253 else
254 break;
257 if (debug['F']) {
258 fprintf(stderr, "\n> SCCP findings:");
259 for (t=Tmp0; t<fn->ntmp; t++) {
260 if (val[t] == Bot)
261 continue;
262 fprintf(stderr, "\n%10s: ", fn->tmp[t].name);
263 if (val[t] == Top)
264 fprintf(stderr, "Top");
265 else
266 printref(CON(val[t]), fn, stderr);
268 fprintf(stderr, "\n dead code: ");
271 /* 2. trim dead code, replace constants */
272 d = 0;
273 for (pb=&fn->start; (b=*pb);) {
274 if (b->visit == 0) {
275 d = 1;
276 if (debug['F'])
277 fprintf(stderr, "%s ", b->name);
278 edgedel(b, &b->s1);
279 edgedel(b, &b->s2);
280 *pb = b->link;
281 continue;
283 for (pp=&b->phi; (p=*pp);)
284 if (val[p->to.val] != Bot)
285 *pp = p->link;
286 else {
287 for (a=0; a<p->narg; a++)
288 if (!deadedge(p->blk[a]->id, b->id))
289 renref(&p->arg[a]);
290 pp = &p->link;
292 for (i=b->ins; i-b->ins < b->nins; i++)
293 if (renref(&i->to))
294 *i = (Ins){.op = Onop};
295 else
296 for (n=0; n<2; n++)
297 renref(&i->arg[n]);
298 renref(&b->jmp.arg);
299 if (b->jmp.type == Jjnz && rtype(b->jmp.arg) == RCon) {
300 if (czero(&fn->con[b->jmp.arg.val], 0)) {
301 edgedel(b, &b->s1);
302 b->s1 = b->s2;
303 b->s2 = 0;
304 } else
305 edgedel(b, &b->s2);
306 b->jmp.type = Jjmp;
307 b->jmp.arg = R;
309 pb = &b->link;
312 if (debug['F']) {
313 if (!d)
314 fprintf(stderr, "(none)");
315 fprintf(stderr, "\n\n> After constant folding:\n");
316 printfn(fn, stderr);
319 free(val);
320 free(edge);
321 vfree(usewrk);
324 /* boring folding code */
326 static int
327 foldint(Con *res, int op, int w, Con *cl, Con *cr)
329 union {
330 int64_t s;
331 uint64_t u;
332 float fs;
333 double fd;
334 } l, r;
335 uint64_t x;
336 char *lab;
338 lab = 0;
339 l.s = cl->bits.i;
340 r.s = cr->bits.i;
341 if (op == Oadd) {
342 if (cl->type == CAddr) {
343 if (cr->type == CAddr)
344 err("undefined addition (addr + addr)");
345 lab = cl->label;
347 else if (cr->type == CAddr)
348 lab = cr->label;
350 else if (op == Osub) {
351 if (cl->type == CAddr) {
352 if (cr->type != CAddr)
353 lab = cl->label;
354 else if (strcmp(cl->label, cr->label) != 0)
355 err("undefined substraction (addr1 - addr2)");
357 else if (cr->type == CAddr)
358 err("undefined substraction (num - addr)");
360 else if (cl->type == CAddr || cr->type == CAddr) {
361 if (Ocmpl <= op && op <= Ocmpl1)
362 return 1;
363 err("invalid address operand for '%s'", opdesc[op].name);
365 switch (op) {
366 case Oadd: x = l.u + r.u; break;
367 case Osub: x = l.u - r.u; break;
368 case Odiv: x = l.s / r.s; break;
369 case Orem: x = l.s % r.s; break;
370 case Oudiv: x = l.u / r.u; break;
371 case Ourem: x = l.u % r.u; break;
372 case Omul: x = l.u * r.u; break;
373 case Oand: x = l.u & r.u; break;
374 case Oor: x = l.u | r.u; break;
375 case Oxor: x = l.u ^ r.u; break;
376 case Osar: x = l.s >> (r.u & 63); break;
377 case Oshr: x = l.u >> (r.u & 63); break;
378 case Oshl: x = l.u << (r.u & 63); break;
379 case Oextsb: x = (int8_t)l.u; break;
380 case Oextub: x = (uint8_t)l.u; break;
381 case Oextsh: x = (int16_t)l.u; break;
382 case Oextuh: x = (uint16_t)l.u; break;
383 case Oextsw: x = (int32_t)l.u; break;
384 case Oextuw: x = (uint32_t)l.u; break;
385 case Ostosi: x = w ? (int64_t)cl->bits.s : (int32_t)cl->bits.s; break;
386 case Odtosi: x = w ? (int64_t)cl->bits.d : (int32_t)cl->bits.d; break;
387 case Ocast:
388 x = l.u;
389 if (cl->type == CAddr)
390 lab = cl->label;
391 break;
392 default:
393 if (Ocmpw <= op && op <= Ocmpl1) {
394 if (op <= Ocmpw1) {
395 l.u = (int32_t)l.u;
396 r.u = (int32_t)r.u;
397 } else
398 op -= Ocmpl - Ocmpw;
399 switch (op - Ocmpw) {
400 case ICule: x = l.u <= r.u; break;
401 case ICult: x = l.u < r.u; break;
402 case ICsle: x = l.s <= r.s; break;
403 case ICslt: x = l.s < r.s; break;
404 case ICsgt: x = l.s > r.s; break;
405 case ICsge: x = l.s >= r.s; break;
406 case ICugt: x = l.u > r.u; break;
407 case ICuge: x = l.u >= r.u; break;
408 case ICeq: x = l.u == r.u; break;
409 case ICne: x = l.u != r.u; break;
410 default: die("unreachable");
413 else if (Ocmps <= op && op <= Ocmps1) {
414 switch (op - Ocmps) {
415 case FCle: x = l.fs <= r.fs; break;
416 case FClt: x = l.fs < r.fs; break;
417 case FCgt: x = l.fs > r.fs; break;
418 case FCge: x = l.fs >= r.fs; break;
419 case FCne: x = l.fs != r.fs; break;
420 case FCeq: x = l.fs == r.fs; break;
421 case FCo: x = l.fs < r.fs || l.fs >= r.fs; break;
422 case FCuo: x = !(l.fs < r.fs || l.fs >= r.fs); break;
423 default: die("unreachable");
426 else if (Ocmpd <= op && op <= Ocmpd1) {
427 switch (op - Ocmpd) {
428 case FCle: x = l.fd <= r.fd; break;
429 case FClt: x = l.fd < r.fd; break;
430 case FCgt: x = l.fd > r.fd; break;
431 case FCge: x = l.fd >= r.fd; break;
432 case FCne: x = l.fd != r.fd; break;
433 case FCeq: x = l.fd == r.fd; break;
434 case FCo: x = l.fd < r.fd || l.fd >= r.fd; break;
435 case FCuo: x = !(l.fd < r.fd || l.fd >= r.fd); break;
436 default: die("unreachable");
439 else
440 die("unreachable");
442 *res = (Con){lab ? CAddr : CBits, .bits={.i=x}};
443 res->bits.i = x;
444 if (lab)
445 strcpy(res->label, lab);
446 return 0;
449 static void
450 foldflt(Con *res, int op, int w, Con *cl, Con *cr)
452 float xs, ls, rs;
453 double xd, ld, rd;
455 if (cl->type != CBits || cr->type != CBits)
456 err("invalid address operand for '%s'", opdesc[op].name);
457 if (w) {
458 ld = cl->bits.d;
459 rd = cr->bits.d;
460 switch (op) {
461 case Oadd: xd = ld + rd; break;
462 case Osub: xd = ld - rd; break;
463 case Odiv: xd = ld / rd; break;
464 case Omul: xd = ld * rd; break;
465 case Oswtof: xd = (int32_t)cl->bits.i; break;
466 case Osltof: xd = (int64_t)cl->bits.i; break;
467 case Oexts: xd = cl->bits.s; break;
468 case Ocast: xd = ld; break;
469 default: die("unreachable");
471 *res = (Con){CBits, .bits={.d=xd}, .flt=2};
472 } else {
473 ls = cl->bits.s;
474 rs = cr->bits.s;
475 switch (op) {
476 case Oadd: xs = ls + rs; break;
477 case Osub: xs = ls - rs; break;
478 case Odiv: xs = ls / rs; break;
479 case Omul: xs = ls * rs; break;
480 case Oswtof: xs = (int32_t)cl->bits.i; break;
481 case Osltof: xs = (int64_t)cl->bits.i; break;
482 case Otruncd: xs = cl->bits.d; break;
483 case Ocast: xs = ls; break;
484 default: die("unreachable");
486 *res = (Con){CBits, .bits={.s=xs}, .flt=1};
490 static int
491 opfold(int op, int cls, Con *cl, Con *cr, Fn *fn)
493 int nc;
494 Con c;
496 if ((op == Odiv || op == Oudiv
497 || op == Orem || op == Ourem) && czero(cr, KWIDE(cls)))
498 err("null divisor in '%s'", opdesc[op].name);
499 if (cls == Kw || cls == Kl) {
500 if (foldint(&c, op, cls == Kl, cl, cr))
501 return Bot;
502 } else
503 foldflt(&c, op, cls == Kd, cl, cr);
504 if (c.type == CBits)
505 nc = getcon(c.bits.i, fn).val;
506 else {
507 nc = fn->ncon;
508 vgrow(&fn->con, ++fn->ncon);
510 assert(!(cls == Ks || cls == Kd) || c.flt);
511 fn->con[nc] = c;
512 return nc;